import gzip import json import uuid import logging from typing import Dict, Any, Optional import websockets from services import realtime_voice_protocol as protocol from utils.settings import ( VOLCENGINE_REALTIME_URL, VOLCENGINE_APP_ID, VOLCENGINE_ACCESS_KEY, VOLCENGINE_RESOURCE_ID, VOLCENGINE_APP_KEY, VOLCENGINE_DEFAULT_SPEAKER, VOLCENGINE_TTS_SAMPLE_RATE, ) logger = logging.getLogger('app') class RealtimeDialogClient: """Fire Mountain Engine Realtime Dialogue API WebSocket Client""" def __init__( self, session_id: str, speaker: Optional[str] = None, system_role: Optional[str] = None, speaking_style: Optional[str] = None, bot_name: Optional[str] = None, recv_timeout: int = 60, input_mod: str = "audio", ) -> None: self.session_id = session_id self.speaker = speaker or VOLCENGINE_DEFAULT_SPEAKER self.system_role = system_role or "" self.speaking_style = speaking_style or "" self.bot_name = bot_name or "" self.recv_timeout = recv_timeout self.input_mod = input_mod self.logid = "" self.ws = None self._connect_id = str(uuid.uuid4()) def _build_headers(self) -> Dict[str, str]: return { "X-Api-App-ID": VOLCENGINE_APP_ID, "X-Api-Access-Key": VOLCENGINE_ACCESS_KEY, "X-Api-Resource-Id": VOLCENGINE_RESOURCE_ID, "X-Api-App-Key": VOLCENGINE_APP_KEY, "X-Api-Connect-Id": self._connect_id, } def _build_session_params(self) -> Dict[str, Any]: return { "asr": { "extra": { "end_smooth_window_ms": 1500, }, }, "tts": { "speaker": self.speaker, "audio_config": { "channel": 1, "format": "pcm", "sample_rate": VOLCENGINE_TTS_SAMPLE_RATE, }, }, "dialog": { "bot_name": self.bot_name, "system_role": self.system_role, "speaking_style": self.speaking_style, "extra": { "strict_audit": False, "recv_timeout": self.recv_timeout, "input_mod": self.input_mod, "enable_volc_websearch": False, "enable_music": False, "model": "1.2.1.1" }, } } def _build_event_request(self, event_id: int, payload: dict, with_session: bool = True) -> bytearray: request = bytearray(protocol.generate_header()) request.extend(int(event_id).to_bytes(4, 'big')) payload_bytes = str.encode(json.dumps(payload)) payload_bytes = gzip.compress(payload_bytes) if with_session: request.extend((len(self.session_id)).to_bytes(4, 'big')) request.extend(str.encode(self.session_id)) request.extend((len(payload_bytes)).to_bytes(4, 'big')) request.extend(payload_bytes) return request async def connect(self) -> None: logger.info(f"Connecting to Volcengine Realtime API: {VOLCENGINE_REALTIME_URL}") headers = self._build_headers() try: self.ws = await websockets.connect( VOLCENGINE_REALTIME_URL, additional_headers=headers, ping_interval=None, proxy=None, ) except websockets.exceptions.InvalidStatus as e: body = "" if e.response and e.response.body: body = e.response.body.decode("utf-8", errors="replace") raise ConnectionError( f"Volcengine Realtime API rejected connection: HTTP {e.response.status_code} - {body}" ) from e self.logid = self.ws.response.headers.get("X-Tt-Logid", "") logger.info(f"Volcengine Realtime connected, logid: {self.logid}") # StartConnection (event 1) start_conn = bytearray(protocol.generate_header()) start_conn.extend(int(1).to_bytes(4, 'big')) payload_bytes = gzip.compress(str.encode("{}")) start_conn.extend((len(payload_bytes)).to_bytes(4, 'big')) start_conn.extend(payload_bytes) await self.ws.send(start_conn) response = await self.ws.recv() logger.info(f"StartConnection response: {protocol.parse_response(response)}") # StartSession (event 100) session_params = self._build_session_params() payload_bytes = gzip.compress(str.encode(json.dumps(session_params))) start_session = bytearray(protocol.generate_header()) start_session.extend(int(100).to_bytes(4, 'big')) start_session.extend((len(self.session_id)).to_bytes(4, 'big')) start_session.extend(str.encode(self.session_id)) start_session.extend((len(payload_bytes)).to_bytes(4, 'big')) start_session.extend(payload_bytes) await self.ws.send(start_session) response = await self.ws.recv() logger.info(f"StartSession response: {protocol.parse_response(response)}") async def send_audio(self, audio: bytes) -> None: """Send audio data (event 200)""" task_request = bytearray( protocol.generate_header( message_type=protocol.CLIENT_AUDIO_ONLY_REQUEST, serial_method=protocol.NO_SERIALIZATION, ) ) task_request.extend(int(200).to_bytes(4, 'big')) task_request.extend((len(self.session_id)).to_bytes(4, 'big')) task_request.extend(str.encode(self.session_id)) payload_bytes = gzip.compress(audio) task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) task_request.extend(payload_bytes) await self.ws.send(task_request) async def chat_text_query(self, content: str) -> None: """Send text query (event 501)""" await self.ws.send(self._build_event_request(501, {"content": content})) async def chat_tts_text(self, content: str, start: bool = True, end: bool = True) -> None: """Send TTS text for synthesis (event 500)""" await self.ws.send( self._build_event_request(500, {"start": start, "end": end, "content": content}) ) async def chat_rag_text(self, external_rag: str) -> None: """Inject external RAG result (event 502)""" await self.ws.send( self._build_event_request(502, {"external_rag": external_rag}) ) async def receive_response(self) -> Dict[str, Any]: try: response = await self.ws.recv() return protocol.parse_response(response) except Exception as e: raise Exception(f"Failed to receive message: {e}") async def finish_session(self) -> None: """Finish session (event 102)""" finish_req = bytearray(protocol.generate_header()) finish_req.extend(int(102).to_bytes(4, 'big')) payload_bytes = gzip.compress(str.encode("{}")) finish_req.extend((len(self.session_id)).to_bytes(4, 'big')) finish_req.extend(str.encode(self.session_id)) finish_req.extend((len(payload_bytes)).to_bytes(4, 'big')) finish_req.extend(payload_bytes) await self.ws.send(finish_req) async def finish_connection(self) -> None: """Finish connection (event 2)""" finish_req = bytearray(protocol.generate_header()) finish_req.extend(int(2).to_bytes(4, 'big')) payload_bytes = gzip.compress(str.encode("{}")) finish_req.extend((len(payload_bytes)).to_bytes(4, 'big')) finish_req.extend(payload_bytes) await self.ws.send(finish_req) response = await self.ws.recv() logger.info(f"FinishConnection response: {protocol.parse_response(response)}") async def close(self) -> None: if self.ws: logger.info("Closing Volcengine Realtime WebSocket connection") await self.ws.close() self.ws = None