import asyncio import json import logging import re import uuid from typing import Optional, Callable, Awaitable, AsyncGenerator from services.realtime_voice_client import RealtimeDialogClient logger = logging.getLogger('app') class _StreamTagFilter: """ Filters streaming text based on tag blocks. Only passes through content inside [ANSWER] blocks. If no tags are found at all, passes through everything (fallback). Skips content inside [TOOL_CALL], [TOOL_RESPONSE], [THINK], [SOURCE], etc. """ SKIP_TAGS = {"TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE"} def __init__(self): self.state = "idle" # idle, answer, skip self.found_any_tag = False self._pending = "" # buffer for partial tag like "[TOO..." def feed(self, chunk: str) -> str: """Feed a chunk, return text that should be passed to TTS.""" self._pending += chunk output = [] while self._pending: if self.state in ("idle", "answer"): bracket_pos = self._pending.find("[") if bracket_pos == -1: if self.state == "answer" or not self.found_any_tag: output.append(self._pending) self._pending = "" else: before = self._pending[:bracket_pos] if before and (self.state == "answer" or not self.found_any_tag): output.append(before) close_pos = self._pending.find("]", bracket_pos) if close_pos == -1: # Incomplete tag — wait for next chunk self._pending = self._pending[bracket_pos:] break tag_name = self._pending[bracket_pos + 1:close_pos] self._pending = self._pending[close_pos + 1:] self.found_any_tag = True if tag_name == "ANSWER": self.state = "answer" else: self.state = "skip" elif self.state == "skip": bracket_pos = self._pending.find("[") if bracket_pos == -1: self._pending = "" else: close_pos = self._pending.find("]", bracket_pos) if close_pos == -1: self._pending = self._pending[bracket_pos:] break tag_name = self._pending[bracket_pos + 1:close_pos] self._pending = self._pending[close_pos + 1:] if tag_name == "ANSWER": self.state = "answer" else: self.state = "skip" return "".join(output) class VoiceSession: """Manages a single voice dialogue session lifecycle""" def __init__( self, bot_id: str, session_id: Optional[str] = None, user_identifier: Optional[str] = None, on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None, on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None, on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None, on_agent_stream: Optional[Callable[[str], Awaitable[None]]] = None, on_llm_text: Optional[Callable[[str], Awaitable[None]]] = None, on_status: Optional[Callable[[str], Awaitable[None]]] = None, on_error: Optional[Callable[[str], Awaitable[None]]] = None, ): self.bot_id = bot_id self.session_id = session_id or str(uuid.uuid4()) self.user_identifier = user_identifier or "" self.realtime_client: Optional[RealtimeDialogClient] = None self._bot_config: dict = {} # Callbacks self._on_audio = on_audio self._on_asr_text = on_asr_text self._on_agent_result = on_agent_result self._on_agent_stream = on_agent_stream self._on_llm_text = on_llm_text self._on_status = on_status self._on_error = on_error self._running = False self._is_user_querying = False self._current_asr_text = "" # When True, discard TTS audio from SERVER_ACK (comfort speech period) self._is_sending_chat_tts_text = False self._receive_task: Optional[asyncio.Task] = None self._agent_task: Optional[asyncio.Task] = None async def start(self) -> None: """Fetch bot config, connect to Volcengine and start receiving responses""" from utils.fastapi_utils import fetch_bot_config_from_db self._bot_config = await fetch_bot_config_from_db(self.bot_id, self.user_identifier) self.realtime_client = RealtimeDialogClient( session_id=self.session_id, speaker=self._bot_config.get("voice_speaker"), system_role=self._bot_config.get("voice_system_role"), speaking_style=self._bot_config.get("voice_speaking_style"), bot_name=self._bot_config.get("name", ""), ) await self.realtime_client.connect() self._running = True self._receive_task = asyncio.create_task(self._receive_loop()) await self._emit_status("ready") async def stop(self) -> None: """Gracefully stop the session""" self._running = False try: await self.realtime_client.finish_session() await asyncio.sleep(0.5) await self.realtime_client.finish_connection() except Exception as e: logger.warning(f"Error during session cleanup: {e}") finally: if self._agent_task and not self._agent_task.done(): from utils.cancel_manager import trigger_cancel trigger_cancel(self.session_id) self._agent_task.cancel() if self._receive_task and not self._receive_task.done(): self._receive_task.cancel() await self.realtime_client.close() async def handle_audio(self, audio_data: bytes) -> None: """Forward user audio to Volcengine""" if self._running and self.realtime_client.ws: await self.realtime_client.send_audio(audio_data) async def handle_text(self, text: str) -> None: """Handle text input - send as text query""" if self._running and self.realtime_client.ws: await self.realtime_client.chat_text_query(text) async def _receive_loop(self) -> None: """Continuously receive and dispatch Volcengine responses""" try: while self._running: response = await self.realtime_client.receive_response() if not response: continue await self._handle_response(response) except asyncio.CancelledError: logger.info(f"Voice session receive loop cancelled: {self.session_id}") except Exception as e: logger.error(f"Voice session receive loop error: {e}") await self._emit_error(f"Connection error: {str(e)}") finally: self._running = False async def _handle_response(self, response: dict) -> None: msg_type = response.get('message_type', '') event = response.get('event') payload_msg = response.get('payload_msg', {}) if msg_type == 'SERVER_ACK' and isinstance(payload_msg, bytes): # TTS audio data — discard during comfort speech period if self._is_sending_chat_tts_text: return if self._on_audio: await self._on_audio(payload_msg) elif msg_type == 'SERVER_FULL_RESPONSE': logger.info(f"[Voice] event={event}, payload_msg={payload_msg if not isinstance(payload_msg, bytes) else f'<{len(payload_msg)} bytes>'}") if event == 450: # User started speaking — clear audio, set querying flag, reset ASR accumulator self._is_user_querying = True self._current_asr_text = "" await self._emit_status("listening") elif event == 451: # Streaming ASR result — accumulate recognized text if isinstance(payload_msg, dict): results = payload_msg.get("results", []) if results and isinstance(results, list) and len(results) > 0: text = results[0].get("text", "") if text: self._current_asr_text = text logger.debug(f"[Voice] ASR streaming (451): '{text}'") elif event == 459: # ASR completed — use accumulated text from event 451 self._is_user_querying = False asr_text = self._current_asr_text # Filter out ASR during thinking/speaking (TTS echo protection) if self._is_sending_chat_tts_text: logger.info(f"[Voice] Discarding ASR during thinking/speaking: '{asr_text}'") return logger.info(f"[Voice] ASR result: '{asr_text}'") if self._on_asr_text and asr_text: await self._on_asr_text(asr_text) await self._emit_status("thinking") # Cancel previous agent task if still running if self._agent_task and not self._agent_task.done(): logger.info(f"[Voice] Interrupting previous agent task") from utils.cancel_manager import trigger_cancel trigger_cancel(self.session_id) self._agent_task.cancel() self._agent_task = None # Trigger comfort TTS + agent call self._is_sending_chat_tts_text = True self._agent_task = asyncio.create_task(self._on_asr_text_received(asr_text)) elif event == 350: # TTS segment completed tts_type = "" if isinstance(payload_msg, dict): tts_type = payload_msg.get("tts_type", "") logger.info(f"[Voice] TTS segment done, type={tts_type}, is_sending={self._is_sending_chat_tts_text}") # When comfort TTS or RAG TTS finishes, stop discarding audio if self._is_sending_chat_tts_text and tts_type == "chat_tts_text": self._is_sending_chat_tts_text = False logger.info(f"[Voice] Comfort/RAG TTS done, resuming audio forwarding") elif event == 359: # TTS fully completed (all segments done) logger.info(f"[Voice] TTS fully completed") # await self._emit_status("idle") elif event in (152, 153): logger.info(f"[Voice] Session finished event: {event}") self._running = False elif msg_type == 'SERVER_ERROR': error_msg = str(payload_msg) if payload_msg else "Unknown server error" logger.error(f"[Voice] Server error: {error_msg}") await self._emit_error(error_msg) # Sentence-ending punctuation pattern for splitting TTS _SENTENCE_END_RE = re.compile(r'[。!?;\n.!?;]') # Markdown syntax to strip before TTS _MD_CLEAN_RE = re.compile(r'#{1,6}\s*|(?\s*|^\s*[-*+]\s+|^\s*\d+\.\s+|\[([^\]]*)\]\([^)]*\)|!\[([^\]]*)\]\([^)]*\)', re.MULTILINE) @staticmethod def _clean_markdown(text: str) -> str: """Strip Markdown formatting characters for TTS readability.""" # Replace links/images with their display text text = re.sub(r'!\[([^\]]*)\]\([^)]*\)', r'\1', text) text = re.sub(r'\[([^\]]*)\]\([^)]*\)', r'\1', text) # Remove headings, bold, italic, strikethrough, code marks, blockquote text = re.sub(r'#{1,6}\s*', '', text) text = re.sub(r'\*{1,3}|_{1,3}|~~|`{1,3}', '', text) text = re.sub(r'^>\s*', '', text, flags=re.MULTILINE) # Remove list markers text = re.sub(r'^\s*[-*+]\s+', '', text, flags=re.MULTILINE) text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE) # Remove horizontal rules text = re.sub(r'^[\s]*[-*_]{3,}[\s]*$', '', text, flags=re.MULTILINE) # Collapse extra whitespace text = re.sub(r'\n{2,}', '\n', text) return text.strip() async def _on_asr_text_received(self, text: str) -> None: """Called when ASR text is received — stream agent output, send TTS sentence by sentence""" if not text.strip(): self._is_sending_chat_tts_text = False return try: logger.info(f"[Voice] Calling v3 agent with text: '{text}'") accumulated_text = [] # full agent output for on_agent_result callback sentence_buf = "" # buffer for accumulating until sentence boundary tts_started = False # whether we've sent the first TTS chunk tag_filter = _StreamTagFilter() async for chunk in self._stream_v3_agent(text): accumulated_text.append(chunk) if self._on_agent_stream: await self._on_agent_stream(chunk) # Filter out [TOOL_CALL], [TOOL_RESPONSE], [THINK] etc., only keep [ANSWER] content passthrough = tag_filter.feed(chunk) if not passthrough: continue sentence_buf += passthrough # Check for sentence boundaries and send complete sentences to TTS while True: match = self._SENTENCE_END_RE.search(sentence_buf) if not match: break # Split at sentence boundary (include the punctuation) end_pos = match.end() sentence = sentence_buf[:end_pos].strip() sentence_buf = sentence_buf[end_pos:] if sentence: sentence = self._clean_markdown(sentence) if sentence: logger.info(f"[Voice] Sending TTS sentence: '{sentence[:80]}'") await self.realtime_client.chat_tts_text( content=sentence, start=not tts_started, end=False, ) tts_started = True # Handle remaining text in buffer (last sentence without ending punctuation) remaining = sentence_buf.strip() if remaining: remaining = self._clean_markdown(remaining) if remaining: logger.info(f"[Voice] Sending TTS remaining: '{remaining[:80]}'") await self.realtime_client.chat_tts_text( content=remaining, start=not tts_started, end=False, ) tts_started = True # Send TTS end signal if tts_started: await self.realtime_client.chat_tts_text( content="", start=False, end=True, ) else: logger.warning(f"[Voice] Agent returned no usable text for TTS") self._is_sending_chat_tts_text = False # Emit full agent result full_result = "".join(accumulated_text) logger.info(f"[Voice] Agent result ({len(full_result)} chars): {full_result[:200]}") if self._on_agent_result and full_result: await self._on_agent_result(full_result) await self._emit_status("idle") except asyncio.CancelledError: logger.info(f"[Voice] Agent task cancelled (user interrupted)") self._is_sending_chat_tts_text = False raise except Exception as e: logger.error(f"[Voice] Error in ASR text callback: {e}", exc_info=True) self._is_sending_chat_tts_text = False await self._emit_error(f"Agent call failed: {str(e)}") async def _stream_v3_agent(self, user_text: str) -> AsyncGenerator[str, None]: """Call v3 agent API in streaming mode, yield text chunks as they arrive""" try: from utils.api_models import ChatRequestV3, Message from utils.fastapi_utils import ( process_messages, create_project_directory, ) from agent.agent_config import AgentConfig from routes.chat import enhanced_generate_stream_response bot_config = self._bot_config language = bot_config.get("language", "zh") messages_obj = [Message(role="user", content=user_text)] request = ChatRequestV3( messages=messages_obj, bot_id=self.bot_id, stream=True, session_id=self.session_id, user_identifier=self.user_identifier, ) project_dir = create_project_directory( bot_config.get("dataset_ids", []), self.bot_id, bot_config.get("skills", []), ) processed_messages = process_messages(messages_obj, language) config = await AgentConfig.from_v3_request( request, bot_config, project_dir, processed_messages, language, ) config.stream = True async for sse_line in enhanced_generate_stream_response(config): if not sse_line or not sse_line.startswith("data: "): continue data_str = sse_line.strip().removeprefix("data: ") if data_str == "[DONE]": break try: data = json.loads(data_str) choices = data.get("choices", []) if choices: delta = choices[0].get("delta", {}) content = delta.get("content", "") if content: yield content except (json.JSONDecodeError, KeyError): continue except asyncio.CancelledError: logger.info(f"[Voice] v3 agent call cancelled") raise except Exception as e: logger.error(f"[Voice] Error calling v3 agent: {e}", exc_info=True) @staticmethod def _extract_answer(agent_result: str) -> str: """Extract the answer portion from agent result, stripping tags like [ANSWER], [THINK] etc.""" lines = agent_result.split('\n') answer_lines = [] in_answer = False for line in lines: if line.strip().startswith('[ANSWER]'): in_answer = True rest = line.strip()[len('[ANSWER]'):].strip() if rest: answer_lines.append(rest) continue if line.strip().startswith('[') and not line.strip().startswith('[ANSWER]'): in_answer = False continue if in_answer: answer_lines.append(line) if answer_lines: return '\n'.join(answer_lines).strip() return agent_result.strip() async def _emit_status(self, status: str) -> None: if self._on_status: await self._on_status(status) async def _emit_error(self, message: str) -> None: if self._on_error: await self._on_error(message)