qwen_agent/services/voice_session_manager.py
2026-03-21 01:00:02 +08:00

306 lines
12 KiB
Python

import asyncio
import json
import logging
import uuid
from typing import Optional, Callable, Awaitable
from services.realtime_voice_client import RealtimeDialogClient
logger = logging.getLogger('app')
class VoiceSession:
"""Manages a single voice dialogue session lifecycle"""
def __init__(
self,
bot_id: str,
session_id: Optional[str] = None,
user_identifier: Optional[str] = None,
on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None,
on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None,
on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None,
on_llm_text: Optional[Callable[[str], Awaitable[None]]] = None,
on_status: Optional[Callable[[str], Awaitable[None]]] = None,
on_error: Optional[Callable[[str], Awaitable[None]]] = None,
):
self.bot_id = bot_id
self.session_id = session_id or str(uuid.uuid4())
self.user_identifier = user_identifier or ""
self.realtime_client: Optional[RealtimeDialogClient] = None
self._bot_config: dict = {}
# Callbacks
self._on_audio = on_audio
self._on_asr_text = on_asr_text
self._on_agent_result = on_agent_result
self._on_llm_text = on_llm_text
self._on_status = on_status
self._on_error = on_error
self._running = False
self._is_user_querying = False
self._current_asr_text = ""
# When True, discard TTS audio from SERVER_ACK (comfort speech period)
self._is_sending_chat_tts_text = False
self._receive_task: Optional[asyncio.Task] = None
async def start(self) -> None:
"""Fetch bot config, connect to Volcengine and start receiving responses"""
from utils.fastapi_utils import fetch_bot_config_from_db
self._bot_config = await fetch_bot_config_from_db(self.bot_id, self.user_identifier)
self.realtime_client = RealtimeDialogClient(
session_id=self.session_id,
speaker=self._bot_config.get("voice_speaker"),
system_role=self._bot_config.get("voice_system_role"),
speaking_style=self._bot_config.get("voice_speaking_style"),
bot_name=self._bot_config.get("name", ""),
)
await self.realtime_client.connect()
self._running = True
self._receive_task = asyncio.create_task(self._receive_loop())
await self._emit_status("ready")
async def stop(self) -> None:
"""Gracefully stop the session"""
self._running = False
try:
await self.realtime_client.finish_session()
await asyncio.sleep(0.5)
await self.realtime_client.finish_connection()
except Exception as e:
logger.warning(f"Error during session cleanup: {e}")
finally:
if self._receive_task and not self._receive_task.done():
self._receive_task.cancel()
await self.realtime_client.close()
async def handle_audio(self, audio_data: bytes) -> None:
"""Forward user audio to Volcengine"""
if self._running and self.realtime_client.ws:
await self.realtime_client.send_audio(audio_data)
async def handle_text(self, text: str) -> None:
"""Handle text input - send as text query"""
if self._running and self.realtime_client.ws:
await self.realtime_client.chat_text_query(text)
async def _receive_loop(self) -> None:
"""Continuously receive and dispatch Volcengine responses"""
try:
while self._running:
response = await self.realtime_client.receive_response()
if not response:
continue
await self._handle_response(response)
except asyncio.CancelledError:
logger.info(f"Voice session receive loop cancelled: {self.session_id}")
except Exception as e:
logger.error(f"Voice session receive loop error: {e}")
await self._emit_error(f"Connection error: {str(e)}")
finally:
self._running = False
async def _handle_response(self, response: dict) -> None:
msg_type = response.get('message_type', '')
event = response.get('event')
payload_msg = response.get('payload_msg', {})
if msg_type == 'SERVER_ACK' and isinstance(payload_msg, bytes):
# TTS audio data — discard during comfort speech period
if self._is_sending_chat_tts_text:
return
if self._on_audio:
await self._on_audio(payload_msg)
elif msg_type == 'SERVER_FULL_RESPONSE':
logger.info(f"[Voice] event={event}, payload_msg={payload_msg if not isinstance(payload_msg, bytes) else f'<{len(payload_msg)} bytes>'}")
if event == 450:
# User started speaking — clear audio, set querying flag, reset ASR accumulator
self._is_user_querying = True
self._current_asr_text = ""
await self._emit_status("listening")
elif event == 451:
# Streaming ASR result — accumulate recognized text
if isinstance(payload_msg, dict):
results = payload_msg.get("results", [])
if results and isinstance(results, list) and len(results) > 0:
text = results[0].get("text", "")
if text:
self._current_asr_text = text
logger.debug(f"[Voice] ASR streaming (451): '{text}'")
elif event == 459:
# ASR completed — use accumulated text from event 451
self._is_user_querying = False
asr_text = self._current_asr_text
logger.info(f"[Voice] ASR result: '{asr_text}'")
if self._on_asr_text and asr_text:
await self._on_asr_text(asr_text)
await self._emit_status("thinking")
# Trigger comfort TTS + agent call
self._is_sending_chat_tts_text = True
asyncio.create_task(self._on_asr_text_received(asr_text))
elif event == 350:
# TTS segment completed
tts_type = ""
if isinstance(payload_msg, dict):
tts_type = payload_msg.get("tts_type", "")
logger.info(f"[Voice] TTS segment done, type={tts_type}, is_sending={self._is_sending_chat_tts_text}")
# When comfort TTS or RAG TTS finishes, stop discarding audio
if self._is_sending_chat_tts_text and tts_type in ("chat_tts_text", "external_rag"):
self._is_sending_chat_tts_text = False
logger.info(f"[Voice] Comfort/RAG TTS done, resuming audio forwarding")
elif event == 359:
# TTS fully completed (all segments done)
logger.info(f"[Voice] TTS fully completed")
await self._emit_status("idle")
elif event in (152, 153):
logger.info(f"[Voice] Session finished event: {event}")
self._running = False
elif msg_type == 'SERVER_ERROR':
error_msg = str(payload_msg) if payload_msg else "Unknown server error"
logger.error(f"[Voice] Server error: {error_msg}")
await self._emit_error(error_msg)
async def _on_asr_text_received(self, text: str) -> None:
"""Called when ASR text is received — send comfort TTS, call agent, inject RAG"""
if not text.strip():
self._is_sending_chat_tts_text = False
return
try:
# 1. Send comfort TTS (real Chinese text, not "...")
logger.info(f"[Voice] Sending comfort TTS...")
await self.realtime_client.chat_tts_text(
content="请稍等,让我查一下。",
start=True,
end=False,
)
await self.realtime_client.chat_tts_text(
content="",
start=False,
end=True,
)
# 2. Call v3 agent (this may take a while)
logger.info(f"[Voice] Calling v3 agent with text: '{text}'")
agent_result = await self._call_v3_agent(text)
logger.info(f"[Voice] Agent result ({len(agent_result)} chars): {agent_result[:200]}")
if self._on_agent_result and agent_result:
await self._on_agent_result(agent_result)
# 3. Inject RAG result so the built-in LLM can polish and TTS it
if agent_result:
clean_result = self._extract_answer(agent_result)
logger.info(f"[Voice] Injecting RAG text ({len(clean_result)} chars): {clean_result[:200]}")
await self.realtime_client.chat_rag_text(clean_result)
else:
logger.warning(f"[Voice] Agent returned empty result, skipping RAG injection")
self._is_sending_chat_tts_text = False
except Exception as e:
logger.error(f"[Voice] Error in ASR text callback: {e}", exc_info=True)
self._is_sending_chat_tts_text = False
await self._emit_error(f"Agent call failed: {str(e)}")
async def _call_v3_agent(self, user_text: str) -> str:
"""Call v3 agent API internally (stream=false) to get full reasoning result"""
try:
from utils.api_models import ChatRequestV3, Message
from utils.fastapi_utils import (
process_messages,
create_project_directory,
)
from agent.agent_config import AgentConfig
from routes.chat import create_agent_and_generate_response
bot_config = self._bot_config
language = bot_config.get("language", "zh")
messages_raw = [{"role": "user", "content": user_text}]
messages_obj = [Message(role="user", content=user_text)]
request = ChatRequestV3(
messages=messages_obj,
bot_id=self.bot_id,
stream=False,
session_id=self.session_id,
user_identifier=self.user_identifier,
)
project_dir = create_project_directory(
bot_config.get("dataset_ids", []),
self.bot_id,
bot_config.get("skills", []),
)
processed_messages = process_messages(messages_obj, language)
config = await AgentConfig.from_v3_request(
request,
bot_config,
project_dir,
processed_messages,
language,
)
config.stream = False
result = await create_agent_and_generate_response(config)
if hasattr(result, 'choices'):
choices = result.choices
if choices and len(choices) > 0:
return choices[0].get("message", {}).get("content", "")
return ""
except Exception as e:
logger.error(f"[Voice] Error calling v3 agent: {e}", exc_info=True)
return ""
@staticmethod
def _extract_answer(agent_result: str) -> str:
"""Extract the answer portion from agent result, stripping tags like [ANSWER], [THINK] etc."""
lines = agent_result.split('\n')
answer_lines = []
in_answer = False
for line in lines:
if line.strip().startswith('[ANSWER]'):
in_answer = True
rest = line.strip()[len('[ANSWER]'):].strip()
if rest:
answer_lines.append(rest)
continue
if line.strip().startswith('[') and not line.strip().startswith('[ANSWER]'):
in_answer = False
continue
if in_answer:
answer_lines.append(line)
if answer_lines:
return '\n'.join(answer_lines).strip()
return agent_result.strip()
async def _emit_status(self, status: str) -> None:
if self._on_status:
await self._on_status(status)
async def _emit_error(self, message: str) -> None:
if self._on_error:
await self._on_error(message)