语音
This commit is contained in:
parent
16c50fa261
commit
4b70da5bb0
@ -287,10 +287,12 @@ async def init_agent(config: AgentConfig):
|
||||
middleware=middleware,
|
||||
checkpointer=checkpointer,
|
||||
shell_env={
|
||||
"ASSISTANT_ID": config.bot_id,
|
||||
"USER_IDENTIFIER": config.user_identifier,
|
||||
"TRACE_ID": config.trace_id,
|
||||
**(config.shell_env or {}),
|
||||
k: v for k, v in {
|
||||
"ASSISTANT_ID": config.bot_id,
|
||||
"USER_IDENTIFIER": config.user_identifier,
|
||||
"TRACE_ID": config.trace_id,
|
||||
**(config.shell_env or {}),
|
||||
}.items() if v is not None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -1,14 +1,83 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Optional, Callable, Awaitable
|
||||
from typing import Optional, Callable, Awaitable, AsyncGenerator
|
||||
|
||||
from services.realtime_voice_client import RealtimeDialogClient
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
|
||||
class _StreamTagFilter:
|
||||
"""
|
||||
Filters streaming text based on tag blocks.
|
||||
Only passes through content inside [ANSWER] blocks.
|
||||
If no tags are found at all, passes through everything (fallback).
|
||||
Skips content inside [TOOL_CALL], [TOOL_RESPONSE], [THINK], [SOURCE], etc.
|
||||
"""
|
||||
|
||||
SKIP_TAGS = {"TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE"}
|
||||
|
||||
def __init__(self):
|
||||
self.state = "idle" # idle, answer, skip
|
||||
self.found_any_tag = False
|
||||
self._pending = "" # buffer for partial tag like "[TOO..."
|
||||
|
||||
def feed(self, chunk: str) -> str:
|
||||
"""Feed a chunk, return text that should be passed to TTS."""
|
||||
self._pending += chunk
|
||||
output = []
|
||||
|
||||
while self._pending:
|
||||
if self.state in ("idle", "answer"):
|
||||
bracket_pos = self._pending.find("[")
|
||||
if bracket_pos == -1:
|
||||
if self.state == "answer" or not self.found_any_tag:
|
||||
output.append(self._pending)
|
||||
self._pending = ""
|
||||
else:
|
||||
before = self._pending[:bracket_pos]
|
||||
if before and (self.state == "answer" or not self.found_any_tag):
|
||||
output.append(before)
|
||||
|
||||
close_pos = self._pending.find("]", bracket_pos)
|
||||
if close_pos == -1:
|
||||
# Incomplete tag — wait for next chunk
|
||||
self._pending = self._pending[bracket_pos:]
|
||||
break
|
||||
|
||||
tag_name = self._pending[bracket_pos + 1:close_pos]
|
||||
self._pending = self._pending[close_pos + 1:]
|
||||
self.found_any_tag = True
|
||||
|
||||
if tag_name == "ANSWER":
|
||||
self.state = "answer"
|
||||
else:
|
||||
self.state = "skip"
|
||||
|
||||
elif self.state == "skip":
|
||||
bracket_pos = self._pending.find("[")
|
||||
if bracket_pos == -1:
|
||||
self._pending = ""
|
||||
else:
|
||||
close_pos = self._pending.find("]", bracket_pos)
|
||||
if close_pos == -1:
|
||||
self._pending = self._pending[bracket_pos:]
|
||||
break
|
||||
|
||||
tag_name = self._pending[bracket_pos + 1:close_pos]
|
||||
self._pending = self._pending[close_pos + 1:]
|
||||
|
||||
if tag_name == "ANSWER":
|
||||
self.state = "answer"
|
||||
else:
|
||||
self.state = "skip"
|
||||
|
||||
return "".join(output)
|
||||
|
||||
|
||||
class VoiceSession:
|
||||
"""Manages a single voice dialogue session lifecycle"""
|
||||
|
||||
@ -197,51 +266,84 @@ class VoiceSession:
|
||||
logger.error(f"[Voice] Server error: {error_msg}")
|
||||
await self._emit_error(error_msg)
|
||||
|
||||
# Sentence-ending punctuation pattern for splitting TTS
|
||||
_SENTENCE_END_RE = re.compile(r'[。!?;\n.!?;]')
|
||||
|
||||
async def _on_asr_text_received(self, text: str) -> None:
|
||||
"""Called when ASR text is received — send comfort TTS, call agent, inject RAG"""
|
||||
"""Called when ASR text is received — stream agent output, send TTS sentence by sentence"""
|
||||
if not text.strip():
|
||||
self._is_sending_chat_tts_text = False
|
||||
return
|
||||
|
||||
try:
|
||||
# 1. Send comfort TTS (real Chinese text, not "...")
|
||||
# logger.info(f"[Voice] Sending comfort TTS...")
|
||||
# await self.realtime_client.chat_tts_text(
|
||||
# content="请稍等,让我查一下。",
|
||||
# start=True,
|
||||
# end=False,
|
||||
# )
|
||||
# await self.realtime_client.chat_tts_text(
|
||||
# content="",
|
||||
# start=False,
|
||||
# end=True,
|
||||
# )
|
||||
|
||||
# 2. Call v3 agent (this may take a while)
|
||||
logger.info(f"[Voice] Calling v3 agent with text: '{text}'")
|
||||
agent_result = await self._call_v3_agent(text)
|
||||
logger.info(f"[Voice] Agent result ({len(agent_result)} chars): {agent_result[:200]}")
|
||||
|
||||
if self._on_agent_result and agent_result:
|
||||
await self._on_agent_result(agent_result)
|
||||
accumulated_text = [] # full agent output for on_agent_result callback
|
||||
sentence_buf = "" # buffer for accumulating until sentence boundary
|
||||
tts_started = False # whether we've sent the first TTS chunk
|
||||
tag_filter = _StreamTagFilter()
|
||||
|
||||
# 3. Send agent result directly as TTS (bypass LLM)
|
||||
if agent_result:
|
||||
clean_result = self._extract_answer(agent_result)
|
||||
logger.info(f"[Voice] Sending agent result as TTS ({len(clean_result)} chars)")
|
||||
async for chunk in self._stream_v3_agent(text):
|
||||
accumulated_text.append(chunk)
|
||||
|
||||
if self._on_agent_stream:
|
||||
await self._on_agent_stream(chunk)
|
||||
|
||||
# Filter out [TOOL_CALL], [TOOL_RESPONSE], [THINK] etc., only keep [ANSWER] content
|
||||
passthrough = tag_filter.feed(chunk)
|
||||
|
||||
if not passthrough:
|
||||
continue
|
||||
|
||||
sentence_buf += passthrough
|
||||
|
||||
# Check for sentence boundaries and send complete sentences to TTS
|
||||
while True:
|
||||
match = self._SENTENCE_END_RE.search(sentence_buf)
|
||||
if not match:
|
||||
break
|
||||
# Split at sentence boundary (include the punctuation)
|
||||
end_pos = match.end()
|
||||
sentence = sentence_buf[:end_pos].strip()
|
||||
sentence_buf = sentence_buf[end_pos:]
|
||||
|
||||
if sentence:
|
||||
logger.info(f"[Voice] Sending TTS sentence: '{sentence[:80]}'")
|
||||
await self.realtime_client.chat_tts_text(
|
||||
content=sentence,
|
||||
start=not tts_started,
|
||||
end=False,
|
||||
)
|
||||
tts_started = True
|
||||
|
||||
# Handle remaining text in buffer (last sentence without ending punctuation)
|
||||
remaining = sentence_buf.strip()
|
||||
if remaining:
|
||||
logger.info(f"[Voice] Sending TTS remaining: '{remaining[:80]}'")
|
||||
await self.realtime_client.chat_tts_text(
|
||||
content=clean_result,
|
||||
start=True,
|
||||
content=remaining,
|
||||
start=not tts_started,
|
||||
end=False,
|
||||
)
|
||||
tts_started = True
|
||||
|
||||
# Send TTS end signal
|
||||
if tts_started:
|
||||
await self.realtime_client.chat_tts_text(
|
||||
content="",
|
||||
start=False,
|
||||
end=True,
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Voice] Agent returned empty result")
|
||||
logger.warning(f"[Voice] Agent returned no usable text for TTS")
|
||||
self._is_sending_chat_tts_text = False
|
||||
|
||||
# Emit full agent result
|
||||
full_result = "".join(accumulated_text)
|
||||
logger.info(f"[Voice] Agent result ({len(full_result)} chars): {full_result[:200]}")
|
||||
if self._on_agent_result and full_result:
|
||||
await self._on_agent_result(full_result)
|
||||
|
||||
await self._emit_status("idle")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
@ -253,8 +355,8 @@ class VoiceSession:
|
||||
self._is_sending_chat_tts_text = False
|
||||
await self._emit_error(f"Agent call failed: {str(e)}")
|
||||
|
||||
async def _call_v3_agent(self, user_text: str) -> str:
|
||||
"""Call v3 agent API in streaming mode, accumulate text and return full result"""
|
||||
async def _stream_v3_agent(self, user_text: str) -> AsyncGenerator[str, None]:
|
||||
"""Call v3 agent API in streaming mode, yield text chunks as they arrive"""
|
||||
try:
|
||||
from utils.api_models import ChatRequestV3, Message
|
||||
from utils.fastapi_utils import (
|
||||
@ -294,8 +396,6 @@ class VoiceSession:
|
||||
)
|
||||
config.stream = True
|
||||
|
||||
# Consume the async generator, parse SSE chunks to accumulate content
|
||||
accumulated_text = []
|
||||
async for sse_line in enhanced_generate_stream_response(config):
|
||||
if not sse_line or not sse_line.startswith("data: "):
|
||||
continue
|
||||
@ -309,20 +409,15 @@ class VoiceSession:
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
accumulated_text.append(content)
|
||||
if self._on_agent_stream:
|
||||
await self._on_agent_stream(content)
|
||||
yield content
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
return "".join(accumulated_text)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[Voice] v3 agent call cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Voice] Error calling v3 agent: {e}", exc_info=True)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _extract_answer(agent_result: str) -> str:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user