qwen_agent/services/voice_lite_session.py
2026-03-23 17:32:07 +08:00

520 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import logging
import struct
import uuid
from typing import Optional, Callable, Awaitable
import webrtcvad
from services.streaming_asr_client import StreamingASRClient
from services.streaming_tts_client import StreamingTTSClient
from services.voice_utils import (
StreamTagFilter,
clean_markdown,
stream_v3_agent,
TTSSentenceSplitter,
)
from utils.settings import VOICE_LITE_SILENCE_TIMEOUT
logger = logging.getLogger('app')
class VoiceLiteSession:
"""Voice Lite session: ASR -> Agent -> TTS pipeline (cheaper alternative)."""
def __init__(
self,
bot_id: str,
session_id: Optional[str] = None,
user_identifier: Optional[str] = None,
sample_rate: int = 24000,
on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None,
on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None,
on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None,
on_agent_stream: Optional[Callable[[str], Awaitable[None]]] = None,
on_llm_text: Optional[Callable[[str], Awaitable[None]]] = None,
on_status: Optional[Callable[[str], Awaitable[None]]] = None,
on_error: Optional[Callable[[str], Awaitable[None]]] = None,
):
self.bot_id = bot_id
self.session_id = session_id or str(uuid.uuid4())
self.user_identifier = user_identifier or ""
self._client_sample_rate = sample_rate
self._bot_config: dict = {}
self._speaker: str = ""
# Callbacks
self._on_audio = on_audio
self._on_asr_text = on_asr_text
self._on_agent_result = on_agent_result
self._on_agent_stream = on_agent_stream
self._on_llm_text = on_llm_text
self._on_status = on_status
self._on_error = on_error
self._running = False
self._asr_client: Optional[StreamingASRClient] = None
self._asr_receive_task: Optional[asyncio.Task] = None
self._agent_task: Optional[asyncio.Task] = None
# Silence timeout tracking
self._last_asr_time: float = 0
self._silence_timer_task: Optional[asyncio.Task] = None
self._current_asr_text: str = ""
self._last_text_change_time: float = 0
self._last_changed_text: str = ""
self._last_asr_emit_time: float = 0
self._utterance_lock = asyncio.Lock()
# VAD (Voice Activity Detection) via webrtcvad
self._vad = webrtcvad.Vad(2) # aggressiveness 0-3 (2 = balanced)
self._vad_speaking = False # Whether user is currently speaking
self._vad_silence_start: float = 0 # When silence started
self._vad_finish_task: Optional[asyncio.Task] = None
self._pre_buffer: list = [] # Buffer audio before VAD triggers
self._vad_voice_streak: int = 0 # Consecutive voiced chunks count
self._vad_silence_streak: int = 0 # Consecutive silent chunks count
async def start(self) -> None:
"""Fetch bot config, mark session as running."""
from utils.fastapi_utils import fetch_bot_config_from_db
self._bot_config = await fetch_bot_config_from_db(self.bot_id, self.user_identifier)
self._speaker = self._bot_config.get("voice_speaker", "")
self._running = True
await self._emit_status("ready")
async def stop(self) -> None:
"""Gracefully stop the session."""
self._running = False
if self._vad_finish_task and not self._vad_finish_task.done():
self._vad_finish_task.cancel()
if self._silence_timer_task and not self._silence_timer_task.done():
self._silence_timer_task.cancel()
if self._agent_task and not self._agent_task.done():
from utils.cancel_manager import trigger_cancel
trigger_cancel(self.session_id)
self._agent_task.cancel()
if self._asr_receive_task and not self._asr_receive_task.done():
self._asr_receive_task.cancel()
if self._asr_client:
try:
await self._asr_client.send_finish()
except Exception:
pass
await self._asr_client.close()
# VAD configuration
VAD_SILENCE_DURATION = 3.0 # Seconds of silence before sending finish
VAD_PRE_BUFFER_SIZE = 5 # Number of audio chunks to buffer before VAD triggers
VAD_SOURCE_RATE = 24000 # Input audio sample rate
VAD_TARGET_RATE = 16000 # webrtcvad supported sample rate
VAD_FRAME_DURATION_MS = 30 # Frame duration for webrtcvad (10, 20, or 30 ms)
VAD_SPEECH_CHUNKS = 3 # Consecutive voiced chunks required to start speech
VAD_SILENCE_CHUNKS = 5 # Consecutive silent chunks required to confirm silence
_audio_chunk_count = 0
@staticmethod
def _resample_24k_to_16k(pcm_data: bytes) -> bytes:
"""Downsample 16-bit PCM from 24kHz to 16kHz (ratio 3:2).
Takes every 2 out of 3 samples (simple decimation).
"""
n_samples = len(pcm_data) // 2
if n_samples == 0:
return b''
samples = struct.unpack(f'<{n_samples}h', pcm_data[:n_samples * 2])
# Pick samples at indices 0, 1.5, 3, 4.5, ... -> floor(i * 3/2) for output index i
out_len = (n_samples * 2) // 3
resampled = []
for i in range(out_len):
src_idx = (i * 3) // 2
if src_idx < n_samples:
resampled.append(samples[src_idx])
return struct.pack(f'<{len(resampled)}h', *resampled)
@staticmethod
def _resample_16k_to_24k(pcm_data: bytes) -> bytes:
"""Upsample 16-bit PCM from 16kHz to 24kHz (ratio 2:3).
For every 2 input samples, produces 3 output samples using linear interpolation.
"""
n_samples = len(pcm_data) // 2
if n_samples == 0:
return b''
samples = struct.unpack(f'<{n_samples}h', pcm_data[:n_samples * 2])
out_len = (n_samples * 3) // 2
resampled = []
for i in range(out_len):
src_pos = (i * 2) / 3
src_idx = int(src_pos)
frac = src_pos - src_idx
if src_idx + 1 < n_samples:
val = int(samples[src_idx] * (1 - frac) + samples[src_idx + 1] * frac)
elif src_idx < n_samples:
val = samples[src_idx]
else:
break
resampled.append(max(-32768, min(32767, val)))
return struct.pack(f'<{len(resampled)}h', *resampled)
def _resample_input(self, audio_data: bytes) -> bytes:
"""Resample incoming audio to 24kHz if needed."""
if self._client_sample_rate == 16000:
return self._resample_16k_to_24k(audio_data)
return audio_data
def _resample_output(self, audio_data: bytes) -> bytes:
"""Resample outgoing audio from 24kHz to client sample rate if needed."""
if self._client_sample_rate == 16000:
return self._resample_24k_to_16k(audio_data)
return audio_data
def _webrtcvad_detect(self, pcm_data: bytes) -> bool:
"""Run webrtcvad on audio data. Returns True if voice is detected in any frame.
Expects 24kHz PCM input, resamples to 16kHz for webrtcvad.
"""
resampled = self._resample_24k_to_16k(pcm_data)
return self._webrtcvad_check(resampled)
def _webrtcvad_detect_16k(self, pcm_data: bytes) -> bool:
"""Run webrtcvad directly on 16kHz PCM data (no resampling needed)."""
return self._webrtcvad_check(pcm_data)
def _webrtcvad_check(self, pcm_16k: bytes) -> bool:
"""Core webrtcvad check on 16kHz PCM data."""
frame_size = (self.VAD_TARGET_RATE * self.VAD_FRAME_DURATION_MS // 1000) * 2 # bytes per frame
if len(pcm_16k) < frame_size:
return False
voice_frames = 0
total_frames = 0
for offset in range(0, len(pcm_16k) - frame_size + 1, frame_size):
frame = pcm_16k[offset:offset + frame_size]
total_frames += 1
try:
if self._vad.is_speech(frame, self.VAD_TARGET_RATE):
voice_frames += 1
except Exception:
pass
return voice_frames > 0
async def handle_audio(self, audio_data: bytes) -> None:
"""Forward user audio to ASR with VAD gating. Lazy-connect on speech start."""
if not self._running:
return
# VAD 检测:直接在原始采样率上做,避免双重重采样导致精度损失
if self._client_sample_rate == 16000:
has_voice = self._webrtcvad_detect_16k(audio_data)
else:
has_voice = self._webrtcvad_detect(audio_data)
# 上采样后的音频用于 ASRASR 需要 24kHz
audio_for_asr = self._resample_input(audio_data)
self._audio_chunk_count += 1
now = asyncio.get_event_loop().time()
# Update consecutive streaks
if has_voice:
self._vad_voice_streak += 1
self._vad_silence_streak = 0
else:
self._vad_silence_streak += 1
self._vad_voice_streak = 0
if has_voice:
# Cancel any pending finish
if self._vad_finish_task and not self._vad_finish_task.done():
self._vad_finish_task.cancel()
self._vad_finish_task = None
if not self._vad_speaking and self._vad_voice_streak >= self.VAD_SPEECH_CHUNKS:
# Speech just started — connect ASR
self._vad_speaking = True
logger.info(f"[VoiceLite] VAD: speech started (webrtcvad), connecting ASR...")
try:
await self._connect_asr()
# Send buffered pre-speech audio
for buffered in self._pre_buffer:
await self._asr_client.send_audio(buffered)
self._pre_buffer.clear()
except Exception as e:
logger.error(f"[VoiceLite] VAD: ASR connect failed: {e}", exc_info=True)
self._vad_speaking = False
return
# Send current chunk
if self._asr_client:
try:
await self._asr_client.send_audio(audio_for_asr)
except Exception:
pass
self._vad_silence_start = 0
else:
if self._vad_speaking:
# Brief silence while speaking — keep sending for ASR context
if self._asr_client:
try:
await self._asr_client.send_audio(audio_for_asr)
except Exception:
pass
if self._vad_silence_start == 0:
self._vad_silence_start = now
# Require both consecutive silent chunks AND time threshold
if (self._vad_silence_streak >= self.VAD_SILENCE_CHUNKS
and (now - self._vad_silence_start) >= self.VAD_SILENCE_DURATION):
if not self._vad_finish_task or self._vad_finish_task.done():
self._vad_finish_task = asyncio.create_task(self._vad_send_finish())
else:
# Not speaking — buffer recent audio for pre-speech context
self._pre_buffer.append(audio_for_asr)
if len(self._pre_buffer) > self.VAD_PRE_BUFFER_SIZE:
self._pre_buffer.pop(0)
async def _vad_send_finish(self) -> None:
"""Send finish signal to ASR after silence detected."""
logger.info(f"[VoiceLite] VAD: silence detected, sending finish to ASR")
self._vad_speaking = False
self._vad_silence_start = 0
self._vad_voice_streak = 0
self._vad_silence_streak = 0
if self._asr_client:
try:
await self._asr_client.send_finish()
except Exception as e:
logger.warning(f"[VoiceLite] VAD: send_finish failed: {e}")
async def handle_text(self, text: str) -> None:
"""Handle direct text input - bypass ASR and go straight to agent."""
if not self._running:
return
await self._interrupt_current()
self._agent_task = asyncio.create_task(self._process_utterance(text))
async def _connect_asr(self) -> None:
"""Create and connect a new ASR client, start receive loop."""
if self._asr_client:
try:
await self._asr_client.close()
except Exception:
pass
if self._asr_receive_task and not self._asr_receive_task.done():
self._asr_receive_task.cancel()
self._asr_client = StreamingASRClient(uid=self.user_identifier or "voice_lite")
await self._asr_client.connect()
logger.info(f"[VoiceLite] ASR client connected")
# Start receive loop for this ASR session
self._asr_receive_task = asyncio.create_task(self._asr_receive_loop())
async def _asr_receive_loop(self) -> None:
"""Receive ASR results from the current ASR session."""
try:
_listening_emitted = False
async for text, is_final in self._asr_client.receive_results():
if not self._running:
return
if not _listening_emitted:
await self._emit_status("listening")
_listening_emitted = True
now = asyncio.get_event_loop().time()
self._last_asr_time = now
self._current_asr_text = text
# Track text changes for stability detection
if text != self._last_changed_text:
self._last_changed_text = text
self._last_text_change_time = now
# Reset stability timer on every result
self._reset_silence_timer()
except asyncio.CancelledError:
pass
except Exception as e:
if self._running:
logger.warning(f"[VoiceLite] ASR session ended: {e}")
finally:
logger.info(f"[VoiceLite] ASR session done")
# Clean up ASR client after session ends
if self._asr_client:
try:
await self._asr_client.close()
except Exception:
pass
self._asr_client = None
def _reset_silence_timer(self) -> None:
"""Reset the silence timeout timer."""
if self._silence_timer_task and not self._silence_timer_task.done():
self._silence_timer_task.cancel()
self._silence_timer_task = asyncio.create_task(self._silence_timeout())
async def _silence_timeout(self) -> None:
"""Wait for silence timeout, then check if text has been stable."""
try:
await asyncio.sleep(VOICE_LITE_SILENCE_TIMEOUT)
if not self._running:
return
# Check if text has been stable (unchanged) for the timeout period
now = asyncio.get_event_loop().time()
if (self._current_asr_text
and (now - self._last_text_change_time) >= VOICE_LITE_SILENCE_TIMEOUT):
logger.info(f"[VoiceLite] Text stable for {VOICE_LITE_SILENCE_TIMEOUT}s, processing: '{self._current_asr_text}'")
await self._on_utterance_complete(self._current_asr_text)
except asyncio.CancelledError:
pass
async def _on_utterance_complete(self, text: str) -> None:
"""Called when a complete utterance is detected."""
if not text.strip():
return
async with self._utterance_lock:
# Cancel silence timer
if self._silence_timer_task and not self._silence_timer_task.done():
self._silence_timer_task.cancel()
# Reset VAD state for next utterance
self._vad_speaking = False
self._vad_silence_start = 0
self._vad_voice_streak = 0
self._vad_silence_streak = 0
if self._vad_finish_task and not self._vad_finish_task.done():
self._vad_finish_task.cancel()
self._vad_finish_task = None
# Interrupt any in-progress agent+TTS
await self._interrupt_current()
# Send final ASR text to frontend
if self._on_asr_text:
await self._on_asr_text(text)
self._current_asr_text = ""
self._last_changed_text = ""
self._agent_task = asyncio.create_task(self._process_utterance(text))
async def _interrupt_current(self) -> None:
"""Cancel current agent+TTS task if running."""
if self._agent_task and not self._agent_task.done():
logger.info(f"[VoiceLite] Interrupting previous agent task")
from utils.cancel_manager import trigger_cancel
trigger_cancel(self.session_id)
self._agent_task.cancel()
try:
await self._agent_task
except (asyncio.CancelledError, Exception):
pass
self._agent_task = None
async def _process_utterance(self, text: str) -> None:
"""Process a complete utterance: agent -> TTS pipeline."""
try:
logger.info(f"[VoiceLite] Processing utterance: '{text}'")
await self._emit_status("thinking")
accumulated_text = []
tag_filter = StreamTagFilter()
splitter = TTSSentenceSplitter()
tts_client = StreamingTTSClient(speaker=self._speaker)
speaking = False
async for chunk in stream_v3_agent(
user_text=text,
bot_id=self.bot_id,
bot_config=self._bot_config,
session_id=self.session_id,
user_identifier=self.user_identifier,
):
accumulated_text.append(chunk)
if self._on_agent_stream:
await self._on_agent_stream(chunk)
passthrough = tag_filter.feed(chunk)
if not passthrough:
if tag_filter.answer_ended:
for sentence in splitter.flush():
sentence = clean_markdown(sentence)
if sentence:
if not speaking:
await self._emit_status("speaking")
speaking = True
await self._send_tts(tts_client, sentence)
continue
# Feed raw passthrough to splitter (preserve newlines for splitting),
# apply clean_markdown on output sentences
for sentence in splitter.feed(passthrough):
sentence = clean_markdown(sentence)
if sentence:
if not speaking:
await self._emit_status("speaking")
speaking = True
await self._send_tts(tts_client, sentence)
# Handle remaining text
for sentence in splitter.flush():
sentence = clean_markdown(sentence)
if sentence:
if not speaking:
await self._emit_status("speaking")
speaking = True
await self._send_tts(tts_client, sentence)
# Log full agent result (not sent to frontend, already streamed)
full_result = "".join(accumulated_text)
logger.info(f"[VoiceLite] Agent done ({len(full_result)} chars)")
# Notify frontend that agent text stream is complete
if self._on_agent_result:
await self._on_agent_result(full_result)
await self._emit_status("idle")
except asyncio.CancelledError:
logger.info(f"[VoiceLite] Agent task cancelled (user interrupted)")
raise
except Exception as e:
logger.error(f"[VoiceLite] Error processing utterance: {e}", exc_info=True)
await self._emit_error(f"Processing failed: {str(e)}")
async def _send_tts(self, tts_client: StreamingTTSClient, sentence: str) -> None:
"""Synthesize a sentence and emit audio chunks."""
logger.info(f"[VoiceLite] TTS sentence: '{sentence[:80]}'")
if self._client_sample_rate != 24000:
# Client needs non-24kHz: use raw int16 pipeline to allow resampling
async for audio_chunk in tts_client.synthesize_raw(sentence):
if self._on_audio:
await self._on_audio(self._resample_output(audio_chunk))
else:
async for audio_chunk in tts_client.synthesize(sentence):
if self._on_audio:
await self._on_audio(audio_chunk)
async def _emit_status(self, status: str) -> None:
if self._on_status:
await self._on_status(status)
async def _emit_error(self, message: str) -> None:
if self._on_error:
await self._on_error(message)