This commit is contained in:
朱潮 2025-09-18 21:34:36 +08:00
parent d4ff3fd774
commit e6aa7f7be8
15 changed files with 824 additions and 0 deletions

BIN
.DS_Store vendored

Binary file not shown.

BIN
doubao/.DS_Store vendored Normal file

Binary file not shown.

37
doubao/README.md Normal file
View File

@ -0,0 +1,37 @@
# RealtimeDialog
实时语音对话程序,支持语音输入和语音输出。
## 使用说明
此demo使用python3.7环境进行开发调试其他python版本可能会有兼容性问题需要自己尝试解决。
1. 配置API密钥
- 打开 `config.py` 文件
- 修改以下两个字段:
```python
"X-Api-App-ID": "火山控制台上端到端大模型对应的App ID",
"X-Api-Access-Key": "火山控制台上端到端大模型对应的Access Key",
```
- 修改speaker字段指定发音人本次支持四个发音人
- `zh_female_vv_jupiter_bigtts`中文vv女声
- `zh_female_xiaohe_jupiter_bigtts`中文xiaohe女声
- `zh_male_yunzhou_jupiter_bigtts`:中文云洲男声
- `zh_male_xiaotian_jupiter_bigtts`:中文小天男声
2. 安装依赖
```bash
pip install -r requirements.txt
3. 通过麦克风运行程序
```bash
python main.py --format=pcm
```
4. 通过录音文件启动程序
```bash
python main.py --audio=whoareyou.wav
```
5. 通过纯文本输入和程序交互
```bash
python main.py --mod=text --recv_timeout=120
```

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

385
doubao/audio_manager.py Normal file
View File

@ -0,0 +1,385 @@
import asyncio
import queue
import random
import signal
import sys
import threading
import time
import uuid
import wave
from dataclasses import dataclass
from typing import Optional, Dict, Any
import pyaudio
import config
from realtime_dialog_client import RealtimeDialogClient
@dataclass
class AudioConfig:
"""音频配置数据类"""
format: str
bit_size: int
channels: int
sample_rate: int
chunk: int
class AudioDeviceManager:
"""音频设备管理类,处理音频输入输出"""
def __init__(self, input_config: AudioConfig, output_config: AudioConfig):
self.input_config = input_config
self.output_config = output_config
self.pyaudio = pyaudio.PyAudio()
self.input_stream: Optional[pyaudio.Stream] = None
self.output_stream: Optional[pyaudio.Stream] = None
def open_input_stream(self) -> pyaudio.Stream:
"""打开音频输入流"""
# p = pyaudio.PyAudio()
self.input_stream = self.pyaudio.open(
format=self.input_config.bit_size,
channels=self.input_config.channels,
rate=self.input_config.sample_rate,
input=True,
frames_per_buffer=self.input_config.chunk
)
return self.input_stream
def open_output_stream(self) -> pyaudio.Stream:
"""打开音频输出流"""
self.output_stream = self.pyaudio.open(
format=self.output_config.bit_size,
channels=self.output_config.channels,
rate=self.output_config.sample_rate,
output=True,
frames_per_buffer=self.output_config.chunk
)
return self.output_stream
def cleanup(self) -> None:
"""清理音频设备资源"""
for stream in [self.input_stream, self.output_stream]:
if stream:
stream.stop_stream()
stream.close()
self.pyaudio.terminate()
class DialogSession:
"""对话会话管理类"""
is_audio_file_input: bool
mod: str
def __init__(self, ws_config: Dict[str, Any], output_audio_format: str = "pcm", audio_file_path: str = "",
mod: str = "audio", recv_timeout: int = 10):
self.audio_file_path = audio_file_path
self.recv_timeout = recv_timeout
self.is_audio_file_input = self.audio_file_path != ""
if self.is_audio_file_input:
mod = 'audio_file'
else:
self.say_hello_over_event = asyncio.Event()
self.mod = mod
self.session_id = str(uuid.uuid4())
self.client = RealtimeDialogClient(config=ws_config, session_id=self.session_id,
output_audio_format=output_audio_format, mod=mod, recv_timeout=recv_timeout)
if output_audio_format == "pcm_s16le":
config.output_audio_config["format"] = "pcm_s16le"
config.output_audio_config["bit_size"] = pyaudio.paInt16
self.is_running = True
self.is_session_finished = False
self.is_user_querying = False
self.is_sending_chat_tts_text = False
self.audio_buffer = b''
signal.signal(signal.SIGINT, self._keyboard_signal)
self.audio_queue = queue.Queue()
if not self.is_audio_file_input:
self.audio_device = AudioDeviceManager(
AudioConfig(**config.input_audio_config),
AudioConfig(**config.output_audio_config)
)
# 初始化音频队列和输出流
self.output_stream = self.audio_device.open_output_stream()
# 启动播放线程
self.is_recording = True
self.is_playing = True
self.player_thread = threading.Thread(target=self._audio_player_thread)
self.player_thread.daemon = True
self.player_thread.start()
def _audio_player_thread(self):
"""音频播放线程"""
while self.is_playing:
try:
# 从队列获取音频数据
audio_data = self.audio_queue.get(timeout=1.0)
if audio_data is not None:
self.output_stream.write(audio_data)
except queue.Empty:
# 队列为空时等待一小段时间
time.sleep(0.1)
except Exception as e:
print(f"音频播放错误: {e}")
time.sleep(0.1)
def handle_server_response(self, response: Dict[str, Any]) -> None:
if response == {}:
return
"""处理服务器响应"""
if response['message_type'] == 'SERVER_ACK' and isinstance(response.get('payload_msg'), bytes):
# print(f"\n接收到音频数据: {len(response['payload_msg'])} 字节")
if self.is_sending_chat_tts_text:
return
audio_data = response['payload_msg']
if not self.is_audio_file_input:
self.audio_queue.put(audio_data)
self.audio_buffer += audio_data
elif response['message_type'] == 'SERVER_FULL_RESPONSE':
print(f"服务器响应: {response}")
event = response.get('event')
payload_msg = response.get('payload_msg', {})
if event == 450:
print(f"清空缓存音频: {response['session_id']}")
while not self.audio_queue.empty():
try:
self.audio_queue.get_nowait()
except queue.Empty:
continue
self.is_user_querying = True
if event == 350 and self.is_sending_chat_tts_text and payload_msg.get("tts_type") in ["chat_tts_text", "external_rag"]:
while not self.audio_queue.empty():
try:
self.audio_queue.get_nowait()
except queue.Empty:
continue
self.is_sending_chat_tts_text = False
if event == 459:
self.is_user_querying = False
if random.randint(0, 100000)%1 == 0:
self.is_sending_chat_tts_text = True
asyncio.create_task(self.trigger_chat_tts_text())
asyncio.create_task(self.trigger_chat_rag_text())
elif response['message_type'] == 'SERVER_ERROR':
print(f"服务器错误: {response['payload_msg']}")
raise Exception("服务器错误")
async def trigger_chat_tts_text(self):
"""概率触发发送ChatTTSText请求"""
print("hit ChatTTSText event, start sending...")
await self.client.chat_tts_text(
is_user_querying=self.is_user_querying,
start=True,
end=False,
content="这是查询到外部数据之前的安抚话术。",
)
await self.client.chat_tts_text(
is_user_querying=self.is_user_querying,
start=False,
end=True,
content="",
)
async def trigger_chat_rag_text(self):
await asyncio.sleep(5) # 模拟查询外部RAG的耗时这里为了不影响GTA安抚话术的播报直接sleep 5秒
print("hit ChatRAGText event, start sending...")
await self.client.chat_rag_text(self.is_user_querying, external_rag='[{"title":"北京天气","content":"今天北京整体以晴到多云为主,但西部和北部地带可能会出现分散性雷阵雨,特别是午后至傍晚时段需注意突发降雨。\n💨 风况与湿度\n风力较弱,一般为 23 级南风或西南风\n白天湿度较高,早晚略凉爽"}]')
def _keyboard_signal(self, sig, frame):
print(f"receive keyboard Ctrl+C")
self.stop()
def stop(self):
self.is_recording = False
self.is_playing = False
self.is_running = False
async def receive_loop(self):
try:
while True:
response = await self.client.receive_server_response()
self.handle_server_response(response)
if 'event' in response and (response['event'] == 152 or response['event'] == 153):
print(f"receive session finished event: {response['event']}")
self.is_session_finished = True
break
if 'event' in response and response['event'] == 359:
if self.is_audio_file_input:
print(f"receive tts ended event")
self.is_session_finished = True
break
else:
if not self.say_hello_over_event.is_set():
print(f"receive tts sayhello ended event")
self.say_hello_over_event.set()
if self.mod == "text":
print("请输入内容:")
except asyncio.CancelledError:
print("接收任务已取消")
except Exception as e:
print(f"接收消息错误: {e}")
finally:
self.stop()
self.is_session_finished = True
async def process_audio_file(self) -> None:
await self.process_audio_file_input(self.audio_file_path)
async def process_text_input(self) -> None:
await self.client.say_hello()
await self.say_hello_over_event.wait()
"""主逻辑处理文本输入和WebSocket通信"""
# 确保连接最终关闭
try:
# 启动输入监听线程
input_queue = queue.Queue()
input_thread = threading.Thread(target=self.input_listener, args=(input_queue,), daemon=True)
input_thread.start()
# 主循环:处理输入和上下文结束
while self.is_running:
try:
# 检查是否有输入(非阻塞)
input_str = input_queue.get_nowait()
if input_str is None:
# 输入流关闭
print("Input channel closed")
break
if input_str:
# 发送输入内容
await self.client.chat_text_query(input_str)
except queue.Empty:
# 无输入时短暂休眠
await asyncio.sleep(0.1)
except Exception as e:
print(f"Main loop error: {e}")
break
finally:
print("exit text input")
def input_listener(self, input_queue: queue.Queue) -> None:
"""在单独线程中监听标准输入"""
print("Start listening for input")
try:
while True:
# 读取标准输入(阻塞操作)
line = sys.stdin.readline()
if not line:
# 输入流关闭
input_queue.put(None)
break
input_str = line.strip()
input_queue.put(input_str)
except Exception as e:
print(f"Input listener error: {e}")
input_queue.put(None)
async def process_audio_file_input(self, audio_file_path: str) -> None:
# 读取WAV文件
with wave.open(audio_file_path, 'rb') as wf:
chunk_size = config.input_audio_config["chunk"]
framerate = wf.getframerate() # 采样率如16000Hz
# 时长 = chunkSize帧数 ÷ 采样率(帧/秒)
sleep_seconds = chunk_size / framerate
print(f"开始处理音频文件: {audio_file_path}")
# 分块读取并发送音频数据
while True:
audio_data = wf.readframes(chunk_size)
if not audio_data:
break # 文件读取完毕
await self.client.task_request(audio_data)
# sleep与chunk对应的音频时长一致模拟实时输入
await asyncio.sleep(sleep_seconds)
print(f"音频文件处理完成,等待服务器响应...")
async def process_silence_audio(self) -> None:
"""发送静音音频"""
silence_data = b'\x00' * 320
await self.client.task_request(silence_data)
async def process_microphone_input(self) -> None:
await self.client.say_hello()
await self.say_hello_over_event.wait()
await self.client.chat_text_query("你好,我也叫豆包")
"""处理麦克风输入"""
stream = self.audio_device.open_input_stream()
print("已打开麦克风,请讲话...")
while self.is_recording:
try:
# 添加exception_on_overflow=False参数来忽略溢出错误
audio_data = stream.read(config.input_audio_config["chunk"], exception_on_overflow=False)
save_input_pcm_to_wav(audio_data, "input.pcm")
await self.client.task_request(audio_data)
await asyncio.sleep(0.01) # 避免CPU过度使用
except Exception as e:
print(f"读取麦克风数据出错: {e}")
await asyncio.sleep(0.1) # 给系统一些恢复时间
async def start(self) -> None:
"""启动对话会话"""
try:
await self.client.connect()
if self.mod == "text":
asyncio.create_task(self.process_text_input())
asyncio.create_task(self.receive_loop())
while self.is_running:
await asyncio.sleep(0.1)
else:
if self.is_audio_file_input:
asyncio.create_task(self.process_audio_file())
await self.receive_loop()
else:
asyncio.create_task(self.process_microphone_input())
asyncio.create_task(self.receive_loop())
while self.is_running:
await asyncio.sleep(0.1)
await self.client.finish_session()
while not self.is_session_finished:
await asyncio.sleep(0.1)
await self.client.finish_connection()
await asyncio.sleep(0.1)
await self.client.close()
print(f"dialog request logid: {self.client.logid}, chat mod: {self.mod}")
save_output_to_file(self.audio_buffer, "output.pcm")
except Exception as e:
print(f"会话错误: {e}")
finally:
if not self.is_audio_file_input:
self.audio_device.cleanup()
def save_input_pcm_to_wav(pcm_data: bytes, filename: str) -> None:
"""保存PCM数据为WAV文件"""
with wave.open(filename, 'wb') as wf:
wf.setnchannels(config.input_audio_config["channels"])
wf.setsampwidth(2) # paInt16 = 2 bytes
wf.setframerate(config.input_audio_config["sample_rate"])
wf.writeframes(pcm_data)
def save_output_to_file(audio_data: bytes, filename: str) -> None:
"""保存原始PCM音频数据到文件"""
if not audio_data:
print("No audio data to save.")
return
try:
with open(filename, 'wb') as f:
f.write(audio_data)
except IOError as e:
print(f"Failed to save pcm file: {e}")

63
doubao/config.py Normal file
View File

@ -0,0 +1,63 @@
import uuid
import pyaudio
# 配置信息
ws_connect_config = {
"base_url": "wss://openspeech.bytedance.com/api/v3/realtime/dialogue",
"headers": {
"X-Api-App-ID": "",
"X-Api-Access-Key": "",
"X-Api-Resource-Id": "volc.speech.dialog", # 固定值
"X-Api-App-Key": "PlgvMymc7f3tQnJ6", # 固定值
"X-Api-Connect-Id": str(uuid.uuid4()),
}
}
start_session_req = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_male_yunzhou_jupiter_bigtts",
# "speaker": "S_XXXXXX", // 指定自定义的复刻音色,需要填下character_manifest
# "speaker": "ICL_zh_female_aojiaonvyou_tob" // 指定官方复刻音色不需要填character_manifest
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
# "character_manifest": "外貌与穿着\n26岁短发干净利落眉眼分明笑起来露出整齐有力的牙齿。体态挺拔肌肉线条不夸张但明显。常穿简单的衬衫或夹克看似随意但每件衣服都干净整洁给人一种干练可靠的感觉。平时冷峻眼神锐利专注时让人不自觉紧张。\n\n性格特点\n平时话不多不喜欢多说废话通常用“嗯”或者短句带过。但内心极为细腻特别在意身边人的感受只是不轻易表露。嘴硬是常态“少管我”是他的常用台词但会悄悄做些体贴的事情比如把对方喜欢的饮料放在手边。战斗或训练后常说“没事”但动作中透露出疲惫习惯用小动作缓解身体酸痛。\n性格上坚毅果断但不会冲动做事有条理且有原则。\n\n常用表达方式与口头禅\n\t•\t认可对方时\n“行吧这次算你靠谱。”声音稳重手却不自觉放松一下心里松口气\n\t•\t关心对方时\n“快点回去别磨蹭。”语气干脆但眼神一直追着对方的背影\n\t•\t想了解情况时\n“刚刚……你看到那道光了吗话语随意手指敲着桌面但内心紧张小心隐藏身份",
"location": {
"city": "北京",
},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 10,
"input_mod": "audio"
}
}
}
input_audio_config = {
"chunk": 3200,
"format": "pcm",
"channels": 1,
"sample_rate": 16000,
"bit_size": pyaudio.paInt16
}
output_audio_config = {
"chunk": 3200,
"format": "pcm",
"channels": 1,
"sample_rate": 24000,
"bit_size": pyaudio.paFloat32
}

20
doubao/main.py Normal file
View File

@ -0,0 +1,20 @@
import asyncio
import argparse
import config
from audio_manager import DialogSession
async def main() -> None:
parser = argparse.ArgumentParser(description="Real-time Dialog Client")
parser.add_argument("--format", type=str, default="pcm", help="The audio format (e.g., pcm, pcm_s16le).")
parser.add_argument("--audio", type=str, default="", help="audio file send to server, if not set, will use microphone input.")
parser.add_argument("--mod",type=str,default="audio",help="Use mod to select plain text input mode or audio mode, the default is audio mode")
parser.add_argument("--recv_timeout",type=int,default=10,help="Timeout for receiving messages,value range [10,120]")
args = parser.parse_args()
session = DialogSession(ws_config=config.ws_connect_config, output_audio_format=args.format, audio_file_path=args.audio,mod=args.mod,recv_timeout=args.recv_timeout)
await session.start()
if __name__ == "__main__":
asyncio.run(main())

135
doubao/protocol.py Normal file
View File

@ -0,0 +1,135 @@
import gzip
import json
PROTOCOL_VERSION = 0b0001
DEFAULT_HEADER_SIZE = 0b0001
PROTOCOL_VERSION_BITS = 4
HEADER_BITS = 4
MESSAGE_TYPE_BITS = 4
MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4
MESSAGE_SERIALIZATION_BITS = 4
MESSAGE_COMPRESSION_BITS = 4
RESERVED_BITS = 8
# Message Type:
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
# Message Type Specific Flags
NO_SEQUENCE = 0b0000 # no check sequence
POS_SEQUENCE = 0b0001
NEG_SEQUENCE = 0b0010
NEG_SEQUENCE_1 = 0b0011
MSG_WITH_EVENT = 0b0100
# Message Serialization
NO_SERIALIZATION = 0b0000
JSON = 0b0001
THRIFT = 0b0011
CUSTOM_TYPE = 0b1111
# Message Compression
NO_COMPRESSION = 0b0000
GZIP = 0b0001
CUSTOM_COMPRESSION = 0b1111
def generate_header(
version=PROTOCOL_VERSION,
message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON,
compression_type=GZIP,
reserved_data=0x00,
extension_header=bytes()
):
"""
protocol_version(4 bits), header_size(4 bits),
message_type(4 bits), message_type_specific_flags(4 bits)
serialization_method(4 bits) message_compression(4 bits)
reserved 8bits) 保留字段
header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
"""
header = bytearray()
header_size = int(len(extension_header) / 4) + 1
header.append((version << 4) | header_size)
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(reserved_data)
header.extend(extension_header)
return header
def parse_response(res):
"""
- header
- (4bytes)header
- (4bits)version(v1) + (4bits)header_size
- (4bits)messageType + (4bits)messageTypeFlags
-- 0001 CompleteClient | -- 0001 hasSequence
-- 0010 audioonly | -- 0010 isTailPacket
| -- 0100 hasEvent
- (4bits)payloadFormat + (4bits)compression
- (8bits) reserve
- payload
- [optional 4 bytes] event
- [optional] session ID
-- (4 bytes)session ID len
-- session ID data
- (4 bytes)data len
- data
"""
if isinstance(res, str):
return {}
protocol_version = res[0] >> 4
header_size = res[0] & 0x0f
message_type = res[1] >> 4
message_type_specific_flags = res[1] & 0x0f
serialization_method = res[2] >> 4
message_compression = res[2] & 0x0f
reserved = res[3]
header_extensions = res[4:header_size * 4]
payload = res[header_size * 4:]
result = {}
payload_msg = None
payload_size = 0
start = 0
if message_type == SERVER_FULL_RESPONSE or message_type == SERVER_ACK:
result['message_type'] = 'SERVER_FULL_RESPONSE'
if message_type == SERVER_ACK:
result['message_type'] = 'SERVER_ACK'
if message_type_specific_flags & NEG_SEQUENCE > 0:
result['seq'] = int.from_bytes(payload[:4], "big", signed=False)
start += 4
if message_type_specific_flags & MSG_WITH_EVENT > 0:
result['event'] = int.from_bytes(payload[:4], "big", signed=False)
start += 4
payload = payload[start:]
session_id_size = int.from_bytes(payload[:4], "big", signed=True)
session_id = payload[4:session_id_size+4]
result['session_id'] = str(session_id)
payload = payload[4 + session_id_size:]
payload_size = int.from_bytes(payload[:4], "big", signed=False)
payload_msg = payload[4:]
elif message_type == SERVER_ERROR_RESPONSE:
code = int.from_bytes(payload[:4], "big", signed=False)
result['code'] = code
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
payload_msg = payload[8:]
if payload_msg is None:
return result
if message_compression == GZIP:
payload_msg = gzip.decompress(payload_msg)
if serialization_method == JSON:
payload_msg = json.loads(str(payload_msg, "utf-8"))
elif serialization_method != NO_SERIALIZATION:
payload_msg = str(payload_msg, "utf-8")
result['payload_msg'] = payload_msg
result['payload_size'] = payload_size
return result

View File

@ -0,0 +1,180 @@
import gzip
import json
from typing import Dict, Any
import websockets
import config
import protocol
class RealtimeDialogClient:
def __init__(self, config: Dict[str, Any], session_id: str, output_audio_format: str = "pcm",
mod: str = "audio", recv_timeout: int = 10) -> None:
self.config = config
self.logid = ""
self.session_id = session_id
self.output_audio_format = output_audio_format
self.mod = mod
self.recv_timeout = recv_timeout
self.ws = None
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"url: {self.config['base_url']}, headers: {self.config['headers']}")
self.ws = await websockets.connect(
self.config['base_url'],
extra_headers=self.config['headers'],
ping_interval=None
)
self.logid = self.ws.response_headers.get("X-Tt-Logid")
print(f"dialog server response logid: {self.logid}")
# StartConnection request
start_connection_request = bytearray(protocol.generate_header())
start_connection_request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = str.encode("{}")
payload_bytes = gzip.compress(payload_bytes)
start_connection_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
start_connection_request.extend(payload_bytes)
await self.ws.send(start_connection_request)
response = await self.ws.recv()
print(f"StartConnection response: {protocol.parse_response(response)}")
# 扩大这个参数可以在一段时间内保持静默主要用于text模式参数范围[10,120]
config.start_session_req["dialog"]["extra"]["recv_timeout"] = self.recv_timeout
# 这个参数在text或者audio_file模式可以在一段时间内保持静默
config.start_session_req["dialog"]["extra"]["input_mod"] = self.mod
# StartSession request
if self.output_audio_format == "pcm_s16le":
config.start_session_req["tts"]["audio_config"]["format"] = "pcm_s16le"
request_params = config.start_session_req
payload_bytes = str.encode(json.dumps(request_params))
payload_bytes = gzip.compress(payload_bytes)
start_session_request = bytearray(protocol.generate_header())
start_session_request.extend(int(100).to_bytes(4, 'big'))
start_session_request.extend((len(self.session_id)).to_bytes(4, 'big'))
start_session_request.extend(str.encode(self.session_id))
start_session_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
start_session_request.extend(payload_bytes)
await self.ws.send(start_session_request)
response = await self.ws.recv()
print(f"StartSession response: {protocol.parse_response(response)}")
async def say_hello(self) -> None:
"""发送Hello消息"""
payload = {
"content": "你好,我是豆包,有什么可以帮助你的?",
}
hello_request = bytearray(protocol.generate_header())
hello_request.extend(int(300).to_bytes(4, 'big'))
payload_bytes = str.encode(json.dumps(payload))
payload_bytes = gzip.compress(payload_bytes)
hello_request.extend((len(self.session_id)).to_bytes(4, 'big'))
hello_request.extend(str.encode(self.session_id))
hello_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
hello_request.extend(payload_bytes)
await self.ws.send(hello_request)
async def chat_text_query(self, content: str) -> None:
"""发送Chat Text Query消息"""
payload = {
"content": content,
}
chat_text_query_request = bytearray(protocol.generate_header())
chat_text_query_request.extend(int(501).to_bytes(4, 'big'))
payload_bytes = str.encode(json.dumps(payload))
payload_bytes = gzip.compress(payload_bytes)
chat_text_query_request.extend((len(self.session_id)).to_bytes(4, 'big'))
chat_text_query_request.extend(str.encode(self.session_id))
chat_text_query_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
chat_text_query_request.extend(payload_bytes)
await self.ws.send(chat_text_query_request)
async def chat_tts_text(self, is_user_querying: bool, start: bool, end: bool, content: str) -> None:
if is_user_querying:
return
"""发送Chat TTS Text消息"""
payload = {
"start": start,
"end": end,
"content": content,
}
print(f"ChatTTSTextRequest payload: {payload}")
payload_bytes = str.encode(json.dumps(payload))
payload_bytes = gzip.compress(payload_bytes)
chat_tts_text_request = bytearray(protocol.generate_header())
chat_tts_text_request.extend(int(500).to_bytes(4, 'big'))
chat_tts_text_request.extend((len(self.session_id)).to_bytes(4, 'big'))
chat_tts_text_request.extend(str.encode(self.session_id))
chat_tts_text_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
chat_tts_text_request.extend(payload_bytes)
await self.ws.send(chat_tts_text_request)
async def chat_rag_text(self, is_user_querying: bool, external_rag: str) -> None:
if is_user_querying:
return
"""发送Chat TTS Text消息"""
payload = {
"external_rag": external_rag,
}
print(f"ChatRAGTextRequest payload: {payload}")
payload_bytes = str.encode(json.dumps(payload))
payload_bytes = gzip.compress(payload_bytes)
chat_rag_text_request = bytearray(protocol.generate_header())
chat_rag_text_request.extend(int(502).to_bytes(4, 'big'))
chat_rag_text_request.extend((len(self.session_id)).to_bytes(4, 'big'))
chat_rag_text_request.extend(str.encode(self.session_id))
chat_rag_text_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
chat_rag_text_request.extend(payload_bytes)
await self.ws.send(chat_rag_text_request)
async def task_request(self, audio: bytes) -> None:
task_request = bytearray(
protocol.generate_header(message_type=protocol.CLIENT_AUDIO_ONLY_REQUEST,
serial_method=protocol.NO_SERIALIZATION))
task_request.extend(int(200).to_bytes(4, 'big'))
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
task_request.extend(str.encode(self.session_id))
payload_bytes = gzip.compress(audio)
task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
task_request.extend(payload_bytes)
await self.ws.send(task_request)
async def receive_server_response(self) -> Dict[str, Any]:
try:
response = await self.ws.recv()
data = protocol.parse_response(response)
return data
except Exception as e:
raise Exception(f"Failed to receive message: {e}")
async def finish_session(self):
finish_session_request = bytearray(protocol.generate_header())
finish_session_request.extend(int(102).to_bytes(4, 'big'))
payload_bytes = str.encode("{}")
payload_bytes = gzip.compress(payload_bytes)
finish_session_request.extend((len(self.session_id)).to_bytes(4, 'big'))
finish_session_request.extend(str.encode(self.session_id))
finish_session_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
finish_session_request.extend(payload_bytes)
await self.ws.send(finish_session_request)
async def finish_connection(self):
finish_connection_request = bytearray(protocol.generate_header())
finish_connection_request.extend(int(2).to_bytes(4, 'big'))
payload_bytes = str.encode("{}")
payload_bytes = gzip.compress(payload_bytes)
finish_connection_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
finish_connection_request.extend(payload_bytes)
await self.ws.send(finish_connection_request)
response = await self.ws.recv()
print(f"FinishConnection response: {protocol.parse_response(response)}")
async def close(self) -> None:
"""关闭WebSocket连接"""
if self.ws:
print(f"Closing WebSocket connection...")
await self.ws.close()

4
doubao/requirements.txt Normal file
View File

@ -0,0 +1,4 @@
pyaudio
websockets
dataclasses==0.8; python_version < "3.7"
typing-extensions==4.7.1; python_version < "3.8"

BIN
doubao/whoareyou.wav Normal file

Binary file not shown.