#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 豆包音频处理模块 - 简化测试版本 专门测试完整的音频上传和TTS音频下载流程 """ import asyncio import gzip import json import uuid import wave import struct import time import os from typing import Dict, Any, Optional import websockets # 直接复制原始豆包代码的协议常量 PROTOCOL_VERSION = 0b0001 CLIENT_FULL_REQUEST = 0b0001 CLIENT_AUDIO_ONLY_REQUEST = 0b0010 SERVER_FULL_RESPONSE = 0b1001 SERVER_ACK = 0b1011 SERVER_ERROR_RESPONSE = 0b1111 NO_SEQUENCE = 0b0000 POS_SEQUENCE = 0b0001 MSG_WITH_EVENT = 0b0100 NO_SERIALIZATION = 0b0000 JSON = 0b0001 GZIP = 0b0001 def generate_header( version=PROTOCOL_VERSION, message_type=CLIENT_FULL_REQUEST, message_type_specific_flags=MSG_WITH_EVENT, serial_method=JSON, compression_type=GZIP, reserved_data=0x00, extension_header=bytes() ): """直接复制原始豆包代码的generate_header函数""" header = bytearray() header_size = int(len(extension_header) / 4) + 1 header.append((version << 4) | header_size) header.append((message_type << 4) | message_type_specific_flags) header.append((serial_method << 4) | compression_type) header.append(reserved_data) header.extend(extension_header) return header class DoubaoConfig: """豆包配置""" def __init__(self): self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" self.app_id = "8718217928" self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc" self.app_key = "PlgvMymc7f3tQnJ6" self.resource_id = "volc.speech.dialog" def get_headers(self) -> Dict[str, str]: """获取请求头""" return { "X-Api-App-ID": self.app_id, "X-Api-Access-Key": self.access_key, "X-Api-Resource-Id": self.resource_id, "X-Api-App-Key": self.app_key, "X-Api-Connect-Id": str(uuid.uuid4()), } class DoubaoClient: """豆包客户端 - 基于原始代码""" def __init__(self, config: DoubaoConfig): self.config = config self.session_id = str(uuid.uuid4()) self.ws = None self.log_id = "" async def connect(self) -> None: """建立WebSocket连接""" print(f"连接豆包服务器: {self.config.base_url}") try: self.ws = await websockets.connect( self.config.base_url, additional_headers=self.config.get_headers(), ping_interval=None ) # 获取log_id if hasattr(self.ws, 'response_headers'): self.log_id = self.ws.response_headers.get("X-Tt-Logid") elif hasattr(self.ws, 'headers'): self.log_id = self.ws.headers.get("X-Tt-Logid") print(f"连接成功, log_id: {self.log_id}") # 发送StartConnection请求 await self._send_start_connection() # 发送StartSession请求 await self._send_start_session() except Exception as e: print(f"连接失败: {e}") raise async def _send_start_connection(self) -> None: """发送StartConnection请求""" print("发送StartConnection请求...") request = bytearray(generate_header()) request.extend(int(1).to_bytes(4, 'big')) payload_bytes = b"{}" payload_bytes = gzip.compress(payload_bytes) request.extend(len(payload_bytes).to_bytes(4, 'big')) request.extend(payload_bytes) await self.ws.send(request) response = await self.ws.recv() print(f"StartConnection响应长度: {len(response)}") async def _send_start_session(self) -> None: """发送StartSession请求""" print("发送StartSession请求...") session_config = { "asr": { "extra": { "end_smooth_window_ms": 1500, }, }, "tts": { "speaker": "zh_female_vv_jupiter_bigtts", "audio_config": { "channel": 1, "format": "pcm", "sample_rate": 24000 }, }, "dialog": { "bot_name": "豆包", "system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。", "speaking_style": "你的说话风格简洁明了,语速适中,语调自然。", "location": {"city": "北京"}, "extra": { "strict_audit": False, "audit_response": "支持客户自定义安全审核回复话术。", "recv_timeout": 30, "input_mod": "audio", }, }, } request = bytearray(generate_header()) request.extend(int(100).to_bytes(4, 'big')) request.extend(len(self.session_id).to_bytes(4, 'big')) request.extend(self.session_id.encode()) payload_bytes = json.dumps(session_config).encode() payload_bytes = gzip.compress(payload_bytes) request.extend(len(payload_bytes).to_bytes(4, 'big')) request.extend(payload_bytes) await self.ws.send(request) response = await self.ws.recv() print(f"StartSession响应长度: {len(response)}") # 等待一会确保会话完全建立 await asyncio.sleep(1.0) async def task_request(self, audio: bytes) -> None: """直接复制原始豆包代码的task_request方法""" task_request = bytearray( generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST, serial_method=NO_SERIALIZATION)) task_request.extend(int(200).to_bytes(4, 'big')) task_request.extend((len(self.session_id)).to_bytes(4, 'big')) task_request.extend(str.encode(self.session_id)) payload_bytes = gzip.compress(audio) task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) task_request.extend(payload_bytes) await self.ws.send(task_request) async def test_full_dialog(self) -> None: """测试完整对话流程""" print("开始完整对话测试...") # 读取真实的录音文件 try: import wave with wave.open("recording_20250920_135137.wav", 'rb') as wf: # 读取前5秒的音频数据 total_frames = wf.getnframes() frames_to_read = min(total_frames, 80000) # 5秒 audio_data = wf.readframes(frames_to_read) print(f"读取真实音频数据: {len(audio_data)} 字节") print(f"音频参数: 采样率={wf.getframerate()}, 通道数={wf.getnchannels()}, 采样宽度={wf.getsampwidth()}") except Exception as e: print(f"读取音频文件失败: {e}") return print(f"音频数据大小: {len(audio_data)}") try: # 发送音频数据 print("发送音频数据...") await self.task_request(audio_data) print("音频数据发送成功") # 等待语音识别响应 print("等待语音识别响应...") response = await asyncio.wait_for(self.ws.recv(), timeout=15.0) print(f"收到ASR响应,长度: {len(response)}") # 解析ASR响应 if len(response) >= 4: protocol_version = response[0] >> 4 header_size = response[0] & 0x0f message_type = response[1] >> 4 flags = response[1] & 0x0f print(f"ASR响应协议: version={protocol_version}, header_size={header_size}, message_type={message_type}, flags={flags}") if message_type == 9: # SERVER_FULL_RESPONSE payload_start = header_size * 4 payload = response[payload_start:] if len(payload) >= 4: event = int.from_bytes(payload[:4], 'big') print(f"ASR Event: {event}") if len(payload) >= 8: session_id_len = int.from_bytes(payload[4:8], 'big') if len(payload) >= 8 + session_id_len: session_id = payload[8:8+session_id_len].decode() print(f"Session ID: {session_id}") if len(payload) >= 12 + session_id_len: payload_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big') payload_data = payload[12+session_id_len:12+session_id_len+payload_size] print(f"Payload size: {payload_size}") # 解析ASR结果 try: asr_result = json.loads(payload_data.decode('utf-8')) print(f"ASR结果: {asr_result}") # 如果有识别结果,提取文本 if 'results' in asr_result and asr_result['results']: text = asr_result['results'][0].get('text', '') print(f"识别文本: {text}") except Exception as e: print(f"解析ASR结果失败: {e}") # 持续等待TTS音频响应 print("开始持续等待TTS音频响应...") response_count = 0 max_responses = 10 while response_count < max_responses: try: print(f"等待第 {response_count + 1} 个响应...") tts_response = await asyncio.wait_for(self.ws.recv(), timeout=30.0) print(f"收到响应 {response_count + 1},长度: {len(tts_response)}") # 解析响应 if len(tts_response) >= 4: tts_version = tts_response[0] >> 4 tts_header_size = tts_response[0] & 0x0f tts_message_type = tts_response[1] >> 4 tts_flags = tts_response[1] & 0x0f print(f"响应协议: version={tts_version}, header_size={tts_header_size}, message_type={tts_message_type}, flags={tts_flags}") if tts_message_type == 11: # SERVER_ACK (包含TTS音频) tts_payload_start = tts_header_size * 4 tts_payload = tts_response[tts_payload_start:] if len(tts_payload) >= 12: tts_event = int.from_bytes(tts_payload[:4], 'big') tts_session_len = int.from_bytes(tts_payload[4:8], 'big') tts_session = tts_payload[8:8+tts_session_len].decode() tts_audio_size = int.from_bytes(tts_payload[8+tts_session_len:12+tts_session_len], 'big') tts_audio_data = tts_payload[12+tts_session_len:12+tts_session_len+tts_audio_size] print(f"Event: {tts_event}") print(f"音频数据大小: {tts_audio_size}") if tts_audio_size > 0: print("找到TTS音频数据!") # 尝试解压缩TTS音频 try: decompressed_tts = gzip.decompress(tts_audio_data) print(f"解压缩后TTS音频大小: {len(decompressed_tts)}") # 创建WAV文件 sample_rate = 24000 channels = 1 sampwidth = 2 with wave.open(f'tts_response_{response_count}.wav', 'wb') as wav_file: wav_file.setnchannels(channels) wav_file.setsampwidth(sampwidth) wav_file.setframerate(sample_rate) wav_file.writeframes(decompressed_tts) print(f"成功创建TTS WAV文件: tts_response_{response_count}.wav") print(f"音频参数: {sample_rate}Hz, {channels}通道, {sampwidth*8}-bit") # 显示文件信息 if os.path.exists(f'tts_response_{response_count}.wav'): file_size = os.path.getsize(f'tts_response_{response_count}.wav') duration = file_size / (sample_rate * channels * sampwidth) print(f"WAV文件大小: {file_size} 字节") print(f"音频时长: {duration:.2f} 秒") # 成功获取音频,退出循环 break except Exception as tts_e: print(f"TTS音频解压缩失败: {tts_e}") # 保存原始数据 with open(f'tts_response_audio_{response_count}.raw', 'wb') as f: f.write(tts_audio_data) print(f"原始TTS音频数据已保存到 tts_response_audio_{response_count}.raw") elif tts_message_type == 9: # SERVER_FULL_RESPONSE tts_payload_start = tts_header_size * 4 tts_payload = tts_response[tts_payload_start:] if len(tts_payload) >= 4: event = int.from_bytes(tts_payload[:4], 'big') print(f"Event: {event}") if event in [451, 359]: # ASR结果或TTS结束 # 解析payload if len(tts_payload) >= 8: session_id_len = int.from_bytes(tts_payload[4:8], 'big') if len(tts_payload) >= 8 + session_id_len: session_id = tts_payload[8:8+session_id_len].decode() if len(tts_payload) >= 12 + session_id_len: payload_size = int.from_bytes(tts_payload[8+session_id_len:12+session_id_len], 'big') payload_data = tts_payload[12+session_id_len:12+session_id_len+payload_size] try: json_data = json.loads(payload_data.decode('utf-8')) print(f"JSON数据: {json_data}") # 如果是ASR结果 if 'results' in json_data: text = json_data['results'][0].get('text', '') print(f"识别文本: {text}") # 如果是TTS结束标记 if event == 359: print("TTS响应结束") break except Exception as e: print(f"解析JSON失败: {e}") # 保存原始数据 with open(f'tts_response_{response_count}.raw', 'wb') as f: f.write(payload_data) # 保存完整响应用于调试 with open(f'tts_response_full_{response_count}.raw', 'wb') as f: f.write(tts_response) print(f"完整响应已保存到 tts_response_full_{response_count}.raw") response_count += 1 except asyncio.TimeoutError: print(f"等待第 {response_count + 1} 个响应超时") break except websockets.exceptions.ConnectionClosed: print("连接已关闭") break print(f"共收到 {response_count} 个响应") except asyncio.TimeoutError: print("等待响应超时") except websockets.exceptions.ConnectionClosed as e: print(f"连接关闭: {e}") except Exception as e: print(f"测试失败: {e}") import traceback traceback.print_exc() async def close(self) -> None: """关闭连接""" if self.ws: try: await self.ws.close() except: pass print("连接已关闭") async def main(): """测试函数""" config = DoubaoConfig() client = DoubaoClient(config) try: await client.connect() await client.test_full_dialog() except Exception as e: print(f"测试失败: {e}") import traceback traceback.print_exc() finally: await client.close() if __name__ == "__main__": asyncio.run(main())