This commit is contained in:
朱潮 2026-03-21 02:41:10 +08:00
parent 16c50fa261
commit 4b70da5bb0
2 changed files with 138 additions and 41 deletions

View File

@ -287,10 +287,12 @@ async def init_agent(config: AgentConfig):
middleware=middleware, middleware=middleware,
checkpointer=checkpointer, checkpointer=checkpointer,
shell_env={ shell_env={
"ASSISTANT_ID": config.bot_id, k: v for k, v in {
"USER_IDENTIFIER": config.user_identifier, "ASSISTANT_ID": config.bot_id,
"TRACE_ID": config.trace_id, "USER_IDENTIFIER": config.user_identifier,
**(config.shell_env or {}), "TRACE_ID": config.trace_id,
**(config.shell_env or {}),
}.items() if v is not None
} }
) )

View File

@ -1,14 +1,83 @@
import asyncio import asyncio
import json import json
import logging import logging
import re
import uuid import uuid
from typing import Optional, Callable, Awaitable from typing import Optional, Callable, Awaitable, AsyncGenerator
from services.realtime_voice_client import RealtimeDialogClient from services.realtime_voice_client import RealtimeDialogClient
logger = logging.getLogger('app') logger = logging.getLogger('app')
class _StreamTagFilter:
"""
Filters streaming text based on tag blocks.
Only passes through content inside [ANSWER] blocks.
If no tags are found at all, passes through everything (fallback).
Skips content inside [TOOL_CALL], [TOOL_RESPONSE], [THINK], [SOURCE], etc.
"""
SKIP_TAGS = {"TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE"}
def __init__(self):
self.state = "idle" # idle, answer, skip
self.found_any_tag = False
self._pending = "" # buffer for partial tag like "[TOO..."
def feed(self, chunk: str) -> str:
"""Feed a chunk, return text that should be passed to TTS."""
self._pending += chunk
output = []
while self._pending:
if self.state in ("idle", "answer"):
bracket_pos = self._pending.find("[")
if bracket_pos == -1:
if self.state == "answer" or not self.found_any_tag:
output.append(self._pending)
self._pending = ""
else:
before = self._pending[:bracket_pos]
if before and (self.state == "answer" or not self.found_any_tag):
output.append(before)
close_pos = self._pending.find("]", bracket_pos)
if close_pos == -1:
# Incomplete tag — wait for next chunk
self._pending = self._pending[bracket_pos:]
break
tag_name = self._pending[bracket_pos + 1:close_pos]
self._pending = self._pending[close_pos + 1:]
self.found_any_tag = True
if tag_name == "ANSWER":
self.state = "answer"
else:
self.state = "skip"
elif self.state == "skip":
bracket_pos = self._pending.find("[")
if bracket_pos == -1:
self._pending = ""
else:
close_pos = self._pending.find("]", bracket_pos)
if close_pos == -1:
self._pending = self._pending[bracket_pos:]
break
tag_name = self._pending[bracket_pos + 1:close_pos]
self._pending = self._pending[close_pos + 1:]
if tag_name == "ANSWER":
self.state = "answer"
else:
self.state = "skip"
return "".join(output)
class VoiceSession: class VoiceSession:
"""Manages a single voice dialogue session lifecycle""" """Manages a single voice dialogue session lifecycle"""
@ -197,51 +266,84 @@ class VoiceSession:
logger.error(f"[Voice] Server error: {error_msg}") logger.error(f"[Voice] Server error: {error_msg}")
await self._emit_error(error_msg) await self._emit_error(error_msg)
# Sentence-ending punctuation pattern for splitting TTS
_SENTENCE_END_RE = re.compile(r'[。!?;\n.!?;]')
async def _on_asr_text_received(self, text: str) -> None: async def _on_asr_text_received(self, text: str) -> None:
"""Called when ASR text is received — send comfort TTS, call agent, inject RAG""" """Called when ASR text is received — stream agent output, send TTS sentence by sentence"""
if not text.strip(): if not text.strip():
self._is_sending_chat_tts_text = False self._is_sending_chat_tts_text = False
return return
try: try:
# 1. Send comfort TTS (real Chinese text, not "...")
# logger.info(f"[Voice] Sending comfort TTS...")
# await self.realtime_client.chat_tts_text(
# content="请稍等,让我查一下。",
# start=True,
# end=False,
# )
# await self.realtime_client.chat_tts_text(
# content="",
# start=False,
# end=True,
# )
# 2. Call v3 agent (this may take a while)
logger.info(f"[Voice] Calling v3 agent with text: '{text}'") logger.info(f"[Voice] Calling v3 agent with text: '{text}'")
agent_result = await self._call_v3_agent(text)
logger.info(f"[Voice] Agent result ({len(agent_result)} chars): {agent_result[:200]}")
if self._on_agent_result and agent_result: accumulated_text = [] # full agent output for on_agent_result callback
await self._on_agent_result(agent_result) sentence_buf = "" # buffer for accumulating until sentence boundary
tts_started = False # whether we've sent the first TTS chunk
tag_filter = _StreamTagFilter()
# 3. Send agent result directly as TTS (bypass LLM) async for chunk in self._stream_v3_agent(text):
if agent_result: accumulated_text.append(chunk)
clean_result = self._extract_answer(agent_result)
logger.info(f"[Voice] Sending agent result as TTS ({len(clean_result)} chars)") if self._on_agent_stream:
await self._on_agent_stream(chunk)
# Filter out [TOOL_CALL], [TOOL_RESPONSE], [THINK] etc., only keep [ANSWER] content
passthrough = tag_filter.feed(chunk)
if not passthrough:
continue
sentence_buf += passthrough
# Check for sentence boundaries and send complete sentences to TTS
while True:
match = self._SENTENCE_END_RE.search(sentence_buf)
if not match:
break
# Split at sentence boundary (include the punctuation)
end_pos = match.end()
sentence = sentence_buf[:end_pos].strip()
sentence_buf = sentence_buf[end_pos:]
if sentence:
logger.info(f"[Voice] Sending TTS sentence: '{sentence[:80]}'")
await self.realtime_client.chat_tts_text(
content=sentence,
start=not tts_started,
end=False,
)
tts_started = True
# Handle remaining text in buffer (last sentence without ending punctuation)
remaining = sentence_buf.strip()
if remaining:
logger.info(f"[Voice] Sending TTS remaining: '{remaining[:80]}'")
await self.realtime_client.chat_tts_text( await self.realtime_client.chat_tts_text(
content=clean_result, content=remaining,
start=True, start=not tts_started,
end=False, end=False,
) )
tts_started = True
# Send TTS end signal
if tts_started:
await self.realtime_client.chat_tts_text( await self.realtime_client.chat_tts_text(
content="", content="",
start=False, start=False,
end=True, end=True,
) )
else: else:
logger.warning(f"[Voice] Agent returned empty result") logger.warning(f"[Voice] Agent returned no usable text for TTS")
self._is_sending_chat_tts_text = False self._is_sending_chat_tts_text = False
# Emit full agent result
full_result = "".join(accumulated_text)
logger.info(f"[Voice] Agent result ({len(full_result)} chars): {full_result[:200]}")
if self._on_agent_result and full_result:
await self._on_agent_result(full_result)
await self._emit_status("idle") await self._emit_status("idle")
except asyncio.CancelledError: except asyncio.CancelledError:
@ -253,8 +355,8 @@ class VoiceSession:
self._is_sending_chat_tts_text = False self._is_sending_chat_tts_text = False
await self._emit_error(f"Agent call failed: {str(e)}") await self._emit_error(f"Agent call failed: {str(e)}")
async def _call_v3_agent(self, user_text: str) -> str: async def _stream_v3_agent(self, user_text: str) -> AsyncGenerator[str, None]:
"""Call v3 agent API in streaming mode, accumulate text and return full result""" """Call v3 agent API in streaming mode, yield text chunks as they arrive"""
try: try:
from utils.api_models import ChatRequestV3, Message from utils.api_models import ChatRequestV3, Message
from utils.fastapi_utils import ( from utils.fastapi_utils import (
@ -294,8 +396,6 @@ class VoiceSession:
) )
config.stream = True config.stream = True
# Consume the async generator, parse SSE chunks to accumulate content
accumulated_text = []
async for sse_line in enhanced_generate_stream_response(config): async for sse_line in enhanced_generate_stream_response(config):
if not sse_line or not sse_line.startswith("data: "): if not sse_line or not sse_line.startswith("data: "):
continue continue
@ -309,20 +409,15 @@ class VoiceSession:
delta = choices[0].get("delta", {}) delta = choices[0].get("delta", {})
content = delta.get("content", "") content = delta.get("content", "")
if content: if content:
accumulated_text.append(content) yield content
if self._on_agent_stream:
await self._on_agent_stream(content)
except (json.JSONDecodeError, KeyError): except (json.JSONDecodeError, KeyError):
continue continue
return "".join(accumulated_text)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info(f"[Voice] v3 agent call cancelled") logger.info(f"[Voice] v3 agent call cancelled")
raise raise
except Exception as e: except Exception as e:
logger.error(f"[Voice] Error calling v3 agent: {e}", exc_info=True) logger.error(f"[Voice] Error calling v3 agent: {e}", exc_info=True)
return ""
@staticmethod @staticmethod
def _extract_answer(agent_result: str) -> str: def _extract_answer(agent_result: str) -> str: