#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 豆包音频处理模块 简化版WebSocket API,支持音频文件上传和返回音频处理 """ import asyncio import gzip import json import uuid import wave import struct import time import os from typing import Dict, Any, Optional import websockets class DoubaoConfig: """豆包配置""" def __init__(self): self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" self.app_id = "8718217928" self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc" self.app_key = "PlgvMymc7f3tQnJ6" self.resource_id = "volc.speech.dialog" def get_headers(self) -> Dict[str, str]: """获取请求头""" return { "X-Api-App-ID": self.app_id, "X-Api-Access-Key": self.access_key, "X-Api-Resource-Id": self.resource_id, "X-Api-App-Key": self.app_key, "X-Api-Connect-Id": str(uuid.uuid4()), } class DoubaoProtocol: """豆包协议处理""" # 协议常量 PROTOCOL_VERSION = 0b0001 CLIENT_FULL_REQUEST = 0b0001 CLIENT_AUDIO_ONLY_REQUEST = 0b0010 SERVER_FULL_RESPONSE = 0b1001 SERVER_ACK = 0b1011 SERVER_ERROR_RESPONSE = 0b1111 NO_SEQUENCE = 0b0000 POS_SEQUENCE = 0b0001 MSG_WITH_EVENT = 0b0100 JSON = 0b0001 NO_SERIALIZATION = 0b0000 GZIP = 0b0001 @classmethod def generate_header(cls, message_type=CLIENT_FULL_REQUEST, message_type_specific_flags=MSG_WITH_EVENT, serial_method=JSON, compression_type=GZIP) -> bytes: """生成协议头""" header = bytearray() header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size header.append((message_type << 4) | message_type_specific_flags) header.append((serial_method << 4) | compression_type) header.append(0x00) # reserved return bytes(header) @classmethod def parse_response(cls, res: bytes) -> Dict[str, Any]: """解析响应""" if isinstance(res, str): return {} protocol_version = res[0] >> 4 header_size = res[0] & 0x0f message_type = res[1] >> 4 message_type_specific_flags = res[1] & 0x0f serialization_method = res[2] >> 4 message_compression = res[2] & 0x0f payload = res[header_size * 4:] result = {} if message_type == cls.SERVER_FULL_RESPONSE or message_type == cls.SERVER_ACK: result['message_type'] = 'SERVER_FULL_RESPONSE' if message_type == cls.SERVER_ACK: result['message_type'] = 'SERVER_ACK' start = 0 if message_type_specific_flags & cls.MSG_WITH_EVENT: result['event'] = int.from_bytes(payload[:4], "big", signed=False) start += 4 payload = payload[start:] session_id_size = int.from_bytes(payload[:4], "big", signed=True) session_id = payload[4:session_id_size+4] result['session_id'] = str(session_id) payload = payload[4 + session_id_size:] payload_size = int.from_bytes(payload[:4], "big", signed=False) payload_msg = payload[4:] result['payload_size'] = payload_size if payload_msg: if message_compression == cls.GZIP: payload_msg = gzip.decompress(payload_msg) if serialization_method == cls.JSON: payload_msg = json.loads(str(payload_msg, "utf-8")) result['payload_msg'] = payload_msg elif message_type == cls.SERVER_ERROR_RESPONSE: code = int.from_bytes(payload[:4], "big", signed=False) result['code'] = code payload_size = int.from_bytes(payload[4:8], "big", signed=False) result['payload_size'] = payload_size return result class AudioProcessor: """音频处理器""" @staticmethod def read_wav_file(file_path: str) -> tuple: """读取WAV文件,返回音频数据和参数""" with wave.open(file_path, 'rb') as wf: # 获取音频参数 channels = wf.getnchannels() sampwidth = wf.getsampwidth() framerate = wf.getframerate() nframes = wf.getnframes() # 读取音频数据 audio_data = wf.readframes(nframes) return audio_data, { 'channels': channels, 'sampwidth': sampwidth, 'framerate': framerate, 'nframes': nframes } @staticmethod def create_wav_file(audio_data: bytes, output_path: str, sample_rate: int = 24000, channels: int = 1, sampwidth: int = 2) -> None: """创建WAV文件,适配树莓派播放""" with wave.open(output_path, 'wb') as wf: wf.setnchannels(channels) wf.setsampwidth(sampwidth) wf.setframerate(sample_rate) wf.writeframes(audio_data) class DoubaoClient: """豆包客户端""" def __init__(self, config: DoubaoConfig): self.config = config self.session_id = str(uuid.uuid4()) self.ws = None self.log_id = "" async def connect(self) -> None: """建立WebSocket连接""" print(f"连接豆包服务器: {self.config.base_url}") self.ws = await websockets.connect( self.config.base_url, additional_headers=self.config.get_headers(), ping_interval=None ) # 获取log_id if hasattr(self.ws, 'response_headers'): self.log_id = self.ws.response_headers.get("X-Tt-Logid") elif hasattr(self.ws, 'headers'): self.log_id = self.ws.headers.get("X-Tt-Logid") print(f"连接成功, log_id: {self.log_id}") # 发送StartConnection请求 await self._send_start_connection() # 发送StartSession请求 await self._send_start_session() async def _send_start_connection(self) -> None: """发送StartConnection请求""" request = bytearray(DoubaoProtocol.generate_header()) request.extend(int(1).to_bytes(4, 'big')) payload_bytes = b"{}" payload_bytes = gzip.compress(payload_bytes) request.extend(len(payload_bytes).to_bytes(4, 'big')) request.extend(payload_bytes) await self.ws.send(request) response = await self.ws.recv() parsed_response = DoubaoProtocol.parse_response(response) print(f"StartConnection响应: {parsed_response}") async def _send_start_session(self) -> None: """发送StartSession请求""" session_config = { "asr": { "extra": { "end_smooth_window_ms": 1500, }, }, "tts": { "speaker": "zh_female_vv_jupiter_bigtts", "audio_config": { "channel": 1, "format": "pcm", "sample_rate": 24000 }, }, "dialog": { "bot_name": "豆包", "system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。", "speaking_style": "你的说话风格简洁明了,语速适中,语调自然。", "location": {"city": "北京"}, "extra": { "strict_audit": False, "audit_response": "支持客户自定义安全审核回复话术。", "recv_timeout": 10, "input_mod": "audio_file", # 使用音频文件模式 }, }, } request = bytearray(DoubaoProtocol.generate_header()) request.extend(int(100).to_bytes(4, 'big')) request.extend(len(self.session_id).to_bytes(4, 'big')) request.extend(self.session_id.encode()) payload_bytes = json.dumps(session_config).encode() payload_bytes = gzip.compress(payload_bytes) request.extend(len(payload_bytes).to_bytes(4, 'big')) request.extend(payload_bytes) await self.ws.send(request) response = await self.ws.recv() parsed_response = DoubaoProtocol.parse_response(response) print(f"StartSession响应: {parsed_response}") async def send_audio_file(self, file_path: str) -> bytes: """发送音频文件并返回响应音频""" print(f"处理音频文件: {file_path}") # 读取音频文件 audio_data, audio_info = AudioProcessor.read_wav_file(file_path) print(f"音频参数: {audio_info}") # 计算分块大小(200ms) chunk_size = int(audio_info['framerate'] * audio_info['channels'] * audio_info['sampwidth'] * 0.2) # 分块发送音频数据 total_chunks = (len(audio_data) + chunk_size - 1) // chunk_size print(f"开始分块发送音频,共 {total_chunks} 块") received_audio = b"" error_count = 0 for i in range(0, len(audio_data), chunk_size): chunk = audio_data[i:i + chunk_size] is_last = (i + chunk_size >= len(audio_data)) # 发送音频块 await self._send_audio_chunk(chunk, is_last) # 接收响应 try: response = await asyncio.wait_for(self.ws.recv(), timeout=2.0) parsed_response = DoubaoProtocol.parse_response(response) # 检查是否是错误响应 if 'code' in parsed_response and parsed_response['code'] != 0: print(f"服务器返回错误: {parsed_response}") error_count += 1 if error_count > 3: raise Exception(f"服务器连续返回错误: {parsed_response}") continue # 处理音频响应 if (parsed_response.get('message_type') == 'SERVER_ACK' and isinstance(parsed_response.get('payload_msg'), bytes)): audio_chunk = parsed_response['payload_msg'] received_audio += audio_chunk print(f"接收到音频数据块,大小: {len(audio_chunk)} 字节") # 检查会话状态 event = parsed_response.get('event') if event in [359, 152, 153]: # 这些事件表示会话相关状态 print(f"会话事件: {event}") if event in [152, 153]: # 会话结束 print("检测到会话结束事件") break except asyncio.TimeoutError: print("等待响应超时,继续发送") # 模拟实时发送的延迟 await asyncio.sleep(0.05) print("音频文件发送完成") return received_audio async def _send_audio_chunk(self, audio_data: bytes, is_last: bool = False) -> None: """发送音频块""" request = bytearray( DoubaoProtocol.generate_header( message_type=DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST, message_type_specific_flags=DoubaoProtocol.NO_SEQUENCE, serial_method=DoubaoProtocol.NO_SERIALIZATION, # 音频数据不需要序列化 compression_type=DoubaoProtocol.GZIP ) ) request.extend(int(200).to_bytes(4, 'big')) request.extend(len(self.session_id).to_bytes(4, 'big')) request.extend(self.session_id.encode()) # 压缩音频数据 compressed_audio = gzip.compress(audio_data) request.extend(len(compressed_audio).to_bytes(4, 'big')) # payload size(4 bytes) request.extend(compressed_audio) await self.ws.send(request) async def close(self) -> None: """关闭连接""" if self.ws: try: # 发送FinishSession await self._send_finish_session() # 发送FinishConnection await self._send_finish_connection() except Exception as e: print(f"关闭会话时出错: {e}") finally: # 确保WebSocket连接关闭 try: await self.ws.close() except: pass print("连接已关闭") async def _send_finish_session(self) -> None: """发送FinishSession请求""" request = bytearray(DoubaoProtocol.generate_header()) request.extend(int(102).to_bytes(4, 'big')) request.extend(len(self.session_id).to_bytes(4, 'big')) request.extend(self.session_id.encode()) payload_bytes = b"{}" payload_bytes = gzip.compress(payload_bytes) request.extend(len(payload_bytes).to_bytes(4, 'big')) request.extend(payload_bytes) await self.ws.send(request) async def _send_finish_connection(self) -> None: """发送FinishConnection请求""" request = bytearray(DoubaoProtocol.generate_header()) request.extend(int(2).to_bytes(4, 'big')) payload_bytes = b"{}" payload_bytes = gzip.compress(payload_bytes) request.extend(len(payload_bytes).to_bytes(4, 'big')) request.extend(payload_bytes) await self.ws.send(request) response = await self.ws.recv() parsed_response = DoubaoProtocol.parse_response(response) print(f"FinishConnection响应: {parsed_response}") class DoubaoProcessor: """豆包音频处理器""" def __init__(self): self.config = DoubaoConfig() self.client = DoubaoClient(self.config) async def process_audio_file(self, input_file: str, output_file: str = None) -> str: """处理音频文件 Args: input_file: 输入音频文件路径 output_file: 输出音频文件路径,如果为None则自动生成 Returns: 输出音频文件路径 """ if not os.path.exists(input_file): raise FileNotFoundError(f"音频文件不存在: {input_file}") # 生成输出文件名 if output_file is None: timestamp = time.strftime("%Y%m%d_%H%M%S") output_file = f"doubao_response_{timestamp}.wav" try: # 连接豆包服务器 await self.client.connect() # 等待一会确保会话建立 await asyncio.sleep(0.5) # 发送音频文件并获取响应 received_audio = await self.client.send_audio_file(input_file) if received_audio: print(f"总共接收到音频数据: {len(received_audio)} 字节") # 转换为WAV格式保存(适配树莓派播放) AudioProcessor.create_wav_file( received_audio, output_file, sample_rate=24000, # 豆包返回的音频采样率 channels=1, sampwidth=2 # 16-bit ) print(f"响应音频已保存到: {output_file}") # 显示文件信息 file_size = os.path.getsize(output_file) print(f"输出文件大小: {file_size} 字节") else: print("警告: 未接收到音频响应") return output_file except Exception as e: print(f"处理音频文件时出错: {e}") import traceback traceback.print_exc() raise finally: await self.client.close() async def main(): """测试函数""" import argparse parser = argparse.ArgumentParser(description="豆包音频处理测试") parser.add_argument("--input", type=str, required=True, help="输入音频文件路径") parser.add_argument("--output", type=str, help="输出音频文件路径") args = parser.parse_args() processor = DoubaoProcessor() try: output_file = await processor.process_audio_file(args.input, args.output) print(f"处理完成,输出文件: {output_file}") except Exception as e: print(f"处理失败: {e}") if __name__ == "__main__": asyncio.run(main())