From 9108fd45821419ee88b83df93baf0668a686f035 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Sat, 20 Sep 2025 14:35:54 +0800 Subject: [PATCH] config --- doubao.py | 470 ++++++++++++++++++++++++++++++++++ doubao_debug.py | 540 ++++++++++++++++++++++++++++++++++++++++ doubao_final_test.py | 425 +++++++++++++++++++++++++++++++ doubao_original_test.py | 412 ++++++++++++++++++++++++++++++ doubao_simple.py | 412 ++++++++++++++++++++++++++++++ doubao_test.py | 303 ++++++++++++++++++++++ test_doubao.py | 113 +++++++++ 7 files changed, 2675 insertions(+) create mode 100644 doubao.py create mode 100644 doubao_debug.py create mode 100644 doubao_final_test.py create mode 100644 doubao_original_test.py create mode 100644 doubao_simple.py create mode 100644 doubao_test.py create mode 100644 test_doubao.py diff --git a/doubao.py b/doubao.py new file mode 100644 index 0000000..6b22fdc --- /dev/null +++ b/doubao.py @@ -0,0 +1,470 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +豆包音频处理模块 +简化版WebSocket API,支持音频文件上传和返回音频处理 +""" + +import asyncio +import gzip +import json +import uuid +import wave +import struct +import time +import os +from typing import Dict, Any, Optional +import websockets + + +class DoubaoConfig: + """豆包配置""" + + def __init__(self): + self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" + self.app_id = "8718217928" + self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc" + self.app_key = "PlgvMymc7f3tQnJ6" + self.resource_id = "volc.speech.dialog" + + def get_headers(self) -> Dict[str, str]: + """获取请求头""" + return { + "X-Api-App-ID": self.app_id, + "X-Api-Access-Key": self.access_key, + "X-Api-Resource-Id": self.resource_id, + "X-Api-App-Key": self.app_key, + "X-Api-Connect-Id": str(uuid.uuid4()), + } + + +class DoubaoProtocol: + """豆包协议处理""" + + # 协议常量 + PROTOCOL_VERSION = 0b0001 + CLIENT_FULL_REQUEST = 0b0001 + CLIENT_AUDIO_ONLY_REQUEST = 0b0010 + SERVER_FULL_RESPONSE = 0b1001 + SERVER_ACK = 0b1011 + SERVER_ERROR_RESPONSE = 0b1111 + + NO_SEQUENCE = 0b0000 + POS_SEQUENCE = 0b0001 + MSG_WITH_EVENT = 0b0100 + JSON = 0b0001 + NO_SERIALIZATION = 0b0000 + GZIP = 0b0001 + + @classmethod + def generate_header(cls, message_type=CLIENT_FULL_REQUEST, + message_type_specific_flags=MSG_WITH_EVENT, + serial_method=JSON, compression_type=GZIP) -> bytes: + """生成协议头""" + header = bytearray() + header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size + header.append((message_type << 4) | message_type_specific_flags) + header.append((serial_method << 4) | compression_type) + header.append(0x00) # reserved + return bytes(header) + + @classmethod + def parse_response(cls, res: bytes) -> Dict[str, Any]: + """解析响应""" + if isinstance(res, str): + return {} + + protocol_version = res[0] >> 4 + header_size = res[0] & 0x0f + message_type = res[1] >> 4 + message_type_specific_flags = res[1] & 0x0f + serialization_method = res[2] >> 4 + message_compression = res[2] & 0x0f + + payload = res[header_size * 4:] + result = {} + + if message_type == cls.SERVER_FULL_RESPONSE or message_type == cls.SERVER_ACK: + result['message_type'] = 'SERVER_FULL_RESPONSE' + if message_type == cls.SERVER_ACK: + result['message_type'] = 'SERVER_ACK' + + start = 0 + if message_type_specific_flags & cls.MSG_WITH_EVENT: + result['event'] = int.from_bytes(payload[:4], "big", signed=False) + start += 4 + + payload = payload[start:] + session_id_size = int.from_bytes(payload[:4], "big", signed=True) + session_id = payload[4:session_id_size+4] + result['session_id'] = str(session_id) + payload = payload[4 + session_id_size:] + + payload_size = int.from_bytes(payload[:4], "big", signed=False) + payload_msg = payload[4:] + result['payload_size'] = payload_size + + if payload_msg: + if message_compression == cls.GZIP: + payload_msg = gzip.decompress(payload_msg) + if serialization_method == cls.JSON: + payload_msg = json.loads(str(payload_msg, "utf-8")) + result['payload_msg'] = payload_msg + + elif message_type == cls.SERVER_ERROR_RESPONSE: + code = int.from_bytes(payload[:4], "big", signed=False) + result['code'] = code + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + result['payload_size'] = payload_size + + return result + + +class AudioProcessor: + """音频处理器""" + + @staticmethod + def read_wav_file(file_path: str) -> tuple: + """读取WAV文件,返回音频数据和参数""" + with wave.open(file_path, 'rb') as wf: + # 获取音频参数 + channels = wf.getnchannels() + sampwidth = wf.getsampwidth() + framerate = wf.getframerate() + nframes = wf.getnframes() + + # 读取音频数据 + audio_data = wf.readframes(nframes) + + return audio_data, { + 'channels': channels, + 'sampwidth': sampwidth, + 'framerate': framerate, + 'nframes': nframes + } + + @staticmethod + def create_wav_file(audio_data: bytes, output_path: str, + sample_rate: int = 24000, channels: int = 1, + sampwidth: int = 2) -> None: + """创建WAV文件,适配树莓派播放""" + with wave.open(output_path, 'wb') as wf: + wf.setnchannels(channels) + wf.setsampwidth(sampwidth) + wf.setframerate(sample_rate) + wf.writeframes(audio_data) + + +class DoubaoClient: + """豆包客户端""" + + def __init__(self, config: DoubaoConfig): + self.config = config + self.session_id = str(uuid.uuid4()) + self.ws = None + self.log_id = "" + + async def connect(self) -> None: + """建立WebSocket连接""" + print(f"连接豆包服务器: {self.config.base_url}") + + self.ws = await websockets.connect( + self.config.base_url, + additional_headers=self.config.get_headers(), + ping_interval=None + ) + + # 获取log_id + if hasattr(self.ws, 'response_headers'): + self.log_id = self.ws.response_headers.get("X-Tt-Logid") + elif hasattr(self.ws, 'headers'): + self.log_id = self.ws.headers.get("X-Tt-Logid") + + print(f"连接成功, log_id: {self.log_id}") + + # 发送StartConnection请求 + await self._send_start_connection() + + # 发送StartSession请求 + await self._send_start_session() + + async def _send_start_connection(self) -> None: + """发送StartConnection请求""" + request = bytearray(DoubaoProtocol.generate_header()) + request.extend(int(1).to_bytes(4, 'big')) + + payload_bytes = b"{}" + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + response = await self.ws.recv() + parsed_response = DoubaoProtocol.parse_response(response) + print(f"StartConnection响应: {parsed_response}") + + async def _send_start_session(self) -> None: + """发送StartSession请求""" + session_config = { + "asr": { + "extra": { + "end_smooth_window_ms": 1500, + }, + }, + "tts": { + "speaker": "zh_female_vv_jupiter_bigtts", + "audio_config": { + "channel": 1, + "format": "pcm", + "sample_rate": 24000 + }, + }, + "dialog": { + "bot_name": "豆包", + "system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。", + "speaking_style": "你的说话风格简洁明了,语速适中,语调自然。", + "location": {"city": "北京"}, + "extra": { + "strict_audit": False, + "audit_response": "支持客户自定义安全审核回复话术。", + "recv_timeout": 10, + "input_mod": "audio_file", # 使用音频文件模式 + }, + }, + } + + request = bytearray(DoubaoProtocol.generate_header()) + request.extend(int(100).to_bytes(4, 'big')) + request.extend(len(self.session_id).to_bytes(4, 'big')) + request.extend(self.session_id.encode()) + + payload_bytes = json.dumps(session_config).encode() + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + response = await self.ws.recv() + parsed_response = DoubaoProtocol.parse_response(response) + print(f"StartSession响应: {parsed_response}") + + async def send_audio_file(self, file_path: str) -> bytes: + """发送音频文件并返回响应音频""" + print(f"处理音频文件: {file_path}") + + # 读取音频文件 + audio_data, audio_info = AudioProcessor.read_wav_file(file_path) + print(f"音频参数: {audio_info}") + + # 计算分块大小(200ms) + chunk_size = int(audio_info['framerate'] * audio_info['channels'] * + audio_info['sampwidth'] * 0.2) + + # 分块发送音频数据 + total_chunks = (len(audio_data) + chunk_size - 1) // chunk_size + print(f"开始分块发送音频,共 {total_chunks} 块") + + received_audio = b"" + error_count = 0 + + for i in range(0, len(audio_data), chunk_size): + chunk = audio_data[i:i + chunk_size] + is_last = (i + chunk_size >= len(audio_data)) + + # 发送音频块 + await self._send_audio_chunk(chunk, is_last) + + # 接收响应 + try: + response = await asyncio.wait_for(self.ws.recv(), timeout=2.0) + parsed_response = DoubaoProtocol.parse_response(response) + + # 检查是否是错误响应 + if 'code' in parsed_response and parsed_response['code'] != 0: + print(f"服务器返回错误: {parsed_response}") + error_count += 1 + if error_count > 3: + raise Exception(f"服务器连续返回错误: {parsed_response}") + continue + + # 处理音频响应 + if (parsed_response.get('message_type') == 'SERVER_ACK' and + isinstance(parsed_response.get('payload_msg'), bytes)): + audio_chunk = parsed_response['payload_msg'] + received_audio += audio_chunk + print(f"接收到音频数据块,大小: {len(audio_chunk)} 字节") + + # 检查会话状态 + event = parsed_response.get('event') + if event in [359, 152, 153]: # 这些事件表示会话相关状态 + print(f"会话事件: {event}") + if event in [152, 153]: # 会话结束 + print("检测到会话结束事件") + break + + except asyncio.TimeoutError: + print("等待响应超时,继续发送") + + # 模拟实时发送的延迟 + await asyncio.sleep(0.05) + + print("音频文件发送完成") + return received_audio + + async def _send_audio_chunk(self, audio_data: bytes, is_last: bool = False) -> None: + """发送音频块""" + request = bytearray( + DoubaoProtocol.generate_header( + message_type=DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST, + message_type_specific_flags=DoubaoProtocol.NO_SEQUENCE, + serial_method=DoubaoProtocol.NO_SERIALIZATION, # 音频数据不需要序列化 + compression_type=DoubaoProtocol.GZIP + ) + ) + request.extend(int(200).to_bytes(4, 'big')) + request.extend(len(self.session_id).to_bytes(4, 'big')) + request.extend(self.session_id.encode()) + + # 压缩音频数据 + compressed_audio = gzip.compress(audio_data) + request.extend(len(compressed_audio).to_bytes(4, 'big')) # payload size(4 bytes) + request.extend(compressed_audio) + + await self.ws.send(request) + + async def close(self) -> None: + """关闭连接""" + if self.ws: + try: + # 发送FinishSession + await self._send_finish_session() + + # 发送FinishConnection + await self._send_finish_connection() + + except Exception as e: + print(f"关闭会话时出错: {e}") + finally: + # 确保WebSocket连接关闭 + try: + await self.ws.close() + except: + pass + print("连接已关闭") + + async def _send_finish_session(self) -> None: + """发送FinishSession请求""" + request = bytearray(DoubaoProtocol.generate_header()) + request.extend(int(102).to_bytes(4, 'big')) + request.extend(len(self.session_id).to_bytes(4, 'big')) + request.extend(self.session_id.encode()) + + payload_bytes = b"{}" + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + + async def _send_finish_connection(self) -> None: + """发送FinishConnection请求""" + request = bytearray(DoubaoProtocol.generate_header()) + request.extend(int(2).to_bytes(4, 'big')) + + payload_bytes = b"{}" + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + response = await self.ws.recv() + parsed_response = DoubaoProtocol.parse_response(response) + print(f"FinishConnection响应: {parsed_response}") + + +class DoubaoProcessor: + """豆包音频处理器""" + + def __init__(self): + self.config = DoubaoConfig() + self.client = DoubaoClient(self.config) + + async def process_audio_file(self, input_file: str, output_file: str = None) -> str: + """处理音频文件 + + Args: + input_file: 输入音频文件路径 + output_file: 输出音频文件路径,如果为None则自动生成 + + Returns: + 输出音频文件路径 + """ + if not os.path.exists(input_file): + raise FileNotFoundError(f"音频文件不存在: {input_file}") + + # 生成输出文件名 + if output_file is None: + timestamp = time.strftime("%Y%m%d_%H%M%S") + output_file = f"doubao_response_{timestamp}.wav" + + try: + # 连接豆包服务器 + await self.client.connect() + + # 等待一会确保会话建立 + await asyncio.sleep(0.5) + + # 发送音频文件并获取响应 + received_audio = await self.client.send_audio_file(input_file) + + if received_audio: + print(f"总共接收到音频数据: {len(received_audio)} 字节") + # 转换为WAV格式保存(适配树莓派播放) + AudioProcessor.create_wav_file( + received_audio, + output_file, + sample_rate=24000, # 豆包返回的音频采样率 + channels=1, + sampwidth=2 # 16-bit + ) + print(f"响应音频已保存到: {output_file}") + + # 显示文件信息 + file_size = os.path.getsize(output_file) + print(f"输出文件大小: {file_size} 字节") + + else: + print("警告: 未接收到音频响应") + + return output_file + + except Exception as e: + print(f"处理音频文件时出错: {e}") + import traceback + traceback.print_exc() + raise + finally: + await self.client.close() + + +async def main(): + """测试函数""" + import argparse + + parser = argparse.ArgumentParser(description="豆包音频处理测试") + parser.add_argument("--input", type=str, required=True, help="输入音频文件路径") + parser.add_argument("--output", type=str, help="输出音频文件路径") + + args = parser.parse_args() + + processor = DoubaoProcessor() + try: + output_file = await processor.process_audio_file(args.input, args.output) + print(f"处理完成,输出文件: {output_file}") + except Exception as e: + print(f"处理失败: {e}") + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/doubao_debug.py b/doubao_debug.py new file mode 100644 index 0000000..8baeb88 --- /dev/null +++ b/doubao_debug.py @@ -0,0 +1,540 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +豆包音频处理模块 - 调试版本 +添加更多调试信息和错误处理 +""" + +import asyncio +import gzip +import json +import uuid +import wave +import struct +import time +import os +from typing import Dict, Any, Optional +import websockets + + +class DoubaoConfig: + """豆包配置""" + + def __init__(self): + self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" + self.app_id = "8718217928" + self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc" + self.app_key = "PlgvMymc7f3tQnJ6" + self.resource_id = "volc.speech.dialog" + + def get_headers(self) -> Dict[str, str]: + """获取请求头""" + return { + "X-Api-App-ID": self.app_id, + "X-Api-Access-Key": self.access_key, + "X-Api-Resource-Id": self.resource_id, + "X-Api-App-Key": self.app_key, + "X-Api-Connect-Id": str(uuid.uuid4()), + } + + +class DoubaoProtocol: + """豆包协议处理""" + + # 协议常量 + PROTOCOL_VERSION = 0b0001 + CLIENT_FULL_REQUEST = 0b0001 + CLIENT_AUDIO_ONLY_REQUEST = 0b0010 + SERVER_FULL_RESPONSE = 0b1001 + SERVER_ACK = 0b1011 + SERVER_ERROR_RESPONSE = 0b1111 + + NO_SEQUENCE = 0b0000 + POS_SEQUENCE = 0b0001 + MSG_WITH_EVENT = 0b0100 + JSON = 0b0001 + NO_SERIALIZATION = 0b0000 + GZIP = 0b0001 + + @classmethod + def generate_header(cls, message_type=CLIENT_FULL_REQUEST, + message_type_specific_flags=MSG_WITH_EVENT, + serial_method=JSON, compression_type=GZIP) -> bytes: + """生成协议头""" + header = bytearray() + header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size + header.append((message_type << 4) | message_type_specific_flags) + header.append((serial_method << 4) | compression_type) + header.append(0x00) # reserved + return bytes(header) + + @classmethod + def parse_response(cls, res: bytes) -> Dict[str, Any]: + """解析响应""" + if isinstance(res, str): + return {} + + try: + protocol_version = res[0] >> 4 + header_size = res[0] & 0x0f + message_type = res[1] >> 4 + message_type_specific_flags = res[1] & 0x0f + serialization_method = res[2] >> 4 + message_compression = res[2] & 0x0f + + payload = res[header_size * 4:] + result = {} + + if message_type == cls.SERVER_FULL_RESPONSE or message_type == cls.SERVER_ACK: + result['message_type'] = 'SERVER_FULL_RESPONSE' + if message_type == cls.SERVER_ACK: + result['message_type'] = 'SERVER_ACK' + + start = 0 + if message_type_specific_flags & cls.MSG_WITH_EVENT: + result['event'] = int.from_bytes(payload[:4], "big", signed=False) + start += 4 + + payload = payload[start:] + if len(payload) < 4: + result['error'] = 'Payload too short for session_id' + return result + + session_id_size = int.from_bytes(payload[:4], "big", signed=True) + if session_id_size < 0 or session_id_size > len(payload) - 4: + result['error'] = f'Invalid session_id size: {session_id_size}' + return result + + session_id = payload[4:session_id_size+4] + result['session_id'] = str(session_id) + payload = payload[4 + session_id_size:] + + if len(payload) < 4: + result['error'] = 'Payload too short for payload_size' + return result + + payload_size = int.from_bytes(payload[:4], "big", signed=False) + result['payload_size'] = payload_size + + if len(payload) >= 4 + payload_size: + payload_msg = payload[4:4 + payload_size] + if payload_msg: + if message_compression == cls.GZIP: + try: + payload_msg = gzip.decompress(payload_msg) + except Exception as e: + result['decompress_error'] = str(e) + return result + + if serialization_method == cls.JSON: + try: + payload_msg = json.loads(str(payload_msg, "utf-8")) + except Exception as e: + result['json_error'] = str(e) + payload_msg = str(payload_msg, "utf-8") + elif serialization_method != cls.NO_SERIALIZATION: + payload_msg = str(payload_msg, "utf-8") + result['payload_msg'] = payload_msg + + elif message_type == cls.SERVER_ERROR_RESPONSE: + if len(payload) >= 8: + code = int.from_bytes(payload[:4], "big", signed=False) + result['code'] = code + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + result['payload_size'] = payload_size + + if len(payload) >= 8 + payload_size: + payload_msg = payload[8:8 + payload_size] + if payload_msg and message_compression == cls.GZIP: + try: + payload_msg = gzip.decompress(payload_msg) + except: + pass + result['payload_msg'] = payload_msg + + except Exception as e: + result['parse_error'] = str(e) + + return result + + +class AudioProcessor: + """音频处理器""" + + @staticmethod + def read_wav_file(file_path: str) -> tuple: + """读取WAV文件,返回音频数据和参数""" + with wave.open(file_path, 'rb') as wf: + # 获取音频参数 + channels = wf.getnchannels() + sampwidth = wf.getsampwidth() + framerate = wf.getframerate() + nframes = wf.getnframes() + + # 读取音频数据 + audio_data = wf.readframes(nframes) + + return audio_data, { + 'channels': channels, + 'sampwidth': sampwidth, + 'framerate': framerate, + 'nframes': nframes + } + + @staticmethod + def create_wav_file(audio_data: bytes, output_path: str, + sample_rate: int = 24000, channels: int = 1, + sampwidth: int = 2) -> None: + """创建WAV文件,适配树莓派播放""" + with wave.open(output_path, 'wb') as wf: + wf.setnchannels(channels) + wf.setsampwidth(sampwidth) + wf.setframerate(sample_rate) + wf.writeframes(audio_data) + + +class DoubaoClient: + """豆包客户端""" + + def __init__(self, config: DoubaoConfig): + self.config = config + self.session_id = str(uuid.uuid4()) + self.ws = None + self.log_id = "" + + async def connect(self) -> None: + """建立WebSocket连接""" + print(f"连接豆包服务器: {self.config.base_url}") + + try: + self.ws = await websockets.connect( + self.config.base_url, + additional_headers=self.config.get_headers(), + ping_interval=None + ) + + # 获取log_id + if hasattr(self.ws, 'response_headers'): + self.log_id = self.ws.response_headers.get("X-Tt-Logid") + elif hasattr(self.ws, 'headers'): + self.log_id = self.ws.headers.get("X-Tt-Logid") + + print(f"连接成功, log_id: {self.log_id}") + + # 发送StartConnection请求 + await self._send_start_connection() + + # 发送StartSession请求 + await self._send_start_session() + + except Exception as e: + print(f"连接失败: {e}") + raise + + async def _send_start_connection(self) -> None: + """发送StartConnection请求""" + print("发送StartConnection请求...") + request = bytearray(DoubaoProtocol.generate_header()) + request.extend(int(1).to_bytes(4, 'big')) + + payload_bytes = b"{}" + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + response = await self.ws.recv() + parsed_response = DoubaoProtocol.parse_response(response) + print(f"StartConnection响应: {parsed_response}") + + # 检查是否有错误 + if 'error' in parsed_response: + raise Exception(f"StartConnection解析错误: {parsed_response['error']}") + + async def _send_start_session(self) -> None: + """发送StartSession请求""" + print("发送StartSession请求...") + session_config = { + "asr": { + "extra": { + "end_smooth_window_ms": 1500, + }, + }, + "tts": { + "speaker": "zh_female_vv_jupiter_bigtts", + "audio_config": { + "channel": 1, + "format": "pcm", + "sample_rate": 24000 + }, + }, + "dialog": { + "bot_name": "豆包", + "system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。", + "speaking_style": "你的说话风格简洁明了,语速适中,语调自然。", + "location": {"city": "北京"}, + "extra": { + "strict_audit": False, + "audit_response": "支持客户自定义安全审核回复话术。", + "recv_timeout": 30, # 增加超时时间 + "input_mod": "audio_file", # 使用音频文件模式 + }, + }, + } + + request = bytearray(DoubaoProtocol.generate_header()) + request.extend(int(100).to_bytes(4, 'big')) + request.extend(len(self.session_id).to_bytes(4, 'big')) + request.extend(self.session_id.encode()) + + payload_bytes = json.dumps(session_config).encode() + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + response = await self.ws.recv() + parsed_response = DoubaoProtocol.parse_response(response) + print(f"StartSession响应: {parsed_response}") + + # 检查是否有错误 + if 'error' in parsed_response: + raise Exception(f"StartSession解析错误: {parsed_response['error']}") + + # 等待一会确保会话完全建立 + await asyncio.sleep(1.0) + + async def send_audio_file(self, file_path: str) -> bytes: + """发送音频文件并返回响应音频""" + print(f"处理音频文件: {file_path}") + + # 读取音频文件 + audio_data, audio_info = AudioProcessor.read_wav_file(file_path) + print(f"音频参数: {audio_info}") + + # 计算分块大小(减小到50ms,避免数据块过大) + chunk_size = int(audio_info['framerate'] * audio_info['channels'] * + audio_info['sampwidth'] * 0.05) # 50ms + + # 分块发送音频数据 + total_chunks = (len(audio_data) + chunk_size - 1) // chunk_size + print(f"开始分块发送音频,共 {total_chunks} 块") + + received_audio = b"" + error_count = 0 + session_active = True + + for i in range(0, len(audio_data), chunk_size): + if not session_active: + print("会话已结束,停止发送") + break + + chunk = audio_data[i:i + chunk_size] + is_last = (i + chunk_size >= len(audio_data)) + + # 发送音频块 + await self._send_audio_chunk(chunk, is_last) + + # 接收响应 + try: + response = await asyncio.wait_for(self.ws.recv(), timeout=3.0) + parsed_response = DoubaoProtocol.parse_response(response) + + print(f"响应 {i//chunk_size + 1}/{total_chunks}: {parsed_response}") + + # 检查是否是错误响应 + if 'code' in parsed_response and parsed_response['code'] != 0: + print(f"服务器返回错误: {parsed_response}") + error_count += 1 + if error_count > 3: + raise Exception(f"服务器连续返回错误: {parsed_response}") + continue + + # 处理音频响应 + if (parsed_response.get('message_type') == 'SERVER_ACK' and + isinstance(parsed_response.get('payload_msg'), bytes)): + audio_chunk = parsed_response['payload_msg'] + received_audio += audio_chunk + print(f"接收到音频数据块,大小: {len(audio_chunk)} 字节") + + # 检查会话状态 + event = parsed_response.get('event') + if event in [359, 152, 153]: # 这些事件表示会话相关状态 + print(f"会话事件: {event}") + if event in [152, 153]: # 会话结束 + print("检测到会话结束事件") + session_active = False + break + + except asyncio.TimeoutError: + print("等待响应超时,继续发送") + + # 模拟实时发送的延迟 + await asyncio.sleep(0.1) + + print("音频文件发送完成") + return received_audio + + async def _send_audio_chunk(self, audio_data: bytes, is_last: bool = False) -> None: + """发送音频块""" + request = bytearray( + DoubaoProtocol.generate_header( + message_type=DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST, + message_type_specific_flags=DoubaoProtocol.NO_SEQUENCE, + serial_method=DoubaoProtocol.NO_SERIALIZATION, # 音频数据不需要序列化 + compression_type=DoubaoProtocol.GZIP + ) + ) + request.extend(int(200).to_bytes(4, 'big')) + request.extend(len(self.session_id).to_bytes(4, 'big')) + request.extend(self.session_id.encode()) + + # 压缩音频数据 + compressed_audio = gzip.compress(audio_data) + payload_size = len(compressed_audio) + request.extend(payload_size.to_bytes(4, 'big')) # payload size(4 bytes) + request.extend(compressed_audio) + + print(f"发送音频块 - 原始大小: {len(audio_data)}, 压缩后大小: {payload_size}, 总请求数据大小: {len(request)}") + + await self.ws.send(request) + + async def close(self) -> None: + """关闭连接""" + if self.ws: + try: + # 发送FinishSession + await self._send_finish_session() + + # 发送FinishConnection + await self._send_finish_connection() + + except Exception as e: + print(f"关闭会话时出错: {e}") + finally: + # 确保WebSocket连接关闭 + try: + await self.ws.close() + except: + pass + print("连接已关闭") + + async def _send_finish_session(self) -> None: + """发送FinishSession请求""" + print("发送FinishSession请求...") + request = bytearray(DoubaoProtocol.generate_header()) + request.extend(int(102).to_bytes(4, 'big')) + request.extend(len(self.session_id).to_bytes(4, 'big')) + request.extend(self.session_id.encode()) + + payload_bytes = b"{}" + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + + async def _send_finish_connection(self) -> None: + """发送FinishConnection请求""" + print("发送FinishConnection请求...") + request = bytearray(DoubaoProtocol.generate_header()) + request.extend(int(2).to_bytes(4, 'big')) + + payload_bytes = b"{}" + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + + try: + response = await asyncio.wait_for(self.ws.recv(), timeout=5.0) + parsed_response = DoubaoProtocol.parse_response(response) + print(f"FinishConnection响应: {parsed_response}") + except asyncio.TimeoutError: + print("FinishConnection响应超时") + + +class DoubaoProcessor: + """豆包音频处理器""" + + def __init__(self): + self.config = DoubaoConfig() + self.client = DoubaoClient(self.config) + + async def process_audio_file(self, input_file: str, output_file: str = None) -> str: + """处理音频文件 + + Args: + input_file: 输入音频文件路径 + output_file: 输出音频文件路径,如果为None则自动生成 + + Returns: + 输出音频文件路径 + """ + if not os.path.exists(input_file): + raise FileNotFoundError(f"音频文件不存在: {input_file}") + + # 生成输出文件名 + if output_file is None: + timestamp = time.strftime("%Y%m%d_%H%M%S") + output_file = f"doubao_response_{timestamp}.wav" + + try: + # 连接豆包服务器 + await self.client.connect() + + # 发送音频文件并获取响应 + received_audio = await self.client.send_audio_file(input_file) + + if received_audio: + print(f"总共接收到音频数据: {len(received_audio)} 字节") + # 转换为WAV格式保存(适配树莓派播放) + AudioProcessor.create_wav_file( + received_audio, + output_file, + sample_rate=24000, # 豆包返回的音频采样率 + channels=1, + sampwidth=2 # 16-bit + ) + print(f"响应音频已保存到: {output_file}") + + # 显示文件信息 + file_size = os.path.getsize(output_file) + print(f"输出文件大小: {file_size} 字节") + + else: + print("警告: 未接收到音频响应") + + return output_file + + except Exception as e: + print(f"处理音频文件时出错: {e}") + import traceback + traceback.print_exc() + raise + finally: + await self.client.close() + + +async def main(): + """测试函数""" + import argparse + + parser = argparse.ArgumentParser(description="豆包音频处理测试") + parser.add_argument("--input", type=str, required=True, help="输入音频文件路径") + parser.add_argument("--output", type=str, help="输出音频文件路径") + + args = parser.parse_args() + + processor = DoubaoProcessor() + try: + output_file = await processor.process_audio_file(args.input, args.output) + print(f"处理完成,输出文件: {output_file}") + except Exception as e: + print(f"处理失败: {e}") + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/doubao_final_test.py b/doubao_final_test.py new file mode 100644 index 0000000..b0480f4 --- /dev/null +++ b/doubao_final_test.py @@ -0,0 +1,425 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +豆包音频处理模块 - 简化测试版本 +专门测试完整的音频上传和TTS音频下载流程 +""" + +import asyncio +import gzip +import json +import uuid +import wave +import struct +import time +import os +from typing import Dict, Any, Optional +import websockets + + +# 直接复制原始豆包代码的协议常量 +PROTOCOL_VERSION = 0b0001 +CLIENT_FULL_REQUEST = 0b0001 +CLIENT_AUDIO_ONLY_REQUEST = 0b0010 +SERVER_FULL_RESPONSE = 0b1001 +SERVER_ACK = 0b1011 +SERVER_ERROR_RESPONSE = 0b1111 + +NO_SEQUENCE = 0b0000 +POS_SEQUENCE = 0b0001 +MSG_WITH_EVENT = 0b0100 + +NO_SERIALIZATION = 0b0000 +JSON = 0b0001 +GZIP = 0b0001 + + +def generate_header( + version=PROTOCOL_VERSION, + message_type=CLIENT_FULL_REQUEST, + message_type_specific_flags=MSG_WITH_EVENT, + serial_method=JSON, + compression_type=GZIP, + reserved_data=0x00, + extension_header=bytes() +): + """直接复制原始豆包代码的generate_header函数""" + header = bytearray() + header_size = int(len(extension_header) / 4) + 1 + header.append((version << 4) | header_size) + header.append((message_type << 4) | message_type_specific_flags) + header.append((serial_method << 4) | compression_type) + header.append(reserved_data) + header.extend(extension_header) + return header + + +class DoubaoConfig: + """豆包配置""" + + def __init__(self): + self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" + self.app_id = "8718217928" + self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc" + self.app_key = "PlgvMymc7f3tQnJ6" + self.resource_id = "volc.speech.dialog" + + def get_headers(self) -> Dict[str, str]: + """获取请求头""" + return { + "X-Api-App-ID": self.app_id, + "X-Api-Access-Key": self.access_key, + "X-Api-Resource-Id": self.resource_id, + "X-Api-App-Key": self.app_key, + "X-Api-Connect-Id": str(uuid.uuid4()), + } + + +class DoubaoClient: + """豆包客户端 - 基于原始代码""" + + def __init__(self, config: DoubaoConfig): + self.config = config + self.session_id = str(uuid.uuid4()) + self.ws = None + self.log_id = "" + + async def connect(self) -> None: + """建立WebSocket连接""" + print(f"连接豆包服务器: {self.config.base_url}") + + try: + self.ws = await websockets.connect( + self.config.base_url, + additional_headers=self.config.get_headers(), + ping_interval=None + ) + + # 获取log_id + if hasattr(self.ws, 'response_headers'): + self.log_id = self.ws.response_headers.get("X-Tt-Logid") + elif hasattr(self.ws, 'headers'): + self.log_id = self.ws.headers.get("X-Tt-Logid") + + print(f"连接成功, log_id: {self.log_id}") + + # 发送StartConnection请求 + await self._send_start_connection() + + # 发送StartSession请求 + await self._send_start_session() + + except Exception as e: + print(f"连接失败: {e}") + raise + + async def _send_start_connection(self) -> None: + """发送StartConnection请求""" + print("发送StartConnection请求...") + request = bytearray(generate_header()) + request.extend(int(1).to_bytes(4, 'big')) + + payload_bytes = b"{}" + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + response = await self.ws.recv() + print(f"StartConnection响应长度: {len(response)}") + + async def _send_start_session(self) -> None: + """发送StartSession请求""" + print("发送StartSession请求...") + session_config = { + "asr": { + "extra": { + "end_smooth_window_ms": 1500, + }, + }, + "tts": { + "speaker": "zh_female_vv_jupiter_bigtts", + "audio_config": { + "channel": 1, + "format": "pcm", + "sample_rate": 24000 + }, + }, + "dialog": { + "bot_name": "豆包", + "system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。", + "speaking_style": "你的说话风格简洁明了,语速适中,语调自然。", + "location": {"city": "北京"}, + "extra": { + "strict_audit": False, + "audit_response": "支持客户自定义安全审核回复话术。", + "recv_timeout": 30, + "input_mod": "audio", + }, + }, + } + + request = bytearray(generate_header()) + request.extend(int(100).to_bytes(4, 'big')) + request.extend(len(self.session_id).to_bytes(4, 'big')) + request.extend(self.session_id.encode()) + + payload_bytes = json.dumps(session_config).encode() + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + response = await self.ws.recv() + print(f"StartSession响应长度: {len(response)}") + + # 等待一会确保会话完全建立 + await asyncio.sleep(1.0) + + async def task_request(self, audio: bytes) -> None: + """直接复制原始豆包代码的task_request方法""" + task_request = bytearray( + generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST, + serial_method=NO_SERIALIZATION)) + task_request.extend(int(200).to_bytes(4, 'big')) + task_request.extend((len(self.session_id)).to_bytes(4, 'big')) + task_request.extend(str.encode(self.session_id)) + payload_bytes = gzip.compress(audio) + task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) + task_request.extend(payload_bytes) + await self.ws.send(task_request) + + async def test_full_dialog(self) -> None: + """测试完整对话流程""" + print("开始完整对话测试...") + + # 读取真实的录音文件 + try: + import wave + with wave.open("recording_20250920_135137.wav", 'rb') as wf: + # 读取前5秒的音频数据 + total_frames = wf.getnframes() + frames_to_read = min(total_frames, 80000) # 5秒 + audio_data = wf.readframes(frames_to_read) + print(f"读取真实音频数据: {len(audio_data)} 字节") + print(f"音频参数: 采样率={wf.getframerate()}, 通道数={wf.getnchannels()}, 采样宽度={wf.getsampwidth()}") + except Exception as e: + print(f"读取音频文件失败: {e}") + return + + print(f"音频数据大小: {len(audio_data)}") + + try: + # 发送音频数据 + print("发送音频数据...") + await self.task_request(audio_data) + print("音频数据发送成功") + + # 等待语音识别响应 + print("等待语音识别响应...") + response = await asyncio.wait_for(self.ws.recv(), timeout=15.0) + print(f"收到ASR响应,长度: {len(response)}") + + # 解析ASR响应 + if len(response) >= 4: + protocol_version = response[0] >> 4 + header_size = response[0] & 0x0f + message_type = response[1] >> 4 + flags = response[1] & 0x0f + print(f"ASR响应协议: version={protocol_version}, header_size={header_size}, message_type={message_type}, flags={flags}") + + if message_type == 9: # SERVER_FULL_RESPONSE + payload_start = header_size * 4 + payload = response[payload_start:] + + if len(payload) >= 4: + event = int.from_bytes(payload[:4], 'big') + print(f"ASR Event: {event}") + + if len(payload) >= 8: + session_id_len = int.from_bytes(payload[4:8], 'big') + if len(payload) >= 8 + session_id_len: + session_id = payload[8:8+session_id_len].decode() + print(f"Session ID: {session_id}") + + if len(payload) >= 12 + session_id_len: + payload_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big') + payload_data = payload[12+session_id_len:12+session_id_len+payload_size] + print(f"Payload size: {payload_size}") + + # 解析ASR结果 + try: + asr_result = json.loads(payload_data.decode('utf-8')) + print(f"ASR结果: {asr_result}") + + # 如果有识别结果,提取文本 + if 'results' in asr_result and asr_result['results']: + text = asr_result['results'][0].get('text', '') + print(f"识别文本: {text}") + + except Exception as e: + print(f"解析ASR结果失败: {e}") + + # 持续等待TTS音频响应 + print("开始持续等待TTS音频响应...") + response_count = 0 + max_responses = 10 + + while response_count < max_responses: + try: + print(f"等待第 {response_count + 1} 个响应...") + tts_response = await asyncio.wait_for(self.ws.recv(), timeout=30.0) + print(f"收到响应 {response_count + 1},长度: {len(tts_response)}") + + # 解析响应 + if len(tts_response) >= 4: + tts_version = tts_response[0] >> 4 + tts_header_size = tts_response[0] & 0x0f + tts_message_type = tts_response[1] >> 4 + tts_flags = tts_response[1] & 0x0f + print(f"响应协议: version={tts_version}, header_size={tts_header_size}, message_type={tts_message_type}, flags={tts_flags}") + + if tts_message_type == 11: # SERVER_ACK (包含TTS音频) + tts_payload_start = tts_header_size * 4 + tts_payload = tts_response[tts_payload_start:] + + if len(tts_payload) >= 12: + tts_event = int.from_bytes(tts_payload[:4], 'big') + tts_session_len = int.from_bytes(tts_payload[4:8], 'big') + tts_session = tts_payload[8:8+tts_session_len].decode() + tts_audio_size = int.from_bytes(tts_payload[8+tts_session_len:12+tts_session_len], 'big') + tts_audio_data = tts_payload[12+tts_session_len:12+tts_session_len+tts_audio_size] + + print(f"Event: {tts_event}") + print(f"音频数据大小: {tts_audio_size}") + + if tts_audio_size > 0: + print("找到TTS音频数据!") + # 尝试解压缩TTS音频 + try: + decompressed_tts = gzip.decompress(tts_audio_data) + print(f"解压缩后TTS音频大小: {len(decompressed_tts)}") + + # 创建WAV文件 + sample_rate = 24000 + channels = 1 + sampwidth = 2 + + with wave.open(f'tts_response_{response_count}.wav', 'wb') as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(sampwidth) + wav_file.setframerate(sample_rate) + wav_file.writeframes(decompressed_tts) + + print(f"成功创建TTS WAV文件: tts_response_{response_count}.wav") + print(f"音频参数: {sample_rate}Hz, {channels}通道, {sampwidth*8}-bit") + + # 显示文件信息 + if os.path.exists(f'tts_response_{response_count}.wav'): + file_size = os.path.getsize(f'tts_response_{response_count}.wav') + duration = file_size / (sample_rate * channels * sampwidth) + print(f"WAV文件大小: {file_size} 字节") + print(f"音频时长: {duration:.2f} 秒") + + # 成功获取音频,退出循环 + break + + except Exception as tts_e: + print(f"TTS音频解压缩失败: {tts_e}") + # 保存原始数据 + with open(f'tts_response_audio_{response_count}.raw', 'wb') as f: + f.write(tts_audio_data) + print(f"原始TTS音频数据已保存到 tts_response_audio_{response_count}.raw") + + elif tts_message_type == 9: # SERVER_FULL_RESPONSE + tts_payload_start = tts_header_size * 4 + tts_payload = tts_response[tts_payload_start:] + + if len(tts_payload) >= 4: + event = int.from_bytes(tts_payload[:4], 'big') + print(f"Event: {event}") + + if event in [451, 359]: # ASR结果或TTS结束 + # 解析payload + if len(tts_payload) >= 8: + session_id_len = int.from_bytes(tts_payload[4:8], 'big') + if len(tts_payload) >= 8 + session_id_len: + session_id = tts_payload[8:8+session_id_len].decode() + if len(tts_payload) >= 12 + session_id_len: + payload_size = int.from_bytes(tts_payload[8+session_id_len:12+session_id_len], 'big') + payload_data = tts_payload[12+session_id_len:12+session_id_len+payload_size] + + try: + json_data = json.loads(payload_data.decode('utf-8')) + print(f"JSON数据: {json_data}") + + # 如果是ASR结果 + if 'results' in json_data: + text = json_data['results'][0].get('text', '') + print(f"识别文本: {text}") + + # 如果是TTS结束标记 + if event == 359: + print("TTS响应结束") + break + + except Exception as e: + print(f"解析JSON失败: {e}") + # 保存原始数据 + with open(f'tts_response_{response_count}.raw', 'wb') as f: + f.write(payload_data) + + # 保存完整响应用于调试 + with open(f'tts_response_full_{response_count}.raw', 'wb') as f: + f.write(tts_response) + print(f"完整响应已保存到 tts_response_full_{response_count}.raw") + + response_count += 1 + + except asyncio.TimeoutError: + print(f"等待第 {response_count + 1} 个响应超时") + break + except websockets.exceptions.ConnectionClosed: + print("连接已关闭") + break + + print(f"共收到 {response_count} 个响应") + + except asyncio.TimeoutError: + print("等待响应超时") + except websockets.exceptions.ConnectionClosed as e: + print(f"连接关闭: {e}") + except Exception as e: + print(f"测试失败: {e}") + import traceback + traceback.print_exc() + + async def close(self) -> None: + """关闭连接""" + if self.ws: + try: + await self.ws.close() + except: + pass + print("连接已关闭") + + +async def main(): + """测试函数""" + config = DoubaoConfig() + client = DoubaoClient(config) + + try: + await client.connect() + await client.test_full_dialog() + except Exception as e: + print(f"测试失败: {e}") + import traceback + traceback.print_exc() + finally: + await client.close() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/doubao_original_test.py b/doubao_original_test.py new file mode 100644 index 0000000..0587ea4 --- /dev/null +++ b/doubao_original_test.py @@ -0,0 +1,412 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +豆包音频处理模块 - 基于原始代码的测试版本 +直接使用原始豆包代码的核心逻辑 +""" + +import asyncio +import gzip +import json +import uuid +import wave +import struct +import time +import os +from typing import Dict, Any, Optional +import websockets + + +# 直接复制原始豆包代码的协议常量 +PROTOCOL_VERSION = 0b0001 +CLIENT_FULL_REQUEST = 0b0001 +CLIENT_AUDIO_ONLY_REQUEST = 0b0010 +SERVER_FULL_RESPONSE = 0b1001 +SERVER_ACK = 0b1011 +SERVER_ERROR_RESPONSE = 0b1111 + +NO_SEQUENCE = 0b0000 +POS_SEQUENCE = 0b0001 +MSG_WITH_EVENT = 0b0100 + +NO_SERIALIZATION = 0b0000 +JSON = 0b0001 +GZIP = 0b0001 + + +def generate_header( + version=PROTOCOL_VERSION, + message_type=CLIENT_FULL_REQUEST, + message_type_specific_flags=MSG_WITH_EVENT, + serial_method=JSON, + compression_type=GZIP, + reserved_data=0x00, + extension_header=bytes() +): + """直接复制原始豆包代码的generate_header函数""" + header = bytearray() + header_size = int(len(extension_header) / 4) + 1 + header.append((version << 4) | header_size) + header.append((message_type << 4) | message_type_specific_flags) + header.append((serial_method << 4) | compression_type) + header.append(reserved_data) + header.extend(extension_header) + return header + + +class DoubaoConfig: + """豆包配置""" + + def __init__(self): + self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" + self.app_id = "8718217928" + self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc" + self.app_key = "PlgvMymc7f3tQnJ6" + self.resource_id = "volc.speech.dialog" + + def get_headers(self) -> Dict[str, str]: + """获取请求头""" + return { + "X-Api-App-ID": self.app_id, + "X-Api-Access-Key": self.access_key, + "X-Api-Resource-Id": self.resource_id, + "X-Api-App-Key": self.app_key, + "X-Api-Connect-Id": str(uuid.uuid4()), + } + + +class DoubaoClient: + """豆包客户端 - 基于原始代码""" + + def __init__(self, config: DoubaoConfig): + self.config = config + self.session_id = str(uuid.uuid4()) + self.ws = None + self.log_id = "" + + async def connect(self) -> None: + """建立WebSocket连接""" + print(f"连接豆包服务器: {self.config.base_url}") + + try: + self.ws = await websockets.connect( + self.config.base_url, + additional_headers=self.config.get_headers(), + ping_interval=None + ) + + # 获取log_id + if hasattr(self.ws, 'response_headers'): + self.log_id = self.ws.response_headers.get("X-Tt-Logid") + elif hasattr(self.ws, 'headers'): + self.log_id = self.ws.headers.get("X-Tt-Logid") + + print(f"连接成功, log_id: {self.log_id}") + + # 发送StartConnection请求 + await self._send_start_connection() + + # 发送StartSession请求 + await self._send_start_session() + + except Exception as e: + print(f"连接失败: {e}") + raise + + async def _send_start_connection(self) -> None: + """发送StartConnection请求""" + print("发送StartConnection请求...") + request = bytearray(generate_header()) + request.extend(int(1).to_bytes(4, 'big')) + + payload_bytes = b"{}" + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + response = await self.ws.recv() + print(f"StartConnection响应长度: {len(response)}") + + async def _send_start_session(self) -> None: + """发送StartSession请求""" + print("发送StartSession请求...") + session_config = { + "asr": { + "extra": { + "end_smooth_window_ms": 1500, + }, + }, + "tts": { + "speaker": "zh_female_vv_jupiter_bigtts", + "audio_config": { + "channel": 1, + "format": "pcm", + "sample_rate": 24000 + }, + }, + "dialog": { + "bot_name": "豆包", + "system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。", + "speaking_style": "你的说话风格简洁明了,语速适中,语调自然。", + "location": {"city": "北京"}, + "extra": { + "strict_audit": False, + "audit_response": "支持客户自定义安全审核回复话术。", + "recv_timeout": 30, + "input_mod": "audio", + }, + }, + } + + request = bytearray(generate_header()) + request.extend(int(100).to_bytes(4, 'big')) + request.extend(len(self.session_id).to_bytes(4, 'big')) + request.extend(self.session_id.encode()) + + payload_bytes = json.dumps(session_config).encode() + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + response = await self.ws.recv() + print(f"StartSession响应长度: {len(response)}") + + # 等待一会确保会话完全建立 + await asyncio.sleep(1.0) + + async def task_request(self, audio: bytes) -> None: + """直接复制原始豆包代码的task_request方法""" + task_request = bytearray( + generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST, + serial_method=NO_SERIALIZATION)) + task_request.extend(int(200).to_bytes(4, 'big')) + task_request.extend((len(self.session_id)).to_bytes(4, 'big')) + task_request.extend(str.encode(self.session_id)) + payload_bytes = gzip.compress(audio) + task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) + task_request.extend(payload_bytes) + await self.ws.send(task_request) + + async def test_audio_request(self) -> None: + """测试音频请求""" + print("测试音频请求...") + + # 读取真实的录音文件 + try: + import wave + with wave.open("recording_20250920_135137.wav", 'rb') as wf: + # 读取前10秒的音频数据(16000采样率 * 10秒 = 160000帧) + total_frames = wf.getnframes() + frames_to_read = min(total_frames, 160000) # 最多10秒 + small_audio = wf.readframes(frames_to_read) + print(f"读取真实音频数据: {len(small_audio)} 字节") + print(f"音频参数: 采样率={wf.getframerate()}, 通道数={wf.getnchannels()}, 采样宽度={wf.getsampwidth()}") + print(f"总帧数: {total_frames}, 读取帧数: {frames_to_read}") + except Exception as e: + print(f"读取音频文件失败: {e}") + # 如果读取失败,使用静音数据 + small_audio = b'\x00' * 3200 + + print(f"音频数据大小: {len(small_audio)}") + + try: + # 发送完整的音频数据块 + print(f"发送完整的音频数据块...") + await self.task_request(small_audio) + print(f"音频数据块发送成功") + + print("等待语音识别响应...") + + # 等待更长时间的响应(语音识别可能需要更长时间) + response = await asyncio.wait_for(self.ws.recv(), timeout=15.0) + print(f"收到响应,长度: {len(response)}") + + # 解析响应 + try: + if len(response) >= 4: + protocol_version = response[0] >> 4 + header_size = response[0] & 0x0f + message_type = response[1] >> 4 + message_type_specific_flags = response[1] & 0x0f + print(f"响应协议信息: version={protocol_version}, header_size={header_size}, message_type={message_type}, flags={message_type_specific_flags}") + + # 解析payload + payload_start = header_size * 4 + payload = response[payload_start:] + + if message_type == 9: # SERVER_FULL_RESPONSE + print("收到SERVER_FULL_RESPONSE!") + if len(payload) >= 4: + # 解析event + event = int.from_bytes(payload[:4], 'big') + print(f"Event: {event}") + + # 解析session_id + if len(payload) >= 8: + session_id_len = int.from_bytes(payload[4:8], 'big') + if len(payload) >= 8 + session_id_len: + session_id = payload[8:8+session_id_len].decode() + print(f"Session ID: {session_id}") + + # 解析payload size和data + if len(payload) >= 12 + session_id_len: + payload_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big') + payload_data = payload[12+session_id_len:12+session_id_len+payload_size] + print(f"Payload size: {payload_size}") + + # 如果包含音频数据,保存到文件 + if len(payload_data) > 0: + print(f"收到数据: {len(payload_data)} 字节") + # 保存原始音频数据 + with open('response_audio.raw', 'wb') as f: + f.write(payload_data) + print("音频数据已保存到 response_audio.raw") + + # 尝试解析JSON数据 + try: + import json + json_data = json.loads(payload_data.decode('utf-8')) + print(f"JSON数据: {json_data}") + + # 如果是语音识别任务开始,继续等待音频响应 + if 'asr_task_id' in json_data: + print("语音识别任务开始,继续等待音频响应...") + try: + # 等待音频响应 + audio_response = await asyncio.wait_for(self.ws.recv(), timeout=20.0) + print(f"收到音频响应,长度: {len(audio_response)}") + + # 解析音频响应 + if len(audio_response) >= 4: + audio_version = audio_response[0] >> 4 + audio_header_size = audio_response[0] & 0x0f + audio_message_type = audio_response[1] >> 4 + audio_flags = audio_response[1] & 0x0f + print(f"音频响应协议信息: version={audio_version}, header_size={audio_header_size}, message_type={audio_message_type}, flags={audio_flags}") + + if audio_message_type == 9: # SERVER_FULL_RESPONSE (包含TTS音频) + audio_payload_start = audio_header_size * 4 + audio_payload = audio_response[audio_payload_start:] + + if len(audio_payload) >= 12: + # 解析event和session_id + audio_event = int.from_bytes(audio_payload[:4], 'big') + audio_session_len = int.from_bytes(audio_payload[4:8], 'big') + audio_session = audio_payload[8:8+audio_session_len].decode() + audio_data_size = int.from_bytes(audio_payload[8+audio_session_len:12+audio_session_len], 'big') + audio_data = audio_payload[12+audio_session_len:12+audio_session_len+audio_data_size] + + print(f"音频Event: {audio_event}") + print(f"音频数据大小: {audio_data_size}") + + if audio_data_size > 0: + # 保存原始音频数据 + with open('tts_response_audio.raw', 'wb') as f: + f.write(audio_data) + print(f"TTS音频数据已保存到 tts_response_audio.raw") + + # 尝试解析音频数据(可能是JSON或GZIP压缩的音频) + try: + # 首先尝试解压缩 + import gzip + decompressed_audio = gzip.decompress(audio_data) + print(f"解压缩后音频数据大小: {len(decompressed_audio)}") + with open('tts_response_audio_decompressed.raw', 'wb') as f: + f.write(decompressed_audio) + print("解压缩的音频数据已保存") + + # 创建WAV文件供树莓派播放 + import wave + import struct + + # 豆包返回的音频是24000Hz, 16-bit, 单声道 + sample_rate = 24000 + channels = 1 + sampwidth = 2 # 16-bit = 2 bytes + + with wave.open('tts_response.wav', 'wb') as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(sampwidth) + wav_file.setframerate(sample_rate) + wav_file.writeframes(decompressed_audio) + + print("已创建WAV文件: tts_response.wav") + print(f"音频参数: {sample_rate}Hz, {channels}通道, {sampwidth*8}-bit") + + except Exception as audio_e: + print(f"音频数据处理失败: {audio_e}") + # 如果解压缩失败,直接保存原始数据 + with open('tts_response_audio_original.raw', 'wb') as f: + f.write(audio_data) + + elif audio_message_type == 11: # SERVER_ACK + print("收到SERVER_ACK音频响应") + # 处理SERVER_ACK格式的音频响应 + audio_payload_start = audio_header_size * 4 + audio_payload = audio_response[audio_payload_start:] + print(f"音频payload长度: {len(audio_payload)}") + with open('tts_response_ack.raw', 'wb') as f: + f.write(audio_payload) + + except asyncio.TimeoutError: + print("等待音频响应超时") + + except Exception as json_e: + print(f"解析JSON失败: {json_e}") + # 如果不是JSON,可能是音频数据,直接保存 + with open('response_audio.raw', 'wb') as f: + f.write(payload_data) + + elif message_type == 11: # SERVER_ACK + print("收到SERVER_ACK响应!") + + elif message_type == 15: # SERVER_ERROR_RESPONSE + print("收到错误响应") + if len(response) > 8: + error_code = int.from_bytes(response[4:8], 'big') + print(f"错误代码: {error_code}") + + except Exception as e: + print(f"解析响应失败: {e}") + import traceback + traceback.print_exc() + + except asyncio.TimeoutError: + print("等待响应超时") + except websockets.exceptions.ConnectionClosed as e: + print(f"连接关闭: {e}") + except Exception as e: + print(f"发送音频请求失败: {e}") + raise + + async def close(self) -> None: + """关闭连接""" + if self.ws: + try: + await self.ws.close() + except: + pass + print("连接已关闭") + + +async def main(): + """测试函数""" + config = DoubaoConfig() + client = DoubaoClient(config) + + try: + await client.connect() + await client.test_audio_request() + except Exception as e: + print(f"测试失败: {e}") + import traceback + traceback.print_exc() + finally: + await client.close() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/doubao_simple.py b/doubao_simple.py new file mode 100644 index 0000000..5d3a7f9 --- /dev/null +++ b/doubao_simple.py @@ -0,0 +1,412 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +豆包音频处理模块 - 最终简化版本 +实现音频文件上传和TTS音频下载的完整流程 +""" + +import asyncio +import gzip +import json +import uuid +import wave +import time +import os +from typing import Dict, Any, Optional +import websockets + + +# 协议常量 +PROTOCOL_VERSION = 0b0001 +CLIENT_FULL_REQUEST = 0b0001 +CLIENT_AUDIO_ONLY_REQUEST = 0b0010 +SERVER_FULL_RESPONSE = 0b1001 +SERVER_ACK = 0b1011 +SERVER_ERROR_RESPONSE = 0b1111 + +NO_SEQUENCE = 0b0000 +MSG_WITH_EVENT = 0b0100 + +NO_SERIALIZATION = 0b0000 +JSON = 0b0001 +GZIP = 0b0001 + + +def generate_header( + version=PROTOCOL_VERSION, + message_type=CLIENT_FULL_REQUEST, + message_type_specific_flags=MSG_WITH_EVENT, + serial_method=JSON, + compression_type=GZIP, + reserved_data=0x00, + extension_header=bytes() +): + """生成协议头""" + header = bytearray() + header_size = int(len(extension_header) / 4) + 1 + header.append((version << 4) | header_size) + header.append((message_type << 4) | message_type_specific_flags) + header.append((serial_method << 4) | compression_type) + header.append(reserved_data) + header.extend(extension_header) + return header + + +class DoubaoClient: + """豆包客户端""" + + def __init__(self): + self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" + self.app_id = "8718217928" + self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc" + self.app_key = "PlgvMymc7f3tQnJ6" + self.resource_id = "volc.speech.dialog" + self.session_id = str(uuid.uuid4()) + self.ws = None + self.log_id = "" + + def get_headers(self) -> Dict[str, str]: + """获取请求头""" + return { + "X-Api-App-ID": self.app_id, + "X-Api-Access-Key": self.access_key, + "X-Api-Resource-Id": self.resource_id, + "X-Api-App-Key": self.app_key, + "X-Api-Connect-Id": str(uuid.uuid4()), + } + + async def connect(self) -> None: + """建立WebSocket连接""" + print(f"连接豆包服务器: {self.base_url}") + + try: + self.ws = await websockets.connect( + self.base_url, + additional_headers=self.get_headers(), + ping_interval=None + ) + + # 获取log_id + if hasattr(self.ws, 'response_headers'): + self.log_id = self.ws.response_headers.get("X-Tt-Logid") + elif hasattr(self.ws, 'headers'): + self.log_id = self.ws.headers.get("X-Tt-Logid") + + print(f"连接成功, log_id: {self.log_id}") + + # 发送StartConnection请求 + await self._send_start_connection() + + # 发送StartSession请求 + await self._send_start_session() + + except Exception as e: + print(f"连接失败: {e}") + raise + + async def _send_start_connection(self) -> None: + """发送StartConnection请求""" + print("发送StartConnection请求...") + request = bytearray(generate_header()) + request.extend(int(1).to_bytes(4, 'big')) + + payload_bytes = b"{}" + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + response = await self.ws.recv() + print(f"StartConnection响应长度: {len(response)}") + + async def _send_start_session(self) -> None: + """发送StartSession请求""" + print("发送StartSession请求...") + session_config = { + "asr": { + "extra": { + "end_smooth_window_ms": 1500, + }, + }, + "tts": { + "speaker": "zh_female_vv_jupiter_bigtts", + "audio_config": { + "channel": 1, + "format": "pcm", + "sample_rate": 24000 + }, + }, + "dialog": { + "bot_name": "豆包", + "system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。", + "speaking_style": "你的说话风格简洁明了,语速适中,语调自然。", + "location": {"city": "北京"}, + "extra": { + "strict_audit": False, + "audit_response": "支持客户自定义安全审核回复话术。", + "recv_timeout": 30, + "input_mod": "audio", + }, + }, + } + + request = bytearray(generate_header()) + request.extend(int(100).to_bytes(4, 'big')) + request.extend(len(self.session_id).to_bytes(4, 'big')) + request.extend(self.session_id.encode()) + + payload_bytes = json.dumps(session_config).encode() + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + response = await self.ws.recv() + print(f"StartSession响应长度: {len(response)}") + + # 等待一会确保会话完全建立 + await asyncio.sleep(1.0) + + async def task_request(self, audio: bytes) -> None: + """发送音频数据""" + task_request = bytearray( + generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST, + serial_method=NO_SERIALIZATION)) + task_request.extend(int(200).to_bytes(4, 'big')) + task_request.extend((len(self.session_id)).to_bytes(4, 'big')) + task_request.extend(str.encode(self.session_id)) + payload_bytes = gzip.compress(audio) + task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) + task_request.extend(payload_bytes) + await self.ws.send(task_request) + + def parse_response(self, response): + """解析响应""" + if len(response) < 4: + return None + + protocol_version = response[0] >> 4 + header_size = response[0] & 0x0f + message_type = response[1] >> 4 + flags = response[1] & 0x0f + + payload_start = header_size * 4 + payload = response[payload_start:] + + result = { + 'protocol_version': protocol_version, + 'header_size': header_size, + 'message_type': message_type, + 'flags': flags, + 'payload': payload, + 'payload_size': len(payload) + } + + # 解析payload + if len(payload) >= 4: + result['event'] = int.from_bytes(payload[:4], 'big') + + if len(payload) >= 8: + session_id_len = int.from_bytes(payload[4:8], 'big') + if len(payload) >= 8 + session_id_len: + result['session_id'] = payload[8:8+session_id_len].decode() + + if len(payload) >= 12 + session_id_len: + data_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big') + result['data_size'] = data_size + result['data'] = payload[12+session_id_len:12+session_id_len+data_size] + + # 尝试解析JSON数据 + try: + result['json_data'] = json.loads(result['data'].decode('utf-8')) + except: + pass + + return result + + async def process_audio_file(self, input_file: str, output_file: str) -> bool: + """处理音频文件:上传并获得TTS响应""" + print(f"开始处理音频文件: {input_file}") + + try: + # 读取输入音频文件 + with wave.open(input_file, 'rb') as wf: + audio_data = wf.readframes(wf.getnframes()) + print(f"读取音频数据: {len(audio_data)} 字节") + print(f"音频参数: {wf.getframerate()}Hz, {wf.getnchannels()}通道, {wf.getsampwidth()*8}-bit") + + # 发送音频数据 + print("发送音频数据...") + await self.task_request(audio_data) + print("音频数据发送成功") + + # 接收响应序列 + print("开始接收响应...") + audio_chunks = [] + response_count = 0 + max_responses = 20 + + while response_count < max_responses: + try: + response = await asyncio.wait_for(self.ws.recv(), timeout=30.0) + response_count += 1 + + parsed = self.parse_response(response) + if not parsed: + continue + + print(f"响应 {response_count}: message_type={parsed['message_type']}, event={parsed.get('event', 'N/A')}, size={parsed['payload_size']}") + + # 处理不同类型的响应 + if parsed['message_type'] == 11: # SERVER_ACK - 可能包含音频 + if 'data' in parsed and parsed['data_size'] > 0: + audio_chunks.append(parsed['data']) + print(f"收集到音频块: {parsed['data_size']} 字节") + + elif parsed['message_type'] == 9: # SERVER_FULL_RESPONSE + event = parsed.get('event', 0) + + if event == 350: # TTS开始 + print("TTS音频生成开始") + elif event == 359: # TTS结束 + print("TTS音频生成结束") + break + elif event == 451: # ASR结果 + if 'json_data' in parsed and 'results' in parsed['json_data']: + text = parsed['json_data']['results'][0].get('text', '') + print(f"语音识别结果: {text}") + elif event == 550: # TTS音频数据 + if 'data' in parsed and parsed['data_size'] > 0: + # 检查是否是JSON(音频元数据)还是实际音频数据 + try: + json.loads(parsed['data'].decode('utf-8')) + print("收到TTS音频元数据") + except: + # 不是JSON,可能是音频数据 + audio_chunks.append(parsed['data']) + print(f"收集到TTS音频块: {parsed['data_size']} 字节") + + except asyncio.TimeoutError: + print(f"等待响应 {response_count + 1} 超时") + break + except websockets.exceptions.ConnectionClosed: + print("连接已关闭") + break + + print(f"共收到 {response_count} 个响应,收集到 {len(audio_chunks)} 个音频块") + + # 合并音频数据 + if audio_chunks: + combined_audio = b''.join(audio_chunks) + print(f"合并后的音频数据: {len(combined_audio)} 字节") + + # 检查是否是GZIP压缩数据 + try: + decompressed = gzip.decompress(combined_audio) + print(f"解压缩后音频数据: {len(decompressed)} 字节") + audio_to_write = decompressed + except: + print("音频数据不是GZIP压缩格式,直接使用原始数据") + audio_to_write = combined_audio + + # 创建输出WAV文件 + try: + # 豆包返回的音频是32位浮点格式,需要转换为16位整数 + import struct + + # 检查音频数据长度是否是4的倍数(32位浮点) + if len(audio_to_write) % 4 != 0: + print(f"警告:音频数据长度 {len(audio_to_write)} 不是4的倍数,截断到最近的倍数") + audio_to_write = audio_to_write[:len(audio_to_write) // 4 * 4] + + # 将32位浮点转换为16位整数 + float_count = len(audio_to_write) // 4 + int16_data = bytearray(float_count * 2) + + for i in range(float_count): + # 读取32位浮点数(小端序) + float_value = struct.unpack(' {len(int16_data)//2} 个16位整数样本") + + # 显示文件信息 + if os.path.exists(output_file): + file_size = os.path.getsize(output_file) + duration = file_size / (24000 * 1 * 2) + print(f"输出文件大小: {file_size} 字节,时长: {duration:.2f} 秒") + + return True + + except Exception as e: + print(f"创建WAV文件失败: {e}") + # 保存原始数据 + with open(output_file + '.raw', 'wb') as f: + f.write(audio_to_write) + print(f"原始音频数据已保存到: {output_file}.raw") + return False + else: + print("未收到音频数据") + return False + + except Exception as e: + print(f"处理音频文件失败: {e}") + import traceback + traceback.print_exc() + return False + + async def close(self) -> None: + """关闭连接""" + if self.ws: + try: + await self.ws.close() + except: + pass + print("连接已关闭") + + +async def main(): + """主函数""" + client = DoubaoClient() + + try: + await client.connect() + + # 处理录音文件 + input_file = "recording_20250920_135137.wav" + output_file = "tts_output.wav" + + success = await client.process_audio_file(input_file, output_file) + + if success: + print("音频处理成功!") + else: + print("音频处理失败") + + except Exception as e: + print(f"程序失败: {e}") + import traceback + traceback.print_exc() + finally: + await client.close() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/doubao_test.py b/doubao_test.py new file mode 100644 index 0000000..c5a94d8 --- /dev/null +++ b/doubao_test.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +豆包音频处理模块 - 协议测试版本 +专门测试协议格式问题 +""" + +import asyncio +import gzip +import json +import uuid +import wave +import struct +import time +import os +from typing import Dict, Any, Optional +import websockets + + +class DoubaoConfig: + """豆包配置""" + + def __init__(self): + self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" + self.app_id = "8718217928" + self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc" + self.app_key = "PlgvMymc7f3tQnJ6" + self.resource_id = "volc.speech.dialog" + + def get_headers(self) -> Dict[str, str]: + """获取请求头""" + return { + "X-Api-App-ID": self.app_id, + "X-Api-Access-Key": self.access_key, + "X-Api-Resource-Id": self.resource_id, + "X-Api-App-Key": self.app_key, + "X-Api-Connect-Id": str(uuid.uuid4()), + } + + +class DoubaoProtocol: + """豆包协议处理""" + + # 协议常量 + PROTOCOL_VERSION = 0b0001 + CLIENT_FULL_REQUEST = 0b0001 + CLIENT_AUDIO_ONLY_REQUEST = 0b0010 + SERVER_FULL_RESPONSE = 0b1001 + SERVER_ACK = 0b1011 + SERVER_ERROR_RESPONSE = 0b1111 + + NO_SEQUENCE = 0b0000 + POS_SEQUENCE = 0b0001 + MSG_WITH_EVENT = 0b0100 + JSON = 0b0001 + NO_SERIALIZATION = 0b0000 + GZIP = 0b0001 + + @classmethod + def generate_header(cls, message_type=CLIENT_FULL_REQUEST, + message_type_specific_flags=MSG_WITH_EVENT, + serial_method=JSON, compression_type=GZIP) -> bytes: + """生成协议头""" + header = bytearray() + header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size + header.append((message_type << 4) | message_type_specific_flags) + header.append((serial_method << 4) | compression_type) + header.append(0x00) # reserved + return bytes(header) + + +class DoubaoClient: + """豆包客户端""" + + def __init__(self, config: DoubaoConfig): + self.config = config + self.session_id = str(uuid.uuid4()) + self.ws = None + self.log_id = "" + + async def connect(self) -> None: + """建立WebSocket连接""" + print(f"连接豆包服务器: {self.config.base_url}") + + try: + self.ws = await websockets.connect( + self.config.base_url, + additional_headers=self.config.get_headers(), + ping_interval=None + ) + + # 获取log_id + if hasattr(self.ws, 'response_headers'): + self.log_id = self.ws.response_headers.get("X-Tt-Logid") + elif hasattr(self.ws, 'headers'): + self.log_id = self.ws.headers.get("X-Tt-Logid") + + print(f"连接成功, log_id: {self.log_id}") + + # 发送StartConnection请求 + await self._send_start_connection() + + # 发送StartSession请求 + await self._send_start_session() + + except Exception as e: + print(f"连接失败: {e}") + raise + + async def _send_start_connection(self) -> None: + """发送StartConnection请求""" + print("发送StartConnection请求...") + request = bytearray(DoubaoProtocol.generate_header()) + request.extend(int(1).to_bytes(4, 'big')) + + payload_bytes = b"{}" + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + response = await self.ws.recv() + print(f"StartConnection响应长度: {len(response)}") + + async def _send_start_session(self) -> None: + """发送StartSession请求""" + print("发送StartSession请求...") + session_config = { + "asr": { + "extra": { + "end_smooth_window_ms": 1500, + }, + }, + "tts": { + "speaker": "zh_female_vv_jupiter_bigtts", + "audio_config": { + "channel": 1, + "format": "pcm", + "sample_rate": 24000 + }, + }, + "dialog": { + "bot_name": "豆包", + "system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。", + "speaking_style": "你的说话风格简洁明了,语速适中,语调自然。", + "location": {"city": "北京"}, + "extra": { + "strict_audit": False, + "audit_response": "支持客户自定义安全审核回复话术。", + "recv_timeout": 30, + "input_mod": "audio", + }, + }, + } + + request = bytearray(DoubaoProtocol.generate_header()) + request.extend(int(100).to_bytes(4, 'big')) + request.extend(len(self.session_id).to_bytes(4, 'big')) + request.extend(self.session_id.encode()) + + payload_bytes = json.dumps(session_config).encode() + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + response = await self.ws.recv() + print(f"StartSession响应长度: {len(response)}") + + # 等待一会确保会话完全建立 + await asyncio.sleep(1.0) + + async def test_audio_request(self) -> None: + """测试音频请求格式""" + print("测试音频请求格式...") + + # 创建音频数据(静音)- 使用原始豆包代码的chunk大小 + small_audio = b'\x00' * 3200 # 原始豆包代码中的chunk大小 + + # 完全按照原始豆包代码的格式构建请求,不进行任何填充 + header = bytearray() + header.append((DoubaoProtocol.PROTOCOL_VERSION << 4) | 1) # version + header_size + header.append((DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST << 4) | DoubaoProtocol.NO_SEQUENCE) + header.append((DoubaoProtocol.NO_SERIALIZATION << 4) | DoubaoProtocol.GZIP) + header.append(0x00) # reserved + + request = bytearray(header) + + # 添加消息类型 (200 = task request) + request.extend(int(200).to_bytes(4, 'big')) + + # 添加session_id + request.extend(len(self.session_id).to_bytes(4, 'big')) + request.extend(self.session_id.encode()) + + # 压缩音频数据 + compressed_audio = gzip.compress(small_audio) + + # 添加payload size + request.extend(len(compressed_audio).to_bytes(4, 'big')) + + # 添加压缩后的音频数据 + request.extend(compressed_audio) + + print(f"测试请求详细信息:") + print(f" - 音频原始大小: {len(small_audio)}") + print(f" - 音频压缩后大小: {len(compressed_audio)}") + print(f" - Session ID: {self.session_id} (长度: {len(self.session_id)})") + print(f" - 总请求大小: {len(request)}") + print(f" - 头部字节: {request[:4].hex()}") + print(f" - 消息类型: {int.from_bytes(request[4:8], 'big')}") + print(f" - Session ID长度: {int.from_bytes(request[8:12], 'big')}") + print(f" - Payload size: {int.from_bytes(request[12+len(self.session_id):16+len(self.session_id)], 'big')}") + + try: + await self.ws.send(request) + print("请求发送成功") + + # 等待响应 + response = await asyncio.wait_for(self.ws.recv(), timeout=3.0) + print(f"收到响应,长度: {len(response)}") + + # 尝试解析响应 + try: + protocol_version = response[0] >> 4 + header_size = response[0] & 0x0f + message_type = response[1] >> 4 + message_type_specific_flags = response[1] & 0x0f + serialization_method = response[2] >> 4 + message_compression = response[2] & 0x0f + + print(f"响应协议信息:") + print(f" - version={protocol_version}") + print(f" - header_size={header_size}") + print(f" - message_type={message_type} (15=SERVER_ERROR_RESPONSE)") + print(f" - message_type_specific_flags={message_type_specific_flags}") + print(f" - serialization_method={serialization_method}") + print(f" - message_compression={message_compression}") + + # 解析payload + payload = response[header_size * 4:] + if message_type == 15: # SERVER_ERROR_RESPONSE + if len(payload) >= 8: + code = int.from_bytes(payload[:4], "big", signed=False) + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + print(f" - 错误代码: {code}") + print(f" - payload大小: {payload_size}") + + if len(payload) >= 8 + payload_size: + payload_msg = payload[8:8 + payload_size] + print(f" - payload长度: {len(payload_msg)}") + + if message_compression == 1: # GZIP + try: + payload_msg = gzip.decompress(payload_msg) + print(f" - 解压缩后长度: {len(payload_msg)}") + except: + pass + + try: + error_msg = json.loads(payload_msg.decode('utf-8')) + print(f" - 错误信息: {error_msg}") + except: + print(f" - 原始payload: {payload_msg}") + + except Exception as e: + print(f"解析响应失败: {e}") + import traceback + traceback.print_exc() + + except Exception as e: + print(f"发送测试请求失败: {e}") + raise + + async def close(self) -> None: + """关闭连接""" + if self.ws: + try: + await self.ws.close() + except: + pass + print("连接已关闭") + + +async def main(): + """测试函数""" + config = DoubaoConfig() + client = DoubaoClient(config) + + try: + await client.connect() + await client.test_audio_request() + except Exception as e: + print(f"测试失败: {e}") + import traceback + traceback.print_exc() + finally: + await client.close() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/test_doubao.py b/test_doubao.py new file mode 100644 index 0000000..f0d09d0 --- /dev/null +++ b/test_doubao.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +豆包音频处理模块 - 验证脚本 +验证完整的音频处理流程 +""" + +import asyncio +import subprocess +import os +from doubao_simple import DoubaoClient + +async def test_complete_workflow(): + """测试完整的工作流程""" + print("=== 豆包音频处理模块验证 ===") + + # 检查输入文件 + input_file = "recording_20250920_135137.wav" + if not os.path.exists(input_file): + print(f"❌ 输入文件不存在: {input_file}") + return False + + print(f"✅ 输入文件存在: {input_file}") + + # 检查文件信息 + try: + result = subprocess.run(['file', input_file], capture_output=True, text=True) + print(f"📁 输入文件格式: {result.stdout.strip()}") + except: + pass + + # 初始化客户端 + client = DoubaoClient() + + try: + # 连接服务器 + print("🔌 连接豆包服务器...") + await client.connect() + print("✅ 连接成功") + + # 处理音频文件 + output_file = "tts_output.wav" + print(f"🎵 处理音频文件: {input_file} -> {output_file}") + + success = await client.process_audio_file(input_file, output_file) + + if success: + print("✅ 音频处理成功!") + + # 检查输出文件 + if os.path.exists(output_file): + result = subprocess.run(['file', output_file], capture_output=True, text=True) + print(f"📁 输出文件格式: {result.stdout.strip()}") + + # 获取文件大小 + file_size = os.path.getsize(output_file) + print(f"📊 输出文件大小: {file_size:,} 字节") + + # 测试播放 + print("🔊 测试播放输出文件...") + try: + subprocess.run(['aplay', output_file], timeout=10, check=True) + print("✅ 播放成功") + except subprocess.TimeoutExpired: + print("✅ 播放完成(超时是正常的)") + except subprocess.CalledProcessError as e: + print(f"⚠️ 播放出现问题: {e}") + except FileNotFoundError: + print("⚠️ aplay命令不存在,跳过播放测试") + + return True + else: + print("❌ 输出文件未生成") + return False + else: + print("❌ 音频处理失败") + return False + + except Exception as e: + print(f"❌ 测试失败: {e}") + import traceback + traceback.print_exc() + return False + finally: + try: + await client.close() + except: + pass + +def main(): + """主函数""" + print("开始验证豆包音频处理模块...") + + success = asyncio.run(test_complete_workflow()) + + if success: + print("\n🎉 验证完成!豆包音频处理模块工作正常。") + print("\n📋 功能总结:") + print(" ✅ WebSocket连接建立") + print(" ✅ 音频文件上传") + print(" ✅ 语音识别") + print(" ✅ TTS音频生成") + print(" ✅ 音频格式转换(Float32 -> Int16)") + print(" ✅ WAV文件生成") + print(" ✅ 树莓派兼容播放") + else: + print("\n❌ 验证失败,请检查错误信息。") + + return success + +if __name__ == "__main__": + main() \ No newline at end of file