import gzip import json from typing import Dict, Any import websockets import config import protocol class RealtimeDialogClient: def __init__(self, config: Dict[str, Any], session_id: str, output_audio_format: str = "pcm", mod: str = "audio", recv_timeout: int = 10) -> None: self.config = config self.logid = "" self.session_id = session_id self.output_audio_format = output_audio_format self.mod = mod self.recv_timeout = recv_timeout self.ws = None async def connect(self) -> None: """建立WebSocket连接""" print(f"url: {self.config['base_url']}, headers: {self.config['headers']}") # For older websockets versions, use additional_headers instead of extra_headers self.ws = await websockets.connect( self.config['base_url'], additional_headers=self.config['headers'], ping_interval=None ) # In older websockets versions, response headers are accessed differently if hasattr(self.ws, 'response_headers'): self.logid = self.ws.response_headers.get("X-Tt-Logid") elif hasattr(self.ws, 'headers'): self.logid = self.ws.headers.get("X-Tt-Logid") else: self.logid = "unknown" print(f"dialog server response logid: {self.logid}") # StartConnection request start_connection_request = bytearray(protocol.generate_header()) start_connection_request.extend(int(1).to_bytes(4, 'big')) payload_bytes = str.encode("{}") payload_bytes = gzip.compress(payload_bytes) start_connection_request.extend((len(payload_bytes)).to_bytes(4, 'big')) start_connection_request.extend(payload_bytes) await self.ws.send(start_connection_request) response = await self.ws.recv() print(f"StartConnection response: {protocol.parse_response(response)}") # 扩大这个参数,可以在一段时间内保持静默,主要用于text模式,参数范围[10,120] config.start_session_req["dialog"]["extra"]["recv_timeout"] = self.recv_timeout # 这个参数,在text或者audio_file模式,可以在一段时间内保持静默 config.start_session_req["dialog"]["extra"]["input_mod"] = self.mod # StartSession request if self.output_audio_format == "pcm_s16le": config.start_session_req["tts"]["audio_config"]["format"] = "pcm_s16le" request_params = config.start_session_req payload_bytes = str.encode(json.dumps(request_params)) payload_bytes = gzip.compress(payload_bytes) start_session_request = bytearray(protocol.generate_header()) start_session_request.extend(int(100).to_bytes(4, 'big')) start_session_request.extend((len(self.session_id)).to_bytes(4, 'big')) start_session_request.extend(str.encode(self.session_id)) start_session_request.extend((len(payload_bytes)).to_bytes(4, 'big')) start_session_request.extend(payload_bytes) await self.ws.send(start_session_request) response = await self.ws.recv() print(f"StartSession response: {protocol.parse_response(response)}") async def say_hello(self) -> None: """发送Hello消息""" payload = { "content": "你好,我是豆包,有什么可以帮助你的?", } hello_request = bytearray(protocol.generate_header()) hello_request.extend(int(300).to_bytes(4, 'big')) payload_bytes = str.encode(json.dumps(payload)) payload_bytes = gzip.compress(payload_bytes) hello_request.extend((len(self.session_id)).to_bytes(4, 'big')) hello_request.extend(str.encode(self.session_id)) hello_request.extend((len(payload_bytes)).to_bytes(4, 'big')) hello_request.extend(payload_bytes) await self.ws.send(hello_request) async def chat_text_query(self, content: str) -> None: """发送Chat Text Query消息""" payload = { "content": content, } chat_text_query_request = bytearray(protocol.generate_header()) chat_text_query_request.extend(int(501).to_bytes(4, 'big')) payload_bytes = str.encode(json.dumps(payload)) payload_bytes = gzip.compress(payload_bytes) chat_text_query_request.extend((len(self.session_id)).to_bytes(4, 'big')) chat_text_query_request.extend(str.encode(self.session_id)) chat_text_query_request.extend((len(payload_bytes)).to_bytes(4, 'big')) chat_text_query_request.extend(payload_bytes) await self.ws.send(chat_text_query_request) async def chat_tts_text(self, is_user_querying: bool, start: bool, end: bool, content: str) -> None: if is_user_querying: return """发送Chat TTS Text消息""" payload = { "start": start, "end": end, "content": content, } print(f"ChatTTSTextRequest payload: {payload}") payload_bytes = str.encode(json.dumps(payload)) payload_bytes = gzip.compress(payload_bytes) chat_tts_text_request = bytearray(protocol.generate_header()) chat_tts_text_request.extend(int(500).to_bytes(4, 'big')) chat_tts_text_request.extend((len(self.session_id)).to_bytes(4, 'big')) chat_tts_text_request.extend(str.encode(self.session_id)) chat_tts_text_request.extend((len(payload_bytes)).to_bytes(4, 'big')) chat_tts_text_request.extend(payload_bytes) await self.ws.send(chat_tts_text_request) async def chat_rag_text(self, is_user_querying: bool, external_rag: str) -> None: if is_user_querying: return """发送Chat TTS Text消息""" payload = { "external_rag": external_rag, } print(f"ChatRAGTextRequest payload: {payload}") payload_bytes = str.encode(json.dumps(payload)) payload_bytes = gzip.compress(payload_bytes) chat_rag_text_request = bytearray(protocol.generate_header()) chat_rag_text_request.extend(int(502).to_bytes(4, 'big')) chat_rag_text_request.extend((len(self.session_id)).to_bytes(4, 'big')) chat_rag_text_request.extend(str.encode(self.session_id)) chat_rag_text_request.extend((len(payload_bytes)).to_bytes(4, 'big')) chat_rag_text_request.extend(payload_bytes) await self.ws.send(chat_rag_text_request) async def task_request(self, audio: bytes) -> None: task_request = bytearray( protocol.generate_header(message_type=protocol.CLIENT_AUDIO_ONLY_REQUEST, serial_method=protocol.NO_SERIALIZATION)) task_request.extend(int(200).to_bytes(4, 'big')) task_request.extend((len(self.session_id)).to_bytes(4, 'big')) task_request.extend(str.encode(self.session_id)) payload_bytes = gzip.compress(audio) task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) task_request.extend(payload_bytes) await self.ws.send(task_request) async def receive_server_response(self) -> Dict[str, Any]: try: response = await self.ws.recv() data = protocol.parse_response(response) return data except Exception as e: raise Exception(f"Failed to receive message: {e}") async def finish_session(self): finish_session_request = bytearray(protocol.generate_header()) finish_session_request.extend(int(102).to_bytes(4, 'big')) payload_bytes = str.encode("{}") payload_bytes = gzip.compress(payload_bytes) finish_session_request.extend((len(self.session_id)).to_bytes(4, 'big')) finish_session_request.extend(str.encode(self.session_id)) finish_session_request.extend((len(payload_bytes)).to_bytes(4, 'big')) finish_session_request.extend(payload_bytes) await self.ws.send(finish_session_request) async def finish_connection(self): finish_connection_request = bytearray(protocol.generate_header()) finish_connection_request.extend(int(2).to_bytes(4, 'big')) payload_bytes = str.encode("{}") payload_bytes = gzip.compress(payload_bytes) finish_connection_request.extend((len(payload_bytes)).to_bytes(4, 'big')) finish_connection_request.extend(payload_bytes) await self.ws.send(finish_connection_request) response = await self.ws.recv() print(f"FinishConnection response: {protocol.parse_response(response)}") async def close(self) -> None: """关闭WebSocket连接""" if self.ws: print(f"Closing WebSocket connection...") await self.ws.close()