diff --git a/agent/deep_assistant.py b/agent/deep_assistant.py index 7dfb662..df64c5a 100644 --- a/agent/deep_assistant.py +++ b/agent/deep_assistant.py @@ -287,10 +287,12 @@ async def init_agent(config: AgentConfig): middleware=middleware, checkpointer=checkpointer, shell_env={ - "ASSISTANT_ID": config.bot_id, - "USER_IDENTIFIER": config.user_identifier, - "TRACE_ID": config.trace_id, - **(config.shell_env or {}), + k: v for k, v in { + "ASSISTANT_ID": config.bot_id, + "USER_IDENTIFIER": config.user_identifier, + "TRACE_ID": config.trace_id, + **(config.shell_env or {}), + }.items() if v is not None } ) diff --git a/services/voice_session_manager.py b/services/voice_session_manager.py index 0d9109e..f3a35a4 100644 --- a/services/voice_session_manager.py +++ b/services/voice_session_manager.py @@ -1,14 +1,83 @@ import asyncio import json import logging +import re import uuid -from typing import Optional, Callable, Awaitable +from typing import Optional, Callable, Awaitable, AsyncGenerator from services.realtime_voice_client import RealtimeDialogClient logger = logging.getLogger('app') +class _StreamTagFilter: + """ + Filters streaming text based on tag blocks. + Only passes through content inside [ANSWER] blocks. + If no tags are found at all, passes through everything (fallback). + Skips content inside [TOOL_CALL], [TOOL_RESPONSE], [THINK], [SOURCE], etc. + """ + + SKIP_TAGS = {"TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE"} + + def __init__(self): + self.state = "idle" # idle, answer, skip + self.found_any_tag = False + self._pending = "" # buffer for partial tag like "[TOO..." + + def feed(self, chunk: str) -> str: + """Feed a chunk, return text that should be passed to TTS.""" + self._pending += chunk + output = [] + + while self._pending: + if self.state in ("idle", "answer"): + bracket_pos = self._pending.find("[") + if bracket_pos == -1: + if self.state == "answer" or not self.found_any_tag: + output.append(self._pending) + self._pending = "" + else: + before = self._pending[:bracket_pos] + if before and (self.state == "answer" or not self.found_any_tag): + output.append(before) + + close_pos = self._pending.find("]", bracket_pos) + if close_pos == -1: + # Incomplete tag — wait for next chunk + self._pending = self._pending[bracket_pos:] + break + + tag_name = self._pending[bracket_pos + 1:close_pos] + self._pending = self._pending[close_pos + 1:] + self.found_any_tag = True + + if tag_name == "ANSWER": + self.state = "answer" + else: + self.state = "skip" + + elif self.state == "skip": + bracket_pos = self._pending.find("[") + if bracket_pos == -1: + self._pending = "" + else: + close_pos = self._pending.find("]", bracket_pos) + if close_pos == -1: + self._pending = self._pending[bracket_pos:] + break + + tag_name = self._pending[bracket_pos + 1:close_pos] + self._pending = self._pending[close_pos + 1:] + + if tag_name == "ANSWER": + self.state = "answer" + else: + self.state = "skip" + + return "".join(output) + + class VoiceSession: """Manages a single voice dialogue session lifecycle""" @@ -197,51 +266,84 @@ class VoiceSession: logger.error(f"[Voice] Server error: {error_msg}") await self._emit_error(error_msg) + # Sentence-ending punctuation pattern for splitting TTS + _SENTENCE_END_RE = re.compile(r'[。!?;\n.!?;]') + async def _on_asr_text_received(self, text: str) -> None: - """Called when ASR text is received — send comfort TTS, call agent, inject RAG""" + """Called when ASR text is received — stream agent output, send TTS sentence by sentence""" if not text.strip(): self._is_sending_chat_tts_text = False return try: - # 1. Send comfort TTS (real Chinese text, not "...") - # logger.info(f"[Voice] Sending comfort TTS...") - # await self.realtime_client.chat_tts_text( - # content="请稍等,让我查一下。", - # start=True, - # end=False, - # ) - # await self.realtime_client.chat_tts_text( - # content="", - # start=False, - # end=True, - # ) - - # 2. Call v3 agent (this may take a while) logger.info(f"[Voice] Calling v3 agent with text: '{text}'") - agent_result = await self._call_v3_agent(text) - logger.info(f"[Voice] Agent result ({len(agent_result)} chars): {agent_result[:200]}") - if self._on_agent_result and agent_result: - await self._on_agent_result(agent_result) + accumulated_text = [] # full agent output for on_agent_result callback + sentence_buf = "" # buffer for accumulating until sentence boundary + tts_started = False # whether we've sent the first TTS chunk + tag_filter = _StreamTagFilter() - # 3. Send agent result directly as TTS (bypass LLM) - if agent_result: - clean_result = self._extract_answer(agent_result) - logger.info(f"[Voice] Sending agent result as TTS ({len(clean_result)} chars)") + async for chunk in self._stream_v3_agent(text): + accumulated_text.append(chunk) + + if self._on_agent_stream: + await self._on_agent_stream(chunk) + + # Filter out [TOOL_CALL], [TOOL_RESPONSE], [THINK] etc., only keep [ANSWER] content + passthrough = tag_filter.feed(chunk) + + if not passthrough: + continue + + sentence_buf += passthrough + + # Check for sentence boundaries and send complete sentences to TTS + while True: + match = self._SENTENCE_END_RE.search(sentence_buf) + if not match: + break + # Split at sentence boundary (include the punctuation) + end_pos = match.end() + sentence = sentence_buf[:end_pos].strip() + sentence_buf = sentence_buf[end_pos:] + + if sentence: + logger.info(f"[Voice] Sending TTS sentence: '{sentence[:80]}'") + await self.realtime_client.chat_tts_text( + content=sentence, + start=not tts_started, + end=False, + ) + tts_started = True + + # Handle remaining text in buffer (last sentence without ending punctuation) + remaining = sentence_buf.strip() + if remaining: + logger.info(f"[Voice] Sending TTS remaining: '{remaining[:80]}'") await self.realtime_client.chat_tts_text( - content=clean_result, - start=True, + content=remaining, + start=not tts_started, end=False, ) + tts_started = True + + # Send TTS end signal + if tts_started: await self.realtime_client.chat_tts_text( content="", start=False, end=True, ) else: - logger.warning(f"[Voice] Agent returned empty result") + logger.warning(f"[Voice] Agent returned no usable text for TTS") self._is_sending_chat_tts_text = False + + # Emit full agent result + full_result = "".join(accumulated_text) + logger.info(f"[Voice] Agent result ({len(full_result)} chars): {full_result[:200]}") + if self._on_agent_result and full_result: + await self._on_agent_result(full_result) + await self._emit_status("idle") except asyncio.CancelledError: @@ -253,8 +355,8 @@ class VoiceSession: self._is_sending_chat_tts_text = False await self._emit_error(f"Agent call failed: {str(e)}") - async def _call_v3_agent(self, user_text: str) -> str: - """Call v3 agent API in streaming mode, accumulate text and return full result""" + async def _stream_v3_agent(self, user_text: str) -> AsyncGenerator[str, None]: + """Call v3 agent API in streaming mode, yield text chunks as they arrive""" try: from utils.api_models import ChatRequestV3, Message from utils.fastapi_utils import ( @@ -294,8 +396,6 @@ class VoiceSession: ) config.stream = True - # Consume the async generator, parse SSE chunks to accumulate content - accumulated_text = [] async for sse_line in enhanced_generate_stream_response(config): if not sse_line or not sse_line.startswith("data: "): continue @@ -309,20 +409,15 @@ class VoiceSession: delta = choices[0].get("delta", {}) content = delta.get("content", "") if content: - accumulated_text.append(content) - if self._on_agent_stream: - await self._on_agent_stream(content) + yield content except (json.JSONDecodeError, KeyError): continue - return "".join(accumulated_text) - except asyncio.CancelledError: logger.info(f"[Voice] v3 agent call cancelled") raise except Exception as e: logger.error(f"[Voice] Error calling v3 agent: {e}", exc_info=True) - return "" @staticmethod def _extract_answer(agent_result: str) -> str: