173 lines
6.0 KiB
Python
173 lines
6.0 KiB
Python
import json
|
|
import re
|
|
import logging
|
|
from typing import Optional, AsyncGenerator
|
|
|
|
logger = logging.getLogger('app')
|
|
|
|
SENTENCE_END_RE = re.compile(r'[。!?;\n.!?;]')
|
|
|
|
|
|
class StreamTagFilter:
|
|
"""
|
|
Filters streaming text based on tag blocks.
|
|
Only passes through content inside [ANSWER] blocks.
|
|
If no tags are found at all, passes through everything (fallback).
|
|
Skips content inside [TOOL_CALL], [TOOL_RESPONSE], [THINK], [SOURCE], etc.
|
|
"""
|
|
|
|
SKIP_TAGS = {"TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE"}
|
|
KNOWN_TAGS = {"ANSWER", "TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE", "PREAMBLE", "SUMMARY"}
|
|
|
|
def __init__(self):
|
|
self.state = "idle" # idle, answer, skip
|
|
self.found_any_tag = False
|
|
self._pending = ""
|
|
self.answer_ended = False
|
|
|
|
def feed(self, chunk: str) -> str:
|
|
"""Feed a chunk, return text that should be passed to TTS."""
|
|
self.answer_ended = False
|
|
self._pending += chunk
|
|
output = []
|
|
|
|
while self._pending:
|
|
if self.state in ("idle", "answer"):
|
|
bracket_pos = self._pending.find("[")
|
|
if bracket_pos == -1:
|
|
if self.state == "answer" or not self.found_any_tag:
|
|
output.append(self._pending)
|
|
self._pending = ""
|
|
else:
|
|
before = self._pending[:bracket_pos]
|
|
if before and (self.state == "answer" or not self.found_any_tag):
|
|
output.append(before)
|
|
|
|
close_pos = self._pending.find("]", bracket_pos)
|
|
if close_pos == -1:
|
|
self._pending = self._pending[bracket_pos:]
|
|
break
|
|
|
|
tag_name = self._pending[bracket_pos + 1:close_pos]
|
|
self._pending = self._pending[close_pos + 1:]
|
|
|
|
if tag_name not in self.KNOWN_TAGS:
|
|
if self.state == "answer" or not self.found_any_tag:
|
|
output.append(f"[{tag_name}]")
|
|
continue
|
|
|
|
self.found_any_tag = True
|
|
if tag_name == "ANSWER":
|
|
self.state = "answer"
|
|
else:
|
|
if self.state == "answer":
|
|
self.answer_ended = True
|
|
self.state = "skip"
|
|
|
|
elif self.state == "skip":
|
|
bracket_pos = self._pending.find("[")
|
|
if bracket_pos == -1:
|
|
self._pending = ""
|
|
else:
|
|
close_pos = self._pending.find("]", bracket_pos)
|
|
if close_pos == -1:
|
|
self._pending = self._pending[bracket_pos:]
|
|
break
|
|
|
|
tag_name = self._pending[bracket_pos + 1:close_pos]
|
|
self._pending = self._pending[close_pos + 1:]
|
|
|
|
if tag_name not in self.KNOWN_TAGS:
|
|
continue
|
|
|
|
if tag_name == "ANSWER":
|
|
self.state = "answer"
|
|
else:
|
|
self.state = "skip"
|
|
|
|
return "".join(output)
|
|
|
|
|
|
def clean_markdown(text: str) -> str:
|
|
"""Strip Markdown formatting characters for TTS readability."""
|
|
text = re.sub(r'!\[([^\]]*)\]\([^)]*\)', r'\1', text)
|
|
text = re.sub(r'\[([^\]]*)\]\([^)]*\)', r'\1', text)
|
|
text = re.sub(r'#{1,6}\s*', '', text)
|
|
text = re.sub(r'\*{1,3}|_{1,3}|~~|`{1,3}', '', text)
|
|
text = re.sub(r'^>\s*', '', text, flags=re.MULTILINE)
|
|
text = re.sub(r'^\s*[-*+]\s+', '', text, flags=re.MULTILINE)
|
|
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
|
|
text = re.sub(r'^[\s]*[-*_]{3,}[\s]*$', '', text, flags=re.MULTILINE)
|
|
text = re.sub(r'\n{2,}', '\n', text)
|
|
return text.strip()
|
|
|
|
|
|
async def stream_v3_agent(
|
|
user_text: str,
|
|
bot_id: str,
|
|
bot_config: dict,
|
|
session_id: str,
|
|
user_identifier: str,
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Call v3 agent API in streaming mode, yield text chunks as they arrive."""
|
|
import asyncio
|
|
try:
|
|
from utils.api_models import ChatRequestV3, Message
|
|
from utils.fastapi_utils import (
|
|
process_messages,
|
|
create_project_directory,
|
|
)
|
|
from agent.agent_config import AgentConfig
|
|
from routes.chat import enhanced_generate_stream_response
|
|
|
|
language = bot_config.get("language", "zh")
|
|
messages_obj = [Message(role="user", content=user_text)]
|
|
|
|
request = ChatRequestV3(
|
|
messages=messages_obj,
|
|
bot_id=bot_id,
|
|
stream=True,
|
|
session_id=session_id,
|
|
user_identifier=user_identifier,
|
|
)
|
|
|
|
project_dir = create_project_directory(
|
|
bot_config.get("dataset_ids", []),
|
|
bot_id,
|
|
bot_config.get("skills", []),
|
|
)
|
|
|
|
processed_messages = process_messages(messages_obj, language)
|
|
|
|
config = await AgentConfig.from_v3_request(
|
|
request,
|
|
bot_config,
|
|
project_dir,
|
|
processed_messages,
|
|
language,
|
|
)
|
|
config.stream = True
|
|
|
|
async for sse_line in enhanced_generate_stream_response(config):
|
|
if not sse_line or not sse_line.startswith("data: "):
|
|
continue
|
|
data_str = sse_line.strip().removeprefix("data: ")
|
|
if data_str == "[DONE]":
|
|
break
|
|
try:
|
|
data = json.loads(data_str)
|
|
choices = data.get("choices", [])
|
|
if choices:
|
|
delta = choices[0].get("delta", {})
|
|
content = delta.get("content", "")
|
|
if content:
|
|
yield content
|
|
except (json.JSONDecodeError, KeyError):
|
|
continue
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info(f"[Voice] v3 agent call cancelled")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"[Voice] Error calling v3 agent: {e}", exc_info=True)
|