317 lines
12 KiB
Python
317 lines
12 KiB
Python
import json
|
||
import re
|
||
import logging
|
||
from typing import Optional, AsyncGenerator
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
SENTENCE_END_RE = re.compile(r'[。!?;\n.!?;]')
|
||
|
||
# Emoji pattern: matches Unicode emoji without touching CJK characters
|
||
_EMOJI_RE = re.compile(
|
||
"["
|
||
"\U0001F600-\U0001F64F" # emoticons
|
||
"\U0001F300-\U0001F5FF" # symbols & pictographs
|
||
"\U0001F680-\U0001F6FF" # transport & map
|
||
"\U0001F1E0-\U0001F1FF" # flags
|
||
"\U0001F900-\U0001F9FF" # supplemental symbols
|
||
"\U0001FA00-\U0001FA6F" # chess symbols
|
||
"\U0001FA70-\U0001FAFF" # symbols extended-A
|
||
"\U00002702-\U000027B0" # dingbats
|
||
"\U00002600-\U000026FF" # misc symbols
|
||
"\U0000FE00-\U0000FE0F" # variation selectors
|
||
"\U0000200D" # zero width joiner
|
||
"\U000024C2" # Ⓜ enclosed letter
|
||
"\U00002B50\U00002B55" # star, circle
|
||
"\U000023CF\U000023E9-\U000023F3\U000023F8-\U000023FA" # media controls
|
||
"\U0001F170-\U0001F251" # enclosed alphanumeric supplement
|
||
"]+",
|
||
flags=re.UNICODE,
|
||
)
|
||
|
||
# Strong sentence-ending punctuation (excluding \n which is handled separately)
|
||
_STRONG_PUNCT_RE = re.compile(r'[。!?;.!?;~~]')
|
||
# Soft punctuation (usable as split points when buffer is getting long)
|
||
_SOFT_PUNCT_RE = re.compile(r'[,,::、)) \t]')
|
||
|
||
|
||
class TTSSentenceSplitter:
|
||
"""
|
||
Intelligent sentence splitter for TTS streaming.
|
||
|
||
Rules (in priority order):
|
||
1. Split on newlines unconditionally (LLM paragraph boundaries)
|
||
2. Split on strong punctuation (。!?~ etc.) only if accumulated >= MIN_LENGTH
|
||
3. If buffer reaches SOFT_THRESHOLD, also split on soft punctuation (,、etc.)
|
||
4. If buffer reaches MAX_LENGTH, force split at best available position
|
||
- Strip emoji from output (TTS cannot pronounce them)
|
||
- On flush(), return any remaining text regardless of length
|
||
"""
|
||
|
||
MIN_LENGTH = 10 # Don't send sentences shorter than this
|
||
SOFT_THRESHOLD = 30 # Start considering soft punctuation splits
|
||
MAX_LENGTH = 80 # Force split even without punctuation
|
||
|
||
def __init__(self):
|
||
self._buf = ""
|
||
|
||
def _clean_for_tts(self, text: str) -> str:
|
||
"""Remove emoji and collapse whitespace."""
|
||
text = _EMOJI_RE.sub("", text)
|
||
text = re.sub(r'[ \t]+', ' ', text)
|
||
return text.strip()
|
||
|
||
def feed(self, chunk: str) -> list[str]:
|
||
"""Feed a text chunk, return list of ready sentences (may be empty)."""
|
||
self._buf += chunk
|
||
results = []
|
||
|
||
while self._buf:
|
||
buf_len = len(self._buf)
|
||
|
||
# 0. Newline split — highest priority
|
||
nl_pos = self._buf.find('\n')
|
||
if nl_pos >= 0:
|
||
before = self._buf[:nl_pos]
|
||
rest = self._buf[nl_pos:].lstrip('\n')
|
||
cleaned = self._clean_for_tts(before)
|
||
if len(cleaned) >= self.MIN_LENGTH:
|
||
# Long enough, emit as a sentence
|
||
self._buf = rest
|
||
results.append(cleaned)
|
||
continue
|
||
elif not rest:
|
||
# No more text after newline, keep buffer and wait
|
||
break
|
||
else:
|
||
# Too short — merge with next paragraph
|
||
self._buf = before + rest
|
||
continue
|
||
|
||
# 1. Try strong punctuation split — scan for the best split point
|
||
best_end = -1
|
||
for match in _STRONG_PUNCT_RE.finditer(self._buf):
|
||
end_pos = match.end()
|
||
candidate = self._buf[:end_pos]
|
||
if len(candidate.strip()) >= self.MIN_LENGTH:
|
||
best_end = end_pos
|
||
break # Take the first valid (long enough) split
|
||
# Short segment before this punct — skip and keep scanning
|
||
|
||
if best_end > 0:
|
||
sentence = self._clean_for_tts(self._buf[:best_end])
|
||
self._buf = self._buf[best_end:]
|
||
if sentence:
|
||
results.append(sentence)
|
||
continue
|
||
|
||
# 2. Buffer getting long: try soft punctuation split
|
||
if buf_len >= self.SOFT_THRESHOLD:
|
||
best_soft = -1
|
||
for m in _SOFT_PUNCT_RE.finditer(self._buf):
|
||
pos = m.end()
|
||
if pos >= self.MIN_LENGTH:
|
||
best_soft = pos
|
||
if pos >= self.SOFT_THRESHOLD:
|
||
break
|
||
if best_soft >= self.MIN_LENGTH:
|
||
sentence = self._clean_for_tts(self._buf[:best_soft])
|
||
self._buf = self._buf[best_soft:]
|
||
if sentence:
|
||
results.append(sentence)
|
||
continue
|
||
|
||
# 3. Buffer too long: force split at MAX_LENGTH
|
||
if buf_len >= self.MAX_LENGTH:
|
||
split_at = self.MAX_LENGTH
|
||
search_region = self._buf[self.MIN_LENGTH:self.MAX_LENGTH]
|
||
last_space = max(search_region.rfind(' '), search_region.rfind(','),
|
||
search_region.rfind(','), search_region.rfind('、'))
|
||
if last_space >= 0:
|
||
split_at = self.MIN_LENGTH + last_space + 1
|
||
|
||
sentence = self._clean_for_tts(self._buf[:split_at])
|
||
self._buf = self._buf[split_at:]
|
||
if sentence:
|
||
results.append(sentence)
|
||
continue
|
||
|
||
# Not enough text yet, wait for more
|
||
break
|
||
|
||
return results
|
||
|
||
def flush(self) -> list[str]:
|
||
"""Flush remaining buffer. Call at end of stream."""
|
||
results = []
|
||
if self._buf.strip():
|
||
sentence = self._clean_for_tts(self._buf)
|
||
if sentence:
|
||
results.append(sentence)
|
||
self._buf = ""
|
||
return results
|
||
|
||
|
||
class StreamTagFilter:
|
||
"""
|
||
Filters streaming text based on tag blocks.
|
||
Only passes through content inside [ANSWER] blocks.
|
||
If no tags are found at all, passes through everything (fallback).
|
||
Skips content inside [TOOL_CALL], [TOOL_RESPONSE], [THINK], [SOURCE], etc.
|
||
"""
|
||
|
||
SKIP_TAGS = {"TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE"}
|
||
KNOWN_TAGS = {"ANSWER", "TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE", "PREAMBLE", "SUMMARY"}
|
||
|
||
def __init__(self):
|
||
self.state = "idle" # idle, answer, skip
|
||
self.found_any_tag = False
|
||
self._pending = ""
|
||
self.answer_ended = False
|
||
|
||
def feed(self, chunk: str) -> str:
|
||
"""Feed a chunk, return text that should be passed to TTS."""
|
||
self.answer_ended = False
|
||
self._pending += chunk
|
||
output = []
|
||
|
||
while self._pending:
|
||
if self.state in ("idle", "answer"):
|
||
bracket_pos = self._pending.find("[")
|
||
if bracket_pos == -1:
|
||
if self.state == "answer" or not self.found_any_tag:
|
||
output.append(self._pending)
|
||
self._pending = ""
|
||
else:
|
||
before = self._pending[:bracket_pos]
|
||
if before and (self.state == "answer" or not self.found_any_tag):
|
||
output.append(before)
|
||
|
||
close_pos = self._pending.find("]", bracket_pos)
|
||
if close_pos == -1:
|
||
self._pending = self._pending[bracket_pos:]
|
||
break
|
||
|
||
tag_name = self._pending[bracket_pos + 1:close_pos]
|
||
self._pending = self._pending[close_pos + 1:]
|
||
|
||
if tag_name not in self.KNOWN_TAGS:
|
||
if self.state == "answer" or not self.found_any_tag:
|
||
output.append(f"[{tag_name}]")
|
||
continue
|
||
|
||
self.found_any_tag = True
|
||
if tag_name == "ANSWER":
|
||
self.state = "answer"
|
||
else:
|
||
if self.state == "answer":
|
||
self.answer_ended = True
|
||
self.state = "skip"
|
||
|
||
elif self.state == "skip":
|
||
bracket_pos = self._pending.find("[")
|
||
if bracket_pos == -1:
|
||
self._pending = ""
|
||
else:
|
||
close_pos = self._pending.find("]", bracket_pos)
|
||
if close_pos == -1:
|
||
self._pending = self._pending[bracket_pos:]
|
||
break
|
||
|
||
tag_name = self._pending[bracket_pos + 1:close_pos]
|
||
self._pending = self._pending[close_pos + 1:]
|
||
|
||
if tag_name not in self.KNOWN_TAGS:
|
||
continue
|
||
|
||
if tag_name == "ANSWER":
|
||
self.state = "answer"
|
||
else:
|
||
self.state = "skip"
|
||
|
||
return "".join(output)
|
||
|
||
|
||
def clean_markdown(text: str) -> str:
|
||
"""Strip Markdown formatting characters for TTS readability."""
|
||
text = re.sub(r'!\[([^\]]*)\]\([^)]*\)', r'\1', text)
|
||
text = re.sub(r'\[([^\]]*)\]\([^)]*\)', r'\1', text)
|
||
text = re.sub(r'#{1,6}\s*', '', text)
|
||
text = re.sub(r'\*{1,3}|_{1,3}|~~|`{1,3}', '', text)
|
||
text = re.sub(r'^>\s*', '', text, flags=re.MULTILINE)
|
||
text = re.sub(r'^\s*[-*+]\s+', '', text, flags=re.MULTILINE)
|
||
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
|
||
text = re.sub(r'^[\s]*[-*_]{3,}[\s]*$', '', text, flags=re.MULTILINE)
|
||
text = re.sub(r'\n{2,}', '\n', text)
|
||
return text.strip()
|
||
|
||
|
||
async def stream_v3_agent(
|
||
user_text: str,
|
||
bot_id: str,
|
||
bot_config: dict,
|
||
session_id: str,
|
||
user_identifier: str,
|
||
) -> AsyncGenerator[str, None]:
|
||
"""Call v3 agent API in streaming mode, yield text chunks as they arrive."""
|
||
import asyncio
|
||
try:
|
||
from utils.api_models import ChatRequestV3, Message
|
||
from utils.fastapi_utils import (
|
||
process_messages,
|
||
create_project_directory,
|
||
)
|
||
from agent.agent_config import AgentConfig
|
||
from routes.chat import enhanced_generate_stream_response
|
||
|
||
language = bot_config.get("language", "zh")
|
||
messages_obj = [Message(role="user", content=user_text)]
|
||
|
||
request = ChatRequestV3(
|
||
messages=messages_obj,
|
||
bot_id=bot_id,
|
||
stream=True,
|
||
session_id=session_id,
|
||
user_identifier=user_identifier,
|
||
)
|
||
|
||
project_dir = create_project_directory(
|
||
bot_config.get("dataset_ids", []),
|
||
bot_id,
|
||
bot_config.get("skills", []),
|
||
)
|
||
|
||
processed_messages = process_messages(messages_obj, language)
|
||
|
||
config = await AgentConfig.from_v3_request(
|
||
request,
|
||
bot_config,
|
||
project_dir,
|
||
processed_messages,
|
||
language,
|
||
)
|
||
config.stream = True
|
||
|
||
async for sse_line in enhanced_generate_stream_response(config):
|
||
if not sse_line or not sse_line.startswith("data: "):
|
||
continue
|
||
data_str = sse_line.strip().removeprefix("data: ")
|
||
if data_str == "[DONE]":
|
||
break
|
||
try:
|
||
data = json.loads(data_str)
|
||
choices = data.get("choices", [])
|
||
if choices:
|
||
delta = choices[0].get("delta", {})
|
||
content = delta.get("content", "")
|
||
if content:
|
||
yield content
|
||
except (json.JSONDecodeError, KeyError):
|
||
continue
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info(f"[Voice] v3 agent call cancelled")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"[Voice] Error calling v3 agent: {e}", exc_info=True)
|