571 lines
23 KiB
Python
571 lines
23 KiB
Python
import asyncio
|
||
import logging
|
||
import struct
|
||
import uuid
|
||
from typing import Optional, Callable, Awaitable
|
||
|
||
import webrtcvad
|
||
|
||
from services.streaming_asr_client import StreamingASRClient
|
||
from services.streaming_tts_client import StreamingTTSClient
|
||
from services.voice_utils import (
|
||
StreamTagFilter,
|
||
clean_markdown,
|
||
stream_v3_agent,
|
||
TTSSentenceSplitter,
|
||
)
|
||
from utils.settings import VOICE_LITE_SILENCE_TIMEOUT
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
|
||
class VoiceLiteSession:
|
||
"""Voice Lite session: ASR -> Agent -> TTS pipeline (cheaper alternative)."""
|
||
|
||
def __init__(
|
||
self,
|
||
bot_id: str,
|
||
session_id: Optional[str] = None,
|
||
user_identifier: Optional[str] = None,
|
||
sample_rate: int = 24000,
|
||
on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None,
|
||
on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None,
|
||
on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None,
|
||
on_agent_stream: Optional[Callable[[str], Awaitable[None]]] = None,
|
||
on_llm_text: Optional[Callable[[str], Awaitable[None]]] = None,
|
||
on_status: Optional[Callable[[str], Awaitable[None]]] = None,
|
||
on_error: Optional[Callable[[str], Awaitable[None]]] = None,
|
||
get_pending_message: Optional[Callable[[], Awaitable[Optional[str]]]] = None,
|
||
):
|
||
self.bot_id = bot_id
|
||
self.session_id = session_id or str(uuid.uuid4())
|
||
self.user_identifier = user_identifier or ""
|
||
self._client_sample_rate = sample_rate
|
||
|
||
self._bot_config: dict = {}
|
||
self._speaker: str = ""
|
||
|
||
# Callbacks
|
||
self._on_audio = on_audio
|
||
self._on_asr_text = on_asr_text
|
||
self._on_agent_result = on_agent_result
|
||
self._on_agent_stream = on_agent_stream
|
||
self._on_llm_text = on_llm_text
|
||
self._on_status = on_status
|
||
self._on_error = on_error
|
||
self._get_pending_message = get_pending_message
|
||
|
||
self._running = False
|
||
self._status: str = "ready" # Current session status
|
||
self._idle_check_task: Optional[asyncio.Task] = None
|
||
self._asr_client: Optional[StreamingASRClient] = None
|
||
self._asr_receive_task: Optional[asyncio.Task] = None
|
||
self._agent_task: Optional[asyncio.Task] = None
|
||
|
||
# Silence timeout tracking
|
||
self._last_asr_time: float = 0
|
||
self._silence_timer_task: Optional[asyncio.Task] = None
|
||
self._current_asr_text: str = ""
|
||
self._last_text_change_time: float = 0
|
||
self._last_changed_text: str = ""
|
||
self._last_asr_emit_time: float = 0
|
||
self._utterance_lock = asyncio.Lock()
|
||
|
||
# VAD (Voice Activity Detection) via webrtcvad
|
||
self._vad = webrtcvad.Vad(2) # aggressiveness 0-3 (2 = balanced)
|
||
self._vad_speaking = False # Whether user is currently speaking
|
||
self._vad_silence_start: float = 0 # When silence started
|
||
self._vad_finish_task: Optional[asyncio.Task] = None
|
||
self._pre_buffer: list = [] # Buffer audio before VAD triggers
|
||
self._vad_voice_streak: int = 0 # Consecutive voiced chunks count
|
||
self._vad_silence_streak: int = 0 # Consecutive silent chunks count
|
||
|
||
async def start(self) -> None:
|
||
"""Fetch bot config, mark session as running."""
|
||
from utils.fastapi_utils import fetch_bot_config_from_db
|
||
|
||
self._bot_config = await fetch_bot_config_from_db(self.bot_id, self.user_identifier)
|
||
self._speaker = self._bot_config.get("voice_speaker", "")
|
||
|
||
self._running = True
|
||
await self._emit_status("ready")
|
||
|
||
# Start idle check task for broadcast messages
|
||
if self._get_pending_message:
|
||
self._idle_check_task = asyncio.create_task(self._idle_check_loop())
|
||
|
||
async def stop(self) -> None:
|
||
"""Gracefully stop the session."""
|
||
self._running = False
|
||
|
||
if self._idle_check_task and not self._idle_check_task.done():
|
||
self._idle_check_task.cancel()
|
||
|
||
if self._vad_finish_task and not self._vad_finish_task.done():
|
||
self._vad_finish_task.cancel()
|
||
|
||
if self._silence_timer_task and not self._silence_timer_task.done():
|
||
self._silence_timer_task.cancel()
|
||
|
||
if self._agent_task and not self._agent_task.done():
|
||
from utils.cancel_manager import trigger_cancel
|
||
trigger_cancel(self.session_id)
|
||
self._agent_task.cancel()
|
||
|
||
if self._asr_receive_task and not self._asr_receive_task.done():
|
||
self._asr_receive_task.cancel()
|
||
|
||
if self._asr_client:
|
||
try:
|
||
await self._asr_client.send_finish()
|
||
except Exception:
|
||
pass
|
||
await self._asr_client.close()
|
||
|
||
# VAD configuration
|
||
VAD_SILENCE_DURATION = 3.0 # Seconds of silence before sending finish
|
||
VAD_PRE_BUFFER_SIZE = 5 # Number of audio chunks to buffer before VAD triggers
|
||
VAD_SOURCE_RATE = 24000 # Input audio sample rate
|
||
VAD_TARGET_RATE = 16000 # webrtcvad supported sample rate
|
||
VAD_FRAME_DURATION_MS = 30 # Frame duration for webrtcvad (10, 20, or 30 ms)
|
||
VAD_SPEECH_CHUNKS = 3 # Consecutive voiced chunks required to start speech
|
||
VAD_SILENCE_CHUNKS = 5 # Consecutive silent chunks required to confirm silence
|
||
|
||
_audio_chunk_count = 0
|
||
|
||
@staticmethod
|
||
def _resample_24k_to_16k(pcm_data: bytes) -> bytes:
|
||
"""Downsample 16-bit PCM from 24kHz to 16kHz (ratio 3:2).
|
||
|
||
Takes every 2 out of 3 samples (simple decimation).
|
||
"""
|
||
n_samples = len(pcm_data) // 2
|
||
if n_samples == 0:
|
||
return b''
|
||
samples = struct.unpack(f'<{n_samples}h', pcm_data[:n_samples * 2])
|
||
# Pick samples at indices 0, 1.5, 3, 4.5, ... -> floor(i * 3/2) for output index i
|
||
out_len = (n_samples * 2) // 3
|
||
resampled = []
|
||
for i in range(out_len):
|
||
src_idx = (i * 3) // 2
|
||
if src_idx < n_samples:
|
||
resampled.append(samples[src_idx])
|
||
return struct.pack(f'<{len(resampled)}h', *resampled)
|
||
|
||
@staticmethod
|
||
def _resample_16k_to_24k(pcm_data: bytes) -> bytes:
|
||
"""Upsample 16-bit PCM from 16kHz to 24kHz (ratio 2:3).
|
||
|
||
For every 2 input samples, produces 3 output samples using linear interpolation.
|
||
"""
|
||
n_samples = len(pcm_data) // 2
|
||
if n_samples == 0:
|
||
return b''
|
||
samples = struct.unpack(f'<{n_samples}h', pcm_data[:n_samples * 2])
|
||
out_len = (n_samples * 3) // 2
|
||
resampled = []
|
||
for i in range(out_len):
|
||
src_pos = (i * 2) / 3
|
||
src_idx = int(src_pos)
|
||
frac = src_pos - src_idx
|
||
if src_idx + 1 < n_samples:
|
||
val = int(samples[src_idx] * (1 - frac) + samples[src_idx + 1] * frac)
|
||
elif src_idx < n_samples:
|
||
val = samples[src_idx]
|
||
else:
|
||
break
|
||
resampled.append(max(-32768, min(32767, val)))
|
||
return struct.pack(f'<{len(resampled)}h', *resampled)
|
||
|
||
def _resample_input(self, audio_data: bytes) -> bytes:
|
||
"""Resample incoming audio to 24kHz if needed."""
|
||
if self._client_sample_rate == 16000:
|
||
return self._resample_16k_to_24k(audio_data)
|
||
return audio_data
|
||
|
||
def _resample_output(self, audio_data: bytes) -> bytes:
|
||
"""Resample outgoing audio from 24kHz to client sample rate if needed."""
|
||
if self._client_sample_rate == 16000:
|
||
return self._resample_24k_to_16k(audio_data)
|
||
return audio_data
|
||
|
||
def _webrtcvad_detect(self, pcm_data: bytes) -> bool:
|
||
"""Run webrtcvad on audio data. Returns True if voice is detected in any frame.
|
||
|
||
Expects 24kHz PCM input, resamples to 16kHz for webrtcvad.
|
||
"""
|
||
resampled = self._resample_24k_to_16k(pcm_data)
|
||
return self._webrtcvad_check(resampled)
|
||
|
||
def _webrtcvad_detect_16k(self, pcm_data: bytes) -> bool:
|
||
"""Run webrtcvad directly on 16kHz PCM data (no resampling needed)."""
|
||
return self._webrtcvad_check(pcm_data)
|
||
|
||
def _webrtcvad_check(self, pcm_16k: bytes) -> bool:
|
||
"""Core webrtcvad check on 16kHz PCM data."""
|
||
frame_size = (self.VAD_TARGET_RATE * self.VAD_FRAME_DURATION_MS // 1000) * 2 # bytes per frame
|
||
if len(pcm_16k) < frame_size:
|
||
return False
|
||
voice_frames = 0
|
||
total_frames = 0
|
||
for offset in range(0, len(pcm_16k) - frame_size + 1, frame_size):
|
||
frame = pcm_16k[offset:offset + frame_size]
|
||
total_frames += 1
|
||
try:
|
||
if self._vad.is_speech(frame, self.VAD_TARGET_RATE):
|
||
voice_frames += 1
|
||
except Exception:
|
||
pass
|
||
return voice_frames > 0
|
||
|
||
async def handle_audio(self, audio_data: bytes) -> None:
|
||
"""Forward user audio to ASR with VAD gating. Lazy-connect on speech start."""
|
||
if not self._running:
|
||
return
|
||
|
||
# VAD 检测:直接在原始采样率上做,避免双重重采样导致精度损失
|
||
if self._client_sample_rate == 16000:
|
||
has_voice = self._webrtcvad_detect_16k(audio_data)
|
||
else:
|
||
has_voice = self._webrtcvad_detect(audio_data)
|
||
|
||
# 上采样后的音频用于 ASR(ASR 需要 24kHz)
|
||
audio_for_asr = self._resample_input(audio_data)
|
||
|
||
self._audio_chunk_count += 1
|
||
now = asyncio.get_event_loop().time()
|
||
|
||
# Update consecutive streaks
|
||
if has_voice:
|
||
self._vad_voice_streak += 1
|
||
self._vad_silence_streak = 0
|
||
else:
|
||
self._vad_silence_streak += 1
|
||
self._vad_voice_streak = 0
|
||
|
||
if has_voice:
|
||
# Cancel any pending finish
|
||
if self._vad_finish_task and not self._vad_finish_task.done():
|
||
self._vad_finish_task.cancel()
|
||
self._vad_finish_task = None
|
||
|
||
if not self._vad_speaking and self._vad_voice_streak >= self.VAD_SPEECH_CHUNKS:
|
||
# Speech just started — connect ASR
|
||
self._vad_speaking = True
|
||
logger.info(f"[VoiceLite] VAD: speech started (webrtcvad), connecting ASR...")
|
||
try:
|
||
await self._connect_asr()
|
||
# Send buffered pre-speech audio
|
||
for buffered in self._pre_buffer:
|
||
await self._asr_client.send_audio(buffered)
|
||
self._pre_buffer.clear()
|
||
except Exception as e:
|
||
logger.error(f"[VoiceLite] VAD: ASR connect failed: {e}", exc_info=True)
|
||
self._vad_speaking = False
|
||
return
|
||
|
||
# Send current chunk
|
||
if self._asr_client:
|
||
try:
|
||
await self._asr_client.send_audio(audio_for_asr)
|
||
except Exception:
|
||
pass
|
||
|
||
self._vad_silence_start = 0
|
||
else:
|
||
if self._vad_speaking:
|
||
# Brief silence while speaking — keep sending for ASR context
|
||
if self._asr_client:
|
||
try:
|
||
await self._asr_client.send_audio(audio_for_asr)
|
||
except Exception:
|
||
pass
|
||
|
||
if self._vad_silence_start == 0:
|
||
self._vad_silence_start = now
|
||
|
||
# Require both consecutive silent chunks AND time threshold
|
||
if (self._vad_silence_streak >= self.VAD_SILENCE_CHUNKS
|
||
and (now - self._vad_silence_start) >= self.VAD_SILENCE_DURATION):
|
||
if not self._vad_finish_task or self._vad_finish_task.done():
|
||
self._vad_finish_task = asyncio.create_task(self._vad_send_finish())
|
||
else:
|
||
# Not speaking — buffer recent audio for pre-speech context
|
||
self._pre_buffer.append(audio_for_asr)
|
||
if len(self._pre_buffer) > self.VAD_PRE_BUFFER_SIZE:
|
||
self._pre_buffer.pop(0)
|
||
|
||
async def _vad_send_finish(self) -> None:
|
||
"""Send finish signal to ASR after silence detected."""
|
||
logger.info(f"[VoiceLite] VAD: silence detected, sending finish to ASR")
|
||
self._vad_speaking = False
|
||
self._vad_silence_start = 0
|
||
self._vad_voice_streak = 0
|
||
self._vad_silence_streak = 0
|
||
if self._asr_client:
|
||
try:
|
||
await self._asr_client.send_finish()
|
||
except Exception as e:
|
||
logger.warning(f"[VoiceLite] VAD: send_finish failed: {e}")
|
||
|
||
async def handle_text(self, text: str) -> None:
|
||
"""Handle direct text input - bypass ASR and go straight to agent."""
|
||
if not self._running:
|
||
return
|
||
await self._interrupt_current()
|
||
self._agent_task = asyncio.create_task(self._process_utterance(text))
|
||
|
||
async def _connect_asr(self) -> None:
|
||
"""Create and connect a new ASR client, start receive loop."""
|
||
if self._asr_client:
|
||
try:
|
||
await self._asr_client.close()
|
||
except Exception:
|
||
pass
|
||
if self._asr_receive_task and not self._asr_receive_task.done():
|
||
self._asr_receive_task.cancel()
|
||
|
||
self._asr_client = StreamingASRClient(uid=self.user_identifier or "voice_lite")
|
||
await self._asr_client.connect()
|
||
logger.info(f"[VoiceLite] ASR client connected")
|
||
|
||
# Start receive loop for this ASR session
|
||
self._asr_receive_task = asyncio.create_task(self._asr_receive_loop())
|
||
|
||
async def _asr_receive_loop(self) -> None:
|
||
"""Receive ASR results from the current ASR session."""
|
||
try:
|
||
_listening_emitted = False
|
||
async for text, is_final in self._asr_client.receive_results():
|
||
if not self._running:
|
||
return
|
||
|
||
if not _listening_emitted:
|
||
await self._emit_status("listening")
|
||
_listening_emitted = True
|
||
|
||
now = asyncio.get_event_loop().time()
|
||
self._last_asr_time = now
|
||
self._current_asr_text = text
|
||
|
||
# Track text changes for stability detection
|
||
if text != self._last_changed_text:
|
||
self._last_changed_text = text
|
||
self._last_text_change_time = now
|
||
|
||
# Reset stability timer on every result
|
||
self._reset_silence_timer()
|
||
|
||
except asyncio.CancelledError:
|
||
pass
|
||
except Exception as e:
|
||
if self._running:
|
||
logger.warning(f"[VoiceLite] ASR session ended: {e}")
|
||
finally:
|
||
logger.info(f"[VoiceLite] ASR session done")
|
||
# Clean up ASR client after session ends
|
||
if self._asr_client:
|
||
try:
|
||
await self._asr_client.close()
|
||
except Exception:
|
||
pass
|
||
self._asr_client = None
|
||
|
||
def _reset_silence_timer(self) -> None:
|
||
"""Reset the silence timeout timer."""
|
||
if self._silence_timer_task and not self._silence_timer_task.done():
|
||
self._silence_timer_task.cancel()
|
||
self._silence_timer_task = asyncio.create_task(self._silence_timeout())
|
||
|
||
async def _silence_timeout(self) -> None:
|
||
"""Wait for silence timeout, then check if text has been stable."""
|
||
try:
|
||
await asyncio.sleep(VOICE_LITE_SILENCE_TIMEOUT)
|
||
if not self._running:
|
||
return
|
||
# Check if text has been stable (unchanged) for the timeout period
|
||
now = asyncio.get_event_loop().time()
|
||
if (self._current_asr_text
|
||
and (now - self._last_text_change_time) >= VOICE_LITE_SILENCE_TIMEOUT):
|
||
logger.info(f"[VoiceLite] Text stable for {VOICE_LITE_SILENCE_TIMEOUT}s, processing: '{self._current_asr_text}'")
|
||
await self._on_utterance_complete(self._current_asr_text)
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
async def _on_utterance_complete(self, text: str) -> None:
|
||
"""Called when a complete utterance is detected."""
|
||
if not text.strip():
|
||
return
|
||
|
||
async with self._utterance_lock:
|
||
# Cancel silence timer
|
||
if self._silence_timer_task and not self._silence_timer_task.done():
|
||
self._silence_timer_task.cancel()
|
||
|
||
# Reset VAD state for next utterance
|
||
self._vad_speaking = False
|
||
self._vad_silence_start = 0
|
||
self._vad_voice_streak = 0
|
||
self._vad_silence_streak = 0
|
||
if self._vad_finish_task and not self._vad_finish_task.done():
|
||
self._vad_finish_task.cancel()
|
||
self._vad_finish_task = None
|
||
|
||
# Interrupt any in-progress agent+TTS
|
||
await self._interrupt_current()
|
||
|
||
# Send final ASR text to frontend
|
||
if self._on_asr_text:
|
||
await self._on_asr_text(text)
|
||
|
||
self._current_asr_text = ""
|
||
self._last_changed_text = ""
|
||
self._agent_task = asyncio.create_task(self._process_utterance(text))
|
||
|
||
async def _interrupt_current(self) -> None:
|
||
"""Cancel current agent+TTS task if running."""
|
||
if self._agent_task and not self._agent_task.done():
|
||
logger.info(f"[VoiceLite] Interrupting previous agent task")
|
||
from utils.cancel_manager import trigger_cancel
|
||
trigger_cancel(self.session_id)
|
||
self._agent_task.cancel()
|
||
try:
|
||
await self._agent_task
|
||
except (asyncio.CancelledError, Exception):
|
||
pass
|
||
self._agent_task = None
|
||
|
||
async def _process_utterance(self, text: str) -> None:
|
||
"""Process a complete utterance: agent -> TTS pipeline."""
|
||
try:
|
||
logger.info(f"[VoiceLite] Processing utterance: '{text}'")
|
||
await self._emit_status("thinking")
|
||
|
||
accumulated_text = []
|
||
tag_filter = StreamTagFilter()
|
||
splitter = TTSSentenceSplitter()
|
||
tts_client = StreamingTTSClient(speaker=self._speaker)
|
||
speaking = False
|
||
|
||
async for chunk in stream_v3_agent(
|
||
user_text=text,
|
||
bot_id=self.bot_id,
|
||
bot_config=self._bot_config,
|
||
session_id=self.session_id,
|
||
user_identifier=self.user_identifier,
|
||
):
|
||
accumulated_text.append(chunk)
|
||
|
||
if self._on_agent_stream:
|
||
await self._on_agent_stream(chunk)
|
||
|
||
passthrough = tag_filter.feed(chunk)
|
||
|
||
if not passthrough:
|
||
if tag_filter.answer_ended:
|
||
for sentence in splitter.flush():
|
||
sentence = clean_markdown(sentence)
|
||
if sentence:
|
||
if not speaking:
|
||
await self._emit_status("speaking")
|
||
speaking = True
|
||
await self._send_tts(tts_client, sentence)
|
||
continue
|
||
|
||
# Feed raw passthrough to splitter (preserve newlines for splitting),
|
||
# apply clean_markdown on output sentences
|
||
for sentence in splitter.feed(passthrough):
|
||
sentence = clean_markdown(sentence)
|
||
if sentence:
|
||
if not speaking:
|
||
await self._emit_status("speaking")
|
||
speaking = True
|
||
await self._send_tts(tts_client, sentence)
|
||
|
||
# Handle remaining text
|
||
for sentence in splitter.flush():
|
||
sentence = clean_markdown(sentence)
|
||
if sentence:
|
||
if not speaking:
|
||
await self._emit_status("speaking")
|
||
speaking = True
|
||
await self._send_tts(tts_client, sentence)
|
||
|
||
# Log full agent result (not sent to frontend, already streamed)
|
||
full_result = "".join(accumulated_text)
|
||
logger.info(f"[VoiceLite] Agent done ({len(full_result)} chars)")
|
||
|
||
# Notify frontend that agent text stream is complete
|
||
if self._on_agent_result:
|
||
await self._on_agent_result(full_result)
|
||
|
||
await self._emit_status("idle")
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info(f"[VoiceLite] Agent task cancelled (user interrupted)")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"[VoiceLite] Error processing utterance: {e}", exc_info=True)
|
||
await self._emit_error(f"Processing failed: {str(e)}")
|
||
|
||
async def _send_tts(self, tts_client: StreamingTTSClient, sentence: str) -> None:
|
||
"""Synthesize a sentence and emit audio chunks."""
|
||
logger.info(f"[VoiceLite] TTS sentence: '{sentence[:80]}'")
|
||
if self._client_sample_rate != 24000:
|
||
# Client needs non-24kHz: use raw int16 pipeline to allow resampling
|
||
async for audio_chunk in tts_client.synthesize_raw(sentence):
|
||
if self._on_audio:
|
||
await self._on_audio(self._resample_output(audio_chunk))
|
||
else:
|
||
async for audio_chunk in tts_client.synthesize(sentence):
|
||
if self._on_audio:
|
||
await self._on_audio(audio_chunk)
|
||
|
||
async def _emit_status(self, status: str) -> None:
|
||
self._status = status
|
||
if self._on_status:
|
||
await self._on_status(status)
|
||
|
||
async def _emit_error(self, message: str) -> None:
|
||
if self._on_error:
|
||
await self._on_error(message)
|
||
|
||
async def _idle_check_loop(self) -> None:
|
||
"""Background task: check and play pending broadcast messages when idle."""
|
||
while self._running:
|
||
try:
|
||
await asyncio.sleep(1.0) # Check every second
|
||
# Check in both "ready" and "idle" states
|
||
if self._status in ("ready", "idle") and self._get_pending_message:
|
||
msg = await self._get_pending_message()
|
||
if msg:
|
||
await self.speak_text(msg)
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as e:
|
||
logger.warning(f"[VoiceLite] Idle check error: {e}")
|
||
|
||
async def speak_text(self, text: str) -> None:
|
||
"""Play text directly via TTS (skip agent, used for broadcast messages)."""
|
||
if not text.strip():
|
||
return
|
||
|
||
logger.info(f"[VoiceLite] Broadcasting: '{text[:80]}'")
|
||
await self._emit_status("speaking")
|
||
|
||
try:
|
||
tts_client = StreamingTTSClient(speaker=self._speaker)
|
||
if self._client_sample_rate != 24000:
|
||
async for audio_chunk in tts_client.synthesize_raw(text):
|
||
if self._on_audio:
|
||
await self._on_audio(self._resample_output(audio_chunk))
|
||
else:
|
||
async for audio_chunk in tts_client.synthesize(text):
|
||
if self._on_audio:
|
||
await self._on_audio(audio_chunk)
|
||
except Exception as e:
|
||
logger.error(f"[VoiceLite] Broadcast TTS error: {e}", exc_info=True)
|
||
finally:
|
||
if self._running:
|
||
await self._emit_status("idle")
|