支持语音合成和语音识别api
This commit is contained in:
parent
ace37fbec2
commit
99273a91d3
@ -7,6 +7,7 @@ from typing import Optional
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
from services.voice_session_manager import VoiceSession
|
||||
from utils.settings import VOICE_DEFAULT_MODE
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
@ -35,7 +36,7 @@ async def voice_realtime(websocket: WebSocket):
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
session: Optional[VoiceSession] = None
|
||||
session = None
|
||||
|
||||
async def send_json(data: dict):
|
||||
try:
|
||||
@ -91,7 +92,9 @@ async def voice_realtime(websocket: WebSocket):
|
||||
await send_json({"type": "error", "message": "bot_id is required"})
|
||||
continue
|
||||
|
||||
session = VoiceSession(
|
||||
voice_mode = msg.get("voice_mode") or VOICE_DEFAULT_MODE
|
||||
|
||||
session_kwargs = dict(
|
||||
bot_id=bot_id,
|
||||
session_id=msg.get("session_id"),
|
||||
user_identifier=msg.get("user_identifier"),
|
||||
@ -104,6 +107,13 @@ async def voice_realtime(websocket: WebSocket):
|
||||
on_error=on_error,
|
||||
)
|
||||
|
||||
if voice_mode == "lite":
|
||||
from services.voice_lite_session import VoiceLiteSession
|
||||
session = VoiceLiteSession(**session_kwargs)
|
||||
logger.info(f"[Voice] Using lite mode for bot_id={bot_id}")
|
||||
else:
|
||||
session = VoiceSession(**session_kwargs)
|
||||
|
||||
try:
|
||||
await session.start()
|
||||
except Exception as e:
|
||||
|
||||
12
run.log
Normal file
12
run.log
Normal file
@ -0,0 +1,12 @@
|
||||
2026-03-21 22:07:08,577 - ERROR - Failed to connect to WebSocket: 403, message='Invalid response status', url='wss://openspeech.bytedance.com/api/v3/sauc/bigmodel'
|
||||
2026-03-21 22:07:08,578 - ERROR - Error in ASR execution: 403, message='Invalid response status', url='wss://openspeech.bytedance.com/api/v3/sauc/bigmodel'
|
||||
2026-03-21 22:12:34,063 - ERROR - Failed to connect to WebSocket: 400, message='Invalid response status', url='wss://openspeech.bytedance.com/api/v3/sauc/bigmodel'
|
||||
2026-03-21 22:12:34,065 - ERROR - Error in ASR execution: 400, message='Invalid response status', url='wss://openspeech.bytedance.com/api/v3/sauc/bigmodel'
|
||||
2026-03-21 22:12:45,595 - ERROR - Failed to connect to WebSocket: 400, message='Invalid response status', url='wss://openspeech.bytedance.com/api/v3/sauc/bigmodel'
|
||||
2026-03-21 22:12:45,595 - ERROR - Error in ASR execution: 400, message='Invalid response status', url='wss://openspeech.bytedance.com/api/v3/sauc/bigmodel'
|
||||
2026-03-21 22:13:00,103 - ERROR - Failed to connect to WebSocket: 401, message='Invalid response status', url='wss://openspeech.bytedance.com/api/v3/sauc/bigmodel'
|
||||
2026-03-21 22:13:00,106 - ERROR - Error in ASR execution: 401, message='Invalid response status', url='wss://openspeech.bytedance.com/api/v3/sauc/bigmodel'
|
||||
2026-03-21 22:13:15,842 - ERROR - Failed to connect to WebSocket: 401, message='Invalid response status', url='wss://openspeech.bytedance.com/api/v3/sauc/bigmodel'
|
||||
2026-03-21 22:13:15,843 - ERROR - Error in ASR execution: 401, message='Invalid response status', url='wss://openspeech.bytedance.com/api/v3/sauc/bigmodel'
|
||||
2026-03-21 22:19:22,676 - ERROR - Failed to connect to WebSocket: 403, message='Invalid response status', url='wss://openspeech.bytedance.com/api/v3/sauc/bigmodel'
|
||||
2026-03-21 22:19:22,677 - ERROR - Error in ASR execution: 403, message='Invalid response status', url='wss://openspeech.bytedance.com/api/v3/sauc/bigmodel'
|
||||
@ -8,15 +8,13 @@ import websockets
|
||||
|
||||
from services import realtime_voice_protocol as protocol
|
||||
from utils.settings import (
|
||||
VOLCENGINE_REALTIME_URL,
|
||||
VOLCENGINE_APP_ID,
|
||||
VOLCENGINE_ACCESS_KEY,
|
||||
VOLCENGINE_RESOURCE_ID,
|
||||
VOLCENGINE_APP_KEY,
|
||||
VOLCENGINE_DEFAULT_SPEAKER,
|
||||
VOLCENGINE_TTS_SAMPLE_RATE,
|
||||
)
|
||||
|
||||
VOLCENGINE_REALTIME_URL = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
|
||||
@ -48,8 +46,8 @@ class RealtimeDialogClient:
|
||||
return {
|
||||
"X-Api-App-ID": VOLCENGINE_APP_ID,
|
||||
"X-Api-Access-Key": VOLCENGINE_ACCESS_KEY,
|
||||
"X-Api-Resource-Id": VOLCENGINE_RESOURCE_ID,
|
||||
"X-Api-App-Key": VOLCENGINE_APP_KEY,
|
||||
"X-Api-Resource-Id": "volc.speech.dialog",
|
||||
"X-Api-App-Key": "PlgvMymc7f3tQnJ6",
|
||||
"X-Api-Connect-Id": self._connect_id,
|
||||
}
|
||||
|
||||
|
||||
250
services/streaming_asr_client.py
Normal file
250
services/streaming_asr_client.py
Normal file
@ -0,0 +1,250 @@
|
||||
import gzip
|
||||
import json
|
||||
import struct
|
||||
import uuid
|
||||
import logging
|
||||
from typing import AsyncGenerator, Tuple
|
||||
|
||||
import websockets
|
||||
|
||||
from utils.settings import (
|
||||
VOLCENGINE_ACCESS_KEY,
|
||||
VOLCENGINE_APP_ID,
|
||||
)
|
||||
|
||||
VOLCENGINE_ASR_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async"
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
# Protocol constants (v3/sauc)
|
||||
PROTOCOL_VERSION = 0b0001
|
||||
HEADER_SIZE = 0b0001
|
||||
|
||||
# Message types
|
||||
FULL_CLIENT_REQUEST = 0b0001
|
||||
AUDIO_ONLY_REQUEST = 0b0010
|
||||
FULL_SERVER_RESPONSE = 0b1001
|
||||
SERVER_ERROR_RESPONSE = 0b1111
|
||||
|
||||
# Flags
|
||||
POS_SEQUENCE = 0b0001
|
||||
NEG_SEQUENCE = 0b0010
|
||||
NEG_WITH_SEQUENCE = 0b0011
|
||||
|
||||
# Serialization / Compression
|
||||
JSON_SERIAL = 0b0001
|
||||
GZIP_COMPRESS = 0b0001
|
||||
|
||||
|
||||
def _build_header(msg_type: int, flags: int = POS_SEQUENCE,
|
||||
serial: int = JSON_SERIAL, compress: int = GZIP_COMPRESS) -> bytearray:
|
||||
header = bytearray(4)
|
||||
header[0] = (PROTOCOL_VERSION << 4) | HEADER_SIZE
|
||||
header[1] = (msg_type << 4) | flags
|
||||
header[2] = (serial << 4) | compress
|
||||
header[3] = 0x00
|
||||
return header
|
||||
|
||||
|
||||
class StreamingASRClient:
|
||||
"""Volcengine v3/sauc/bigmodel streaming ASR client."""
|
||||
|
||||
def __init__(self, uid: str = "voice_lite"):
|
||||
self._uid = uid
|
||||
self._ws = None
|
||||
self._seq = 1
|
||||
|
||||
def _build_config(self) -> dict:
|
||||
return {
|
||||
"user": {
|
||||
"uid": self._uid,
|
||||
},
|
||||
"audio": {
|
||||
"format": "pcm",
|
||||
"codec": "raw",
|
||||
"rate": 24000,
|
||||
"bits": 16,
|
||||
"channel": 1,
|
||||
},
|
||||
"request": {
|
||||
"model_name": "bigmodel",
|
||||
"enable_itn": True,
|
||||
"enable_punc": True,
|
||||
"enable_ddc": True,
|
||||
"show_utterances": True,
|
||||
"enable_nonstream": False,
|
||||
},
|
||||
}
|
||||
|
||||
def _build_auth_headers(self) -> dict:
|
||||
return {
|
||||
"X-Api-Resource-Id": "volc.seedasr.sauc.duration",
|
||||
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||||
"X-Api-Access-Key": VOLCENGINE_ACCESS_KEY,
|
||||
"X-Api-App-Key": VOLCENGINE_APP_ID,
|
||||
}
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to ASR WebSocket and send initial full_client_request."""
|
||||
headers = self._build_auth_headers()
|
||||
logger.info(f"[ASR] Connecting to {VOLCENGINE_ASR_URL} with headers: {headers}")
|
||||
self._ws = await websockets.connect(
|
||||
VOLCENGINE_ASR_URL,
|
||||
additional_headers=headers,
|
||||
ping_interval=None,
|
||||
proxy=None
|
||||
)
|
||||
logger.info(f"[ASR] Connected to {VOLCENGINE_ASR_URL}")
|
||||
|
||||
# Send full_client_request with config (seq=1)
|
||||
self._seq = 1
|
||||
config = self._build_config()
|
||||
config_bytes = gzip.compress(json.dumps(config).encode())
|
||||
|
||||
frame = bytearray(_build_header(FULL_CLIENT_REQUEST, POS_SEQUENCE, JSON_SERIAL, GZIP_COMPRESS))
|
||||
frame.extend(struct.pack('>i', self._seq))
|
||||
frame.extend(struct.pack('>I', len(config_bytes)))
|
||||
frame.extend(config_bytes)
|
||||
self._seq += 1
|
||||
|
||||
await self._ws.send(bytes(frame))
|
||||
|
||||
# Wait for server ack
|
||||
resp = await self._ws.recv()
|
||||
parsed = self._parse_response(resp)
|
||||
if parsed and parsed.get("code", 0) != 0:
|
||||
raise ConnectionError(f"[ASR] Server rejected config: {parsed}")
|
||||
logger.info(f"[ASR] Config accepted, ready for audio")
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send an audio chunk to ASR with sequence number."""
|
||||
if not self._ws:
|
||||
return
|
||||
compressed = gzip.compress(chunk)
|
||||
frame = bytearray(_build_header(AUDIO_ONLY_REQUEST, POS_SEQUENCE, JSON_SERIAL, GZIP_COMPRESS))
|
||||
frame.extend(struct.pack('>i', self._seq))
|
||||
frame.extend(struct.pack('>I', len(compressed)))
|
||||
frame.extend(compressed)
|
||||
self._seq += 1
|
||||
await self._ws.send(bytes(frame))
|
||||
|
||||
async def send_finish(self) -> None:
|
||||
"""Send last audio frame with negative sequence to signal end."""
|
||||
if not self._ws:
|
||||
return
|
||||
payload = gzip.compress(b'')
|
||||
frame = bytearray(_build_header(AUDIO_ONLY_REQUEST, NEG_WITH_SEQUENCE, JSON_SERIAL, GZIP_COMPRESS))
|
||||
frame.extend(struct.pack('>i', -self._seq))
|
||||
frame.extend(struct.pack('>I', len(payload)))
|
||||
frame.extend(payload)
|
||||
await self._ws.send(bytes(frame))
|
||||
|
||||
async def receive_results(self) -> AsyncGenerator[Tuple[str, bool], None]:
|
||||
"""Yield (text, is_last) tuples from ASR responses."""
|
||||
if not self._ws:
|
||||
return
|
||||
try:
|
||||
async for message in self._ws:
|
||||
if isinstance(message, str):
|
||||
logger.info(f"[ASR] Received text message: {message[:200]}")
|
||||
continue
|
||||
parsed = self._parse_response(message)
|
||||
logger.info(f"[ASR] Received binary ({len(message)} bytes), parsed: {parsed}")
|
||||
if parsed is None:
|
||||
continue
|
||||
|
||||
code = parsed.get("code", 0)
|
||||
if code != 0:
|
||||
logger.warning(f"[ASR] Server error: {parsed}")
|
||||
return
|
||||
|
||||
is_last = parsed.get("is_last", False)
|
||||
payload_msg = parsed.get("payload_msg")
|
||||
|
||||
if payload_msg and isinstance(payload_msg, dict):
|
||||
text = self._extract_text(payload_msg)
|
||||
if text:
|
||||
yield (text, is_last)
|
||||
|
||||
if is_last:
|
||||
return
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info("[ASR] Connection closed")
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(payload: dict) -> str:
|
||||
"""Extract recognized text from payload."""
|
||||
result = payload.get("result")
|
||||
if not result or not isinstance(result, dict):
|
||||
return ""
|
||||
|
||||
# Try utterances first (show_utterances=True)
|
||||
utterances = result.get("utterances", [])
|
||||
if utterances:
|
||||
parts = []
|
||||
for utt in utterances:
|
||||
text = utt.get("text", "")
|
||||
if text:
|
||||
parts.append(text)
|
||||
return "".join(parts)
|
||||
|
||||
# Fallback to result.text
|
||||
text = result.get("text", "")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return ""
|
||||
|
||||
def _parse_response(self, data: bytes) -> dict:
|
||||
"""Parse binary ASR response into a dict."""
|
||||
if len(data) < 4:
|
||||
return None
|
||||
|
||||
msg_type = data[1] >> 4
|
||||
msg_flags = data[1] & 0x0f
|
||||
serial_method = data[2] >> 4
|
||||
compression = data[2] & 0x0f
|
||||
|
||||
header_size = data[0] & 0x0f
|
||||
payload = data[header_size * 4:]
|
||||
|
||||
result = {"code": 0, "is_last": False}
|
||||
|
||||
# Parse sequence and last flag
|
||||
if msg_flags & 0x01: # has sequence
|
||||
result["sequence"] = struct.unpack('>i', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
if msg_flags & 0x02: # is last package
|
||||
result["is_last"] = True
|
||||
|
||||
if msg_type == SERVER_ERROR_RESPONSE:
|
||||
result["code"] = struct.unpack('>i', payload[:4])[0]
|
||||
payload_size = struct.unpack('>I', payload[4:8])[0]
|
||||
payload = payload[8:]
|
||||
elif msg_type == FULL_SERVER_RESPONSE:
|
||||
payload_size = struct.unpack('>I', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
else:
|
||||
return result
|
||||
|
||||
if not payload:
|
||||
return result
|
||||
|
||||
if compression == GZIP_COMPRESS:
|
||||
try:
|
||||
payload = gzip.decompress(payload)
|
||||
except Exception:
|
||||
return result
|
||||
|
||||
if serial_method == JSON_SERIAL:
|
||||
try:
|
||||
result["payload_msg"] = json.loads(payload.decode('utf-8'))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._ws:
|
||||
logger.info("[ASR] Closing connection")
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
158
services/streaming_tts_client.py
Normal file
158
services/streaming_tts_client.py
Normal file
@ -0,0 +1,158 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
|
||||
from utils.settings import (
|
||||
VOLCENGINE_APP_ID,
|
||||
VOLCENGINE_ACCESS_KEY,
|
||||
VOLCENGINE_DEFAULT_SPEAKER,
|
||||
)
|
||||
|
||||
VOLCENGINE_TTS_URL= "https://openspeech.bytedance.com/api/v3/tts/unidirectional/sse"
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
# Regex to detect text that is only emoji/whitespace (no speakable content)
|
||||
_EMOJI_ONLY_RE = re.compile(
|
||||
r'^[\s\U00002600-\U000027BF\U0001F300-\U0001FAFF\U0000FE00-\U0000FE0F\U0000200D]*$'
|
||||
)
|
||||
|
||||
|
||||
def convert_pcm_s16_to_f32(pcm_data: bytes) -> bytes:
|
||||
"""Convert PCM int16 audio to float32 PCM for frontend playback."""
|
||||
samples = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
if len(samples) == 0:
|
||||
return b''
|
||||
return samples.astype(np.float32).tobytes()
|
||||
|
||||
|
||||
class StreamingTTSClient:
|
||||
"""Volcengine v3/tts/unidirectional/sse streaming TTS client."""
|
||||
|
||||
def __init__(self, speaker: str = ""):
|
||||
self._speaker = speaker or VOLCENGINE_DEFAULT_SPEAKER
|
||||
|
||||
async def synthesize(self, text: str):
|
||||
"""
|
||||
Synthesize text to audio via SSE streaming.
|
||||
Yields 24kHz float32 PCM audio chunks.
|
||||
"""
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
# Skip pure emoji text
|
||||
if _EMOJI_ONLY_RE.match(text):
|
||||
logger.info(f"[TTS] Skipping emoji-only text: '{text}'")
|
||||
return
|
||||
|
||||
headers = {
|
||||
"X-Api-App-Id": VOLCENGINE_APP_ID,
|
||||
"X-Api-Access-Key": VOLCENGINE_ACCESS_KEY,
|
||||
"X-Api-Resource-Id": "seed-tts-2.0",
|
||||
"Content-Type": "application/json",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
body = {
|
||||
"user": {
|
||||
"uid": str(uuid.uuid4()),
|
||||
},
|
||||
"req_params": {
|
||||
"text": text,
|
||||
"speaker": self._speaker,
|
||||
"audio_params": {
|
||||
"format": "pcm",
|
||||
"sample_rate": 24000,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
logger.info(f"[TTS] Requesting: speaker={self._speaker}, text='{text[:50]}'")
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, read=60.0)) as client:
|
||||
async with client.stream("POST", VOLCENGINE_TTS_URL, headers=headers, json=body) as response:
|
||||
logger.info(f"[TTS] Response status: {response.status_code}")
|
||||
if response.status_code != 200:
|
||||
error_body = await response.aread()
|
||||
logger.error(f"[TTS] HTTP {response.status_code}: {error_body.decode('utf-8', errors='replace')}")
|
||||
return
|
||||
|
||||
chunk_count = 0
|
||||
# Parse SSE format: lines prefixed with "event:", "data:", separated by blank lines
|
||||
current_event = ""
|
||||
current_data = ""
|
||||
raw_logged = False
|
||||
|
||||
async for raw_line in response.aiter_lines():
|
||||
if not raw_logged:
|
||||
logger.info(f"[TTS] First SSE line: {raw_line[:200]}")
|
||||
raw_logged = True
|
||||
|
||||
line = raw_line.strip()
|
||||
|
||||
if line == "":
|
||||
# Blank line = end of one SSE event
|
||||
if current_data:
|
||||
async for audio in self._process_sse_data(current_data):
|
||||
chunk_count += 1
|
||||
yield audio
|
||||
current_event = ""
|
||||
current_data = ""
|
||||
continue
|
||||
|
||||
if line.startswith(":"):
|
||||
# SSE comment, skip
|
||||
continue
|
||||
|
||||
if ":" in line:
|
||||
field, value = line.split(":", 1)
|
||||
value = value.lstrip()
|
||||
if field == "event":
|
||||
current_event = value
|
||||
elif field == "data":
|
||||
current_data += value + "\n"
|
||||
|
||||
# Handle remaining data without trailing blank line
|
||||
if current_data:
|
||||
async for audio in self._process_sse_data(current_data):
|
||||
chunk_count += 1
|
||||
yield audio
|
||||
|
||||
logger.info(f"[TTS] Stream done, yielded {chunk_count} audio chunks")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[TTS] Error: {e}", exc_info=True)
|
||||
|
||||
async def _process_sse_data(self, data_str: str):
|
||||
"""Parse SSE data field and yield audio chunks if present."""
|
||||
data_str = data_str.rstrip("\n")
|
||||
if not data_str:
|
||||
return
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"[TTS] Non-JSON SSE data: {data_str[:100]}")
|
||||
return
|
||||
|
||||
code = data.get("code", 0)
|
||||
|
||||
if code == 0 and data.get("data"):
|
||||
# Audio data chunk
|
||||
pcm_raw = base64.b64decode(data["data"])
|
||||
pcm_f32 = convert_pcm_s16_to_f32(pcm_raw)
|
||||
if pcm_f32:
|
||||
yield pcm_f32
|
||||
|
||||
elif code == 20000000:
|
||||
# End of stream
|
||||
logger.info(f"[TTS] End signal received")
|
||||
return
|
||||
|
||||
elif code > 0:
|
||||
error_msg = data.get("message", "Unknown TTS error")
|
||||
logger.error(f"[TTS] Error code={code}: {error_msg}")
|
||||
return
|
||||
407
services/voice_lite_session.py
Normal file
407
services/voice_lite_session.py
Normal file
@ -0,0 +1,407 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import struct
|
||||
import uuid
|
||||
from typing import Optional, Callable, Awaitable
|
||||
|
||||
from services.streaming_asr_client import StreamingASRClient
|
||||
from services.streaming_tts_client import StreamingTTSClient
|
||||
from services.voice_utils import (
|
||||
StreamTagFilter,
|
||||
clean_markdown,
|
||||
stream_v3_agent,
|
||||
SENTENCE_END_RE,
|
||||
)
|
||||
from utils.settings import VOICE_LITE_SILENCE_TIMEOUT
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
|
||||
class VoiceLiteSession:
|
||||
"""Voice Lite session: ASR -> Agent -> TTS pipeline (cheaper alternative)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
user_identifier: Optional[str] = None,
|
||||
on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None,
|
||||
on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
on_agent_stream: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
on_llm_text: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
on_status: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
on_error: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
):
|
||||
self.bot_id = bot_id
|
||||
self.session_id = session_id or str(uuid.uuid4())
|
||||
self.user_identifier = user_identifier or ""
|
||||
|
||||
self._bot_config: dict = {}
|
||||
self._speaker: str = ""
|
||||
|
||||
# Callbacks
|
||||
self._on_audio = on_audio
|
||||
self._on_asr_text = on_asr_text
|
||||
self._on_agent_result = on_agent_result
|
||||
self._on_agent_stream = on_agent_stream
|
||||
self._on_llm_text = on_llm_text
|
||||
self._on_status = on_status
|
||||
self._on_error = on_error
|
||||
|
||||
self._running = False
|
||||
self._asr_client: Optional[StreamingASRClient] = None
|
||||
self._asr_receive_task: Optional[asyncio.Task] = None
|
||||
self._agent_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Silence timeout tracking
|
||||
self._last_asr_time: float = 0
|
||||
self._silence_timer_task: Optional[asyncio.Task] = None
|
||||
self._current_asr_text: str = ""
|
||||
self._last_text_change_time: float = 0
|
||||
self._last_changed_text: str = ""
|
||||
self._last_asr_emit_time: float = 0
|
||||
self._utterance_lock = asyncio.Lock()
|
||||
|
||||
# VAD (Voice Activity Detection) state
|
||||
self._vad_speaking = False # Whether user is currently speaking
|
||||
self._vad_silence_start: float = 0 # When silence started
|
||||
self._vad_finish_task: Optional[asyncio.Task] = None
|
||||
self._pre_buffer: list = [] # Buffer audio before VAD triggers
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Fetch bot config, mark session as running."""
|
||||
from utils.fastapi_utils import fetch_bot_config_from_db
|
||||
|
||||
self._bot_config = await fetch_bot_config_from_db(self.bot_id, self.user_identifier)
|
||||
self._speaker = self._bot_config.get("voice_speaker", "")
|
||||
|
||||
self._running = True
|
||||
await self._emit_status("ready")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Gracefully stop the session."""
|
||||
self._running = False
|
||||
|
||||
if self._vad_finish_task and not self._vad_finish_task.done():
|
||||
self._vad_finish_task.cancel()
|
||||
|
||||
if self._silence_timer_task and not self._silence_timer_task.done():
|
||||
self._silence_timer_task.cancel()
|
||||
|
||||
if self._agent_task and not self._agent_task.done():
|
||||
from utils.cancel_manager import trigger_cancel
|
||||
trigger_cancel(self.session_id)
|
||||
self._agent_task.cancel()
|
||||
|
||||
if self._asr_receive_task and not self._asr_receive_task.done():
|
||||
self._asr_receive_task.cancel()
|
||||
|
||||
if self._asr_client:
|
||||
try:
|
||||
await self._asr_client.send_finish()
|
||||
except Exception:
|
||||
pass
|
||||
await self._asr_client.close()
|
||||
|
||||
# VAD configuration
|
||||
VAD_ENERGY_THRESHOLD = 500 # RMS energy threshold for voice detection
|
||||
VAD_SILENCE_DURATION = 1.5 # Seconds of silence before sending finish
|
||||
VAD_PRE_BUFFER_SIZE = 5 # Number of audio chunks to buffer before VAD triggers
|
||||
|
||||
_audio_chunk_count = 0
|
||||
|
||||
@staticmethod
|
||||
def _calc_rms(pcm_data: bytes) -> float:
|
||||
"""Calculate RMS energy of 16-bit PCM audio."""
|
||||
if len(pcm_data) < 2:
|
||||
return 0.0
|
||||
n_samples = len(pcm_data) // 2
|
||||
samples = struct.unpack(f'<{n_samples}h', pcm_data[:n_samples * 2])
|
||||
if not samples:
|
||||
return 0.0
|
||||
sum_sq = sum(s * s for s in samples)
|
||||
return (sum_sq / n_samples) ** 0.5
|
||||
|
||||
async def handle_audio(self, audio_data: bytes) -> None:
|
||||
"""Forward user audio to ASR with VAD gating. Lazy-connect on speech start."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._audio_chunk_count += 1
|
||||
rms = self._calc_rms(audio_data)
|
||||
has_voice = rms > self.VAD_ENERGY_THRESHOLD
|
||||
now = asyncio.get_event_loop().time()
|
||||
|
||||
if has_voice:
|
||||
# Cancel any pending finish
|
||||
if self._vad_finish_task and not self._vad_finish_task.done():
|
||||
self._vad_finish_task.cancel()
|
||||
self._vad_finish_task = None
|
||||
|
||||
if not self._vad_speaking:
|
||||
# Speech just started — connect ASR
|
||||
self._vad_speaking = True
|
||||
logger.info(f"[VoiceLite] VAD: speech started (rms={rms:.0f}), connecting ASR...")
|
||||
try:
|
||||
await self._connect_asr()
|
||||
# Send buffered pre-speech audio
|
||||
for buffered in self._pre_buffer:
|
||||
await self._asr_client.send_audio(buffered)
|
||||
self._pre_buffer.clear()
|
||||
except Exception as e:
|
||||
logger.error(f"[VoiceLite] VAD: ASR connect failed: {e}", exc_info=True)
|
||||
self._vad_speaking = False
|
||||
return
|
||||
|
||||
# Send current chunk
|
||||
if self._asr_client:
|
||||
try:
|
||||
await self._asr_client.send_audio(audio_data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._vad_silence_start = 0
|
||||
else:
|
||||
if self._vad_speaking:
|
||||
# Brief silence while speaking — keep sending for ASR context
|
||||
if self._asr_client:
|
||||
try:
|
||||
await self._asr_client.send_audio(audio_data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self._vad_silence_start == 0:
|
||||
self._vad_silence_start = now
|
||||
|
||||
# Silence exceeded threshold -> send finish
|
||||
if (now - self._vad_silence_start) >= self.VAD_SILENCE_DURATION:
|
||||
if not self._vad_finish_task or self._vad_finish_task.done():
|
||||
self._vad_finish_task = asyncio.create_task(self._vad_send_finish())
|
||||
else:
|
||||
# Not speaking — buffer recent audio for pre-speech context
|
||||
self._pre_buffer.append(audio_data)
|
||||
if len(self._pre_buffer) > self.VAD_PRE_BUFFER_SIZE:
|
||||
self._pre_buffer.pop(0)
|
||||
|
||||
async def _vad_send_finish(self) -> None:
|
||||
"""Send finish signal to ASR after silence detected."""
|
||||
logger.info(f"[VoiceLite] VAD: silence detected, sending finish to ASR")
|
||||
self._vad_speaking = False
|
||||
self._vad_silence_start = 0
|
||||
if self._asr_client:
|
||||
try:
|
||||
await self._asr_client.send_finish()
|
||||
except Exception as e:
|
||||
logger.warning(f"[VoiceLite] VAD: send_finish failed: {e}")
|
||||
|
||||
async def handle_text(self, text: str) -> None:
|
||||
"""Handle direct text input - bypass ASR and go straight to agent."""
|
||||
if not self._running:
|
||||
return
|
||||
await self._interrupt_current()
|
||||
self._agent_task = asyncio.create_task(self._process_utterance(text))
|
||||
|
||||
async def _connect_asr(self) -> None:
|
||||
"""Create and connect a new ASR client, start receive loop."""
|
||||
if self._asr_client:
|
||||
try:
|
||||
await self._asr_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
if self._asr_receive_task and not self._asr_receive_task.done():
|
||||
self._asr_receive_task.cancel()
|
||||
|
||||
self._asr_client = StreamingASRClient(uid=self.user_identifier or "voice_lite")
|
||||
await self._asr_client.connect()
|
||||
logger.info(f"[VoiceLite] ASR client connected")
|
||||
|
||||
# Start receive loop for this ASR session
|
||||
self._asr_receive_task = asyncio.create_task(self._asr_receive_loop())
|
||||
|
||||
async def _asr_receive_loop(self) -> None:
|
||||
"""Receive ASR results from the current ASR session."""
|
||||
try:
|
||||
_listening_emitted = False
|
||||
async for text, is_final in self._asr_client.receive_results():
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
if not _listening_emitted:
|
||||
await self._emit_status("listening")
|
||||
_listening_emitted = True
|
||||
|
||||
now = asyncio.get_event_loop().time()
|
||||
self._last_asr_time = now
|
||||
self._current_asr_text = text
|
||||
|
||||
# Track text changes for stability detection
|
||||
if text != self._last_changed_text:
|
||||
self._last_changed_text = text
|
||||
self._last_text_change_time = now
|
||||
|
||||
# Reset stability timer on every result
|
||||
self._reset_silence_timer()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
if self._running:
|
||||
logger.warning(f"[VoiceLite] ASR session ended: {e}")
|
||||
finally:
|
||||
logger.info(f"[VoiceLite] ASR session done")
|
||||
# Clean up ASR client after session ends
|
||||
if self._asr_client:
|
||||
try:
|
||||
await self._asr_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._asr_client = None
|
||||
|
||||
def _reset_silence_timer(self) -> None:
|
||||
"""Reset the silence timeout timer."""
|
||||
if self._silence_timer_task and not self._silence_timer_task.done():
|
||||
self._silence_timer_task.cancel()
|
||||
self._silence_timer_task = asyncio.create_task(self._silence_timeout())
|
||||
|
||||
async def _silence_timeout(self) -> None:
|
||||
"""Wait for silence timeout, then check if text has been stable."""
|
||||
try:
|
||||
await asyncio.sleep(VOICE_LITE_SILENCE_TIMEOUT)
|
||||
if not self._running:
|
||||
return
|
||||
# Check if text has been stable (unchanged) for the timeout period
|
||||
now = asyncio.get_event_loop().time()
|
||||
if (self._current_asr_text
|
||||
and (now - self._last_text_change_time) >= VOICE_LITE_SILENCE_TIMEOUT):
|
||||
logger.info(f"[VoiceLite] Text stable for {VOICE_LITE_SILENCE_TIMEOUT}s, processing: '{self._current_asr_text}'")
|
||||
await self._on_utterance_complete(self._current_asr_text)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _on_utterance_complete(self, text: str) -> None:
|
||||
"""Called when a complete utterance is detected."""
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
async with self._utterance_lock:
|
||||
# Cancel silence timer
|
||||
if self._silence_timer_task and not self._silence_timer_task.done():
|
||||
self._silence_timer_task.cancel()
|
||||
|
||||
# Interrupt any in-progress agent+TTS
|
||||
await self._interrupt_current()
|
||||
|
||||
# Send final ASR text to frontend
|
||||
if self._on_asr_text:
|
||||
await self._on_asr_text(text)
|
||||
|
||||
self._current_asr_text = ""
|
||||
self._last_changed_text = ""
|
||||
self._agent_task = asyncio.create_task(self._process_utterance(text))
|
||||
|
||||
async def _interrupt_current(self) -> None:
|
||||
"""Cancel current agent+TTS task if running."""
|
||||
if self._agent_task and not self._agent_task.done():
|
||||
logger.info(f"[VoiceLite] Interrupting previous agent task")
|
||||
from utils.cancel_manager import trigger_cancel
|
||||
trigger_cancel(self.session_id)
|
||||
self._agent_task.cancel()
|
||||
try:
|
||||
await self._agent_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
self._agent_task = None
|
||||
|
||||
async def _process_utterance(self, text: str) -> None:
|
||||
"""Process a complete utterance: agent -> TTS pipeline."""
|
||||
try:
|
||||
logger.info(f"[VoiceLite] Processing utterance: '{text}'")
|
||||
await self._emit_status("thinking")
|
||||
|
||||
accumulated_text = []
|
||||
sentence_buf = ""
|
||||
tag_filter = StreamTagFilter()
|
||||
tts_client = StreamingTTSClient(speaker=self._speaker)
|
||||
speaking = False
|
||||
|
||||
async for chunk in stream_v3_agent(
|
||||
user_text=text,
|
||||
bot_id=self.bot_id,
|
||||
bot_config=self._bot_config,
|
||||
session_id=self.session_id,
|
||||
user_identifier=self.user_identifier,
|
||||
):
|
||||
accumulated_text.append(chunk)
|
||||
|
||||
if self._on_agent_stream:
|
||||
await self._on_agent_stream(chunk)
|
||||
|
||||
passthrough = tag_filter.feed(chunk)
|
||||
|
||||
if not passthrough:
|
||||
if tag_filter.answer_ended and sentence_buf:
|
||||
flush = clean_markdown(sentence_buf.strip())
|
||||
sentence_buf = ""
|
||||
if flush:
|
||||
if not speaking:
|
||||
await self._emit_status("speaking")
|
||||
speaking = True
|
||||
await self._send_tts(tts_client, flush)
|
||||
continue
|
||||
|
||||
sentence_buf += passthrough
|
||||
|
||||
while True:
|
||||
match = SENTENCE_END_RE.search(sentence_buf)
|
||||
if not match:
|
||||
break
|
||||
end_pos = match.end()
|
||||
sentence = clean_markdown(sentence_buf[:end_pos].strip())
|
||||
sentence_buf = sentence_buf[end_pos:]
|
||||
|
||||
if sentence:
|
||||
if not speaking:
|
||||
await self._emit_status("speaking")
|
||||
speaking = True
|
||||
await self._send_tts(tts_client, sentence)
|
||||
|
||||
# Handle remaining text
|
||||
remaining = clean_markdown(sentence_buf.strip())
|
||||
if remaining:
|
||||
if not speaking:
|
||||
await self._emit_status("speaking")
|
||||
speaking = True
|
||||
await self._send_tts(tts_client, remaining)
|
||||
|
||||
# Log full agent result (not sent to frontend, already streamed)
|
||||
full_result = "".join(accumulated_text)
|
||||
logger.info(f"[VoiceLite] Agent done ({len(full_result)} chars)")
|
||||
|
||||
# Notify frontend that agent text stream is complete
|
||||
if self._on_agent_result:
|
||||
await self._on_agent_result(full_result)
|
||||
|
||||
await self._emit_status("idle")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[VoiceLite] Agent task cancelled (user interrupted)")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[VoiceLite] Error processing utterance: {e}", exc_info=True)
|
||||
await self._emit_error(f"Processing failed: {str(e)}")
|
||||
|
||||
async def _send_tts(self, tts_client: StreamingTTSClient, sentence: str) -> None:
|
||||
"""Synthesize a sentence and emit audio chunks."""
|
||||
logger.info(f"[VoiceLite] TTS sentence: '{sentence[:80]}'")
|
||||
async for audio_chunk in tts_client.synthesize(sentence):
|
||||
if self._on_audio:
|
||||
await self._on_audio(audio_chunk)
|
||||
|
||||
async def _emit_status(self, status: str) -> None:
|
||||
if self._on_status:
|
||||
await self._on_status(status)
|
||||
|
||||
async def _emit_error(self, message: str) -> None:
|
||||
if self._on_error:
|
||||
await self._on_error(message)
|
||||
@ -1,96 +1,14 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Optional, Callable, Awaitable, AsyncGenerator
|
||||
from typing import Optional, Callable, Awaitable
|
||||
|
||||
from services.realtime_voice_client import RealtimeDialogClient
|
||||
from services.voice_utils import StreamTagFilter, clean_markdown, stream_v3_agent, SENTENCE_END_RE
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
|
||||
class _StreamTagFilter:
|
||||
"""
|
||||
Filters streaming text based on tag blocks.
|
||||
Only passes through content inside [ANSWER] blocks.
|
||||
If no tags are found at all, passes through everything (fallback).
|
||||
Skips content inside [TOOL_CALL], [TOOL_RESPONSE], [THINK], [SOURCE], etc.
|
||||
"""
|
||||
|
||||
SKIP_TAGS = {"TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE"}
|
||||
KNOWN_TAGS = {"ANSWER", "TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE", "PREAMBLE", "SUMMARY"}
|
||||
|
||||
def __init__(self):
|
||||
self.state = "idle" # idle, answer, skip
|
||||
self.found_any_tag = False
|
||||
self._pending = "" # buffer for partial tag like "[TOO..."
|
||||
self.answer_ended = False # True when ANSWER block ends (e.g. hit [TOOL_CALL])
|
||||
|
||||
def feed(self, chunk: str) -> str:
|
||||
"""Feed a chunk, return text that should be passed to TTS."""
|
||||
self.answer_ended = False
|
||||
self._pending += chunk
|
||||
output = []
|
||||
|
||||
while self._pending:
|
||||
if self.state in ("idle", "answer"):
|
||||
bracket_pos = self._pending.find("[")
|
||||
if bracket_pos == -1:
|
||||
if self.state == "answer" or not self.found_any_tag:
|
||||
output.append(self._pending)
|
||||
self._pending = ""
|
||||
else:
|
||||
before = self._pending[:bracket_pos]
|
||||
if before and (self.state == "answer" or not self.found_any_tag):
|
||||
output.append(before)
|
||||
|
||||
close_pos = self._pending.find("]", bracket_pos)
|
||||
if close_pos == -1:
|
||||
# Incomplete tag — wait for next chunk
|
||||
self._pending = self._pending[bracket_pos:]
|
||||
break
|
||||
|
||||
tag_name = self._pending[bracket_pos + 1:close_pos]
|
||||
self._pending = self._pending[close_pos + 1:]
|
||||
|
||||
if tag_name not in self.KNOWN_TAGS:
|
||||
if self.state == "answer" or not self.found_any_tag:
|
||||
output.append(f"[{tag_name}]")
|
||||
continue
|
||||
|
||||
self.found_any_tag = True
|
||||
if tag_name == "ANSWER":
|
||||
self.state = "answer"
|
||||
else:
|
||||
if self.state == "answer":
|
||||
self.answer_ended = True
|
||||
self.state = "skip"
|
||||
|
||||
elif self.state == "skip":
|
||||
bracket_pos = self._pending.find("[")
|
||||
if bracket_pos == -1:
|
||||
self._pending = ""
|
||||
else:
|
||||
close_pos = self._pending.find("]", bracket_pos)
|
||||
if close_pos == -1:
|
||||
self._pending = self._pending[bracket_pos:]
|
||||
break
|
||||
|
||||
tag_name = self._pending[bracket_pos + 1:close_pos]
|
||||
self._pending = self._pending[close_pos + 1:]
|
||||
|
||||
if tag_name not in self.KNOWN_TAGS:
|
||||
continue
|
||||
|
||||
if tag_name == "ANSWER":
|
||||
self.state = "answer"
|
||||
else:
|
||||
self.state = "skip"
|
||||
|
||||
return "".join(output)
|
||||
|
||||
|
||||
class VoiceSession:
|
||||
"""Manages a single voice dialogue session lifecycle"""
|
||||
|
||||
@ -288,30 +206,6 @@ class VoiceSession:
|
||||
logger.error(f"[Voice] Server error: {error_msg}")
|
||||
await self._emit_error(error_msg)
|
||||
|
||||
# Sentence-ending punctuation pattern for splitting TTS
|
||||
_SENTENCE_END_RE = re.compile(r'[。!?;\n.!?;]')
|
||||
# Markdown syntax to strip before TTS
|
||||
_MD_CLEAN_RE = re.compile(r'#{1,6}\s*|(?<!\w)\*{1,3}|(?<!\w)_{1,3}|\*{1,3}(?!\w)|_{1,3}(?!\w)|~~|`{1,3}|^>\s*|^\s*[-*+]\s+|^\s*\d+\.\s+|\[([^\]]*)\]\([^)]*\)|!\[([^\]]*)\]\([^)]*\)', re.MULTILINE)
|
||||
|
||||
@staticmethod
|
||||
def _clean_markdown(text: str) -> str:
|
||||
"""Strip Markdown formatting characters for TTS readability."""
|
||||
# Replace links/images with their display text
|
||||
text = re.sub(r'!\[([^\]]*)\]\([^)]*\)', r'\1', text)
|
||||
text = re.sub(r'\[([^\]]*)\]\([^)]*\)', r'\1', text)
|
||||
# Remove headings, bold, italic, strikethrough, code marks, blockquote
|
||||
text = re.sub(r'#{1,6}\s*', '', text)
|
||||
text = re.sub(r'\*{1,3}|_{1,3}|~~|`{1,3}', '', text)
|
||||
text = re.sub(r'^>\s*', '', text, flags=re.MULTILINE)
|
||||
# Remove list markers
|
||||
text = re.sub(r'^\s*[-*+]\s+', '', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
|
||||
# Remove horizontal rules
|
||||
text = re.sub(r'^[\s]*[-*_]{3,}[\s]*$', '', text, flags=re.MULTILINE)
|
||||
# Collapse extra whitespace
|
||||
text = re.sub(r'\n{2,}', '\n', text)
|
||||
return text.strip()
|
||||
|
||||
async def _on_asr_text_received(self, text: str) -> None:
|
||||
"""Called when ASR text is received — stream agent output, send TTS sentence by sentence"""
|
||||
if not text.strip():
|
||||
@ -324,9 +218,15 @@ class VoiceSession:
|
||||
accumulated_text = [] # full agent output for on_agent_result callback
|
||||
sentence_buf = "" # buffer for accumulating until sentence boundary
|
||||
tts_started = False # whether we've sent the first TTS chunk
|
||||
tag_filter = _StreamTagFilter()
|
||||
tag_filter = StreamTagFilter()
|
||||
|
||||
async for chunk in self._stream_v3_agent(text):
|
||||
async for chunk in stream_v3_agent(
|
||||
user_text=text,
|
||||
bot_id=self.bot_id,
|
||||
bot_config=self._bot_config,
|
||||
session_id=self.session_id,
|
||||
user_identifier=self.user_identifier,
|
||||
):
|
||||
accumulated_text.append(chunk)
|
||||
|
||||
if self._on_agent_stream:
|
||||
@ -341,7 +241,7 @@ class VoiceSession:
|
||||
flush = sentence_buf.strip()
|
||||
sentence_buf = ""
|
||||
if flush:
|
||||
flush = self._clean_markdown(flush)
|
||||
flush = clean_markdown(flush)
|
||||
if flush:
|
||||
if tts_started and self._tts_segment_done:
|
||||
logger.info(f"[Voice] TTS segment done, closing session and waiting for delivery (answer ended)")
|
||||
@ -371,7 +271,7 @@ class VoiceSession:
|
||||
|
||||
# Check for sentence boundaries and send complete sentences to TTS
|
||||
while True:
|
||||
match = self._SENTENCE_END_RE.search(sentence_buf)
|
||||
match = SENTENCE_END_RE.search(sentence_buf)
|
||||
if not match:
|
||||
break
|
||||
# Split at sentence boundary (include the punctuation)
|
||||
@ -380,7 +280,7 @@ class VoiceSession:
|
||||
sentence_buf = sentence_buf[end_pos:]
|
||||
|
||||
if sentence:
|
||||
sentence = self._clean_markdown(sentence)
|
||||
sentence = clean_markdown(sentence)
|
||||
if sentence:
|
||||
# If previous TTS segment completed (e.g. gap during tool call),
|
||||
# close old session, wait for TTS delivery to finish, then restart
|
||||
@ -410,7 +310,7 @@ class VoiceSession:
|
||||
# Handle remaining text in buffer (last sentence without ending punctuation)
|
||||
remaining = sentence_buf.strip()
|
||||
if remaining:
|
||||
remaining = self._clean_markdown(remaining)
|
||||
remaining = clean_markdown(remaining)
|
||||
if remaining:
|
||||
# If previous TTS segment completed, close and wait before restart
|
||||
if tts_started and self._tts_segment_done:
|
||||
@ -464,70 +364,6 @@ class VoiceSession:
|
||||
self._is_sending_chat_tts_text = False
|
||||
await self._emit_error(f"Agent call failed: {str(e)}")
|
||||
|
||||
async def _stream_v3_agent(self, user_text: str) -> AsyncGenerator[str, None]:
|
||||
"""Call v3 agent API in streaming mode, yield text chunks as they arrive"""
|
||||
try:
|
||||
from utils.api_models import ChatRequestV3, Message
|
||||
from utils.fastapi_utils import (
|
||||
process_messages,
|
||||
create_project_directory,
|
||||
)
|
||||
from agent.agent_config import AgentConfig
|
||||
from routes.chat import enhanced_generate_stream_response
|
||||
|
||||
bot_config = self._bot_config
|
||||
language = bot_config.get("language", "zh")
|
||||
|
||||
messages_obj = [Message(role="user", content=user_text)]
|
||||
|
||||
request = ChatRequestV3(
|
||||
messages=messages_obj,
|
||||
bot_id=self.bot_id,
|
||||
stream=True,
|
||||
session_id=self.session_id,
|
||||
user_identifier=self.user_identifier,
|
||||
)
|
||||
|
||||
project_dir = create_project_directory(
|
||||
bot_config.get("dataset_ids", []),
|
||||
self.bot_id,
|
||||
bot_config.get("skills", []),
|
||||
)
|
||||
|
||||
processed_messages = process_messages(messages_obj, language)
|
||||
|
||||
config = await AgentConfig.from_v3_request(
|
||||
request,
|
||||
bot_config,
|
||||
project_dir,
|
||||
processed_messages,
|
||||
language,
|
||||
)
|
||||
config.stream = True
|
||||
|
||||
async for sse_line in enhanced_generate_stream_response(config):
|
||||
if not sse_line or not sse_line.startswith("data: "):
|
||||
continue
|
||||
data_str = sse_line.strip().removeprefix("data: ")
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[Voice] v3 agent call cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Voice] Error calling v3 agent: {e}", exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def _extract_answer(agent_result: str) -> str:
|
||||
"""Extract the answer portion from agent result, stripping tags like [ANSWER], [THINK] etc."""
|
||||
|
||||
172
services/voice_utils.py
Normal file
172
services/voice_utils.py
Normal file
@ -0,0 +1,172 @@
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
from typing import Optional, AsyncGenerator
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
SENTENCE_END_RE = re.compile(r'[。!?;\n.!?;]')
|
||||
|
||||
|
||||
class StreamTagFilter:
|
||||
"""
|
||||
Filters streaming text based on tag blocks.
|
||||
Only passes through content inside [ANSWER] blocks.
|
||||
If no tags are found at all, passes through everything (fallback).
|
||||
Skips content inside [TOOL_CALL], [TOOL_RESPONSE], [THINK], [SOURCE], etc.
|
||||
"""
|
||||
|
||||
SKIP_TAGS = {"TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE"}
|
||||
KNOWN_TAGS = {"ANSWER", "TOOL_CALL", "TOOL_RESPONSE", "THINK", "SOURCE", "REFERENCE", "PREAMBLE", "SUMMARY"}
|
||||
|
||||
def __init__(self):
|
||||
self.state = "idle" # idle, answer, skip
|
||||
self.found_any_tag = False
|
||||
self._pending = ""
|
||||
self.answer_ended = False
|
||||
|
||||
def feed(self, chunk: str) -> str:
|
||||
"""Feed a chunk, return text that should be passed to TTS."""
|
||||
self.answer_ended = False
|
||||
self._pending += chunk
|
||||
output = []
|
||||
|
||||
while self._pending:
|
||||
if self.state in ("idle", "answer"):
|
||||
bracket_pos = self._pending.find("[")
|
||||
if bracket_pos == -1:
|
||||
if self.state == "answer" or not self.found_any_tag:
|
||||
output.append(self._pending)
|
||||
self._pending = ""
|
||||
else:
|
||||
before = self._pending[:bracket_pos]
|
||||
if before and (self.state == "answer" or not self.found_any_tag):
|
||||
output.append(before)
|
||||
|
||||
close_pos = self._pending.find("]", bracket_pos)
|
||||
if close_pos == -1:
|
||||
self._pending = self._pending[bracket_pos:]
|
||||
break
|
||||
|
||||
tag_name = self._pending[bracket_pos + 1:close_pos]
|
||||
self._pending = self._pending[close_pos + 1:]
|
||||
|
||||
if tag_name not in self.KNOWN_TAGS:
|
||||
if self.state == "answer" or not self.found_any_tag:
|
||||
output.append(f"[{tag_name}]")
|
||||
continue
|
||||
|
||||
self.found_any_tag = True
|
||||
if tag_name == "ANSWER":
|
||||
self.state = "answer"
|
||||
else:
|
||||
if self.state == "answer":
|
||||
self.answer_ended = True
|
||||
self.state = "skip"
|
||||
|
||||
elif self.state == "skip":
|
||||
bracket_pos = self._pending.find("[")
|
||||
if bracket_pos == -1:
|
||||
self._pending = ""
|
||||
else:
|
||||
close_pos = self._pending.find("]", bracket_pos)
|
||||
if close_pos == -1:
|
||||
self._pending = self._pending[bracket_pos:]
|
||||
break
|
||||
|
||||
tag_name = self._pending[bracket_pos + 1:close_pos]
|
||||
self._pending = self._pending[close_pos + 1:]
|
||||
|
||||
if tag_name not in self.KNOWN_TAGS:
|
||||
continue
|
||||
|
||||
if tag_name == "ANSWER":
|
||||
self.state = "answer"
|
||||
else:
|
||||
self.state = "skip"
|
||||
|
||||
return "".join(output)
|
||||
|
||||
|
||||
def clean_markdown(text: str) -> str:
|
||||
"""Strip Markdown formatting characters for TTS readability."""
|
||||
text = re.sub(r'!\[([^\]]*)\]\([^)]*\)', r'\1', text)
|
||||
text = re.sub(r'\[([^\]]*)\]\([^)]*\)', r'\1', text)
|
||||
text = re.sub(r'#{1,6}\s*', '', text)
|
||||
text = re.sub(r'\*{1,3}|_{1,3}|~~|`{1,3}', '', text)
|
||||
text = re.sub(r'^>\s*', '', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'^\s*[-*+]\s+', '', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'^[\s]*[-*_]{3,}[\s]*$', '', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'\n{2,}', '\n', text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
async def stream_v3_agent(
|
||||
user_text: str,
|
||||
bot_id: str,
|
||||
bot_config: dict,
|
||||
session_id: str,
|
||||
user_identifier: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Call v3 agent API in streaming mode, yield text chunks as they arrive."""
|
||||
import asyncio
|
||||
try:
|
||||
from utils.api_models import ChatRequestV3, Message
|
||||
from utils.fastapi_utils import (
|
||||
process_messages,
|
||||
create_project_directory,
|
||||
)
|
||||
from agent.agent_config import AgentConfig
|
||||
from routes.chat import enhanced_generate_stream_response
|
||||
|
||||
language = bot_config.get("language", "zh")
|
||||
messages_obj = [Message(role="user", content=user_text)]
|
||||
|
||||
request = ChatRequestV3(
|
||||
messages=messages_obj,
|
||||
bot_id=bot_id,
|
||||
stream=True,
|
||||
session_id=session_id,
|
||||
user_identifier=user_identifier,
|
||||
)
|
||||
|
||||
project_dir = create_project_directory(
|
||||
bot_config.get("dataset_ids", []),
|
||||
bot_id,
|
||||
bot_config.get("skills", []),
|
||||
)
|
||||
|
||||
processed_messages = process_messages(messages_obj, language)
|
||||
|
||||
config = await AgentConfig.from_v3_request(
|
||||
request,
|
||||
bot_config,
|
||||
project_dir,
|
||||
processed_messages,
|
||||
language,
|
||||
)
|
||||
config.stream = True
|
||||
|
||||
async for sse_line in enhanced_generate_stream_response(config):
|
||||
if not sse_line or not sse_line.startswith("data: "):
|
||||
continue
|
||||
data_str = sse_line.strip().removeprefix("data: ")
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[Voice] v3 agent call cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Voice] Error calling v3 agent: {e}", exc_info=True)
|
||||
@ -121,19 +121,20 @@ NEW_API_ADMIN_KEY = os.getenv("NEW_API_ADMIN_KEY", "")
|
||||
# ============================================================
|
||||
# Volcengine Realtime Dialogue Configuration
|
||||
# ============================================================
|
||||
VOLCENGINE_REALTIME_URL = os.getenv(
|
||||
"VOLCENGINE_REALTIME_URL",
|
||||
"wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
|
||||
)
|
||||
VOLCENGINE_APP_ID = os.getenv("VOLCENGINE_APP_ID", "8718217928")
|
||||
VOLCENGINE_ACCESS_KEY = os.getenv("VOLCENGINE_ACCESS_KEY", "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc")
|
||||
VOLCENGINE_RESOURCE_ID = os.getenv("VOLCENGINE_RESOURCE_ID", "volc.speech.dialog")
|
||||
VOLCENGINE_APP_KEY = os.getenv("VOLCENGINE_APP_KEY", "PlgvMymc7f3tQnJ6")
|
||||
VOLCENGINE_APP_ID = os.getenv("VOLCENGINE_APP_ID", "2511880162")
|
||||
VOLCENGINE_ACCESS_KEY = os.getenv("VOLCENGINE_ACCESS_KEY", "pjLbaqR1lHFfkv1xcJAYnvKV0HAvsBvt")
|
||||
VOLCENGINE_DEFAULT_SPEAKER = os.getenv(
|
||||
"VOLCENGINE_DEFAULT_SPEAKER", "zh_male_yunzhou_jupiter_bigtts"
|
||||
)
|
||||
VOLCENGINE_TTS_SAMPLE_RATE = int(os.getenv("VOLCENGINE_TTS_SAMPLE_RATE", "24000"))
|
||||
|
||||
# ============================================================
|
||||
# Voice Lite Configuration (ASR + Agent + TTS pipeline)
|
||||
# ============================================================
|
||||
VOICE_DEFAULT_MODE = os.getenv("VOICE_DEFAULT_MODE", "lite") # "realtime" | "lite"
|
||||
# Silence timeout (seconds) - ASR considers user done speaking after this
|
||||
VOICE_LITE_SILENCE_TIMEOUT = float(os.getenv("VOICE_LITE_SILENCE_TIMEOUT", "1.5"))
|
||||
|
||||
# ============================================================
|
||||
# Single Agent Mode Configuration
|
||||
# ============================================================
|
||||
|
||||
Loading…
Reference in New Issue
Block a user