import asyncio import logging import uuid from typing import Optional, Callable, Awaitable from services.realtime_voice_client import RealtimeDialogClient from services.voice_utils import StreamTagFilter, clean_markdown, stream_v3_agent, SENTENCE_END_RE logger = logging.getLogger('app') class VoiceSession: """Manages a single voice dialogue session lifecycle""" def __init__( self, bot_id: str, session_id: Optional[str] = None, user_identifier: Optional[str] = None, on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None, on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None, on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None, on_agent_stream: Optional[Callable[[str], Awaitable[None]]] = None, on_llm_text: Optional[Callable[[str], Awaitable[None]]] = None, on_status: Optional[Callable[[str], Awaitable[None]]] = None, on_error: Optional[Callable[[str], Awaitable[None]]] = None, ): self.bot_id = bot_id self.session_id = session_id or str(uuid.uuid4()) self.user_identifier = user_identifier or "" self.realtime_client: Optional[RealtimeDialogClient] = None self._bot_config: dict = {} # Callbacks self._on_audio = on_audio self._on_asr_text = on_asr_text self._on_agent_result = on_agent_result self._on_agent_stream = on_agent_stream self._on_llm_text = on_llm_text self._on_status = on_status self._on_error = on_error self._running = False self._is_user_querying = False self._current_asr_text = "" # When True, discard TTS audio from SERVER_ACK (comfort speech period) self._is_sending_chat_tts_text = False # Set to True when event 350 fires for chat_tts_text, indicating the TTS segment is done # and next TTS send must use start=True to begin a new session self._tts_segment_done = False # Signaled when event 359 fires (TTS fully completed), used to wait before starting new TTS self._tts_complete_event: asyncio.Event = asyncio.Event() self._receive_task: Optional[asyncio.Task] = None self._agent_task: Optional[asyncio.Task] = None async def start(self) -> None: """Fetch bot config, connect to Volcengine and start receiving responses""" from utils.fastapi_utils import fetch_bot_config_from_db self._bot_config = await fetch_bot_config_from_db(self.bot_id, self.user_identifier) self.realtime_client = RealtimeDialogClient( session_id=self.session_id, speaker=self._bot_config.get("voice_speaker"), system_role=self._bot_config.get("voice_system_role"), speaking_style=self._bot_config.get("voice_speaking_style"), bot_name=self._bot_config.get("name", ""), ) await self.realtime_client.connect() self._running = True self._receive_task = asyncio.create_task(self._receive_loop()) await self._emit_status("ready") async def stop(self) -> None: """Gracefully stop the session""" self._running = False try: await self.realtime_client.finish_session() await asyncio.sleep(0.5) await self.realtime_client.finish_connection() except Exception as e: logger.warning(f"Error during session cleanup: {e}") finally: if self._agent_task and not self._agent_task.done(): from utils.cancel_manager import trigger_cancel trigger_cancel(self.session_id) self._agent_task.cancel() if self._receive_task and not self._receive_task.done(): self._receive_task.cancel() await self.realtime_client.close() async def handle_audio(self, audio_data: bytes) -> None: """Forward user audio to Volcengine""" if self._running and self.realtime_client.ws: await self.realtime_client.send_audio(audio_data) async def handle_text(self, text: str) -> None: """Handle text input - send as text query""" if self._running and self.realtime_client.ws: await self.realtime_client.chat_text_query(text) async def _receive_loop(self) -> None: """Continuously receive and dispatch Volcengine responses""" try: while self._running: response = await self.realtime_client.receive_response() if not response: continue await self._handle_response(response) except asyncio.CancelledError: logger.info(f"Voice session receive loop cancelled: {self.session_id}") except Exception as e: logger.error(f"Voice session receive loop error: {e}") await self._emit_error(f"Connection error: {str(e)}") finally: self._running = False async def _handle_response(self, response: dict) -> None: msg_type = response.get('message_type', '') event = response.get('event') payload_msg = response.get('payload_msg', {}) if msg_type == 'SERVER_ACK' and isinstance(payload_msg, bytes): # TTS audio data — discard during comfort speech period if self._is_sending_chat_tts_text: return if self._on_audio: await self._on_audio(payload_msg) elif msg_type == 'SERVER_FULL_RESPONSE': logger.info(f"[Voice] event={event}, payload_msg={payload_msg if not isinstance(payload_msg, bytes) else f'<{len(payload_msg)} bytes>'}") if event == 450: # User started speaking — clear audio, set querying flag, reset ASR accumulator self._is_user_querying = True self._current_asr_text = "" await self._emit_status("listening") elif event == 451: # Streaming ASR result — accumulate recognized text if isinstance(payload_msg, dict): results = payload_msg.get("results", []) if results and isinstance(results, list) and len(results) > 0: text = results[0].get("text", "") if text: self._current_asr_text = text logger.debug(f"[Voice] ASR streaming (451): '{text}'") elif event == 459: # ASR completed — use accumulated text from event 451 self._is_user_querying = False asr_text = self._current_asr_text # Filter out ASR during thinking/speaking (TTS echo protection) if self._is_sending_chat_tts_text: logger.info(f"[Voice] Discarding ASR during thinking/speaking: '{asr_text}'") return logger.info(f"[Voice] ASR result: '{asr_text}'") if self._on_asr_text and asr_text: await self._on_asr_text(asr_text) await self._emit_status("thinking") # Cancel previous agent task if still running if self._agent_task and not self._agent_task.done(): logger.info(f"[Voice] Interrupting previous agent task") from utils.cancel_manager import trigger_cancel trigger_cancel(self.session_id) self._agent_task.cancel() self._agent_task = None # Trigger comfort TTS + agent call self._is_sending_chat_tts_text = True self._agent_task = asyncio.create_task(self._on_asr_text_received(asr_text)) elif event == 350: # TTS segment completed tts_type = "" if isinstance(payload_msg, dict): tts_type = payload_msg.get("tts_type", "") logger.info(f"[Voice] TTS segment done, type={tts_type}, is_sending={self._is_sending_chat_tts_text}") # When comfort TTS or RAG TTS finishes, stop discarding audio if self._is_sending_chat_tts_text and tts_type == "chat_tts_text": self._is_sending_chat_tts_text = False logger.info(f"[Voice] Comfort/RAG TTS done, resuming audio forwarding") # Mark TTS segment as done so next send uses start=True if tts_type == "chat_tts_text": self._tts_segment_done = True elif event == 359: # TTS fully completed (all segments done) logger.info(f"[Voice] TTS fully completed") self._tts_complete_event.set() elif event in (152, 153): logger.info(f"[Voice] Session finished event: {event}") self._running = False elif msg_type == 'SERVER_ERROR': error_msg = str(payload_msg) if payload_msg else "Unknown server error" logger.error(f"[Voice] Server error: {error_msg}") await self._emit_error(error_msg) async def _on_asr_text_received(self, text: str) -> None: """Called when ASR text is received — stream agent output, send TTS sentence by sentence""" if not text.strip(): self._is_sending_chat_tts_text = False return try: logger.info(f"[Voice] Calling v3 agent with text: '{text}'") accumulated_text = [] # full agent output for on_agent_result callback sentence_buf = "" # buffer for accumulating until sentence boundary tts_started = False # whether we've sent the first TTS chunk tag_filter = StreamTagFilter() async for chunk in stream_v3_agent( user_text=text, bot_id=self.bot_id, bot_config=self._bot_config, session_id=self.session_id, user_identifier=self.user_identifier, ): accumulated_text.append(chunk) if self._on_agent_stream: await self._on_agent_stream(chunk) # Filter out [TOOL_CALL], [TOOL_RESPONSE], [THINK] etc., only keep [ANSWER] content passthrough = tag_filter.feed(chunk) if not passthrough: # ANSWER block ended (e.g. hit [TOOL_CALL]), flush sentence_buf immediately if tag_filter.answer_ended and sentence_buf: flush = sentence_buf.strip() sentence_buf = "" if flush: flush = clean_markdown(flush) if flush: if tts_started and self._tts_segment_done: logger.info(f"[Voice] TTS segment done, closing session and waiting for delivery (answer ended)") await self.realtime_client.chat_tts_text(content="", start=False, end=True) self._tts_complete_event.clear() try: await asyncio.wait_for(self._tts_complete_event.wait(), timeout=10) except asyncio.TimeoutError: logger.warning(f"[Voice] Timeout waiting for TTS complete, proceeding anyway") tts_started = False self._tts_segment_done = False logger.info(f"[Voice] TTS delivery done, starting new session (answer ended)") logger.info(f"[Voice] Sending TTS sentence (answer ended): '{flush[:80]}'") await self.realtime_client.chat_tts_text( content=flush, start=not tts_started, end=False, ) if not tts_started: await self._emit_status("speaking") tts_started = True self._tts_segment_done = False continue sentence_buf += passthrough # Check for sentence boundaries and send complete sentences to TTS while True: match = SENTENCE_END_RE.search(sentence_buf) if not match: break # Split at sentence boundary (include the punctuation) end_pos = match.end() sentence = sentence_buf[:end_pos].strip() sentence_buf = sentence_buf[end_pos:] if sentence: sentence = clean_markdown(sentence) if sentence: # If previous TTS segment completed (e.g. gap during tool call), # close old session, wait for TTS delivery to finish, then restart if tts_started and self._tts_segment_done: logger.info(f"[Voice] TTS segment done, closing session and waiting for delivery") await self.realtime_client.chat_tts_text(content="", start=False, end=True) self._tts_complete_event.clear() try: await asyncio.wait_for(self._tts_complete_event.wait(), timeout=10) except asyncio.TimeoutError: logger.warning(f"[Voice] Timeout waiting for TTS complete, proceeding anyway") tts_started = False self._tts_segment_done = False logger.info(f"[Voice] TTS delivery done, starting new session") logger.info(f"[Voice] Sending TTS sentence: '{sentence[:80]}'") await self.realtime_client.chat_tts_text( content=sentence, start=not tts_started, end=False, ) if not tts_started: await self._emit_status("speaking") tts_started = True self._tts_segment_done = False # Handle remaining text in buffer (last sentence without ending punctuation) remaining = sentence_buf.strip() if remaining: remaining = clean_markdown(remaining) if remaining: # If previous TTS segment completed, close and wait before restart if tts_started and self._tts_segment_done: logger.info(f"[Voice] TTS segment done, closing session and waiting for delivery (remaining)") await self.realtime_client.chat_tts_text(content="", start=False, end=True) self._tts_complete_event.clear() try: await asyncio.wait_for(self._tts_complete_event.wait(), timeout=10) except asyncio.TimeoutError: logger.warning(f"[Voice] Timeout waiting for TTS complete, proceeding anyway") tts_started = False self._tts_segment_done = False logger.info(f"[Voice] TTS delivery done, starting new session for remaining") logger.info(f"[Voice] Sending TTS remaining: '{remaining[:80]}'") await self.realtime_client.chat_tts_text( content=remaining, start=not tts_started, end=False, ) if not tts_started: await self._emit_status("speaking") tts_started = True self._tts_segment_done = False # Send TTS end signal if tts_started: await self.realtime_client.chat_tts_text( content="", start=False, end=True, ) else: logger.warning(f"[Voice] Agent returned no usable text for TTS") self._is_sending_chat_tts_text = False # Emit full agent result full_result = "".join(accumulated_text) logger.info(f"[Voice] Agent result ({len(full_result)} chars): {full_result[:200]}") if self._on_agent_result and full_result: await self._on_agent_result(full_result) await self._emit_status("idle") except asyncio.CancelledError: logger.info(f"[Voice] Agent task cancelled (user interrupted)") self._is_sending_chat_tts_text = False raise except Exception as e: logger.error(f"[Voice] Error in ASR text callback: {e}", exc_info=True) self._is_sending_chat_tts_text = False await self._emit_error(f"Agent call failed: {str(e)}") @staticmethod def _extract_answer(agent_result: str) -> str: """Extract the answer portion from agent result, stripping tags like [ANSWER], [THINK] etc.""" lines = agent_result.split('\n') answer_lines = [] in_answer = False for line in lines: if line.strip().startswith('[ANSWER]'): in_answer = True rest = line.strip()[len('[ANSWER]'):].strip() if rest: answer_lines.append(rest) continue if line.strip().startswith('[') and not line.strip().startswith('[ANSWER]'): in_answer = False continue if in_answer: answer_lines.append(line) if answer_lines: return '\n'.join(answer_lines).strip() return agent_result.strip() async def _emit_status(self, status: str) -> None: if self._on_status: await self._on_status(status) async def _emit_error(self, message: str) -> None: if self._on_error: await self._on_error(message)