qwen_agent/services/voice_utils.py
2026-03-22 00:42:57 +08:00

317 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import re
import logging
from typing import Optional, AsyncGenerator
logger = logging.getLogger('app')
SENTENCE_END_RE = re.compile(r'[。!?;\n.!?;]')
# Emoji pattern: matches Unicode emoji without touching CJK characters
_EMOJI_RE = re.compile(
"["
"\U0001F600-\U0001F64F" # emoticons
"\U0001F300-\U0001F5FF" # symbols & pictographs
"\U0001F680-\U0001F6FF" # transport & map
"\U0001F1E0-\U0001F1FF" # flags
"\U0001F900-\U0001F9FF" # supplemental symbols
"\U0001FA00-\U0001FA6F" # chess symbols
"\U0001FA70-\U0001FAFF" # symbols extended-A
"\U00002702-\U000027B0" # dingbats
"\U00002600-\U000026FF" # misc symbols
"\U0000FE00-\U0000FE0F" # variation selectors
"\U0000200D" # zero width joiner
"\U000024C2" # Ⓜ enclosed letter
"\U00002B50\U00002B55" # star, circle
"\U000023CF\U000023E9-\U000023F3\U000023F8-\U000023FA" # media controls
"\U0001F170-\U0001F251" # enclosed alphanumeric supplement
"]+",
flags=re.UNICODE,
)
# Strong sentence-ending punctuation (excluding \n which is handled separately)
_STRONG_PUNCT_RE = re.compile(r'[。!?;.!?;~]')
# Soft punctuation (usable as split points when buffer is getting long)
_SOFT_PUNCT_RE = re.compile(r'[,:、)) \t]')
class TTSSentenceSplitter:
"""
Intelligent sentence splitter for TTS streaming.
Rules (in priority order):
1. Split on newlines unconditionally (LLM paragraph boundaries)
2. Split on strong punctuation (。!?~ etc.) only if accumulated >= MIN_LENGTH
3. If buffer reaches SOFT_THRESHOLD, also split on soft punctuation (、etc.)
4. If buffer reaches MAX_LENGTH, force split at best available position
- Strip emoji from output (TTS cannot pronounce them)
- On flush(), return any remaining text regardless of length
"""
MIN_LENGTH = 10 # Don't send sentences shorter than this
SOFT_THRESHOLD = 30 # Start considering soft punctuation splits
MAX_LENGTH = 80 # Force split even without punctuation
def __init__(self):
self._buf = ""
def _clean_for_tts(self, text: str) -> str:
"""Remove emoji and collapse whitespace."""
text = _EMOJI_RE.sub("", text)
text = re.sub(r'[ \t]+', ' ', text)
return text.strip()
def feed(self, chunk: str) -> list[str]:
"""Feed a text chunk, return list of ready sentences (may be empty)."""
self._buf += chunk
results = []
while self._buf:
buf_len = len(self._buf)
# 0. Newline split — highest priority
nl_pos = self._buf.find('\n')
if nl_pos >= 0:
before = self._buf[:nl_pos]
rest = self._buf[nl_pos:].lstrip('\n')
cleaned = self._clean_for_tts(before)
if len(cleaned) >= self.MIN_LENGTH:
# Long enough, emit as a sentence
self._buf = rest
results.append(cleaned)
continue
elif not rest:
# No more text after newline, keep buffer and wait
break
else:
# Too short — merge with next paragraph
self._buf = before + rest
continue
# 1. Try strong punctuation split — scan for the best split point
best_end = -1
for match in _STRONG_PUNCT_RE.finditer(self._buf):
end_pos = match.end()
candidate = self._buf[:end_pos]
if len(candidate.strip()) >= self.MIN_LENGTH:
best_end = end_pos
break # Take the first valid (long enough) split
# Short segment before this punct — skip and keep scanning
if best_end > 0:
sentence = self._clean_for_tts(self._buf[:best_end])
self._buf = self._buf[best_end:]
if sentence:
results.append(sentence)
continue
# 2. Buffer getting long: try soft punctuation split
if buf_len >= self.SOFT_THRESHOLD:
best_soft = -1
for m in _SOFT_PUNCT_RE.finditer(self._buf):
pos = m.end()
if pos >= self.MIN_LENGTH:
best_soft = pos
if pos >= self.SOFT_THRESHOLD:
break
if best_soft >= self.MIN_LENGTH:
sentence = self._clean_for_tts(self._buf[:best_soft])
self._buf = self._buf[best_soft:]
if sentence:
results.append(sentence)
continue
# 3. Buffer too long: force split at MAX_LENGTH
if buf_len >= self.MAX_LENGTH:
split_at = self.MAX_LENGTH
search_region = self._buf[self.MIN_LENGTH:self.MAX_LENGTH]
last_space = max(search_region.rfind(' '), search_region.rfind(''),
search_region.rfind(','), search_region.rfind(''))
if last_space >= 0:
split_at = self.MIN_LENGTH + last_space + 1
sentence = self._clean_for_tts(self._buf[:split_at])
self._buf = self._buf[split_at:]
if sentence:
results.append(sentence)
continue
# Not enough text yet, wait for more
break
return results
def flush(self) -> list[str]:
"""Flush remaining buffer. Call at end of stream."""
results = []
if self._buf.strip():
sentence = self._clean_for_tts(self._buf)
if sentence:
results.append(sentence)
self._buf = ""
return results
class StreamTagFilter:
"""
Filters streaming text based on tag blocks.
Only passes through content inside [ANSWER] blocks.
If no tags are found at all, passes through everything (fallback).
Skips content inside [TOOL_CALL], [TOOL_RESPONSE], [THINK], [SOURCE], etc.
"""
SKIP_TAGS = {"TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE"}
KNOWN_TAGS = {"ANSWER", "TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE", "PREAMBLE", "SUMMARY"}
def __init__(self):
self.state = "idle" # idle, answer, skip
self.found_any_tag = False
self._pending = ""
self.answer_ended = False
def feed(self, chunk: str) -> str:
"""Feed a chunk, return text that should be passed to TTS."""
self.answer_ended = False
self._pending += chunk
output = []
while self._pending:
if self.state in ("idle", "answer"):
bracket_pos = self._pending.find("[")
if bracket_pos == -1:
if self.state == "answer" or not self.found_any_tag:
output.append(self._pending)
self._pending = ""
else:
before = self._pending[:bracket_pos]
if before and (self.state == "answer" or not self.found_any_tag):
output.append(before)
close_pos = self._pending.find("]", bracket_pos)
if close_pos == -1:
self._pending = self._pending[bracket_pos:]
break
tag_name = self._pending[bracket_pos + 1:close_pos]
self._pending = self._pending[close_pos + 1:]
if tag_name not in self.KNOWN_TAGS:
if self.state == "answer" or not self.found_any_tag:
output.append(f"[{tag_name}]")
continue
self.found_any_tag = True
if tag_name == "ANSWER":
self.state = "answer"
else:
if self.state == "answer":
self.answer_ended = True
self.state = "skip"
elif self.state == "skip":
bracket_pos = self._pending.find("[")
if bracket_pos == -1:
self._pending = ""
else:
close_pos = self._pending.find("]", bracket_pos)
if close_pos == -1:
self._pending = self._pending[bracket_pos:]
break
tag_name = self._pending[bracket_pos + 1:close_pos]
self._pending = self._pending[close_pos + 1:]
if tag_name not in self.KNOWN_TAGS:
continue
if tag_name == "ANSWER":
self.state = "answer"
else:
self.state = "skip"
return "".join(output)
def clean_markdown(text: str) -> str:
"""Strip Markdown formatting characters for TTS readability."""
text = re.sub(r'!\[([^\]]*)\]\([^)]*\)', r'\1', text)
text = re.sub(r'\[([^\]]*)\]\([^)]*\)', r'\1', text)
text = re.sub(r'#{1,6}\s*', '', text)
text = re.sub(r'\*{1,3}|_{1,3}|~~|`{1,3}', '', text)
text = re.sub(r'^>\s*', '', text, flags=re.MULTILINE)
text = re.sub(r'^\s*[-*+]\s+', '', text, flags=re.MULTILINE)
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
text = re.sub(r'^[\s]*[-*_]{3,}[\s]*$', '', text, flags=re.MULTILINE)
text = re.sub(r'\n{2,}', '\n', text)
return text.strip()
async def stream_v3_agent(
user_text: str,
bot_id: str,
bot_config: dict,
session_id: str,
user_identifier: str,
) -> AsyncGenerator[str, None]:
"""Call v3 agent API in streaming mode, yield text chunks as they arrive."""
import asyncio
try:
from utils.api_models import ChatRequestV3, Message
from utils.fastapi_utils import (
process_messages,
create_project_directory,
)
from agent.agent_config import AgentConfig
from routes.chat import enhanced_generate_stream_response
language = bot_config.get("language", "zh")
messages_obj = [Message(role="user", content=user_text)]
request = ChatRequestV3(
messages=messages_obj,
bot_id=bot_id,
stream=True,
session_id=session_id,
user_identifier=user_identifier,
)
project_dir = create_project_directory(
bot_config.get("dataset_ids", []),
bot_id,
bot_config.get("skills", []),
)
processed_messages = process_messages(messages_obj, language)
config = await AgentConfig.from_v3_request(
request,
bot_config,
project_dir,
processed_messages,
language,
)
config.stream = True
async for sse_line in enhanced_generate_stream_response(config):
if not sse_line or not sse_line.startswith("data: "):
continue
data_str = sse_line.strip().removeprefix("data: ")
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except (json.JSONDecodeError, KeyError):
continue
except asyncio.CancelledError:
logger.info(f"[Voice] v3 agent call cancelled")
raise
except Exception as e:
logger.error(f"[Voice] Error calling v3 agent: {e}", exc_info=True)