Local-Voice/doubao/realtime_dialog_client.py
2025-09-18 23:34:55 +08:00

188 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gzip
import json
from typing import Dict, Any
import websockets
import config
import protocol
class RealtimeDialogClient:
def __init__(self, config: Dict[str, Any], session_id: str, output_audio_format: str = "pcm",
mod: str = "audio", recv_timeout: int = 10) -> None:
self.config = config
self.logid = ""
self.session_id = session_id
self.output_audio_format = output_audio_format
self.mod = mod
self.recv_timeout = recv_timeout
self.ws = None
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"url: {self.config['base_url']}, headers: {self.config['headers']}")
# For older websockets versions, use additional_headers instead of extra_headers
self.ws = await websockets.connect(
self.config['base_url'],
additional_headers=self.config['headers'],
ping_interval=None
)
# In older websockets versions, response headers are accessed differently
if hasattr(self.ws, 'response_headers'):
self.logid = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.logid = self.ws.headers.get("X-Tt-Logid")
else:
self.logid = "unknown"
print(f"dialog server response logid: {self.logid}")
# StartConnection request
start_connection_request = bytearray(protocol.generate_header())
start_connection_request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = str.encode("{}")
payload_bytes = gzip.compress(payload_bytes)
start_connection_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
start_connection_request.extend(payload_bytes)
await self.ws.send(start_connection_request)
response = await self.ws.recv()
print(f"StartConnection response: {protocol.parse_response(response)}")
# 扩大这个参数可以在一段时间内保持静默主要用于text模式参数范围[10,120]
config.start_session_req["dialog"]["extra"]["recv_timeout"] = self.recv_timeout
# 这个参数在text或者audio_file模式可以在一段时间内保持静默
config.start_session_req["dialog"]["extra"]["input_mod"] = self.mod
# StartSession request
if self.output_audio_format == "pcm_s16le":
config.start_session_req["tts"]["audio_config"]["format"] = "pcm_s16le"
request_params = config.start_session_req
payload_bytes = str.encode(json.dumps(request_params))
payload_bytes = gzip.compress(payload_bytes)
start_session_request = bytearray(protocol.generate_header())
start_session_request.extend(int(100).to_bytes(4, 'big'))
start_session_request.extend((len(self.session_id)).to_bytes(4, 'big'))
start_session_request.extend(str.encode(self.session_id))
start_session_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
start_session_request.extend(payload_bytes)
await self.ws.send(start_session_request)
response = await self.ws.recv()
print(f"StartSession response: {protocol.parse_response(response)}")
async def say_hello(self) -> None:
"""发送Hello消息"""
payload = {
"content": "你好,我是豆包,有什么可以帮助你的?",
}
hello_request = bytearray(protocol.generate_header())
hello_request.extend(int(300).to_bytes(4, 'big'))
payload_bytes = str.encode(json.dumps(payload))
payload_bytes = gzip.compress(payload_bytes)
hello_request.extend((len(self.session_id)).to_bytes(4, 'big'))
hello_request.extend(str.encode(self.session_id))
hello_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
hello_request.extend(payload_bytes)
await self.ws.send(hello_request)
async def chat_text_query(self, content: str) -> None:
"""发送Chat Text Query消息"""
payload = {
"content": content,
}
chat_text_query_request = bytearray(protocol.generate_header())
chat_text_query_request.extend(int(501).to_bytes(4, 'big'))
payload_bytes = str.encode(json.dumps(payload))
payload_bytes = gzip.compress(payload_bytes)
chat_text_query_request.extend((len(self.session_id)).to_bytes(4, 'big'))
chat_text_query_request.extend(str.encode(self.session_id))
chat_text_query_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
chat_text_query_request.extend(payload_bytes)
await self.ws.send(chat_text_query_request)
async def chat_tts_text(self, is_user_querying: bool, start: bool, end: bool, content: str) -> None:
if is_user_querying:
return
"""发送Chat TTS Text消息"""
payload = {
"start": start,
"end": end,
"content": content,
}
print(f"ChatTTSTextRequest payload: {payload}")
payload_bytes = str.encode(json.dumps(payload))
payload_bytes = gzip.compress(payload_bytes)
chat_tts_text_request = bytearray(protocol.generate_header())
chat_tts_text_request.extend(int(500).to_bytes(4, 'big'))
chat_tts_text_request.extend((len(self.session_id)).to_bytes(4, 'big'))
chat_tts_text_request.extend(str.encode(self.session_id))
chat_tts_text_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
chat_tts_text_request.extend(payload_bytes)
await self.ws.send(chat_tts_text_request)
async def chat_rag_text(self, is_user_querying: bool, external_rag: str) -> None:
if is_user_querying:
return
"""发送Chat TTS Text消息"""
payload = {
"external_rag": external_rag,
}
print(f"ChatRAGTextRequest payload: {payload}")
payload_bytes = str.encode(json.dumps(payload))
payload_bytes = gzip.compress(payload_bytes)
chat_rag_text_request = bytearray(protocol.generate_header())
chat_rag_text_request.extend(int(502).to_bytes(4, 'big'))
chat_rag_text_request.extend((len(self.session_id)).to_bytes(4, 'big'))
chat_rag_text_request.extend(str.encode(self.session_id))
chat_rag_text_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
chat_rag_text_request.extend(payload_bytes)
await self.ws.send(chat_rag_text_request)
async def task_request(self, audio: bytes) -> None:
task_request = bytearray(
protocol.generate_header(message_type=protocol.CLIENT_AUDIO_ONLY_REQUEST,
serial_method=protocol.NO_SERIALIZATION))
task_request.extend(int(200).to_bytes(4, 'big'))
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
task_request.extend(str.encode(self.session_id))
payload_bytes = gzip.compress(audio)
task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
task_request.extend(payload_bytes)
await self.ws.send(task_request)
async def receive_server_response(self) -> Dict[str, Any]:
try:
response = await self.ws.recv()
data = protocol.parse_response(response)
return data
except Exception as e:
raise Exception(f"Failed to receive message: {e}")
async def finish_session(self):
finish_session_request = bytearray(protocol.generate_header())
finish_session_request.extend(int(102).to_bytes(4, 'big'))
payload_bytes = str.encode("{}")
payload_bytes = gzip.compress(payload_bytes)
finish_session_request.extend((len(self.session_id)).to_bytes(4, 'big'))
finish_session_request.extend(str.encode(self.session_id))
finish_session_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
finish_session_request.extend(payload_bytes)
await self.ws.send(finish_session_request)
async def finish_connection(self):
finish_connection_request = bytearray(protocol.generate_header())
finish_connection_request.extend(int(2).to_bytes(4, 'big'))
payload_bytes = str.encode("{}")
payload_bytes = gzip.compress(payload_bytes)
finish_connection_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
finish_connection_request.extend(payload_bytes)
await self.ws.send(finish_connection_request)
response = await self.ws.recv()
print(f"FinishConnection response: {protocol.parse_response(response)}")
async def close(self) -> None:
"""关闭WebSocket连接"""
if self.ws:
print(f"Closing WebSocket connection...")
await self.ws.close()