qwen_agent/services/realtime_voice_client.py
2026-03-21 23:50:51 +08:00

205 lines
7.9 KiB
Python

import gzip
import json
import uuid
import logging
from typing import Dict, Any, Optional
import websockets
from services import realtime_voice_protocol as protocol
from utils.settings import (
VOLCENGINE_APP_ID,
VOLCENGINE_ACCESS_KEY,
VOLCENGINE_DEFAULT_SPEAKER,
VOLCENGINE_TTS_SAMPLE_RATE,
)
VOLCENGINE_REALTIME_URL = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
logger = logging.getLogger('app')
class RealtimeDialogClient:
"""Fire Mountain Engine Realtime Dialogue API WebSocket Client"""
def __init__(
self,
session_id: str,
speaker: Optional[str] = None,
system_role: Optional[str] = None,
speaking_style: Optional[str] = None,
bot_name: Optional[str] = None,
recv_timeout: int = 60,
input_mod: str = "audio",
) -> None:
self.session_id = session_id
self.speaker = speaker or VOLCENGINE_DEFAULT_SPEAKER
self.system_role = system_role or ""
self.speaking_style = speaking_style or ""
self.bot_name = bot_name or ""
self.recv_timeout = recv_timeout
self.input_mod = input_mod
self.logid = ""
self.ws = None
self._connect_id = str(uuid.uuid4())
def _build_headers(self) -> Dict[str, str]:
return {
"X-Api-App-ID": VOLCENGINE_APP_ID,
"X-Api-Access-Key": VOLCENGINE_ACCESS_KEY,
"X-Api-Resource-Id": "volc.speech.dialog",
"X-Api-App-Key": "PlgvMymc7f3tQnJ6",
"X-Api-Connect-Id": self._connect_id,
}
def _build_session_params(self) -> Dict[str, Any]:
return {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": self.speaker,
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": VOLCENGINE_TTS_SAMPLE_RATE,
},
},
"dialog": {
"bot_name": self.bot_name,
"system_role": self.system_role,
"speaking_style": self.speaking_style,
"extra": {
"strict_audit": False,
"recv_timeout": self.recv_timeout,
"input_mod": self.input_mod,
"enable_volc_websearch": False,
"enable_music": False,
"model": "1.2.1.1"
},
}
}
def _build_event_request(self, event_id: int, payload: dict, with_session: bool = True) -> bytearray:
request = bytearray(protocol.generate_header())
request.extend(int(event_id).to_bytes(4, 'big'))
payload_bytes = str.encode(json.dumps(payload))
payload_bytes = gzip.compress(payload_bytes)
if with_session:
request.extend((len(self.session_id)).to_bytes(4, 'big'))
request.extend(str.encode(self.session_id))
request.extend((len(payload_bytes)).to_bytes(4, 'big'))
request.extend(payload_bytes)
return request
async def connect(self) -> None:
logger.info(f"Connecting to Volcengine Realtime API: {VOLCENGINE_REALTIME_URL}")
headers = self._build_headers()
try:
self.ws = await websockets.connect(
VOLCENGINE_REALTIME_URL,
additional_headers=headers,
ping_interval=None,
proxy=None,
)
except websockets.exceptions.InvalidStatus as e:
body = ""
if e.response and e.response.body:
body = e.response.body.decode("utf-8", errors="replace")
raise ConnectionError(
f"Volcengine Realtime API rejected connection: HTTP {e.response.status_code} - {body}"
) from e
self.logid = self.ws.response.headers.get("X-Tt-Logid", "")
logger.info(f"Volcengine Realtime connected, logid: {self.logid}")
# StartConnection (event 1)
start_conn = bytearray(protocol.generate_header())
start_conn.extend(int(1).to_bytes(4, 'big'))
payload_bytes = gzip.compress(str.encode("{}"))
start_conn.extend((len(payload_bytes)).to_bytes(4, 'big'))
start_conn.extend(payload_bytes)
await self.ws.send(start_conn)
response = await self.ws.recv()
logger.info(f"StartConnection response: {protocol.parse_response(response)}")
# StartSession (event 100)
session_params = self._build_session_params()
payload_bytes = gzip.compress(str.encode(json.dumps(session_params)))
start_session = bytearray(protocol.generate_header())
start_session.extend(int(100).to_bytes(4, 'big'))
start_session.extend((len(self.session_id)).to_bytes(4, 'big'))
start_session.extend(str.encode(self.session_id))
start_session.extend((len(payload_bytes)).to_bytes(4, 'big'))
start_session.extend(payload_bytes)
await self.ws.send(start_session)
response = await self.ws.recv()
logger.info(f"StartSession response: {protocol.parse_response(response)}")
async def send_audio(self, audio: bytes) -> None:
"""Send audio data (event 200)"""
task_request = bytearray(
protocol.generate_header(
message_type=protocol.CLIENT_AUDIO_ONLY_REQUEST,
serial_method=protocol.NO_SERIALIZATION,
)
)
task_request.extend(int(200).to_bytes(4, 'big'))
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
task_request.extend(str.encode(self.session_id))
payload_bytes = gzip.compress(audio)
task_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
task_request.extend(payload_bytes)
await self.ws.send(task_request)
async def chat_text_query(self, content: str) -> None:
"""Send text query (event 501)"""
await self.ws.send(self._build_event_request(501, {"content": content}))
async def chat_tts_text(self, content: str, start: bool = True, end: bool = True) -> None:
"""Send TTS text for synthesis (event 500)"""
await self.ws.send(
self._build_event_request(500, {"start": start, "end": end, "content": content})
)
async def chat_rag_text(self, external_rag: str) -> None:
"""Inject external RAG result (event 502)"""
await self.ws.send(
self._build_event_request(502, {"external_rag": external_rag})
)
async def receive_response(self) -> Dict[str, Any]:
try:
response = await self.ws.recv()
return protocol.parse_response(response)
except Exception as e:
raise Exception(f"Failed to receive message: {e}")
async def finish_session(self) -> None:
"""Finish session (event 102)"""
finish_req = bytearray(protocol.generate_header())
finish_req.extend(int(102).to_bytes(4, 'big'))
payload_bytes = gzip.compress(str.encode("{}"))
finish_req.extend((len(self.session_id)).to_bytes(4, 'big'))
finish_req.extend(str.encode(self.session_id))
finish_req.extend((len(payload_bytes)).to_bytes(4, 'big'))
finish_req.extend(payload_bytes)
await self.ws.send(finish_req)
async def finish_connection(self) -> None:
"""Finish connection (event 2)"""
finish_req = bytearray(protocol.generate_header())
finish_req.extend(int(2).to_bytes(4, 'big'))
payload_bytes = gzip.compress(str.encode("{}"))
finish_req.extend((len(payload_bytes)).to_bytes(4, 'big'))
finish_req.extend(payload_bytes)
await self.ws.send(finish_req)
response = await self.ws.recv()
logger.info(f"FinishConnection response: {protocol.parse_response(response)}")
async def close(self) -> None:
if self.ws:
logger.info("Closing Volcengine Realtime WebSocket connection")
await self.ws.close()
self.ws = None