qwen_agent/services/streaming_tts_client.py
2026-03-21 23:50:51 +08:00

159 lines
5.5 KiB
Python

import base64
import json
import logging
import re
import uuid
import httpx
import numpy as np
from utils.settings import (
VOLCENGINE_APP_ID,
VOLCENGINE_ACCESS_KEY,
VOLCENGINE_DEFAULT_SPEAKER,
)
VOLCENGINE_TTS_URL= "https://openspeech.bytedance.com/api/v3/tts/unidirectional/sse"
logger = logging.getLogger('app')
# Regex to detect text that is only emoji/whitespace (no speakable content)
_EMOJI_ONLY_RE = re.compile(
r'^[\s\U00002600-\U000027BF\U0001F300-\U0001FAFF\U0000FE00-\U0000FE0F\U0000200D]*$'
)
def convert_pcm_s16_to_f32(pcm_data: bytes) -> bytes:
"""Convert PCM int16 audio to float32 PCM for frontend playback."""
samples = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0
if len(samples) == 0:
return b''
return samples.astype(np.float32).tobytes()
class StreamingTTSClient:
"""Volcengine v3/tts/unidirectional/sse streaming TTS client."""
def __init__(self, speaker: str = ""):
self._speaker = speaker or VOLCENGINE_DEFAULT_SPEAKER
async def synthesize(self, text: str):
"""
Synthesize text to audio via SSE streaming.
Yields 24kHz float32 PCM audio chunks.
"""
if not text.strip():
return
# Skip pure emoji text
if _EMOJI_ONLY_RE.match(text):
logger.info(f"[TTS] Skipping emoji-only text: '{text}'")
return
headers = {
"X-Api-App-Id": VOLCENGINE_APP_ID,
"X-Api-Access-Key": VOLCENGINE_ACCESS_KEY,
"X-Api-Resource-Id": "seed-tts-2.0",
"Content-Type": "application/json",
"Connection": "keep-alive",
}
body = {
"user": {
"uid": str(uuid.uuid4()),
},
"req_params": {
"text": text,
"speaker": self._speaker,
"audio_params": {
"format": "pcm",
"sample_rate": 24000,
},
},
}
try:
logger.info(f"[TTS] Requesting: speaker={self._speaker}, text='{text[:50]}'")
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, read=60.0)) as client:
async with client.stream("POST", VOLCENGINE_TTS_URL, headers=headers, json=body) as response:
logger.info(f"[TTS] Response status: {response.status_code}")
if response.status_code != 200:
error_body = await response.aread()
logger.error(f"[TTS] HTTP {response.status_code}: {error_body.decode('utf-8', errors='replace')}")
return
chunk_count = 0
# Parse SSE format: lines prefixed with "event:", "data:", separated by blank lines
current_event = ""
current_data = ""
raw_logged = False
async for raw_line in response.aiter_lines():
if not raw_logged:
logger.info(f"[TTS] First SSE line: {raw_line[:200]}")
raw_logged = True
line = raw_line.strip()
if line == "":
# Blank line = end of one SSE event
if current_data:
async for audio in self._process_sse_data(current_data):
chunk_count += 1
yield audio
current_event = ""
current_data = ""
continue
if line.startswith(":"):
# SSE comment, skip
continue
if ":" in line:
field, value = line.split(":", 1)
value = value.lstrip()
if field == "event":
current_event = value
elif field == "data":
current_data += value + "\n"
# Handle remaining data without trailing blank line
if current_data:
async for audio in self._process_sse_data(current_data):
chunk_count += 1
yield audio
logger.info(f"[TTS] Stream done, yielded {chunk_count} audio chunks")
except Exception as e:
logger.error(f"[TTS] Error: {e}", exc_info=True)
async def _process_sse_data(self, data_str: str):
"""Parse SSE data field and yield audio chunks if present."""
data_str = data_str.rstrip("\n")
if not data_str:
return
try:
data = json.loads(data_str)
except json.JSONDecodeError:
logger.warning(f"[TTS] Non-JSON SSE data: {data_str[:100]}")
return
code = data.get("code", 0)
if code == 0 and data.get("data"):
# Audio data chunk
pcm_raw = base64.b64decode(data["data"])
pcm_f32 = convert_pcm_s16_to_f32(pcm_raw)
if pcm_f32:
yield pcm_f32
elif code == 20000000:
# End of stream
logger.info(f"[TTS] End signal received")
return
elif code > 0:
error_msg = data.get("message", "Unknown TTS error")
logger.error(f"[TTS] Error code={code}: {error_msg}")
return