This commit is contained in:
朱潮 2026-03-21 02:41:10 +08:00
parent 16c50fa261
commit 4b70da5bb0
2 changed files with 138 additions and 41 deletions

View File

@ -287,10 +287,12 @@ async def init_agent(config: AgentConfig):
middleware=middleware,
checkpointer=checkpointer,
shell_env={
k: v for k, v in {
"ASSISTANT_ID": config.bot_id,
"USER_IDENTIFIER": config.user_identifier,
"TRACE_ID": config.trace_id,
**(config.shell_env or {}),
}.items() if v is not None
}
)

View File

@ -1,14 +1,83 @@
import asyncio
import json
import logging
import re
import uuid
from typing import Optional, Callable, Awaitable
from typing import Optional, Callable, Awaitable, AsyncGenerator
from services.realtime_voice_client import RealtimeDialogClient
logger = logging.getLogger('app')
class _StreamTagFilter:
"""
Filters streaming text based on tag blocks.
Only passes through content inside [ANSWER] blocks.
If no tags are found at all, passes through everything (fallback).
Skips content inside [TOOL_CALL], [TOOL_RESPONSE], [THINK], [SOURCE], etc.
"""
SKIP_TAGS = {"TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE"}
def __init__(self):
self.state = "idle" # idle, answer, skip
self.found_any_tag = False
self._pending = "" # buffer for partial tag like "[TOO..."
def feed(self, chunk: str) -> str:
"""Feed a chunk, return text that should be passed to TTS."""
self._pending += chunk
output = []
while self._pending:
if self.state in ("idle", "answer"):
bracket_pos = self._pending.find("[")
if bracket_pos == -1:
if self.state == "answer" or not self.found_any_tag:
output.append(self._pending)
self._pending = ""
else:
before = self._pending[:bracket_pos]
if before and (self.state == "answer" or not self.found_any_tag):
output.append(before)
close_pos = self._pending.find("]", bracket_pos)
if close_pos == -1:
# Incomplete tag — wait for next chunk
self._pending = self._pending[bracket_pos:]
break
tag_name = self._pending[bracket_pos + 1:close_pos]
self._pending = self._pending[close_pos + 1:]
self.found_any_tag = True
if tag_name == "ANSWER":
self.state = "answer"
else:
self.state = "skip"
elif self.state == "skip":
bracket_pos = self._pending.find("[")
if bracket_pos == -1:
self._pending = ""
else:
close_pos = self._pending.find("]", bracket_pos)
if close_pos == -1:
self._pending = self._pending[bracket_pos:]
break
tag_name = self._pending[bracket_pos + 1:close_pos]
self._pending = self._pending[close_pos + 1:]
if tag_name == "ANSWER":
self.state = "answer"
else:
self.state = "skip"
return "".join(output)
class VoiceSession:
"""Manages a single voice dialogue session lifecycle"""
@ -197,51 +266,84 @@ class VoiceSession:
logger.error(f"[Voice] Server error: {error_msg}")
await self._emit_error(error_msg)
# Sentence-ending punctuation pattern for splitting TTS
_SENTENCE_END_RE = re.compile(r'[。!?;\n.!?;]')
async def _on_asr_text_received(self, text: str) -> None:
"""Called when ASR text is received — send comfort TTS, call agent, inject RAG"""
"""Called when ASR text is received — stream agent output, send TTS sentence by sentence"""
if not text.strip():
self._is_sending_chat_tts_text = False
return
try:
# 1. Send comfort TTS (real Chinese text, not "...")
# logger.info(f"[Voice] Sending comfort TTS...")
# await self.realtime_client.chat_tts_text(
# content="请稍等,让我查一下。",
# start=True,
# end=False,
# )
# await self.realtime_client.chat_tts_text(
# content="",
# start=False,
# end=True,
# )
# 2. Call v3 agent (this may take a while)
logger.info(f"[Voice] Calling v3 agent with text: '{text}'")
agent_result = await self._call_v3_agent(text)
logger.info(f"[Voice] Agent result ({len(agent_result)} chars): {agent_result[:200]}")
if self._on_agent_result and agent_result:
await self._on_agent_result(agent_result)
accumulated_text = [] # full agent output for on_agent_result callback
sentence_buf = "" # buffer for accumulating until sentence boundary
tts_started = False # whether we've sent the first TTS chunk
tag_filter = _StreamTagFilter()
# 3. Send agent result directly as TTS (bypass LLM)
if agent_result:
clean_result = self._extract_answer(agent_result)
logger.info(f"[Voice] Sending agent result as TTS ({len(clean_result)} chars)")
async for chunk in self._stream_v3_agent(text):
accumulated_text.append(chunk)
if self._on_agent_stream:
await self._on_agent_stream(chunk)
# Filter out [TOOL_CALL], [TOOL_RESPONSE], [THINK] etc., only keep [ANSWER] content
passthrough = tag_filter.feed(chunk)
if not passthrough:
continue
sentence_buf += passthrough
# Check for sentence boundaries and send complete sentences to TTS
while True:
match = self._SENTENCE_END_RE.search(sentence_buf)
if not match:
break
# Split at sentence boundary (include the punctuation)
end_pos = match.end()
sentence = sentence_buf[:end_pos].strip()
sentence_buf = sentence_buf[end_pos:]
if sentence:
logger.info(f"[Voice] Sending TTS sentence: '{sentence[:80]}'")
await self.realtime_client.chat_tts_text(
content=clean_result,
start=True,
content=sentence,
start=not tts_started,
end=False,
)
tts_started = True
# Handle remaining text in buffer (last sentence without ending punctuation)
remaining = sentence_buf.strip()
if remaining:
logger.info(f"[Voice] Sending TTS remaining: '{remaining[:80]}'")
await self.realtime_client.chat_tts_text(
content=remaining,
start=not tts_started,
end=False,
)
tts_started = True
# Send TTS end signal
if tts_started:
await self.realtime_client.chat_tts_text(
content="",
start=False,
end=True,
)
else:
logger.warning(f"[Voice] Agent returned empty result")
logger.warning(f"[Voice] Agent returned no usable text for TTS")
self._is_sending_chat_tts_text = False
# Emit full agent result
full_result = "".join(accumulated_text)
logger.info(f"[Voice] Agent result ({len(full_result)} chars): {full_result[:200]}")
if self._on_agent_result and full_result:
await self._on_agent_result(full_result)
await self._emit_status("idle")
except asyncio.CancelledError:
@ -253,8 +355,8 @@ class VoiceSession:
self._is_sending_chat_tts_text = False
await self._emit_error(f"Agent call failed: {str(e)}")
async def _call_v3_agent(self, user_text: str) -> str:
"""Call v3 agent API in streaming mode, accumulate text and return full result"""
async def _stream_v3_agent(self, user_text: str) -> AsyncGenerator[str, None]:
"""Call v3 agent API in streaming mode, yield text chunks as they arrive"""
try:
from utils.api_models import ChatRequestV3, Message
from utils.fastapi_utils import (
@ -294,8 +396,6 @@ class VoiceSession:
)
config.stream = True
# Consume the async generator, parse SSE chunks to accumulate content
accumulated_text = []
async for sse_line in enhanced_generate_stream_response(config):
if not sse_line or not sse_line.startswith("data: "):
continue
@ -309,20 +409,15 @@ class VoiceSession:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
accumulated_text.append(content)
if self._on_agent_stream:
await self._on_agent_stream(content)
yield content
except (json.JSONDecodeError, KeyError):
continue
return "".join(accumulated_text)
except asyncio.CancelledError:
logger.info(f"[Voice] v3 agent call cancelled")
raise
except Exception as e:
logger.error(f"[Voice] Error calling v3 agent: {e}", exc_info=True)
return ""
@staticmethod
def _extract_answer(agent_result: str) -> str: