diff --git a/.DS_Store b/.DS_Store index 5da4b23..a3aa95e 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/doubao/.DS_Store b/doubao/.DS_Store new file mode 100644 index 0000000..4cda1a7 Binary files /dev/null and b/doubao/.DS_Store differ diff --git a/doubao/README.md b/doubao/README.md new file mode 100644 index 0000000..7c00a7d --- /dev/null +++ b/doubao/README.md @@ -0,0 +1,37 @@ +# RealtimeDialog + +实时语音对话程序,支持语音输入和语音输出。 + +## 使用说明 + +此demo使用python3.7环境进行开发调试,其他python版本可能会有兼容性问题,需要自己尝试解决。 + +1. 配置API密钥 + - 打开 `config.py` 文件 + - 修改以下两个字段: + ```python + "X-Api-App-ID": "火山控制台上端到端大模型对应的App ID", + "X-Api-Access-Key": "火山控制台上端到端大模型对应的Access Key", + ``` + - 修改speaker字段指定发音人,本次支持四个发音人: + - `zh_female_vv_jupiter_bigtts`:中文vv女声 + - `zh_female_xiaohe_jupiter_bigtts`:中文xiaohe女声 + - `zh_male_yunzhou_jupiter_bigtts`:中文云洲男声 + - `zh_male_xiaotian_jupiter_bigtts`:中文小天男声 + +2. 安装依赖 + ```bash + pip install -r requirements.txt + +3. 通过麦克风运行程序 + ```bash + python main.py --format=pcm + ``` +4. 通过录音文件启动程序 + ```bash + python main.py --audio=whoareyou.wav + ``` +5. 通过纯文本输入和程序交互 + ```bash + python main.py --mod=text --recv_timeout=120 + ``` \ No newline at end of file diff --git a/doubao/__pycache__/audio_manager.cpython-37.pyc b/doubao/__pycache__/audio_manager.cpython-37.pyc new file mode 100644 index 0000000..62b8a23 Binary files /dev/null and b/doubao/__pycache__/audio_manager.cpython-37.pyc differ diff --git a/doubao/__pycache__/config.cpython-312.pyc b/doubao/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000..13dd3b3 Binary files /dev/null and b/doubao/__pycache__/config.cpython-312.pyc differ diff --git a/doubao/__pycache__/config.cpython-37.pyc b/doubao/__pycache__/config.cpython-37.pyc new file mode 100644 index 0000000..52a32ac Binary files /dev/null and b/doubao/__pycache__/config.cpython-37.pyc differ diff --git a/doubao/__pycache__/protocol.cpython-37.pyc b/doubao/__pycache__/protocol.cpython-37.pyc new file mode 100644 index 0000000..925c1d9 Binary files /dev/null and b/doubao/__pycache__/protocol.cpython-37.pyc differ diff --git a/doubao/__pycache__/realtime_dialog_client.cpython-37.pyc b/doubao/__pycache__/realtime_dialog_client.cpython-37.pyc new file mode 100644 index 0000000..b2f60c7 Binary files /dev/null and b/doubao/__pycache__/realtime_dialog_client.cpython-37.pyc differ diff --git a/doubao/audio_manager.py b/doubao/audio_manager.py new file mode 100644 index 0000000..8f7b467 --- /dev/null +++ b/doubao/audio_manager.py @@ -0,0 +1,385 @@ +import asyncio +import queue +import random +import signal +import sys +import threading +import time +import uuid +import wave +from dataclasses import dataclass +from typing import Optional, Dict, Any + +import pyaudio + +import config +from realtime_dialog_client import RealtimeDialogClient + + +@dataclass +class AudioConfig: + """音频配置数据类""" + format: str + bit_size: int + channels: int + sample_rate: int + chunk: int + + +class AudioDeviceManager: + """音频设备管理类,处理音频输入输出""" + + def __init__(self, input_config: AudioConfig, output_config: AudioConfig): + self.input_config = input_config + self.output_config = output_config + self.pyaudio = pyaudio.PyAudio() + self.input_stream: Optional[pyaudio.Stream] = None + self.output_stream: Optional[pyaudio.Stream] = None + + def open_input_stream(self) -> pyaudio.Stream: + """打开音频输入流""" + # p = pyaudio.PyAudio() + self.input_stream = self.pyaudio.open( + format=self.input_config.bit_size, + channels=self.input_config.channels, + rate=self.input_config.sample_rate, + input=True, + frames_per_buffer=self.input_config.chunk + ) + return self.input_stream + + def open_output_stream(self) -> pyaudio.Stream: + """打开音频输出流""" + self.output_stream = self.pyaudio.open( + format=self.output_config.bit_size, + channels=self.output_config.channels, + rate=self.output_config.sample_rate, + output=True, + frames_per_buffer=self.output_config.chunk + ) + return self.output_stream + + def cleanup(self) -> None: + """清理音频设备资源""" + for stream in [self.input_stream, self.output_stream]: + if stream: + stream.stop_stream() + stream.close() + self.pyaudio.terminate() + + +class DialogSession: + """对话会话管理类""" + is_audio_file_input: bool + mod: str + + def __init__(self, ws_config: Dict[str, Any], output_audio_format: str = "pcm", audio_file_path: str = "", + mod: str = "audio", recv_timeout: int = 10): + self.audio_file_path = audio_file_path + self.recv_timeout = recv_timeout + self.is_audio_file_input = self.audio_file_path != "" + if self.is_audio_file_input: + mod = 'audio_file' + else: + self.say_hello_over_event = asyncio.Event() + self.mod = mod + + self.session_id = str(uuid.uuid4()) + self.client = RealtimeDialogClient(config=ws_config, session_id=self.session_id, + output_audio_format=output_audio_format, mod=mod, recv_timeout=recv_timeout) + if output_audio_format == "pcm_s16le": + config.output_audio_config["format"] = "pcm_s16le" + config.output_audio_config["bit_size"] = pyaudio.paInt16 + + self.is_running = True + self.is_session_finished = False + self.is_user_querying = False + self.is_sending_chat_tts_text = False + self.audio_buffer = b'' + + signal.signal(signal.SIGINT, self._keyboard_signal) + self.audio_queue = queue.Queue() + if not self.is_audio_file_input: + self.audio_device = AudioDeviceManager( + AudioConfig(**config.input_audio_config), + AudioConfig(**config.output_audio_config) + ) + # 初始化音频队列和输出流 + self.output_stream = self.audio_device.open_output_stream() + # 启动播放线程 + self.is_recording = True + self.is_playing = True + self.player_thread = threading.Thread(target=self._audio_player_thread) + self.player_thread.daemon = True + self.player_thread.start() + + def _audio_player_thread(self): + """音频播放线程""" + while self.is_playing: + try: + # 从队列获取音频数据 + audio_data = self.audio_queue.get(timeout=1.0) + if audio_data is not None: + self.output_stream.write(audio_data) + except queue.Empty: + # 队列为空时等待一小段时间 + time.sleep(0.1) + except Exception as e: + print(f"音频播放错误: {e}") + time.sleep(0.1) + + def handle_server_response(self, response: Dict[str, Any]) -> None: + if response == {}: + return + """处理服务器响应""" + if response['message_type'] == 'SERVER_ACK' and isinstance(response.get('payload_msg'), bytes): + # print(f"\n接收到音频数据: {len(response['payload_msg'])} 字节") + if self.is_sending_chat_tts_text: + return + audio_data = response['payload_msg'] + if not self.is_audio_file_input: + self.audio_queue.put(audio_data) + self.audio_buffer += audio_data + elif response['message_type'] == 'SERVER_FULL_RESPONSE': + print(f"服务器响应: {response}") + event = response.get('event') + payload_msg = response.get('payload_msg', {}) + + if event == 450: + print(f"清空缓存音频: {response['session_id']}") + while not self.audio_queue.empty(): + try: + self.audio_queue.get_nowait() + except queue.Empty: + continue + self.is_user_querying = True + + if event == 350 and self.is_sending_chat_tts_text and payload_msg.get("tts_type") in ["chat_tts_text", "external_rag"]: + while not self.audio_queue.empty(): + try: + self.audio_queue.get_nowait() + except queue.Empty: + continue + self.is_sending_chat_tts_text = False + + if event == 459: + self.is_user_querying = False + if random.randint(0, 100000)%1 == 0: + self.is_sending_chat_tts_text = True + asyncio.create_task(self.trigger_chat_tts_text()) + asyncio.create_task(self.trigger_chat_rag_text()) + elif response['message_type'] == 'SERVER_ERROR': + print(f"服务器错误: {response['payload_msg']}") + raise Exception("服务器错误") + + async def trigger_chat_tts_text(self): + """概率触发发送ChatTTSText请求""" + print("hit ChatTTSText event, start sending...") + await self.client.chat_tts_text( + is_user_querying=self.is_user_querying, + start=True, + end=False, + content="这是查询到外部数据之前的安抚话术。", + ) + await self.client.chat_tts_text( + is_user_querying=self.is_user_querying, + start=False, + end=True, + content="", + ) + + async def trigger_chat_rag_text(self): + await asyncio.sleep(5) # 模拟查询外部RAG的耗时,这里为了不影响GTA安抚话术的播报,直接sleep 5秒 + print("hit ChatRAGText event, start sending...") + await self.client.chat_rag_text(self.is_user_querying, external_rag='[{"title":"北京天气","content":"今天北京整体以晴到多云为主,但西部和北部地带可能会出现分散性雷阵雨,特别是午后至傍晚时段需注意突发降雨。\n💨 风况与湿度\n风力较弱,一般为 2–3 级南风或西南风\n白天湿度较高,早晚略凉爽"}]') + + def _keyboard_signal(self, sig, frame): + print(f"receive keyboard Ctrl+C") + self.stop() + + def stop(self): + self.is_recording = False + self.is_playing = False + self.is_running = False + + async def receive_loop(self): + try: + while True: + response = await self.client.receive_server_response() + self.handle_server_response(response) + if 'event' in response and (response['event'] == 152 or response['event'] == 153): + print(f"receive session finished event: {response['event']}") + self.is_session_finished = True + break + if 'event' in response and response['event'] == 359: + if self.is_audio_file_input: + print(f"receive tts ended event") + self.is_session_finished = True + break + else: + if not self.say_hello_over_event.is_set(): + print(f"receive tts sayhello ended event") + self.say_hello_over_event.set() + if self.mod == "text": + print("请输入内容:") + + except asyncio.CancelledError: + print("接收任务已取消") + except Exception as e: + print(f"接收消息错误: {e}") + finally: + self.stop() + self.is_session_finished = True + + async def process_audio_file(self) -> None: + await self.process_audio_file_input(self.audio_file_path) + + async def process_text_input(self) -> None: + await self.client.say_hello() + await self.say_hello_over_event.wait() + + """主逻辑:处理文本输入和WebSocket通信""" + # 确保连接最终关闭 + try: + # 启动输入监听线程 + input_queue = queue.Queue() + input_thread = threading.Thread(target=self.input_listener, args=(input_queue,), daemon=True) + input_thread.start() + # 主循环:处理输入和上下文结束 + while self.is_running: + try: + # 检查是否有输入(非阻塞) + input_str = input_queue.get_nowait() + if input_str is None: + # 输入流关闭 + print("Input channel closed") + break + if input_str: + # 发送输入内容 + await self.client.chat_text_query(input_str) + except queue.Empty: + # 无输入时短暂休眠 + await asyncio.sleep(0.1) + except Exception as e: + print(f"Main loop error: {e}") + break + finally: + print("exit text input") + + def input_listener(self, input_queue: queue.Queue) -> None: + """在单独线程中监听标准输入""" + print("Start listening for input") + try: + while True: + # 读取标准输入(阻塞操作) + line = sys.stdin.readline() + if not line: + # 输入流关闭 + input_queue.put(None) + break + input_str = line.strip() + input_queue.put(input_str) + except Exception as e: + print(f"Input listener error: {e}") + input_queue.put(None) + + async def process_audio_file_input(self, audio_file_path: str) -> None: + # 读取WAV文件 + with wave.open(audio_file_path, 'rb') as wf: + chunk_size = config.input_audio_config["chunk"] + framerate = wf.getframerate() # 采样率(如16000Hz) + # 时长 = chunkSize(帧数) ÷ 采样率(帧/秒) + sleep_seconds = chunk_size / framerate + print(f"开始处理音频文件: {audio_file_path}") + + # 分块读取并发送音频数据 + while True: + audio_data = wf.readframes(chunk_size) + if not audio_data: + break # 文件读取完毕 + + await self.client.task_request(audio_data) + # sleep与chunk对应的音频时长一致,模拟实时输入 + await asyncio.sleep(sleep_seconds) + + print(f"音频文件处理完成,等待服务器响应...") + + async def process_silence_audio(self) -> None: + """发送静音音频""" + silence_data = b'\x00' * 320 + await self.client.task_request(silence_data) + + async def process_microphone_input(self) -> None: + await self.client.say_hello() + await self.say_hello_over_event.wait() + await self.client.chat_text_query("你好,我也叫豆包") + + """处理麦克风输入""" + stream = self.audio_device.open_input_stream() + print("已打开麦克风,请讲话...") + + while self.is_recording: + try: + # 添加exception_on_overflow=False参数来忽略溢出错误 + audio_data = stream.read(config.input_audio_config["chunk"], exception_on_overflow=False) + save_input_pcm_to_wav(audio_data, "input.pcm") + await self.client.task_request(audio_data) + await asyncio.sleep(0.01) # 避免CPU过度使用 + except Exception as e: + print(f"读取麦克风数据出错: {e}") + await asyncio.sleep(0.1) # 给系统一些恢复时间 + + async def start(self) -> None: + """启动对话会话""" + try: + await self.client.connect() + + if self.mod == "text": + asyncio.create_task(self.process_text_input()) + asyncio.create_task(self.receive_loop()) + while self.is_running: + await asyncio.sleep(0.1) + else: + if self.is_audio_file_input: + asyncio.create_task(self.process_audio_file()) + await self.receive_loop() + else: + asyncio.create_task(self.process_microphone_input()) + asyncio.create_task(self.receive_loop()) + while self.is_running: + await asyncio.sleep(0.1) + + await self.client.finish_session() + while not self.is_session_finished: + await asyncio.sleep(0.1) + await self.client.finish_connection() + await asyncio.sleep(0.1) + await self.client.close() + print(f"dialog request logid: {self.client.logid}, chat mod: {self.mod}") + save_output_to_file(self.audio_buffer, "output.pcm") + except Exception as e: + print(f"会话错误: {e}") + finally: + if not self.is_audio_file_input: + self.audio_device.cleanup() + + +def save_input_pcm_to_wav(pcm_data: bytes, filename: str) -> None: + """保存PCM数据为WAV文件""" + with wave.open(filename, 'wb') as wf: + wf.setnchannels(config.input_audio_config["channels"]) + wf.setsampwidth(2) # paInt16 = 2 bytes + wf.setframerate(config.input_audio_config["sample_rate"]) + wf.writeframes(pcm_data) + + +def save_output_to_file(audio_data: bytes, filename: str) -> None: + """保存原始PCM音频数据到文件""" + if not audio_data: + print("No audio data to save.") + return + try: + with open(filename, 'wb') as f: + f.write(audio_data) + except IOError as e: + print(f"Failed to save pcm file: {e}") diff --git a/doubao/config.py b/doubao/config.py new file mode 100644 index 0000000..d3196cc --- /dev/null +++ b/doubao/config.py @@ -0,0 +1,63 @@ +import uuid +import pyaudio + +# 配置信息 +ws_connect_config = { + "base_url": "wss://openspeech.bytedance.com/api/v3/realtime/dialogue", + "headers": { + "X-Api-App-ID": "", + "X-Api-Access-Key": "", + "X-Api-Resource-Id": "volc.speech.dialog", # 固定值 + "X-Api-App-Key": "PlgvMymc7f3tQnJ6", # 固定值 + "X-Api-Connect-Id": str(uuid.uuid4()), + } +} + +start_session_req = { + "asr": { + "extra": { + "end_smooth_window_ms": 1500, + }, + }, + "tts": { + "speaker": "zh_male_yunzhou_jupiter_bigtts", + # "speaker": "S_XXXXXX", // 指定自定义的复刻音色,需要填下character_manifest + # "speaker": "ICL_zh_female_aojiaonvyou_tob" // 指定官方复刻音色,不需要填character_manifest + "audio_config": { + "channel": 1, + "format": "pcm", + "sample_rate": 24000 + }, + }, + "dialog": { + "bot_name": "豆包", + "system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。", + "speaking_style": "你的说话风格简洁明了,语速适中,语调自然。", + # "character_manifest": "外貌与穿着\n26岁,短发干净利落,眉眼分明,笑起来露出整齐有力的牙齿。体态挺拔,肌肉线条不夸张但明显。常穿简单的衬衫或夹克,看似随意,但每件衣服都干净整洁,给人一种干练可靠的感觉。平时冷峻,眼神锐利,专注时让人不自觉紧张。\n\n性格特点\n平时话不多,不喜欢多说废话,通常用“嗯”或者短句带过。但内心极为细腻,特别在意身边人的感受,只是不轻易表露。嘴硬是常态,“少管我”是他的常用台词,但会悄悄做些体贴的事情,比如把对方喜欢的饮料放在手边。战斗或训练后常说“没事”,但动作中透露出疲惫,习惯用小动作缓解身体酸痛。\n性格上坚毅果断,但不会冲动,做事有条理且有原则。\n\n常用表达方式与口头禅\n\t•\t认可对方时:\n“行吧,这次算你靠谱。”(声音稳重,手却不自觉放松一下,心里松口气)\n\t•\t关心对方时:\n“快点回去,别磨蹭。”(语气干脆,但眼神一直追着对方的背影)\n\t•\t想了解情况时:\n“刚刚……你看到那道光了吗?”(话语随意,手指敲着桌面,但内心紧张,小心隐藏身份)", + "location": { + "city": "北京", + }, + "extra": { + "strict_audit": False, + "audit_response": "支持客户自定义安全审核回复话术。", + "recv_timeout": 10, + "input_mod": "audio" + } + } +} + +input_audio_config = { + "chunk": 3200, + "format": "pcm", + "channels": 1, + "sample_rate": 16000, + "bit_size": pyaudio.paInt16 +} + +output_audio_config = { + "chunk": 3200, + "format": "pcm", + "channels": 1, + "sample_rate": 24000, + "bit_size": pyaudio.paFloat32 +} diff --git a/doubao/main.py b/doubao/main.py new file mode 100644 index 0000000..985cbd0 --- /dev/null +++ b/doubao/main.py @@ -0,0 +1,20 @@ +import asyncio +import argparse + +import config +from audio_manager import DialogSession + +async def main() -> None: + parser = argparse.ArgumentParser(description="Real-time Dialog Client") + parser.add_argument("--format", type=str, default="pcm", help="The audio format (e.g., pcm, pcm_s16le).") + parser.add_argument("--audio", type=str, default="", help="audio file send to server, if not set, will use microphone input.") + parser.add_argument("--mod",type=str,default="audio",help="Use mod to select plain text input mode or audio mode, the default is audio mode") + parser.add_argument("--recv_timeout",type=int,default=10,help="Timeout for receiving messages,value range [10,120]") + + args = parser.parse_args() + + session = DialogSession(ws_config=config.ws_connect_config, output_audio_format=args.format, audio_file_path=args.audio,mod=args.mod,recv_timeout=args.recv_timeout) + await session.start() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/doubao/protocol.py b/doubao/protocol.py new file mode 100644 index 0000000..5b5b06c --- /dev/null +++ b/doubao/protocol.py @@ -0,0 +1,135 @@ +import gzip +import json + +PROTOCOL_VERSION = 0b0001 +DEFAULT_HEADER_SIZE = 0b0001 + +PROTOCOL_VERSION_BITS = 4 +HEADER_BITS = 4 +MESSAGE_TYPE_BITS = 4 +MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4 +MESSAGE_SERIALIZATION_BITS = 4 +MESSAGE_COMPRESSION_BITS = 4 +RESERVED_BITS = 8 + +# Message Type: +CLIENT_FULL_REQUEST = 0b0001 +CLIENT_AUDIO_ONLY_REQUEST = 0b0010 + +SERVER_FULL_RESPONSE = 0b1001 +SERVER_ACK = 0b1011 +SERVER_ERROR_RESPONSE = 0b1111 + +# Message Type Specific Flags +NO_SEQUENCE = 0b0000 # no check sequence +POS_SEQUENCE = 0b0001 +NEG_SEQUENCE = 0b0010 +NEG_SEQUENCE_1 = 0b0011 + +MSG_WITH_EVENT = 0b0100 + +# Message Serialization +NO_SERIALIZATION = 0b0000 +JSON = 0b0001 +THRIFT = 0b0011 +CUSTOM_TYPE = 0b1111 + +# Message Compression +NO_COMPRESSION = 0b0000 +GZIP = 0b0001 +CUSTOM_COMPRESSION = 0b1111 + + +def generate_header( + version=PROTOCOL_VERSION, + message_type=CLIENT_FULL_REQUEST, + message_type_specific_flags=MSG_WITH_EVENT, + serial_method=JSON, + compression_type=GZIP, + reserved_data=0x00, + extension_header=bytes() +): + """ + protocol_version(4 bits), header_size(4 bits), + message_type(4 bits), message_type_specific_flags(4 bits) + serialization_method(4 bits) message_compression(4 bits) + reserved (8bits) 保留字段 + header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) + """ + header = bytearray() + header_size = int(len(extension_header) / 4) + 1 + header.append((version << 4) | header_size) + header.append((message_type << 4) | message_type_specific_flags) + header.append((serial_method << 4) | compression_type) + header.append(reserved_data) + header.extend(extension_header) + return header + + +def parse_response(res): + """ + - header + - (4bytes)header + - (4bits)version(v1) + (4bits)header_size + - (4bits)messageType + (4bits)messageTypeFlags + -- 0001 CompleteClient | -- 0001 hasSequence + -- 0010 audioonly | -- 0010 isTailPacket + | -- 0100 hasEvent + - (4bits)payloadFormat + (4bits)compression + - (8bits) reserve + - payload + - [optional 4 bytes] event + - [optional] session ID + -- (4 bytes)session ID len + -- session ID data + - (4 bytes)data len + - data + """ + if isinstance(res, str): + return {} + protocol_version = res[0] >> 4 + header_size = res[0] & 0x0f + message_type = res[1] >> 4 + message_type_specific_flags = res[1] & 0x0f + serialization_method = res[2] >> 4 + message_compression = res[2] & 0x0f + reserved = res[3] + header_extensions = res[4:header_size * 4] + payload = res[header_size * 4:] + result = {} + payload_msg = None + payload_size = 0 + start = 0 + if message_type == SERVER_FULL_RESPONSE or message_type == SERVER_ACK: + result['message_type'] = 'SERVER_FULL_RESPONSE' + if message_type == SERVER_ACK: + result['message_type'] = 'SERVER_ACK' + if message_type_specific_flags & NEG_SEQUENCE > 0: + result['seq'] = int.from_bytes(payload[:4], "big", signed=False) + start += 4 + if message_type_specific_flags & MSG_WITH_EVENT > 0: + result['event'] = int.from_bytes(payload[:4], "big", signed=False) + start += 4 + payload = payload[start:] + session_id_size = int.from_bytes(payload[:4], "big", signed=True) + session_id = payload[4:session_id_size+4] + result['session_id'] = str(session_id) + payload = payload[4 + session_id_size:] + payload_size = int.from_bytes(payload[:4], "big", signed=False) + payload_msg = payload[4:] + elif message_type == SERVER_ERROR_RESPONSE: + code = int.from_bytes(payload[:4], "big", signed=False) + result['code'] = code + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + payload_msg = payload[8:] + if payload_msg is None: + return result + if message_compression == GZIP: + payload_msg = gzip.decompress(payload_msg) + if serialization_method == JSON: + payload_msg = json.loads(str(payload_msg, "utf-8")) + elif serialization_method != NO_SERIALIZATION: + payload_msg = str(payload_msg, "utf-8") + result['payload_msg'] = payload_msg + result['payload_size'] = payload_size + return result diff --git a/doubao/realtime_dialog_client.py b/doubao/realtime_dialog_client.py new file mode 100644 index 0000000..92a92ba --- /dev/null +++ b/doubao/realtime_dialog_client.py @@ -0,0 +1,180 @@ +import gzip +import json +from typing import Dict, Any + +import websockets + +import config +import protocol + + +class RealtimeDialogClient: + def __init__(self, config: Dict[str, Any], session_id: str, output_audio_format: str = "pcm", + mod: str = "audio", recv_timeout: int = 10) -> None: + self.config = config + self.logid = "" + self.session_id = session_id + self.output_audio_format = output_audio_format + self.mod = mod + self.recv_timeout = recv_timeout + self.ws = None + + async def connect(self) -> None: + """建立WebSocket连接""" + print(f"url: {self.config['base_url']}, headers: {self.config['headers']}") + self.ws = await websockets.connect( + self.config['base_url'], + extra_headers=self.config['headers'], + ping_interval=None + ) + self.logid = self.ws.response_headers.get("X-Tt-Logid") + print(f"dialog server response logid: {self.logid}") + + # StartConnection request + start_connection_request = bytearray(protocol.generate_header()) + start_connection_request.extend(int(1).to_bytes(4, 'big')) + payload_bytes = str.encode("{}") + payload_bytes = gzip.compress(payload_bytes) + start_connection_request.extend((len(payload_bytes)).to_bytes(4, 'big')) + start_connection_request.extend(payload_bytes) + await self.ws.send(start_connection_request) + response = await self.ws.recv() + print(f"StartConnection response: {protocol.parse_response(response)}") + + # 扩大这个参数,可以在一段时间内保持静默,主要用于text模式,参数范围[10,120] + config.start_session_req["dialog"]["extra"]["recv_timeout"] = self.recv_timeout + # 这个参数,在text或者audio_file模式,可以在一段时间内保持静默 + config.start_session_req["dialog"]["extra"]["input_mod"] = self.mod + # StartSession request + if self.output_audio_format == "pcm_s16le": + config.start_session_req["tts"]["audio_config"]["format"] = "pcm_s16le" + request_params = config.start_session_req + payload_bytes = str.encode(json.dumps(request_params)) + payload_bytes = gzip.compress(payload_bytes) + start_session_request = bytearray(protocol.generate_header()) + start_session_request.extend(int(100).to_bytes(4, 'big')) + start_session_request.extend((len(self.session_id)).to_bytes(4, 'big')) + start_session_request.extend(str.encode(self.session_id)) + start_session_request.extend((len(payload_bytes)).to_bytes(4, 'big')) + start_session_request.extend(payload_bytes) + await self.ws.send(start_session_request) + response = await self.ws.recv() + print(f"StartSession response: {protocol.parse_response(response)}") + + async def say_hello(self) -> None: + """发送Hello消息""" + payload = { + "content": "你好,我是豆包,有什么可以帮助你的?", + } + hello_request = bytearray(protocol.generate_header()) + hello_request.extend(int(300).to_bytes(4, 'big')) + payload_bytes = str.encode(json.dumps(payload)) + payload_bytes = gzip.compress(payload_bytes) + hello_request.extend((len(self.session_id)).to_bytes(4, 'big')) + hello_request.extend(str.encode(self.session_id)) + hello_request.extend((len(payload_bytes)).to_bytes(4, 'big')) + hello_request.extend(payload_bytes) + await self.ws.send(hello_request) + + async def chat_text_query(self, content: str) -> None: + """发送Chat Text Query消息""" + payload = { + "content": content, + } + chat_text_query_request = bytearray(protocol.generate_header()) + chat_text_query_request.extend(int(501).to_bytes(4, 'big')) + payload_bytes = str.encode(json.dumps(payload)) + payload_bytes = gzip.compress(payload_bytes) + chat_text_query_request.extend((len(self.session_id)).to_bytes(4, 'big')) + chat_text_query_request.extend(str.encode(self.session_id)) + chat_text_query_request.extend((len(payload_bytes)).to_bytes(4, 'big')) + chat_text_query_request.extend(payload_bytes) + await self.ws.send(chat_text_query_request) + + async def chat_tts_text(self, is_user_querying: bool, start: bool, end: bool, content: str) -> None: + if is_user_querying: + return + """发送Chat TTS Text消息""" + payload = { + "start": start, + "end": end, + "content": content, + } + print(f"ChatTTSTextRequest payload: {payload}") + payload_bytes = str.encode(json.dumps(payload)) + payload_bytes = gzip.compress(payload_bytes) + + chat_tts_text_request = bytearray(protocol.generate_header()) + chat_tts_text_request.extend(int(500).to_bytes(4, 'big')) + chat_tts_text_request.extend((len(self.session_id)).to_bytes(4, 'big')) + chat_tts_text_request.extend(str.encode(self.session_id)) + chat_tts_text_request.extend((len(payload_bytes)).to_bytes(4, 'big')) + chat_tts_text_request.extend(payload_bytes) + await self.ws.send(chat_tts_text_request) + + async def chat_rag_text(self, is_user_querying: bool, external_rag: str) -> None: + if is_user_querying: + return + """发送Chat TTS Text消息""" + payload = { + "external_rag": external_rag, + } + print(f"ChatRAGTextRequest payload: {payload}") + payload_bytes = str.encode(json.dumps(payload)) + payload_bytes = gzip.compress(payload_bytes) + + chat_rag_text_request = bytearray(protocol.generate_header()) + chat_rag_text_request.extend(int(502).to_bytes(4, 'big')) + chat_rag_text_request.extend((len(self.session_id)).to_bytes(4, 'big')) + chat_rag_text_request.extend(str.encode(self.session_id)) + chat_rag_text_request.extend((len(payload_bytes)).to_bytes(4, 'big')) + chat_rag_text_request.extend(payload_bytes) + await self.ws.send(chat_rag_text_request) + + async def task_request(self, audio: bytes) -> None: + task_request = bytearray( + protocol.generate_header(message_type=protocol.CLIENT_AUDIO_ONLY_REQUEST, + serial_method=protocol.NO_SERIALIZATION)) + task_request.extend(int(200).to_bytes(4, 'big')) + task_request.extend((len(self.session_id)).to_bytes(4, 'big')) + task_request.extend(str.encode(self.session_id)) + payload_bytes = gzip.compress(audio) + task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) + task_request.extend(payload_bytes) + await self.ws.send(task_request) + + async def receive_server_response(self) -> Dict[str, Any]: + try: + response = await self.ws.recv() + data = protocol.parse_response(response) + return data + except Exception as e: + raise Exception(f"Failed to receive message: {e}") + + async def finish_session(self): + finish_session_request = bytearray(protocol.generate_header()) + finish_session_request.extend(int(102).to_bytes(4, 'big')) + payload_bytes = str.encode("{}") + payload_bytes = gzip.compress(payload_bytes) + finish_session_request.extend((len(self.session_id)).to_bytes(4, 'big')) + finish_session_request.extend(str.encode(self.session_id)) + finish_session_request.extend((len(payload_bytes)).to_bytes(4, 'big')) + finish_session_request.extend(payload_bytes) + await self.ws.send(finish_session_request) + + async def finish_connection(self): + finish_connection_request = bytearray(protocol.generate_header()) + finish_connection_request.extend(int(2).to_bytes(4, 'big')) + payload_bytes = str.encode("{}") + payload_bytes = gzip.compress(payload_bytes) + finish_connection_request.extend((len(payload_bytes)).to_bytes(4, 'big')) + finish_connection_request.extend(payload_bytes) + await self.ws.send(finish_connection_request) + response = await self.ws.recv() + print(f"FinishConnection response: {protocol.parse_response(response)}") + + async def close(self) -> None: + """关闭WebSocket连接""" + if self.ws: + print(f"Closing WebSocket connection...") + await self.ws.close() diff --git a/doubao/requirements.txt b/doubao/requirements.txt new file mode 100644 index 0000000..63f83ff --- /dev/null +++ b/doubao/requirements.txt @@ -0,0 +1,4 @@ +pyaudio +websockets +dataclasses==0.8; python_version < "3.7" +typing-extensions==4.7.1; python_version < "3.8" \ No newline at end of file diff --git a/doubao/whoareyou.wav b/doubao/whoareyou.wav new file mode 100644 index 0000000..024d5a2 Binary files /dev/null and b/doubao/whoareyou.wav differ