qwen_agent/services/voice_lite_session.py
2026-03-22 00:52:11 +08:00

449 lines
18 KiB
Python

import asyncio
import logging
import struct
import uuid
from typing import Optional, Callable, Awaitable
import webrtcvad
from services.streaming_asr_client import StreamingASRClient
from services.streaming_tts_client import StreamingTTSClient
from services.voice_utils import (
StreamTagFilter,
clean_markdown,
stream_v3_agent,
TTSSentenceSplitter,
)
from utils.settings import VOICE_LITE_SILENCE_TIMEOUT
logger = logging.getLogger('app')
class VoiceLiteSession:
"""Voice Lite session: ASR -> Agent -> TTS pipeline (cheaper alternative)."""
def __init__(
self,
bot_id: str,
session_id: Optional[str] = None,
user_identifier: Optional[str] = None,
on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None,
on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None,
on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None,
on_agent_stream: Optional[Callable[[str], Awaitable[None]]] = None,
on_llm_text: Optional[Callable[[str], Awaitable[None]]] = None,
on_status: Optional[Callable[[str], Awaitable[None]]] = None,
on_error: Optional[Callable[[str], Awaitable[None]]] = None,
):
self.bot_id = bot_id
self.session_id = session_id or str(uuid.uuid4())
self.user_identifier = user_identifier or ""
self._bot_config: dict = {}
self._speaker: str = ""
# Callbacks
self._on_audio = on_audio
self._on_asr_text = on_asr_text
self._on_agent_result = on_agent_result
self._on_agent_stream = on_agent_stream
self._on_llm_text = on_llm_text
self._on_status = on_status
self._on_error = on_error
self._running = False
self._asr_client: Optional[StreamingASRClient] = None
self._asr_receive_task: Optional[asyncio.Task] = None
self._agent_task: Optional[asyncio.Task] = None
# Silence timeout tracking
self._last_asr_time: float = 0
self._silence_timer_task: Optional[asyncio.Task] = None
self._current_asr_text: str = ""
self._last_text_change_time: float = 0
self._last_changed_text: str = ""
self._last_asr_emit_time: float = 0
self._utterance_lock = asyncio.Lock()
# VAD (Voice Activity Detection) via webrtcvad
self._vad = webrtcvad.Vad(2) # aggressiveness 0-3 (2 = balanced)
self._vad_speaking = False # Whether user is currently speaking
self._vad_silence_start: float = 0 # When silence started
self._vad_finish_task: Optional[asyncio.Task] = None
self._pre_buffer: list = [] # Buffer audio before VAD triggers
self._vad_voice_streak: int = 0 # Consecutive voiced chunks count
self._vad_silence_streak: int = 0 # Consecutive silent chunks count
async def start(self) -> None:
"""Fetch bot config, mark session as running."""
from utils.fastapi_utils import fetch_bot_config_from_db
self._bot_config = await fetch_bot_config_from_db(self.bot_id, self.user_identifier)
self._speaker = self._bot_config.get("voice_speaker", "")
self._running = True
await self._emit_status("ready")
async def stop(self) -> None:
"""Gracefully stop the session."""
self._running = False
if self._vad_finish_task and not self._vad_finish_task.done():
self._vad_finish_task.cancel()
if self._silence_timer_task and not self._silence_timer_task.done():
self._silence_timer_task.cancel()
if self._agent_task and not self._agent_task.done():
from utils.cancel_manager import trigger_cancel
trigger_cancel(self.session_id)
self._agent_task.cancel()
if self._asr_receive_task and not self._asr_receive_task.done():
self._asr_receive_task.cancel()
if self._asr_client:
try:
await self._asr_client.send_finish()
except Exception:
pass
await self._asr_client.close()
# VAD configuration
VAD_SILENCE_DURATION = 1.5 # Seconds of silence before sending finish
VAD_PRE_BUFFER_SIZE = 5 # Number of audio chunks to buffer before VAD triggers
VAD_SOURCE_RATE = 24000 # Input audio sample rate
VAD_TARGET_RATE = 16000 # webrtcvad supported sample rate
VAD_FRAME_DURATION_MS = 30 # Frame duration for webrtcvad (10, 20, or 30 ms)
VAD_SPEECH_CHUNKS = 3 # Consecutive voiced chunks required to start speech
VAD_SILENCE_CHUNKS = 5 # Consecutive silent chunks required to confirm silence
_audio_chunk_count = 0
@staticmethod
def _resample_24k_to_16k(pcm_data: bytes) -> bytes:
"""Downsample 16-bit PCM from 24kHz to 16kHz (ratio 3:2).
Takes every 2 out of 3 samples (simple decimation).
"""
n_samples = len(pcm_data) // 2
if n_samples == 0:
return b''
samples = struct.unpack(f'<{n_samples}h', pcm_data[:n_samples * 2])
# Pick samples at indices 0, 1.5, 3, 4.5, ... -> floor(i * 3/2) for output index i
out_len = (n_samples * 2) // 3
resampled = []
for i in range(out_len):
src_idx = (i * 3) // 2
if src_idx < n_samples:
resampled.append(samples[src_idx])
return struct.pack(f'<{len(resampled)}h', *resampled)
def _webrtcvad_detect(self, pcm_data: bytes) -> bool:
"""Run webrtcvad on audio data. Returns True if voice is detected in any frame."""
resampled = self._resample_24k_to_16k(pcm_data)
frame_size = (self.VAD_TARGET_RATE * self.VAD_FRAME_DURATION_MS // 1000) * 2 # bytes per frame
if len(resampled) < frame_size:
return False
# Check frames; return True if any frame has voice
voice_frames = 0
total_frames = 0
for offset in range(0, len(resampled) - frame_size + 1, frame_size):
frame = resampled[offset:offset + frame_size]
total_frames += 1
try:
if self._vad.is_speech(frame, self.VAD_TARGET_RATE):
voice_frames += 1
except Exception:
pass
# Consider voice detected if at least one frame has speech
return voice_frames > 0
async def handle_audio(self, audio_data: bytes) -> None:
"""Forward user audio to ASR with VAD gating. Lazy-connect on speech start."""
if not self._running:
return
self._audio_chunk_count += 1
has_voice = self._webrtcvad_detect(audio_data)
now = asyncio.get_event_loop().time()
# Update consecutive streaks
if has_voice:
self._vad_voice_streak += 1
self._vad_silence_streak = 0
else:
self._vad_silence_streak += 1
self._vad_voice_streak = 0
if has_voice:
# Cancel any pending finish
if self._vad_finish_task and not self._vad_finish_task.done():
self._vad_finish_task.cancel()
self._vad_finish_task = None
if not self._vad_speaking and self._vad_voice_streak >= self.VAD_SPEECH_CHUNKS:
# Speech just started — connect ASR
self._vad_speaking = True
logger.info(f"[VoiceLite] VAD: speech started (webrtcvad), connecting ASR...")
try:
await self._connect_asr()
# Send buffered pre-speech audio
for buffered in self._pre_buffer:
await self._asr_client.send_audio(buffered)
self._pre_buffer.clear()
except Exception as e:
logger.error(f"[VoiceLite] VAD: ASR connect failed: {e}", exc_info=True)
self._vad_speaking = False
return
# Send current chunk
if self._asr_client:
try:
await self._asr_client.send_audio(audio_data)
except Exception:
pass
self._vad_silence_start = 0
else:
if self._vad_speaking:
# Brief silence while speaking — keep sending for ASR context
if self._asr_client:
try:
await self._asr_client.send_audio(audio_data)
except Exception:
pass
if self._vad_silence_start == 0:
self._vad_silence_start = now
# Require both consecutive silent chunks AND time threshold
if (self._vad_silence_streak >= self.VAD_SILENCE_CHUNKS
and (now - self._vad_silence_start) >= self.VAD_SILENCE_DURATION):
if not self._vad_finish_task or self._vad_finish_task.done():
self._vad_finish_task = asyncio.create_task(self._vad_send_finish())
else:
# Not speaking — buffer recent audio for pre-speech context
self._pre_buffer.append(audio_data)
if len(self._pre_buffer) > self.VAD_PRE_BUFFER_SIZE:
self._pre_buffer.pop(0)
async def _vad_send_finish(self) -> None:
"""Send finish signal to ASR after silence detected."""
logger.info(f"[VoiceLite] VAD: silence detected, sending finish to ASR")
self._vad_speaking = False
self._vad_silence_start = 0
self._vad_voice_streak = 0
self._vad_silence_streak = 0
if self._asr_client:
try:
await self._asr_client.send_finish()
except Exception as e:
logger.warning(f"[VoiceLite] VAD: send_finish failed: {e}")
async def handle_text(self, text: str) -> None:
"""Handle direct text input - bypass ASR and go straight to agent."""
if not self._running:
return
await self._interrupt_current()
self._agent_task = asyncio.create_task(self._process_utterance(text))
async def _connect_asr(self) -> None:
"""Create and connect a new ASR client, start receive loop."""
if self._asr_client:
try:
await self._asr_client.close()
except Exception:
pass
if self._asr_receive_task and not self._asr_receive_task.done():
self._asr_receive_task.cancel()
self._asr_client = StreamingASRClient(uid=self.user_identifier or "voice_lite")
await self._asr_client.connect()
logger.info(f"[VoiceLite] ASR client connected")
# Start receive loop for this ASR session
self._asr_receive_task = asyncio.create_task(self._asr_receive_loop())
async def _asr_receive_loop(self) -> None:
"""Receive ASR results from the current ASR session."""
try:
_listening_emitted = False
async for text, is_final in self._asr_client.receive_results():
if not self._running:
return
if not _listening_emitted:
await self._emit_status("listening")
_listening_emitted = True
now = asyncio.get_event_loop().time()
self._last_asr_time = now
self._current_asr_text = text
# Track text changes for stability detection
if text != self._last_changed_text:
self._last_changed_text = text
self._last_text_change_time = now
# Reset stability timer on every result
self._reset_silence_timer()
except asyncio.CancelledError:
pass
except Exception as e:
if self._running:
logger.warning(f"[VoiceLite] ASR session ended: {e}")
finally:
logger.info(f"[VoiceLite] ASR session done")
# Clean up ASR client after session ends
if self._asr_client:
try:
await self._asr_client.close()
except Exception:
pass
self._asr_client = None
def _reset_silence_timer(self) -> None:
"""Reset the silence timeout timer."""
if self._silence_timer_task and not self._silence_timer_task.done():
self._silence_timer_task.cancel()
self._silence_timer_task = asyncio.create_task(self._silence_timeout())
async def _silence_timeout(self) -> None:
"""Wait for silence timeout, then check if text has been stable."""
try:
await asyncio.sleep(VOICE_LITE_SILENCE_TIMEOUT)
if not self._running:
return
# Check if text has been stable (unchanged) for the timeout period
now = asyncio.get_event_loop().time()
if (self._current_asr_text
and (now - self._last_text_change_time) >= VOICE_LITE_SILENCE_TIMEOUT):
logger.info(f"[VoiceLite] Text stable for {VOICE_LITE_SILENCE_TIMEOUT}s, processing: '{self._current_asr_text}'")
await self._on_utterance_complete(self._current_asr_text)
except asyncio.CancelledError:
pass
async def _on_utterance_complete(self, text: str) -> None:
"""Called when a complete utterance is detected."""
if not text.strip():
return
async with self._utterance_lock:
# Cancel silence timer
if self._silence_timer_task and not self._silence_timer_task.done():
self._silence_timer_task.cancel()
# Interrupt any in-progress agent+TTS
await self._interrupt_current()
# Send final ASR text to frontend
if self._on_asr_text:
await self._on_asr_text(text)
self._current_asr_text = ""
self._last_changed_text = ""
self._agent_task = asyncio.create_task(self._process_utterance(text))
async def _interrupt_current(self) -> None:
"""Cancel current agent+TTS task if running."""
if self._agent_task and not self._agent_task.done():
logger.info(f"[VoiceLite] Interrupting previous agent task")
from utils.cancel_manager import trigger_cancel
trigger_cancel(self.session_id)
self._agent_task.cancel()
try:
await self._agent_task
except (asyncio.CancelledError, Exception):
pass
self._agent_task = None
async def _process_utterance(self, text: str) -> None:
"""Process a complete utterance: agent -> TTS pipeline."""
try:
logger.info(f"[VoiceLite] Processing utterance: '{text}'")
await self._emit_status("thinking")
accumulated_text = []
tag_filter = StreamTagFilter()
splitter = TTSSentenceSplitter()
tts_client = StreamingTTSClient(speaker=self._speaker)
speaking = False
async for chunk in stream_v3_agent(
user_text=text,
bot_id=self.bot_id,
bot_config=self._bot_config,
session_id=self.session_id,
user_identifier=self.user_identifier,
):
accumulated_text.append(chunk)
if self._on_agent_stream:
await self._on_agent_stream(chunk)
passthrough = tag_filter.feed(chunk)
if not passthrough:
if tag_filter.answer_ended:
for sentence in splitter.flush():
sentence = clean_markdown(sentence)
if sentence:
if not speaking:
await self._emit_status("speaking")
speaking = True
await self._send_tts(tts_client, sentence)
continue
# Feed raw passthrough to splitter (preserve newlines for splitting),
# apply clean_markdown on output sentences
for sentence in splitter.feed(passthrough):
sentence = clean_markdown(sentence)
if sentence:
if not speaking:
await self._emit_status("speaking")
speaking = True
await self._send_tts(tts_client, sentence)
# Handle remaining text
for sentence in splitter.flush():
sentence = clean_markdown(sentence)
if sentence:
if not speaking:
await self._emit_status("speaking")
speaking = True
await self._send_tts(tts_client, sentence)
# Log full agent result (not sent to frontend, already streamed)
full_result = "".join(accumulated_text)
logger.info(f"[VoiceLite] Agent done ({len(full_result)} chars)")
# Notify frontend that agent text stream is complete
if self._on_agent_result:
await self._on_agent_result(full_result)
await self._emit_status("idle")
except asyncio.CancelledError:
logger.info(f"[VoiceLite] Agent task cancelled (user interrupted)")
raise
except Exception as e:
logger.error(f"[VoiceLite] Error processing utterance: {e}", exc_info=True)
await self._emit_error(f"Processing failed: {str(e)}")
async def _send_tts(self, tts_client: StreamingTTSClient, sentence: str) -> None:
"""Synthesize a sentence and emit audio chunks."""
logger.info(f"[VoiceLite] TTS sentence: '{sentence[:80]}'")
async for audio_chunk in tts_client.synthesize(sentence):
if self._on_audio:
await self._on_audio(audio_chunk)
async def _emit_status(self, status: str) -> None:
if self._on_status:
await self._on_status(status)
async def _emit_error(self, message: str) -> None:
if self._on_error:
await self._on_error(message)