528 lines
22 KiB
Python
528 lines
22 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import re
|
|
import uuid
|
|
from typing import Optional, Callable, Awaitable, AsyncGenerator
|
|
|
|
from services.realtime_voice_client import RealtimeDialogClient
|
|
|
|
logger = logging.getLogger('app')
|
|
|
|
|
|
class _StreamTagFilter:
|
|
"""
|
|
Filters streaming text based on tag blocks.
|
|
Only passes through content inside [ANSWER] blocks.
|
|
If no tags are found at all, passes through everything (fallback).
|
|
Skips content inside [TOOL_CALL], [TOOL_RESPONSE], [THINK], [SOURCE], etc.
|
|
"""
|
|
|
|
SKIP_TAGS = {"TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE"}
|
|
KNOWN_TAGS = {"ANSWER", "TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE", "PREAMBLE", "SUMMARY"}
|
|
|
|
def __init__(self):
|
|
self.state = "idle" # idle, answer, skip
|
|
self.found_any_tag = False
|
|
self._pending = "" # buffer for partial tag like "[TOO..."
|
|
|
|
def feed(self, chunk: str) -> str:
|
|
"""Feed a chunk, return text that should be passed to TTS."""
|
|
self._pending += chunk
|
|
output = []
|
|
|
|
while self._pending:
|
|
if self.state in ("idle", "answer"):
|
|
bracket_pos = self._pending.find("[")
|
|
if bracket_pos == -1:
|
|
if self.state == "answer" or not self.found_any_tag:
|
|
output.append(self._pending)
|
|
self._pending = ""
|
|
else:
|
|
before = self._pending[:bracket_pos]
|
|
if before and (self.state == "answer" or not self.found_any_tag):
|
|
output.append(before)
|
|
|
|
close_pos = self._pending.find("]", bracket_pos)
|
|
if close_pos == -1:
|
|
# Incomplete tag — wait for next chunk
|
|
self._pending = self._pending[bracket_pos:]
|
|
break
|
|
|
|
tag_name = self._pending[bracket_pos + 1:close_pos]
|
|
self._pending = self._pending[close_pos + 1:]
|
|
|
|
if tag_name not in self.KNOWN_TAGS:
|
|
if self.state == "answer" or not self.found_any_tag:
|
|
output.append(f"[{tag_name}]")
|
|
continue
|
|
|
|
self.found_any_tag = True
|
|
if tag_name == "ANSWER":
|
|
self.state = "answer"
|
|
else:
|
|
self.state = "skip"
|
|
|
|
elif self.state == "skip":
|
|
bracket_pos = self._pending.find("[")
|
|
if bracket_pos == -1:
|
|
self._pending = ""
|
|
else:
|
|
close_pos = self._pending.find("]", bracket_pos)
|
|
if close_pos == -1:
|
|
self._pending = self._pending[bracket_pos:]
|
|
break
|
|
|
|
tag_name = self._pending[bracket_pos + 1:close_pos]
|
|
self._pending = self._pending[close_pos + 1:]
|
|
|
|
if tag_name not in self.KNOWN_TAGS:
|
|
continue
|
|
|
|
if tag_name == "ANSWER":
|
|
self.state = "answer"
|
|
else:
|
|
self.state = "skip"
|
|
|
|
return "".join(output)
|
|
|
|
|
|
class VoiceSession:
|
|
"""Manages a single voice dialogue session lifecycle"""
|
|
|
|
def __init__(
|
|
self,
|
|
bot_id: str,
|
|
session_id: Optional[str] = None,
|
|
user_identifier: Optional[str] = None,
|
|
on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None,
|
|
on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
on_agent_stream: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
on_llm_text: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
on_status: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
on_error: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
):
|
|
self.bot_id = bot_id
|
|
self.session_id = session_id or str(uuid.uuid4())
|
|
self.user_identifier = user_identifier or ""
|
|
|
|
self.realtime_client: Optional[RealtimeDialogClient] = None
|
|
self._bot_config: dict = {}
|
|
|
|
# Callbacks
|
|
self._on_audio = on_audio
|
|
self._on_asr_text = on_asr_text
|
|
self._on_agent_result = on_agent_result
|
|
self._on_agent_stream = on_agent_stream
|
|
self._on_llm_text = on_llm_text
|
|
self._on_status = on_status
|
|
self._on_error = on_error
|
|
|
|
self._running = False
|
|
self._is_user_querying = False
|
|
self._current_asr_text = ""
|
|
# When True, discard TTS audio from SERVER_ACK (comfort speech period)
|
|
self._is_sending_chat_tts_text = False
|
|
# Set to True when event 350 fires for chat_tts_text, indicating the TTS segment is done
|
|
# and next TTS send must use start=True to begin a new session
|
|
self._tts_segment_done = False
|
|
# Signaled when event 359 fires (TTS fully completed), used to wait before starting new TTS
|
|
self._tts_complete_event: asyncio.Event = asyncio.Event()
|
|
self._receive_task: Optional[asyncio.Task] = None
|
|
self._agent_task: Optional[asyncio.Task] = None
|
|
|
|
async def start(self) -> None:
|
|
"""Fetch bot config, connect to Volcengine and start receiving responses"""
|
|
from utils.fastapi_utils import fetch_bot_config_from_db
|
|
|
|
self._bot_config = await fetch_bot_config_from_db(self.bot_id, self.user_identifier)
|
|
|
|
self.realtime_client = RealtimeDialogClient(
|
|
session_id=self.session_id,
|
|
speaker=self._bot_config.get("voice_speaker"),
|
|
system_role=self._bot_config.get("voice_system_role"),
|
|
speaking_style=self._bot_config.get("voice_speaking_style"),
|
|
bot_name=self._bot_config.get("name", ""),
|
|
)
|
|
|
|
await self.realtime_client.connect()
|
|
self._running = True
|
|
self._receive_task = asyncio.create_task(self._receive_loop())
|
|
await self._emit_status("ready")
|
|
|
|
async def stop(self) -> None:
|
|
"""Gracefully stop the session"""
|
|
self._running = False
|
|
try:
|
|
await self.realtime_client.finish_session()
|
|
await asyncio.sleep(0.5)
|
|
await self.realtime_client.finish_connection()
|
|
except Exception as e:
|
|
logger.warning(f"Error during session cleanup: {e}")
|
|
finally:
|
|
if self._agent_task and not self._agent_task.done():
|
|
from utils.cancel_manager import trigger_cancel
|
|
trigger_cancel(self.session_id)
|
|
self._agent_task.cancel()
|
|
if self._receive_task and not self._receive_task.done():
|
|
self._receive_task.cancel()
|
|
await self.realtime_client.close()
|
|
|
|
async def handle_audio(self, audio_data: bytes) -> None:
|
|
"""Forward user audio to Volcengine"""
|
|
if self._running and self.realtime_client.ws:
|
|
await self.realtime_client.send_audio(audio_data)
|
|
|
|
async def handle_text(self, text: str) -> None:
|
|
"""Handle text input - send as text query"""
|
|
if self._running and self.realtime_client.ws:
|
|
await self.realtime_client.chat_text_query(text)
|
|
|
|
async def _receive_loop(self) -> None:
|
|
"""Continuously receive and dispatch Volcengine responses"""
|
|
try:
|
|
while self._running:
|
|
response = await self.realtime_client.receive_response()
|
|
if not response:
|
|
continue
|
|
await self._handle_response(response)
|
|
except asyncio.CancelledError:
|
|
logger.info(f"Voice session receive loop cancelled: {self.session_id}")
|
|
except Exception as e:
|
|
logger.error(f"Voice session receive loop error: {e}")
|
|
await self._emit_error(f"Connection error: {str(e)}")
|
|
finally:
|
|
self._running = False
|
|
|
|
async def _handle_response(self, response: dict) -> None:
|
|
msg_type = response.get('message_type', '')
|
|
event = response.get('event')
|
|
payload_msg = response.get('payload_msg', {})
|
|
|
|
if msg_type == 'SERVER_ACK' and isinstance(payload_msg, bytes):
|
|
# TTS audio data — discard during comfort speech period
|
|
if self._is_sending_chat_tts_text:
|
|
return
|
|
if self._on_audio:
|
|
await self._on_audio(payload_msg)
|
|
|
|
elif msg_type == 'SERVER_FULL_RESPONSE':
|
|
logger.info(f"[Voice] event={event}, payload_msg={payload_msg if not isinstance(payload_msg, bytes) else f'<{len(payload_msg)} bytes>'}")
|
|
|
|
if event == 450:
|
|
# User started speaking — clear audio, set querying flag, reset ASR accumulator
|
|
self._is_user_querying = True
|
|
self._current_asr_text = ""
|
|
await self._emit_status("listening")
|
|
|
|
elif event == 451:
|
|
# Streaming ASR result — accumulate recognized text
|
|
if isinstance(payload_msg, dict):
|
|
results = payload_msg.get("results", [])
|
|
if results and isinstance(results, list) and len(results) > 0:
|
|
text = results[0].get("text", "")
|
|
if text:
|
|
self._current_asr_text = text
|
|
logger.debug(f"[Voice] ASR streaming (451): '{text}'")
|
|
|
|
elif event == 459:
|
|
# ASR completed — use accumulated text from event 451
|
|
self._is_user_querying = False
|
|
asr_text = self._current_asr_text
|
|
|
|
# Filter out ASR during thinking/speaking (TTS echo protection)
|
|
if self._is_sending_chat_tts_text:
|
|
logger.info(f"[Voice] Discarding ASR during thinking/speaking: '{asr_text}'")
|
|
return
|
|
|
|
logger.info(f"[Voice] ASR result: '{asr_text}'")
|
|
|
|
if self._on_asr_text and asr_text:
|
|
await self._on_asr_text(asr_text)
|
|
await self._emit_status("thinking")
|
|
|
|
# Cancel previous agent task if still running
|
|
if self._agent_task and not self._agent_task.done():
|
|
logger.info(f"[Voice] Interrupting previous agent task")
|
|
from utils.cancel_manager import trigger_cancel
|
|
trigger_cancel(self.session_id)
|
|
self._agent_task.cancel()
|
|
self._agent_task = None
|
|
|
|
# Trigger comfort TTS + agent call
|
|
self._is_sending_chat_tts_text = True
|
|
self._agent_task = asyncio.create_task(self._on_asr_text_received(asr_text))
|
|
|
|
elif event == 350:
|
|
# TTS segment completed
|
|
tts_type = ""
|
|
if isinstance(payload_msg, dict):
|
|
tts_type = payload_msg.get("tts_type", "")
|
|
logger.info(f"[Voice] TTS segment done, type={tts_type}, is_sending={self._is_sending_chat_tts_text}")
|
|
|
|
# When comfort TTS or RAG TTS finishes, stop discarding audio
|
|
if self._is_sending_chat_tts_text and tts_type == "chat_tts_text":
|
|
self._is_sending_chat_tts_text = False
|
|
logger.info(f"[Voice] Comfort/RAG TTS done, resuming audio forwarding")
|
|
|
|
# Mark TTS segment as done so next send uses start=True
|
|
if tts_type == "chat_tts_text":
|
|
self._tts_segment_done = True
|
|
|
|
elif event == 359:
|
|
# TTS fully completed (all segments done)
|
|
logger.info(f"[Voice] TTS fully completed")
|
|
self._tts_complete_event.set()
|
|
|
|
elif event in (152, 153):
|
|
logger.info(f"[Voice] Session finished event: {event}")
|
|
self._running = False
|
|
|
|
elif msg_type == 'SERVER_ERROR':
|
|
error_msg = str(payload_msg) if payload_msg else "Unknown server error"
|
|
logger.error(f"[Voice] Server error: {error_msg}")
|
|
await self._emit_error(error_msg)
|
|
|
|
# Sentence-ending punctuation pattern for splitting TTS
|
|
_SENTENCE_END_RE = re.compile(r'[。!?;\n.!?;]')
|
|
# Markdown syntax to strip before TTS
|
|
_MD_CLEAN_RE = re.compile(r'#{1,6}\s*|(?<!\w)\*{1,3}|(?<!\w)_{1,3}|\*{1,3}(?!\w)|_{1,3}(?!\w)|~~|`{1,3}|^>\s*|^\s*[-*+]\s+|^\s*\d+\.\s+|\[([^\]]*)\]\([^)]*\)|!\[([^\]]*)\]\([^)]*\)', re.MULTILINE)
|
|
|
|
@staticmethod
|
|
def _clean_markdown(text: str) -> str:
|
|
"""Strip Markdown formatting characters for TTS readability."""
|
|
# Replace links/images with their display text
|
|
text = re.sub(r'!\[([^\]]*)\]\([^)]*\)', r'\1', text)
|
|
text = re.sub(r'\[([^\]]*)\]\([^)]*\)', r'\1', text)
|
|
# Remove headings, bold, italic, strikethrough, code marks, blockquote
|
|
text = re.sub(r'#{1,6}\s*', '', text)
|
|
text = re.sub(r'\*{1,3}|_{1,3}|~~|`{1,3}', '', text)
|
|
text = re.sub(r'^>\s*', '', text, flags=re.MULTILINE)
|
|
# Remove list markers
|
|
text = re.sub(r'^\s*[-*+]\s+', '', text, flags=re.MULTILINE)
|
|
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
|
|
# Remove horizontal rules
|
|
text = re.sub(r'^[\s]*[-*_]{3,}[\s]*$', '', text, flags=re.MULTILINE)
|
|
# Collapse extra whitespace
|
|
text = re.sub(r'\n{2,}', '\n', text)
|
|
return text.strip()
|
|
|
|
async def _on_asr_text_received(self, text: str) -> None:
|
|
"""Called when ASR text is received — stream agent output, send TTS sentence by sentence"""
|
|
if not text.strip():
|
|
self._is_sending_chat_tts_text = False
|
|
return
|
|
|
|
try:
|
|
logger.info(f"[Voice] Calling v3 agent with text: '{text}'")
|
|
|
|
accumulated_text = [] # full agent output for on_agent_result callback
|
|
sentence_buf = "" # buffer for accumulating until sentence boundary
|
|
tts_started = False # whether we've sent the first TTS chunk
|
|
tag_filter = _StreamTagFilter()
|
|
|
|
async for chunk in self._stream_v3_agent(text):
|
|
accumulated_text.append(chunk)
|
|
|
|
if self._on_agent_stream:
|
|
await self._on_agent_stream(chunk)
|
|
|
|
# Filter out [TOOL_CALL], [TOOL_RESPONSE], [THINK] etc., only keep [ANSWER] content
|
|
passthrough = tag_filter.feed(chunk)
|
|
|
|
if not passthrough:
|
|
continue
|
|
|
|
sentence_buf += passthrough
|
|
|
|
# Check for sentence boundaries and send complete sentences to TTS
|
|
while True:
|
|
match = self._SENTENCE_END_RE.search(sentence_buf)
|
|
if not match:
|
|
break
|
|
# Split at sentence boundary (include the punctuation)
|
|
end_pos = match.end()
|
|
sentence = sentence_buf[:end_pos].strip()
|
|
sentence_buf = sentence_buf[end_pos:]
|
|
|
|
if sentence:
|
|
sentence = self._clean_markdown(sentence)
|
|
if sentence:
|
|
# If previous TTS segment completed (e.g. gap during tool call),
|
|
# close old session, wait for TTS delivery to finish, then restart
|
|
if tts_started and self._tts_segment_done:
|
|
logger.info(f"[Voice] TTS segment done, closing session and waiting for delivery")
|
|
await self.realtime_client.chat_tts_text(content="", start=False, end=True)
|
|
self._tts_complete_event.clear()
|
|
try:
|
|
await asyncio.wait_for(self._tts_complete_event.wait(), timeout=10)
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"[Voice] Timeout waiting for TTS complete, proceeding anyway")
|
|
tts_started = False
|
|
self._tts_segment_done = False
|
|
logger.info(f"[Voice] TTS delivery done, starting new session")
|
|
|
|
logger.info(f"[Voice] Sending TTS sentence: '{sentence[:80]}'")
|
|
await self.realtime_client.chat_tts_text(
|
|
content=sentence,
|
|
start=not tts_started,
|
|
end=False,
|
|
)
|
|
if not tts_started:
|
|
await self._emit_status("speaking")
|
|
tts_started = True
|
|
self._tts_segment_done = False
|
|
|
|
# Handle remaining text in buffer (last sentence without ending punctuation)
|
|
remaining = sentence_buf.strip()
|
|
if remaining:
|
|
remaining = self._clean_markdown(remaining)
|
|
if remaining:
|
|
# If previous TTS segment completed, close and wait before restart
|
|
if tts_started and self._tts_segment_done:
|
|
logger.info(f"[Voice] TTS segment done, closing session and waiting for delivery (remaining)")
|
|
await self.realtime_client.chat_tts_text(content="", start=False, end=True)
|
|
self._tts_complete_event.clear()
|
|
try:
|
|
await asyncio.wait_for(self._tts_complete_event.wait(), timeout=10)
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"[Voice] Timeout waiting for TTS complete, proceeding anyway")
|
|
tts_started = False
|
|
self._tts_segment_done = False
|
|
logger.info(f"[Voice] TTS delivery done, starting new session for remaining")
|
|
|
|
logger.info(f"[Voice] Sending TTS remaining: '{remaining[:80]}'")
|
|
await self.realtime_client.chat_tts_text(
|
|
content=remaining,
|
|
start=not tts_started,
|
|
end=False,
|
|
)
|
|
if not tts_started:
|
|
await self._emit_status("speaking")
|
|
tts_started = True
|
|
self._tts_segment_done = False
|
|
|
|
# Send TTS end signal
|
|
if tts_started:
|
|
await self.realtime_client.chat_tts_text(
|
|
content="",
|
|
start=False,
|
|
end=True,
|
|
)
|
|
else:
|
|
logger.warning(f"[Voice] Agent returned no usable text for TTS")
|
|
self._is_sending_chat_tts_text = False
|
|
|
|
# Emit full agent result
|
|
full_result = "".join(accumulated_text)
|
|
logger.info(f"[Voice] Agent result ({len(full_result)} chars): {full_result[:200]}")
|
|
if self._on_agent_result and full_result:
|
|
await self._on_agent_result(full_result)
|
|
|
|
await self._emit_status("idle")
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info(f"[Voice] Agent task cancelled (user interrupted)")
|
|
self._is_sending_chat_tts_text = False
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"[Voice] Error in ASR text callback: {e}", exc_info=True)
|
|
self._is_sending_chat_tts_text = False
|
|
await self._emit_error(f"Agent call failed: {str(e)}")
|
|
|
|
async def _stream_v3_agent(self, user_text: str) -> AsyncGenerator[str, None]:
|
|
"""Call v3 agent API in streaming mode, yield text chunks as they arrive"""
|
|
try:
|
|
from utils.api_models import ChatRequestV3, Message
|
|
from utils.fastapi_utils import (
|
|
process_messages,
|
|
create_project_directory,
|
|
)
|
|
from agent.agent_config import AgentConfig
|
|
from routes.chat import enhanced_generate_stream_response
|
|
|
|
bot_config = self._bot_config
|
|
language = bot_config.get("language", "zh")
|
|
|
|
messages_obj = [Message(role="user", content=user_text)]
|
|
|
|
request = ChatRequestV3(
|
|
messages=messages_obj,
|
|
bot_id=self.bot_id,
|
|
stream=True,
|
|
session_id=self.session_id,
|
|
user_identifier=self.user_identifier,
|
|
)
|
|
|
|
project_dir = create_project_directory(
|
|
bot_config.get("dataset_ids", []),
|
|
self.bot_id,
|
|
bot_config.get("skills", []),
|
|
)
|
|
|
|
processed_messages = process_messages(messages_obj, language)
|
|
|
|
config = await AgentConfig.from_v3_request(
|
|
request,
|
|
bot_config,
|
|
project_dir,
|
|
processed_messages,
|
|
language,
|
|
)
|
|
config.stream = True
|
|
|
|
async for sse_line in enhanced_generate_stream_response(config):
|
|
if not sse_line or not sse_line.startswith("data: "):
|
|
continue
|
|
data_str = sse_line.strip().removeprefix("data: ")
|
|
if data_str == "[DONE]":
|
|
break
|
|
try:
|
|
data = json.loads(data_str)
|
|
choices = data.get("choices", [])
|
|
if choices:
|
|
delta = choices[0].get("delta", {})
|
|
content = delta.get("content", "")
|
|
if content:
|
|
yield content
|
|
except (json.JSONDecodeError, KeyError):
|
|
continue
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info(f"[Voice] v3 agent call cancelled")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"[Voice] Error calling v3 agent: {e}", exc_info=True)
|
|
|
|
@staticmethod
|
|
def _extract_answer(agent_result: str) -> str:
|
|
"""Extract the answer portion from agent result, stripping tags like [ANSWER], [THINK] etc."""
|
|
lines = agent_result.split('\n')
|
|
answer_lines = []
|
|
in_answer = False
|
|
for line in lines:
|
|
if line.strip().startswith('[ANSWER]'):
|
|
in_answer = True
|
|
rest = line.strip()[len('[ANSWER]'):].strip()
|
|
if rest:
|
|
answer_lines.append(rest)
|
|
continue
|
|
if line.strip().startswith('[') and not line.strip().startswith('[ANSWER]'):
|
|
in_answer = False
|
|
continue
|
|
if in_answer:
|
|
answer_lines.append(line)
|
|
|
|
if answer_lines:
|
|
return '\n'.join(answer_lines).strip()
|
|
return agent_result.strip()
|
|
|
|
async def _emit_status(self, status: str) -> None:
|
|
if self._on_status:
|
|
await self._on_status(status)
|
|
|
|
async def _emit_error(self, message: str) -> None:
|
|
if self._on_error:
|
|
await self._on_error(message)
|