179 lines
6.3 KiB
Python
179 lines
6.3 KiB
Python
import base64
|
|
import json
|
|
import logging
|
|
import re
|
|
import uuid
|
|
|
|
import httpx
|
|
import numpy as np
|
|
|
|
from utils.settings import (
|
|
VOLCENGINE_APP_ID,
|
|
VOLCENGINE_ACCESS_KEY,
|
|
VOLCENGINE_DEFAULT_SPEAKER,
|
|
)
|
|
|
|
VOLCENGINE_TTS_URL= "https://openspeech.bytedance.com/api/v3/tts/unidirectional/sse"
|
|
logger = logging.getLogger('app')
|
|
|
|
# Regex to detect text that is only emoji/whitespace (no speakable content)
|
|
_EMOJI_ONLY_RE = re.compile(
|
|
r'^[\s\U00002600-\U000027BF\U0001F300-\U0001FAFF\U0000FE00-\U0000FE0F\U0000200D]*$'
|
|
)
|
|
|
|
|
|
def convert_pcm_s16_to_f32(pcm_data: bytes) -> bytes:
|
|
"""Convert PCM int16 audio to float32 PCM for frontend playback."""
|
|
samples = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0
|
|
if len(samples) == 0:
|
|
return b''
|
|
return samples.astype(np.float32).tobytes()
|
|
|
|
|
|
class StreamingTTSClient:
|
|
"""Volcengine v3/tts/unidirectional/sse streaming TTS client."""
|
|
|
|
def __init__(self, speaker: str = ""):
|
|
self._speaker = speaker or VOLCENGINE_DEFAULT_SPEAKER
|
|
|
|
async def synthesize(self, text: str):
|
|
"""
|
|
Synthesize text to audio via SSE streaming.
|
|
Yields 24kHz float32 PCM audio chunks.
|
|
"""
|
|
async for chunk in self._synthesize_internal(text, raw_int16=False):
|
|
yield chunk
|
|
|
|
async def synthesize_raw(self, text: str):
|
|
"""
|
|
Synthesize text to audio via SSE streaming.
|
|
Yields 24kHz int16 PCM audio chunks (no float32 conversion).
|
|
"""
|
|
async for chunk in self._synthesize_internal(text, raw_int16=True):
|
|
yield chunk
|
|
|
|
async def _synthesize_internal(self, text: str, raw_int16: bool = False):
|
|
"""
|
|
Synthesize text to audio via SSE streaming.
|
|
Yields 24kHz float32 PCM audio chunks.
|
|
"""
|
|
if not text.strip():
|
|
return
|
|
|
|
# Skip pure emoji text
|
|
if _EMOJI_ONLY_RE.match(text):
|
|
logger.info(f"[TTS] Skipping emoji-only text: '{text}'")
|
|
return
|
|
|
|
headers = {
|
|
"X-Api-App-Id": VOLCENGINE_APP_ID,
|
|
"X-Api-Access-Key": VOLCENGINE_ACCESS_KEY,
|
|
"X-Api-Resource-Id": "seed-tts-2.0",
|
|
"Content-Type": "application/json",
|
|
"Connection": "keep-alive",
|
|
}
|
|
|
|
body = {
|
|
"user": {
|
|
"uid": str(uuid.uuid4()),
|
|
},
|
|
"req_params": {
|
|
"text": text,
|
|
"speaker": self._speaker,
|
|
"audio_params": {
|
|
"format": "pcm",
|
|
"sample_rate": 24000,
|
|
},
|
|
},
|
|
}
|
|
|
|
try:
|
|
logger.info(f"[TTS] Requesting: speaker={self._speaker}, text='{text[:50]}'")
|
|
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, read=60.0)) as client:
|
|
async with client.stream("POST", VOLCENGINE_TTS_URL, headers=headers, json=body) as response:
|
|
logger.info(f"[TTS] Response status: {response.status_code}")
|
|
if response.status_code != 200:
|
|
error_body = await response.aread()
|
|
logger.error(f"[TTS] HTTP {response.status_code}: {error_body.decode('utf-8', errors='replace')}")
|
|
return
|
|
|
|
chunk_count = 0
|
|
# Parse SSE format: lines prefixed with "event:", "data:", separated by blank lines
|
|
current_event = ""
|
|
current_data = ""
|
|
raw_logged = False
|
|
|
|
async for raw_line in response.aiter_lines():
|
|
if not raw_logged:
|
|
logger.info(f"[TTS] First SSE line: {raw_line[:200]}")
|
|
raw_logged = True
|
|
|
|
line = raw_line.strip()
|
|
|
|
if line == "":
|
|
# Blank line = end of one SSE event
|
|
if current_data:
|
|
async for audio in self._process_sse_data(current_data, raw_int16=raw_int16):
|
|
chunk_count += 1
|
|
yield audio
|
|
current_event = ""
|
|
current_data = ""
|
|
continue
|
|
|
|
if line.startswith(":"):
|
|
# SSE comment, skip
|
|
continue
|
|
|
|
if ":" in line:
|
|
field, value = line.split(":", 1)
|
|
value = value.lstrip()
|
|
if field == "event":
|
|
current_event = value
|
|
elif field == "data":
|
|
current_data += value + "\n"
|
|
|
|
# Handle remaining data without trailing blank line
|
|
if current_data:
|
|
async for audio in self._process_sse_data(current_data, raw_int16=raw_int16):
|
|
chunk_count += 1
|
|
yield audio
|
|
|
|
logger.info(f"[TTS] Stream done, yielded {chunk_count} audio chunks")
|
|
|
|
except Exception as e:
|
|
logger.error(f"[TTS] Error: {e}", exc_info=True)
|
|
|
|
async def _process_sse_data(self, data_str: str, raw_int16: bool = False):
|
|
"""Parse SSE data field and yield audio chunks if present."""
|
|
data_str = data_str.rstrip("\n")
|
|
if not data_str:
|
|
return
|
|
try:
|
|
data = json.loads(data_str)
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"[TTS] Non-JSON SSE data: {data_str[:100]}")
|
|
return
|
|
|
|
code = data.get("code", 0)
|
|
|
|
if code == 0 and data.get("data"):
|
|
# Audio data chunk
|
|
pcm_raw = base64.b64decode(data["data"])
|
|
if raw_int16:
|
|
if pcm_raw:
|
|
yield pcm_raw
|
|
else:
|
|
pcm_f32 = convert_pcm_s16_to_f32(pcm_raw)
|
|
if pcm_f32:
|
|
yield pcm_f32
|
|
|
|
elif code == 20000000:
|
|
# End of stream
|
|
logger.info(f"[TTS] End signal received")
|
|
return
|
|
|
|
elif code > 0:
|
|
error_msg = data.get("message", "Unknown TTS error")
|
|
logger.error(f"[TTS] Error code={code}: {error_msg}")
|
|
return
|