449 lines
18 KiB
Python
449 lines
18 KiB
Python
import asyncio
|
|
import logging
|
|
import struct
|
|
import uuid
|
|
from typing import Optional, Callable, Awaitable
|
|
|
|
import webrtcvad
|
|
|
|
from services.streaming_asr_client import StreamingASRClient
|
|
from services.streaming_tts_client import StreamingTTSClient
|
|
from services.voice_utils import (
|
|
StreamTagFilter,
|
|
clean_markdown,
|
|
stream_v3_agent,
|
|
TTSSentenceSplitter,
|
|
)
|
|
from utils.settings import VOICE_LITE_SILENCE_TIMEOUT
|
|
|
|
logger = logging.getLogger('app')
|
|
|
|
|
|
class VoiceLiteSession:
|
|
"""Voice Lite session: ASR -> Agent -> TTS pipeline (cheaper alternative)."""
|
|
|
|
def __init__(
|
|
self,
|
|
bot_id: str,
|
|
session_id: Optional[str] = None,
|
|
user_identifier: Optional[str] = None,
|
|
on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None,
|
|
on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
on_agent_stream: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
on_llm_text: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
on_status: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
on_error: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
):
|
|
self.bot_id = bot_id
|
|
self.session_id = session_id or str(uuid.uuid4())
|
|
self.user_identifier = user_identifier or ""
|
|
|
|
self._bot_config: dict = {}
|
|
self._speaker: str = ""
|
|
|
|
# Callbacks
|
|
self._on_audio = on_audio
|
|
self._on_asr_text = on_asr_text
|
|
self._on_agent_result = on_agent_result
|
|
self._on_agent_stream = on_agent_stream
|
|
self._on_llm_text = on_llm_text
|
|
self._on_status = on_status
|
|
self._on_error = on_error
|
|
|
|
self._running = False
|
|
self._asr_client: Optional[StreamingASRClient] = None
|
|
self._asr_receive_task: Optional[asyncio.Task] = None
|
|
self._agent_task: Optional[asyncio.Task] = None
|
|
|
|
# Silence timeout tracking
|
|
self._last_asr_time: float = 0
|
|
self._silence_timer_task: Optional[asyncio.Task] = None
|
|
self._current_asr_text: str = ""
|
|
self._last_text_change_time: float = 0
|
|
self._last_changed_text: str = ""
|
|
self._last_asr_emit_time: float = 0
|
|
self._utterance_lock = asyncio.Lock()
|
|
|
|
# VAD (Voice Activity Detection) via webrtcvad
|
|
self._vad = webrtcvad.Vad(2) # aggressiveness 0-3 (2 = balanced)
|
|
self._vad_speaking = False # Whether user is currently speaking
|
|
self._vad_silence_start: float = 0 # When silence started
|
|
self._vad_finish_task: Optional[asyncio.Task] = None
|
|
self._pre_buffer: list = [] # Buffer audio before VAD triggers
|
|
self._vad_voice_streak: int = 0 # Consecutive voiced chunks count
|
|
self._vad_silence_streak: int = 0 # Consecutive silent chunks count
|
|
|
|
async def start(self) -> None:
|
|
"""Fetch bot config, mark session as running."""
|
|
from utils.fastapi_utils import fetch_bot_config_from_db
|
|
|
|
self._bot_config = await fetch_bot_config_from_db(self.bot_id, self.user_identifier)
|
|
self._speaker = self._bot_config.get("voice_speaker", "")
|
|
|
|
self._running = True
|
|
await self._emit_status("ready")
|
|
|
|
async def stop(self) -> None:
|
|
"""Gracefully stop the session."""
|
|
self._running = False
|
|
|
|
if self._vad_finish_task and not self._vad_finish_task.done():
|
|
self._vad_finish_task.cancel()
|
|
|
|
if self._silence_timer_task and not self._silence_timer_task.done():
|
|
self._silence_timer_task.cancel()
|
|
|
|
if self._agent_task and not self._agent_task.done():
|
|
from utils.cancel_manager import trigger_cancel
|
|
trigger_cancel(self.session_id)
|
|
self._agent_task.cancel()
|
|
|
|
if self._asr_receive_task and not self._asr_receive_task.done():
|
|
self._asr_receive_task.cancel()
|
|
|
|
if self._asr_client:
|
|
try:
|
|
await self._asr_client.send_finish()
|
|
except Exception:
|
|
pass
|
|
await self._asr_client.close()
|
|
|
|
# VAD configuration
|
|
VAD_SILENCE_DURATION = 1.5 # Seconds of silence before sending finish
|
|
VAD_PRE_BUFFER_SIZE = 5 # Number of audio chunks to buffer before VAD triggers
|
|
VAD_SOURCE_RATE = 24000 # Input audio sample rate
|
|
VAD_TARGET_RATE = 16000 # webrtcvad supported sample rate
|
|
VAD_FRAME_DURATION_MS = 30 # Frame duration for webrtcvad (10, 20, or 30 ms)
|
|
VAD_SPEECH_CHUNKS = 3 # Consecutive voiced chunks required to start speech
|
|
VAD_SILENCE_CHUNKS = 5 # Consecutive silent chunks required to confirm silence
|
|
|
|
_audio_chunk_count = 0
|
|
|
|
@staticmethod
|
|
def _resample_24k_to_16k(pcm_data: bytes) -> bytes:
|
|
"""Downsample 16-bit PCM from 24kHz to 16kHz (ratio 3:2).
|
|
|
|
Takes every 2 out of 3 samples (simple decimation).
|
|
"""
|
|
n_samples = len(pcm_data) // 2
|
|
if n_samples == 0:
|
|
return b''
|
|
samples = struct.unpack(f'<{n_samples}h', pcm_data[:n_samples * 2])
|
|
# Pick samples at indices 0, 1.5, 3, 4.5, ... -> floor(i * 3/2) for output index i
|
|
out_len = (n_samples * 2) // 3
|
|
resampled = []
|
|
for i in range(out_len):
|
|
src_idx = (i * 3) // 2
|
|
if src_idx < n_samples:
|
|
resampled.append(samples[src_idx])
|
|
return struct.pack(f'<{len(resampled)}h', *resampled)
|
|
|
|
def _webrtcvad_detect(self, pcm_data: bytes) -> bool:
|
|
"""Run webrtcvad on audio data. Returns True if voice is detected in any frame."""
|
|
resampled = self._resample_24k_to_16k(pcm_data)
|
|
frame_size = (self.VAD_TARGET_RATE * self.VAD_FRAME_DURATION_MS // 1000) * 2 # bytes per frame
|
|
if len(resampled) < frame_size:
|
|
return False
|
|
# Check frames; return True if any frame has voice
|
|
voice_frames = 0
|
|
total_frames = 0
|
|
for offset in range(0, len(resampled) - frame_size + 1, frame_size):
|
|
frame = resampled[offset:offset + frame_size]
|
|
total_frames += 1
|
|
try:
|
|
if self._vad.is_speech(frame, self.VAD_TARGET_RATE):
|
|
voice_frames += 1
|
|
except Exception:
|
|
pass
|
|
# Consider voice detected if at least one frame has speech
|
|
return voice_frames > 0
|
|
|
|
async def handle_audio(self, audio_data: bytes) -> None:
|
|
"""Forward user audio to ASR with VAD gating. Lazy-connect on speech start."""
|
|
if not self._running:
|
|
return
|
|
|
|
self._audio_chunk_count += 1
|
|
has_voice = self._webrtcvad_detect(audio_data)
|
|
now = asyncio.get_event_loop().time()
|
|
|
|
# Update consecutive streaks
|
|
if has_voice:
|
|
self._vad_voice_streak += 1
|
|
self._vad_silence_streak = 0
|
|
else:
|
|
self._vad_silence_streak += 1
|
|
self._vad_voice_streak = 0
|
|
|
|
if has_voice:
|
|
# Cancel any pending finish
|
|
if self._vad_finish_task and not self._vad_finish_task.done():
|
|
self._vad_finish_task.cancel()
|
|
self._vad_finish_task = None
|
|
|
|
if not self._vad_speaking and self._vad_voice_streak >= self.VAD_SPEECH_CHUNKS:
|
|
# Speech just started — connect ASR
|
|
self._vad_speaking = True
|
|
logger.info(f"[VoiceLite] VAD: speech started (webrtcvad), connecting ASR...")
|
|
try:
|
|
await self._connect_asr()
|
|
# Send buffered pre-speech audio
|
|
for buffered in self._pre_buffer:
|
|
await self._asr_client.send_audio(buffered)
|
|
self._pre_buffer.clear()
|
|
except Exception as e:
|
|
logger.error(f"[VoiceLite] VAD: ASR connect failed: {e}", exc_info=True)
|
|
self._vad_speaking = False
|
|
return
|
|
|
|
# Send current chunk
|
|
if self._asr_client:
|
|
try:
|
|
await self._asr_client.send_audio(audio_data)
|
|
except Exception:
|
|
pass
|
|
|
|
self._vad_silence_start = 0
|
|
else:
|
|
if self._vad_speaking:
|
|
# Brief silence while speaking — keep sending for ASR context
|
|
if self._asr_client:
|
|
try:
|
|
await self._asr_client.send_audio(audio_data)
|
|
except Exception:
|
|
pass
|
|
|
|
if self._vad_silence_start == 0:
|
|
self._vad_silence_start = now
|
|
|
|
# Require both consecutive silent chunks AND time threshold
|
|
if (self._vad_silence_streak >= self.VAD_SILENCE_CHUNKS
|
|
and (now - self._vad_silence_start) >= self.VAD_SILENCE_DURATION):
|
|
if not self._vad_finish_task or self._vad_finish_task.done():
|
|
self._vad_finish_task = asyncio.create_task(self._vad_send_finish())
|
|
else:
|
|
# Not speaking — buffer recent audio for pre-speech context
|
|
self._pre_buffer.append(audio_data)
|
|
if len(self._pre_buffer) > self.VAD_PRE_BUFFER_SIZE:
|
|
self._pre_buffer.pop(0)
|
|
|
|
async def _vad_send_finish(self) -> None:
|
|
"""Send finish signal to ASR after silence detected."""
|
|
logger.info(f"[VoiceLite] VAD: silence detected, sending finish to ASR")
|
|
self._vad_speaking = False
|
|
self._vad_silence_start = 0
|
|
self._vad_voice_streak = 0
|
|
self._vad_silence_streak = 0
|
|
if self._asr_client:
|
|
try:
|
|
await self._asr_client.send_finish()
|
|
except Exception as e:
|
|
logger.warning(f"[VoiceLite] VAD: send_finish failed: {e}")
|
|
|
|
async def handle_text(self, text: str) -> None:
|
|
"""Handle direct text input - bypass ASR and go straight to agent."""
|
|
if not self._running:
|
|
return
|
|
await self._interrupt_current()
|
|
self._agent_task = asyncio.create_task(self._process_utterance(text))
|
|
|
|
async def _connect_asr(self) -> None:
|
|
"""Create and connect a new ASR client, start receive loop."""
|
|
if self._asr_client:
|
|
try:
|
|
await self._asr_client.close()
|
|
except Exception:
|
|
pass
|
|
if self._asr_receive_task and not self._asr_receive_task.done():
|
|
self._asr_receive_task.cancel()
|
|
|
|
self._asr_client = StreamingASRClient(uid=self.user_identifier or "voice_lite")
|
|
await self._asr_client.connect()
|
|
logger.info(f"[VoiceLite] ASR client connected")
|
|
|
|
# Start receive loop for this ASR session
|
|
self._asr_receive_task = asyncio.create_task(self._asr_receive_loop())
|
|
|
|
async def _asr_receive_loop(self) -> None:
|
|
"""Receive ASR results from the current ASR session."""
|
|
try:
|
|
_listening_emitted = False
|
|
async for text, is_final in self._asr_client.receive_results():
|
|
if not self._running:
|
|
return
|
|
|
|
if not _listening_emitted:
|
|
await self._emit_status("listening")
|
|
_listening_emitted = True
|
|
|
|
now = asyncio.get_event_loop().time()
|
|
self._last_asr_time = now
|
|
self._current_asr_text = text
|
|
|
|
# Track text changes for stability detection
|
|
if text != self._last_changed_text:
|
|
self._last_changed_text = text
|
|
self._last_text_change_time = now
|
|
|
|
# Reset stability timer on every result
|
|
self._reset_silence_timer()
|
|
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as e:
|
|
if self._running:
|
|
logger.warning(f"[VoiceLite] ASR session ended: {e}")
|
|
finally:
|
|
logger.info(f"[VoiceLite] ASR session done")
|
|
# Clean up ASR client after session ends
|
|
if self._asr_client:
|
|
try:
|
|
await self._asr_client.close()
|
|
except Exception:
|
|
pass
|
|
self._asr_client = None
|
|
|
|
def _reset_silence_timer(self) -> None:
|
|
"""Reset the silence timeout timer."""
|
|
if self._silence_timer_task and not self._silence_timer_task.done():
|
|
self._silence_timer_task.cancel()
|
|
self._silence_timer_task = asyncio.create_task(self._silence_timeout())
|
|
|
|
async def _silence_timeout(self) -> None:
|
|
"""Wait for silence timeout, then check if text has been stable."""
|
|
try:
|
|
await asyncio.sleep(VOICE_LITE_SILENCE_TIMEOUT)
|
|
if not self._running:
|
|
return
|
|
# Check if text has been stable (unchanged) for the timeout period
|
|
now = asyncio.get_event_loop().time()
|
|
if (self._current_asr_text
|
|
and (now - self._last_text_change_time) >= VOICE_LITE_SILENCE_TIMEOUT):
|
|
logger.info(f"[VoiceLite] Text stable for {VOICE_LITE_SILENCE_TIMEOUT}s, processing: '{self._current_asr_text}'")
|
|
await self._on_utterance_complete(self._current_asr_text)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
async def _on_utterance_complete(self, text: str) -> None:
|
|
"""Called when a complete utterance is detected."""
|
|
if not text.strip():
|
|
return
|
|
|
|
async with self._utterance_lock:
|
|
# Cancel silence timer
|
|
if self._silence_timer_task and not self._silence_timer_task.done():
|
|
self._silence_timer_task.cancel()
|
|
|
|
# Interrupt any in-progress agent+TTS
|
|
await self._interrupt_current()
|
|
|
|
# Send final ASR text to frontend
|
|
if self._on_asr_text:
|
|
await self._on_asr_text(text)
|
|
|
|
self._current_asr_text = ""
|
|
self._last_changed_text = ""
|
|
self._agent_task = asyncio.create_task(self._process_utterance(text))
|
|
|
|
async def _interrupt_current(self) -> None:
|
|
"""Cancel current agent+TTS task if running."""
|
|
if self._agent_task and not self._agent_task.done():
|
|
logger.info(f"[VoiceLite] Interrupting previous agent task")
|
|
from utils.cancel_manager import trigger_cancel
|
|
trigger_cancel(self.session_id)
|
|
self._agent_task.cancel()
|
|
try:
|
|
await self._agent_task
|
|
except (asyncio.CancelledError, Exception):
|
|
pass
|
|
self._agent_task = None
|
|
|
|
async def _process_utterance(self, text: str) -> None:
|
|
"""Process a complete utterance: agent -> TTS pipeline."""
|
|
try:
|
|
logger.info(f"[VoiceLite] Processing utterance: '{text}'")
|
|
await self._emit_status("thinking")
|
|
|
|
accumulated_text = []
|
|
tag_filter = StreamTagFilter()
|
|
splitter = TTSSentenceSplitter()
|
|
tts_client = StreamingTTSClient(speaker=self._speaker)
|
|
speaking = False
|
|
|
|
async for chunk in stream_v3_agent(
|
|
user_text=text,
|
|
bot_id=self.bot_id,
|
|
bot_config=self._bot_config,
|
|
session_id=self.session_id,
|
|
user_identifier=self.user_identifier,
|
|
):
|
|
accumulated_text.append(chunk)
|
|
|
|
if self._on_agent_stream:
|
|
await self._on_agent_stream(chunk)
|
|
|
|
passthrough = tag_filter.feed(chunk)
|
|
|
|
if not passthrough:
|
|
if tag_filter.answer_ended:
|
|
for sentence in splitter.flush():
|
|
sentence = clean_markdown(sentence)
|
|
if sentence:
|
|
if not speaking:
|
|
await self._emit_status("speaking")
|
|
speaking = True
|
|
await self._send_tts(tts_client, sentence)
|
|
continue
|
|
|
|
# Feed raw passthrough to splitter (preserve newlines for splitting),
|
|
# apply clean_markdown on output sentences
|
|
for sentence in splitter.feed(passthrough):
|
|
sentence = clean_markdown(sentence)
|
|
if sentence:
|
|
if not speaking:
|
|
await self._emit_status("speaking")
|
|
speaking = True
|
|
await self._send_tts(tts_client, sentence)
|
|
|
|
# Handle remaining text
|
|
for sentence in splitter.flush():
|
|
sentence = clean_markdown(sentence)
|
|
if sentence:
|
|
if not speaking:
|
|
await self._emit_status("speaking")
|
|
speaking = True
|
|
await self._send_tts(tts_client, sentence)
|
|
|
|
# Log full agent result (not sent to frontend, already streamed)
|
|
full_result = "".join(accumulated_text)
|
|
logger.info(f"[VoiceLite] Agent done ({len(full_result)} chars)")
|
|
|
|
# Notify frontend that agent text stream is complete
|
|
if self._on_agent_result:
|
|
await self._on_agent_result(full_result)
|
|
|
|
await self._emit_status("idle")
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info(f"[VoiceLite] Agent task cancelled (user interrupted)")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"[VoiceLite] Error processing utterance: {e}", exc_info=True)
|
|
await self._emit_error(f"Processing failed: {str(e)}")
|
|
|
|
async def _send_tts(self, tts_client: StreamingTTSClient, sentence: str) -> None:
|
|
"""Synthesize a sentence and emit audio chunks."""
|
|
logger.info(f"[VoiceLite] TTS sentence: '{sentence[:80]}'")
|
|
async for audio_chunk in tts_client.synthesize(sentence):
|
|
if self._on_audio:
|
|
await self._on_audio(audio_chunk)
|
|
|
|
async def _emit_status(self, status: str) -> None:
|
|
if self._on_status:
|
|
await self._on_status(status)
|
|
|
|
async def _emit_error(self, message: str) -> None:
|
|
if self._on_error:
|
|
await self._on_error(message)
|