import asyncio import logging import struct import uuid from typing import Optional, Callable, Awaitable import webrtcvad from services.streaming_asr_client import StreamingASRClient from services.streaming_tts_client import StreamingTTSClient from services.voice_utils import ( StreamTagFilter, clean_markdown, stream_v3_agent, TTSSentenceSplitter, ) from utils.settings import VOICE_LITE_SILENCE_TIMEOUT logger = logging.getLogger('app') class VoiceLiteSession: """Voice Lite session: ASR -> Agent -> TTS pipeline (cheaper alternative).""" def __init__( self, bot_id: str, session_id: Optional[str] = None, user_identifier: Optional[str] = None, sample_rate: int = 24000, on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None, on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None, on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None, on_agent_stream: Optional[Callable[[str], Awaitable[None]]] = None, on_llm_text: Optional[Callable[[str], Awaitable[None]]] = None, on_status: Optional[Callable[[str], Awaitable[None]]] = None, on_error: Optional[Callable[[str], Awaitable[None]]] = None, ): self.bot_id = bot_id self.session_id = session_id or str(uuid.uuid4()) self.user_identifier = user_identifier or "" self._client_sample_rate = sample_rate self._bot_config: dict = {} self._speaker: str = "" # Callbacks self._on_audio = on_audio self._on_asr_text = on_asr_text self._on_agent_result = on_agent_result self._on_agent_stream = on_agent_stream self._on_llm_text = on_llm_text self._on_status = on_status self._on_error = on_error self._running = False self._asr_client: Optional[StreamingASRClient] = None self._asr_receive_task: Optional[asyncio.Task] = None self._agent_task: Optional[asyncio.Task] = None # Silence timeout tracking self._last_asr_time: float = 0 self._silence_timer_task: Optional[asyncio.Task] = None self._current_asr_text: str = "" self._last_text_change_time: float = 0 self._last_changed_text: str = "" self._last_asr_emit_time: float = 0 self._utterance_lock = asyncio.Lock() # VAD (Voice Activity Detection) via webrtcvad self._vad = webrtcvad.Vad(2) # aggressiveness 0-3 (2 = balanced) self._vad_speaking = False # Whether user is currently speaking self._vad_silence_start: float = 0 # When silence started self._vad_finish_task: Optional[asyncio.Task] = None self._pre_buffer: list = [] # Buffer audio before VAD triggers self._vad_voice_streak: int = 0 # Consecutive voiced chunks count self._vad_silence_streak: int = 0 # Consecutive silent chunks count async def start(self) -> None: """Fetch bot config, mark session as running.""" from utils.fastapi_utils import fetch_bot_config_from_db self._bot_config = await fetch_bot_config_from_db(self.bot_id, self.user_identifier) self._speaker = self._bot_config.get("voice_speaker", "") self._running = True await self._emit_status("ready") async def stop(self) -> None: """Gracefully stop the session.""" self._running = False if self._vad_finish_task and not self._vad_finish_task.done(): self._vad_finish_task.cancel() if self._silence_timer_task and not self._silence_timer_task.done(): self._silence_timer_task.cancel() if self._agent_task and not self._agent_task.done(): from utils.cancel_manager import trigger_cancel trigger_cancel(self.session_id) self._agent_task.cancel() if self._asr_receive_task and not self._asr_receive_task.done(): self._asr_receive_task.cancel() if self._asr_client: try: await self._asr_client.send_finish() except Exception: pass await self._asr_client.close() # VAD configuration VAD_SILENCE_DURATION = 3.0 # Seconds of silence before sending finish VAD_PRE_BUFFER_SIZE = 5 # Number of audio chunks to buffer before VAD triggers VAD_SOURCE_RATE = 24000 # Input audio sample rate VAD_TARGET_RATE = 16000 # webrtcvad supported sample rate VAD_FRAME_DURATION_MS = 30 # Frame duration for webrtcvad (10, 20, or 30 ms) VAD_SPEECH_CHUNKS = 3 # Consecutive voiced chunks required to start speech VAD_SILENCE_CHUNKS = 5 # Consecutive silent chunks required to confirm silence _audio_chunk_count = 0 @staticmethod def _resample_24k_to_16k(pcm_data: bytes) -> bytes: """Downsample 16-bit PCM from 24kHz to 16kHz (ratio 3:2). Takes every 2 out of 3 samples (simple decimation). """ n_samples = len(pcm_data) // 2 if n_samples == 0: return b'' samples = struct.unpack(f'<{n_samples}h', pcm_data[:n_samples * 2]) # Pick samples at indices 0, 1.5, 3, 4.5, ... -> floor(i * 3/2) for output index i out_len = (n_samples * 2) // 3 resampled = [] for i in range(out_len): src_idx = (i * 3) // 2 if src_idx < n_samples: resampled.append(samples[src_idx]) return struct.pack(f'<{len(resampled)}h', *resampled) @staticmethod def _resample_16k_to_24k(pcm_data: bytes) -> bytes: """Upsample 16-bit PCM from 16kHz to 24kHz (ratio 2:3). For every 2 input samples, produces 3 output samples using linear interpolation. """ n_samples = len(pcm_data) // 2 if n_samples == 0: return b'' samples = struct.unpack(f'<{n_samples}h', pcm_data[:n_samples * 2]) out_len = (n_samples * 3) // 2 resampled = [] for i in range(out_len): src_pos = (i * 2) / 3 src_idx = int(src_pos) frac = src_pos - src_idx if src_idx + 1 < n_samples: val = int(samples[src_idx] * (1 - frac) + samples[src_idx + 1] * frac) elif src_idx < n_samples: val = samples[src_idx] else: break resampled.append(max(-32768, min(32767, val))) return struct.pack(f'<{len(resampled)}h', *resampled) def _resample_input(self, audio_data: bytes) -> bytes: """Resample incoming audio to 24kHz if needed.""" if self._client_sample_rate == 16000: return self._resample_16k_to_24k(audio_data) return audio_data def _resample_output(self, audio_data: bytes) -> bytes: """Resample outgoing audio from 24kHz to client sample rate if needed.""" if self._client_sample_rate == 16000: return self._resample_24k_to_16k(audio_data) return audio_data def _webrtcvad_detect(self, pcm_data: bytes) -> bool: """Run webrtcvad on audio data. Returns True if voice is detected in any frame.""" resampled = self._resample_24k_to_16k(pcm_data) frame_size = (self.VAD_TARGET_RATE * self.VAD_FRAME_DURATION_MS // 1000) * 2 # bytes per frame if len(resampled) < frame_size: return False # Check frames; return True if any frame has voice voice_frames = 0 total_frames = 0 for offset in range(0, len(resampled) - frame_size + 1, frame_size): frame = resampled[offset:offset + frame_size] total_frames += 1 try: if self._vad.is_speech(frame, self.VAD_TARGET_RATE): voice_frames += 1 except Exception: pass # Consider voice detected if at least one frame has speech return voice_frames > 0 async def handle_audio(self, audio_data: bytes) -> None: """Forward user audio to ASR with VAD gating. Lazy-connect on speech start.""" if not self._running: return # Resample to 24kHz if client sends lower sample rate audio_data = self._resample_input(audio_data) self._audio_chunk_count += 1 has_voice = self._webrtcvad_detect(audio_data) now = asyncio.get_event_loop().time() # Update consecutive streaks if has_voice: self._vad_voice_streak += 1 self._vad_silence_streak = 0 else: self._vad_silence_streak += 1 self._vad_voice_streak = 0 if has_voice: # Cancel any pending finish if self._vad_finish_task and not self._vad_finish_task.done(): self._vad_finish_task.cancel() self._vad_finish_task = None if not self._vad_speaking and self._vad_voice_streak >= self.VAD_SPEECH_CHUNKS: # Speech just started — connect ASR self._vad_speaking = True logger.info(f"[VoiceLite] VAD: speech started (webrtcvad), connecting ASR...") try: await self._connect_asr() # Send buffered pre-speech audio for buffered in self._pre_buffer: await self._asr_client.send_audio(buffered) self._pre_buffer.clear() except Exception as e: logger.error(f"[VoiceLite] VAD: ASR connect failed: {e}", exc_info=True) self._vad_speaking = False return # Send current chunk if self._asr_client: try: await self._asr_client.send_audio(audio_data) except Exception: pass self._vad_silence_start = 0 else: if self._vad_speaking: # Brief silence while speaking — keep sending for ASR context if self._asr_client: try: await self._asr_client.send_audio(audio_data) except Exception: pass if self._vad_silence_start == 0: self._vad_silence_start = now # Require both consecutive silent chunks AND time threshold if (self._vad_silence_streak >= self.VAD_SILENCE_CHUNKS and (now - self._vad_silence_start) >= self.VAD_SILENCE_DURATION): if not self._vad_finish_task or self._vad_finish_task.done(): self._vad_finish_task = asyncio.create_task(self._vad_send_finish()) else: # Not speaking — buffer recent audio for pre-speech context self._pre_buffer.append(audio_data) if len(self._pre_buffer) > self.VAD_PRE_BUFFER_SIZE: self._pre_buffer.pop(0) async def _vad_send_finish(self) -> None: """Send finish signal to ASR after silence detected.""" logger.info(f"[VoiceLite] VAD: silence detected, sending finish to ASR") self._vad_speaking = False self._vad_silence_start = 0 self._vad_voice_streak = 0 self._vad_silence_streak = 0 if self._asr_client: try: await self._asr_client.send_finish() except Exception as e: logger.warning(f"[VoiceLite] VAD: send_finish failed: {e}") async def handle_text(self, text: str) -> None: """Handle direct text input - bypass ASR and go straight to agent.""" if not self._running: return await self._interrupt_current() self._agent_task = asyncio.create_task(self._process_utterance(text)) async def _connect_asr(self) -> None: """Create and connect a new ASR client, start receive loop.""" if self._asr_client: try: await self._asr_client.close() except Exception: pass if self._asr_receive_task and not self._asr_receive_task.done(): self._asr_receive_task.cancel() self._asr_client = StreamingASRClient(uid=self.user_identifier or "voice_lite") await self._asr_client.connect() logger.info(f"[VoiceLite] ASR client connected") # Start receive loop for this ASR session self._asr_receive_task = asyncio.create_task(self._asr_receive_loop()) async def _asr_receive_loop(self) -> None: """Receive ASR results from the current ASR session.""" try: _listening_emitted = False async for text, is_final in self._asr_client.receive_results(): if not self._running: return if not _listening_emitted: await self._emit_status("listening") _listening_emitted = True now = asyncio.get_event_loop().time() self._last_asr_time = now self._current_asr_text = text # Track text changes for stability detection if text != self._last_changed_text: self._last_changed_text = text self._last_text_change_time = now # Reset stability timer on every result self._reset_silence_timer() except asyncio.CancelledError: pass except Exception as e: if self._running: logger.warning(f"[VoiceLite] ASR session ended: {e}") finally: logger.info(f"[VoiceLite] ASR session done") # Clean up ASR client after session ends if self._asr_client: try: await self._asr_client.close() except Exception: pass self._asr_client = None def _reset_silence_timer(self) -> None: """Reset the silence timeout timer.""" if self._silence_timer_task and not self._silence_timer_task.done(): self._silence_timer_task.cancel() self._silence_timer_task = asyncio.create_task(self._silence_timeout()) async def _silence_timeout(self) -> None: """Wait for silence timeout, then check if text has been stable.""" try: await asyncio.sleep(VOICE_LITE_SILENCE_TIMEOUT) if not self._running: return # Check if text has been stable (unchanged) for the timeout period now = asyncio.get_event_loop().time() if (self._current_asr_text and (now - self._last_text_change_time) >= VOICE_LITE_SILENCE_TIMEOUT): logger.info(f"[VoiceLite] Text stable for {VOICE_LITE_SILENCE_TIMEOUT}s, processing: '{self._current_asr_text}'") await self._on_utterance_complete(self._current_asr_text) except asyncio.CancelledError: pass async def _on_utterance_complete(self, text: str) -> None: """Called when a complete utterance is detected.""" if not text.strip(): return async with self._utterance_lock: # Cancel silence timer if self._silence_timer_task and not self._silence_timer_task.done(): self._silence_timer_task.cancel() # Interrupt any in-progress agent+TTS await self._interrupt_current() # Send final ASR text to frontend if self._on_asr_text: await self._on_asr_text(text) self._current_asr_text = "" self._last_changed_text = "" self._agent_task = asyncio.create_task(self._process_utterance(text)) async def _interrupt_current(self) -> None: """Cancel current agent+TTS task if running.""" if self._agent_task and not self._agent_task.done(): logger.info(f"[VoiceLite] Interrupting previous agent task") from utils.cancel_manager import trigger_cancel trigger_cancel(self.session_id) self._agent_task.cancel() try: await self._agent_task except (asyncio.CancelledError, Exception): pass self._agent_task = None async def _process_utterance(self, text: str) -> None: """Process a complete utterance: agent -> TTS pipeline.""" try: logger.info(f"[VoiceLite] Processing utterance: '{text}'") await self._emit_status("thinking") accumulated_text = [] tag_filter = StreamTagFilter() splitter = TTSSentenceSplitter() tts_client = StreamingTTSClient(speaker=self._speaker) speaking = False async for chunk in stream_v3_agent( user_text=text, bot_id=self.bot_id, bot_config=self._bot_config, session_id=self.session_id, user_identifier=self.user_identifier, ): accumulated_text.append(chunk) if self._on_agent_stream: await self._on_agent_stream(chunk) passthrough = tag_filter.feed(chunk) if not passthrough: if tag_filter.answer_ended: for sentence in splitter.flush(): sentence = clean_markdown(sentence) if sentence: if not speaking: await self._emit_status("speaking") speaking = True await self._send_tts(tts_client, sentence) continue # Feed raw passthrough to splitter (preserve newlines for splitting), # apply clean_markdown on output sentences for sentence in splitter.feed(passthrough): sentence = clean_markdown(sentence) if sentence: if not speaking: await self._emit_status("speaking") speaking = True await self._send_tts(tts_client, sentence) # Handle remaining text for sentence in splitter.flush(): sentence = clean_markdown(sentence) if sentence: if not speaking: await self._emit_status("speaking") speaking = True await self._send_tts(tts_client, sentence) # Log full agent result (not sent to frontend, already streamed) full_result = "".join(accumulated_text) logger.info(f"[VoiceLite] Agent done ({len(full_result)} chars)") # Notify frontend that agent text stream is complete if self._on_agent_result: await self._on_agent_result(full_result) await self._emit_status("idle") except asyncio.CancelledError: logger.info(f"[VoiceLite] Agent task cancelled (user interrupted)") raise except Exception as e: logger.error(f"[VoiceLite] Error processing utterance: {e}", exc_info=True) await self._emit_error(f"Processing failed: {str(e)}") async def _send_tts(self, tts_client: StreamingTTSClient, sentence: str) -> None: """Synthesize a sentence and emit audio chunks.""" logger.info(f"[VoiceLite] TTS sentence: '{sentence[:80]}'") if self._client_sample_rate != 24000: # Client needs non-24kHz: use raw int16 pipeline to allow resampling async for audio_chunk in tts_client.synthesize_raw(sentence): if self._on_audio: await self._on_audio(self._resample_output(audio_chunk)) else: async for audio_chunk in tts_client.synthesize(sentence): if self._on_audio: await self._on_audio(audio_chunk) async def _emit_status(self, status: str) -> None: if self._on_status: await self._on_status(status) async def _emit_error(self, message: str) -> None: if self._on_error: await self._on_error(message)