From dbdeeeefcb7bbb0d51cd16854fc9492748bf283b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Sat, 20 Sep 2025 15:44:46 +0800 Subject: [PATCH] config --- doubao.py | 470 ------------------------- doubao_debug.py | 540 ----------------------------- doubao_final_test.py | 425 ----------------------- doubao_original_test.py | 412 ---------------------- doubao_test.py | 303 ---------------- recognition_example.py | 127 ------- sauc_python/readme.md | 15 - sauc_python/sauc_websocket_demo.py | 523 ---------------------------- speech_recognizer.py | 532 ---------------------------- test_doubao.py | 113 ------ 10 files changed, 3460 deletions(-) delete mode 100644 doubao.py delete mode 100644 doubao_debug.py delete mode 100644 doubao_final_test.py delete mode 100644 doubao_original_test.py delete mode 100644 doubao_test.py delete mode 100644 recognition_example.py delete mode 100644 sauc_python/readme.md delete mode 100644 sauc_python/sauc_websocket_demo.py delete mode 100644 speech_recognizer.py delete mode 100644 test_doubao.py diff --git a/doubao.py b/doubao.py deleted file mode 100644 index 6b22fdc..0000000 --- a/doubao.py +++ /dev/null @@ -1,470 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -豆包音频处理模块 -简化版WebSocket API,支持音频文件上传和返回音频处理 -""" - -import asyncio -import gzip -import json -import uuid -import wave -import struct -import time -import os -from typing import Dict, Any, Optional -import websockets - - -class DoubaoConfig: - """豆包配置""" - - def __init__(self): - self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" - self.app_id = "8718217928" - self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc" - self.app_key = "PlgvMymc7f3tQnJ6" - self.resource_id = "volc.speech.dialog" - - def get_headers(self) -> Dict[str, str]: - """获取请求头""" - return { - "X-Api-App-ID": self.app_id, - "X-Api-Access-Key": self.access_key, - "X-Api-Resource-Id": self.resource_id, - "X-Api-App-Key": self.app_key, - "X-Api-Connect-Id": str(uuid.uuid4()), - } - - -class DoubaoProtocol: - """豆包协议处理""" - - # 协议常量 - PROTOCOL_VERSION = 0b0001 - CLIENT_FULL_REQUEST = 0b0001 - CLIENT_AUDIO_ONLY_REQUEST = 0b0010 - SERVER_FULL_RESPONSE = 0b1001 - SERVER_ACK = 0b1011 - SERVER_ERROR_RESPONSE = 0b1111 - - NO_SEQUENCE = 0b0000 - POS_SEQUENCE = 0b0001 - MSG_WITH_EVENT = 0b0100 - JSON = 0b0001 - NO_SERIALIZATION = 0b0000 - GZIP = 0b0001 - - @classmethod - def generate_header(cls, message_type=CLIENT_FULL_REQUEST, - message_type_specific_flags=MSG_WITH_EVENT, - serial_method=JSON, compression_type=GZIP) -> bytes: - """生成协议头""" - header = bytearray() - header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size - header.append((message_type << 4) | message_type_specific_flags) - header.append((serial_method << 4) | compression_type) - header.append(0x00) # reserved - return bytes(header) - - @classmethod - def parse_response(cls, res: bytes) -> Dict[str, Any]: - """解析响应""" - if isinstance(res, str): - return {} - - protocol_version = res[0] >> 4 - header_size = res[0] & 0x0f - message_type = res[1] >> 4 - message_type_specific_flags = res[1] & 0x0f - serialization_method = res[2] >> 4 - message_compression = res[2] & 0x0f - - payload = res[header_size * 4:] - result = {} - - if message_type == cls.SERVER_FULL_RESPONSE or message_type == cls.SERVER_ACK: - result['message_type'] = 'SERVER_FULL_RESPONSE' - if message_type == cls.SERVER_ACK: - result['message_type'] = 'SERVER_ACK' - - start = 0 - if message_type_specific_flags & cls.MSG_WITH_EVENT: - result['event'] = int.from_bytes(payload[:4], "big", signed=False) - start += 4 - - payload = payload[start:] - session_id_size = int.from_bytes(payload[:4], "big", signed=True) - session_id = payload[4:session_id_size+4] - result['session_id'] = str(session_id) - payload = payload[4 + session_id_size:] - - payload_size = int.from_bytes(payload[:4], "big", signed=False) - payload_msg = payload[4:] - result['payload_size'] = payload_size - - if payload_msg: - if message_compression == cls.GZIP: - payload_msg = gzip.decompress(payload_msg) - if serialization_method == cls.JSON: - payload_msg = json.loads(str(payload_msg, "utf-8")) - result['payload_msg'] = payload_msg - - elif message_type == cls.SERVER_ERROR_RESPONSE: - code = int.from_bytes(payload[:4], "big", signed=False) - result['code'] = code - payload_size = int.from_bytes(payload[4:8], "big", signed=False) - result['payload_size'] = payload_size - - return result - - -class AudioProcessor: - """音频处理器""" - - @staticmethod - def read_wav_file(file_path: str) -> tuple: - """读取WAV文件,返回音频数据和参数""" - with wave.open(file_path, 'rb') as wf: - # 获取音频参数 - channels = wf.getnchannels() - sampwidth = wf.getsampwidth() - framerate = wf.getframerate() - nframes = wf.getnframes() - - # 读取音频数据 - audio_data = wf.readframes(nframes) - - return audio_data, { - 'channels': channels, - 'sampwidth': sampwidth, - 'framerate': framerate, - 'nframes': nframes - } - - @staticmethod - def create_wav_file(audio_data: bytes, output_path: str, - sample_rate: int = 24000, channels: int = 1, - sampwidth: int = 2) -> None: - """创建WAV文件,适配树莓派播放""" - with wave.open(output_path, 'wb') as wf: - wf.setnchannels(channels) - wf.setsampwidth(sampwidth) - wf.setframerate(sample_rate) - wf.writeframes(audio_data) - - -class DoubaoClient: - """豆包客户端""" - - def __init__(self, config: DoubaoConfig): - self.config = config - self.session_id = str(uuid.uuid4()) - self.ws = None - self.log_id = "" - - async def connect(self) -> None: - """建立WebSocket连接""" - print(f"连接豆包服务器: {self.config.base_url}") - - self.ws = await websockets.connect( - self.config.base_url, - additional_headers=self.config.get_headers(), - ping_interval=None - ) - - # 获取log_id - if hasattr(self.ws, 'response_headers'): - self.log_id = self.ws.response_headers.get("X-Tt-Logid") - elif hasattr(self.ws, 'headers'): - self.log_id = self.ws.headers.get("X-Tt-Logid") - - print(f"连接成功, log_id: {self.log_id}") - - # 发送StartConnection请求 - await self._send_start_connection() - - # 发送StartSession请求 - await self._send_start_session() - - async def _send_start_connection(self) -> None: - """发送StartConnection请求""" - request = bytearray(DoubaoProtocol.generate_header()) - request.extend(int(1).to_bytes(4, 'big')) - - payload_bytes = b"{}" - payload_bytes = gzip.compress(payload_bytes) - request.extend(len(payload_bytes).to_bytes(4, 'big')) - request.extend(payload_bytes) - - await self.ws.send(request) - response = await self.ws.recv() - parsed_response = DoubaoProtocol.parse_response(response) - print(f"StartConnection响应: {parsed_response}") - - async def _send_start_session(self) -> None: - """发送StartSession请求""" - session_config = { - "asr": { - "extra": { - "end_smooth_window_ms": 1500, - }, - }, - "tts": { - "speaker": "zh_female_vv_jupiter_bigtts", - "audio_config": { - "channel": 1, - "format": "pcm", - "sample_rate": 24000 - }, - }, - "dialog": { - "bot_name": "豆包", - "system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。", - "speaking_style": "你的说话风格简洁明了,语速适中,语调自然。", - "location": {"city": "北京"}, - "extra": { - "strict_audit": False, - "audit_response": "支持客户自定义安全审核回复话术。", - "recv_timeout": 10, - "input_mod": "audio_file", # 使用音频文件模式 - }, - }, - } - - request = bytearray(DoubaoProtocol.generate_header()) - request.extend(int(100).to_bytes(4, 'big')) - request.extend(len(self.session_id).to_bytes(4, 'big')) - request.extend(self.session_id.encode()) - - payload_bytes = json.dumps(session_config).encode() - payload_bytes = gzip.compress(payload_bytes) - request.extend(len(payload_bytes).to_bytes(4, 'big')) - request.extend(payload_bytes) - - await self.ws.send(request) - response = await self.ws.recv() - parsed_response = DoubaoProtocol.parse_response(response) - print(f"StartSession响应: {parsed_response}") - - async def send_audio_file(self, file_path: str) -> bytes: - """发送音频文件并返回响应音频""" - print(f"处理音频文件: {file_path}") - - # 读取音频文件 - audio_data, audio_info = AudioProcessor.read_wav_file(file_path) - print(f"音频参数: {audio_info}") - - # 计算分块大小(200ms) - chunk_size = int(audio_info['framerate'] * audio_info['channels'] * - audio_info['sampwidth'] * 0.2) - - # 分块发送音频数据 - total_chunks = (len(audio_data) + chunk_size - 1) // chunk_size - print(f"开始分块发送音频,共 {total_chunks} 块") - - received_audio = b"" - error_count = 0 - - for i in range(0, len(audio_data), chunk_size): - chunk = audio_data[i:i + chunk_size] - is_last = (i + chunk_size >= len(audio_data)) - - # 发送音频块 - await self._send_audio_chunk(chunk, is_last) - - # 接收响应 - try: - response = await asyncio.wait_for(self.ws.recv(), timeout=2.0) - parsed_response = DoubaoProtocol.parse_response(response) - - # 检查是否是错误响应 - if 'code' in parsed_response and parsed_response['code'] != 0: - print(f"服务器返回错误: {parsed_response}") - error_count += 1 - if error_count > 3: - raise Exception(f"服务器连续返回错误: {parsed_response}") - continue - - # 处理音频响应 - if (parsed_response.get('message_type') == 'SERVER_ACK' and - isinstance(parsed_response.get('payload_msg'), bytes)): - audio_chunk = parsed_response['payload_msg'] - received_audio += audio_chunk - print(f"接收到音频数据块,大小: {len(audio_chunk)} 字节") - - # 检查会话状态 - event = parsed_response.get('event') - if event in [359, 152, 153]: # 这些事件表示会话相关状态 - print(f"会话事件: {event}") - if event in [152, 153]: # 会话结束 - print("检测到会话结束事件") - break - - except asyncio.TimeoutError: - print("等待响应超时,继续发送") - - # 模拟实时发送的延迟 - await asyncio.sleep(0.05) - - print("音频文件发送完成") - return received_audio - - async def _send_audio_chunk(self, audio_data: bytes, is_last: bool = False) -> None: - """发送音频块""" - request = bytearray( - DoubaoProtocol.generate_header( - message_type=DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST, - message_type_specific_flags=DoubaoProtocol.NO_SEQUENCE, - serial_method=DoubaoProtocol.NO_SERIALIZATION, # 音频数据不需要序列化 - compression_type=DoubaoProtocol.GZIP - ) - ) - request.extend(int(200).to_bytes(4, 'big')) - request.extend(len(self.session_id).to_bytes(4, 'big')) - request.extend(self.session_id.encode()) - - # 压缩音频数据 - compressed_audio = gzip.compress(audio_data) - request.extend(len(compressed_audio).to_bytes(4, 'big')) # payload size(4 bytes) - request.extend(compressed_audio) - - await self.ws.send(request) - - async def close(self) -> None: - """关闭连接""" - if self.ws: - try: - # 发送FinishSession - await self._send_finish_session() - - # 发送FinishConnection - await self._send_finish_connection() - - except Exception as e: - print(f"关闭会话时出错: {e}") - finally: - # 确保WebSocket连接关闭 - try: - await self.ws.close() - except: - pass - print("连接已关闭") - - async def _send_finish_session(self) -> None: - """发送FinishSession请求""" - request = bytearray(DoubaoProtocol.generate_header()) - request.extend(int(102).to_bytes(4, 'big')) - request.extend(len(self.session_id).to_bytes(4, 'big')) - request.extend(self.session_id.encode()) - - payload_bytes = b"{}" - payload_bytes = gzip.compress(payload_bytes) - request.extend(len(payload_bytes).to_bytes(4, 'big')) - request.extend(payload_bytes) - - await self.ws.send(request) - - async def _send_finish_connection(self) -> None: - """发送FinishConnection请求""" - request = bytearray(DoubaoProtocol.generate_header()) - request.extend(int(2).to_bytes(4, 'big')) - - payload_bytes = b"{}" - payload_bytes = gzip.compress(payload_bytes) - request.extend(len(payload_bytes).to_bytes(4, 'big')) - request.extend(payload_bytes) - - await self.ws.send(request) - response = await self.ws.recv() - parsed_response = DoubaoProtocol.parse_response(response) - print(f"FinishConnection响应: {parsed_response}") - - -class DoubaoProcessor: - """豆包音频处理器""" - - def __init__(self): - self.config = DoubaoConfig() - self.client = DoubaoClient(self.config) - - async def process_audio_file(self, input_file: str, output_file: str = None) -> str: - """处理音频文件 - - Args: - input_file: 输入音频文件路径 - output_file: 输出音频文件路径,如果为None则自动生成 - - Returns: - 输出音频文件路径 - """ - if not os.path.exists(input_file): - raise FileNotFoundError(f"音频文件不存在: {input_file}") - - # 生成输出文件名 - if output_file is None: - timestamp = time.strftime("%Y%m%d_%H%M%S") - output_file = f"doubao_response_{timestamp}.wav" - - try: - # 连接豆包服务器 - await self.client.connect() - - # 等待一会确保会话建立 - await asyncio.sleep(0.5) - - # 发送音频文件并获取响应 - received_audio = await self.client.send_audio_file(input_file) - - if received_audio: - print(f"总共接收到音频数据: {len(received_audio)} 字节") - # 转换为WAV格式保存(适配树莓派播放) - AudioProcessor.create_wav_file( - received_audio, - output_file, - sample_rate=24000, # 豆包返回的音频采样率 - channels=1, - sampwidth=2 # 16-bit - ) - print(f"响应音频已保存到: {output_file}") - - # 显示文件信息 - file_size = os.path.getsize(output_file) - print(f"输出文件大小: {file_size} 字节") - - else: - print("警告: 未接收到音频响应") - - return output_file - - except Exception as e: - print(f"处理音频文件时出错: {e}") - import traceback - traceback.print_exc() - raise - finally: - await self.client.close() - - -async def main(): - """测试函数""" - import argparse - - parser = argparse.ArgumentParser(description="豆包音频处理测试") - parser.add_argument("--input", type=str, required=True, help="输入音频文件路径") - parser.add_argument("--output", type=str, help="输出音频文件路径") - - args = parser.parse_args() - - processor = DoubaoProcessor() - try: - output_file = await processor.process_audio_file(args.input, args.output) - print(f"处理完成,输出文件: {output_file}") - except Exception as e: - print(f"处理失败: {e}") - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/doubao_debug.py b/doubao_debug.py deleted file mode 100644 index 8baeb88..0000000 --- a/doubao_debug.py +++ /dev/null @@ -1,540 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -豆包音频处理模块 - 调试版本 -添加更多调试信息和错误处理 -""" - -import asyncio -import gzip -import json -import uuid -import wave -import struct -import time -import os -from typing import Dict, Any, Optional -import websockets - - -class DoubaoConfig: - """豆包配置""" - - def __init__(self): - self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" - self.app_id = "8718217928" - self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc" - self.app_key = "PlgvMymc7f3tQnJ6" - self.resource_id = "volc.speech.dialog" - - def get_headers(self) -> Dict[str, str]: - """获取请求头""" - return { - "X-Api-App-ID": self.app_id, - "X-Api-Access-Key": self.access_key, - "X-Api-Resource-Id": self.resource_id, - "X-Api-App-Key": self.app_key, - "X-Api-Connect-Id": str(uuid.uuid4()), - } - - -class DoubaoProtocol: - """豆包协议处理""" - - # 协议常量 - PROTOCOL_VERSION = 0b0001 - CLIENT_FULL_REQUEST = 0b0001 - CLIENT_AUDIO_ONLY_REQUEST = 0b0010 - SERVER_FULL_RESPONSE = 0b1001 - SERVER_ACK = 0b1011 - SERVER_ERROR_RESPONSE = 0b1111 - - NO_SEQUENCE = 0b0000 - POS_SEQUENCE = 0b0001 - MSG_WITH_EVENT = 0b0100 - JSON = 0b0001 - NO_SERIALIZATION = 0b0000 - GZIP = 0b0001 - - @classmethod - def generate_header(cls, message_type=CLIENT_FULL_REQUEST, - message_type_specific_flags=MSG_WITH_EVENT, - serial_method=JSON, compression_type=GZIP) -> bytes: - """生成协议头""" - header = bytearray() - header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size - header.append((message_type << 4) | message_type_specific_flags) - header.append((serial_method << 4) | compression_type) - header.append(0x00) # reserved - return bytes(header) - - @classmethod - def parse_response(cls, res: bytes) -> Dict[str, Any]: - """解析响应""" - if isinstance(res, str): - return {} - - try: - protocol_version = res[0] >> 4 - header_size = res[0] & 0x0f - message_type = res[1] >> 4 - message_type_specific_flags = res[1] & 0x0f - serialization_method = res[2] >> 4 - message_compression = res[2] & 0x0f - - payload = res[header_size * 4:] - result = {} - - if message_type == cls.SERVER_FULL_RESPONSE or message_type == cls.SERVER_ACK: - result['message_type'] = 'SERVER_FULL_RESPONSE' - if message_type == cls.SERVER_ACK: - result['message_type'] = 'SERVER_ACK' - - start = 0 - if message_type_specific_flags & cls.MSG_WITH_EVENT: - result['event'] = int.from_bytes(payload[:4], "big", signed=False) - start += 4 - - payload = payload[start:] - if len(payload) < 4: - result['error'] = 'Payload too short for session_id' - return result - - session_id_size = int.from_bytes(payload[:4], "big", signed=True) - if session_id_size < 0 or session_id_size > len(payload) - 4: - result['error'] = f'Invalid session_id size: {session_id_size}' - return result - - session_id = payload[4:session_id_size+4] - result['session_id'] = str(session_id) - payload = payload[4 + session_id_size:] - - if len(payload) < 4: - result['error'] = 'Payload too short for payload_size' - return result - - payload_size = int.from_bytes(payload[:4], "big", signed=False) - result['payload_size'] = payload_size - - if len(payload) >= 4 + payload_size: - payload_msg = payload[4:4 + payload_size] - if payload_msg: - if message_compression == cls.GZIP: - try: - payload_msg = gzip.decompress(payload_msg) - except Exception as e: - result['decompress_error'] = str(e) - return result - - if serialization_method == cls.JSON: - try: - payload_msg = json.loads(str(payload_msg, "utf-8")) - except Exception as e: - result['json_error'] = str(e) - payload_msg = str(payload_msg, "utf-8") - elif serialization_method != cls.NO_SERIALIZATION: - payload_msg = str(payload_msg, "utf-8") - result['payload_msg'] = payload_msg - - elif message_type == cls.SERVER_ERROR_RESPONSE: - if len(payload) >= 8: - code = int.from_bytes(payload[:4], "big", signed=False) - result['code'] = code - payload_size = int.from_bytes(payload[4:8], "big", signed=False) - result['payload_size'] = payload_size - - if len(payload) >= 8 + payload_size: - payload_msg = payload[8:8 + payload_size] - if payload_msg and message_compression == cls.GZIP: - try: - payload_msg = gzip.decompress(payload_msg) - except: - pass - result['payload_msg'] = payload_msg - - except Exception as e: - result['parse_error'] = str(e) - - return result - - -class AudioProcessor: - """音频处理器""" - - @staticmethod - def read_wav_file(file_path: str) -> tuple: - """读取WAV文件,返回音频数据和参数""" - with wave.open(file_path, 'rb') as wf: - # 获取音频参数 - channels = wf.getnchannels() - sampwidth = wf.getsampwidth() - framerate = wf.getframerate() - nframes = wf.getnframes() - - # 读取音频数据 - audio_data = wf.readframes(nframes) - - return audio_data, { - 'channels': channels, - 'sampwidth': sampwidth, - 'framerate': framerate, - 'nframes': nframes - } - - @staticmethod - def create_wav_file(audio_data: bytes, output_path: str, - sample_rate: int = 24000, channels: int = 1, - sampwidth: int = 2) -> None: - """创建WAV文件,适配树莓派播放""" - with wave.open(output_path, 'wb') as wf: - wf.setnchannels(channels) - wf.setsampwidth(sampwidth) - wf.setframerate(sample_rate) - wf.writeframes(audio_data) - - -class DoubaoClient: - """豆包客户端""" - - def __init__(self, config: DoubaoConfig): - self.config = config - self.session_id = str(uuid.uuid4()) - self.ws = None - self.log_id = "" - - async def connect(self) -> None: - """建立WebSocket连接""" - print(f"连接豆包服务器: {self.config.base_url}") - - try: - self.ws = await websockets.connect( - self.config.base_url, - additional_headers=self.config.get_headers(), - ping_interval=None - ) - - # 获取log_id - if hasattr(self.ws, 'response_headers'): - self.log_id = self.ws.response_headers.get("X-Tt-Logid") - elif hasattr(self.ws, 'headers'): - self.log_id = self.ws.headers.get("X-Tt-Logid") - - print(f"连接成功, log_id: {self.log_id}") - - # 发送StartConnection请求 - await self._send_start_connection() - - # 发送StartSession请求 - await self._send_start_session() - - except Exception as e: - print(f"连接失败: {e}") - raise - - async def _send_start_connection(self) -> None: - """发送StartConnection请求""" - print("发送StartConnection请求...") - request = bytearray(DoubaoProtocol.generate_header()) - request.extend(int(1).to_bytes(4, 'big')) - - payload_bytes = b"{}" - payload_bytes = gzip.compress(payload_bytes) - request.extend(len(payload_bytes).to_bytes(4, 'big')) - request.extend(payload_bytes) - - await self.ws.send(request) - response = await self.ws.recv() - parsed_response = DoubaoProtocol.parse_response(response) - print(f"StartConnection响应: {parsed_response}") - - # 检查是否有错误 - if 'error' in parsed_response: - raise Exception(f"StartConnection解析错误: {parsed_response['error']}") - - async def _send_start_session(self) -> None: - """发送StartSession请求""" - print("发送StartSession请求...") - session_config = { - "asr": { - "extra": { - "end_smooth_window_ms": 1500, - }, - }, - "tts": { - "speaker": "zh_female_vv_jupiter_bigtts", - "audio_config": { - "channel": 1, - "format": "pcm", - "sample_rate": 24000 - }, - }, - "dialog": { - "bot_name": "豆包", - "system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。", - "speaking_style": "你的说话风格简洁明了,语速适中,语调自然。", - "location": {"city": "北京"}, - "extra": { - "strict_audit": False, - "audit_response": "支持客户自定义安全审核回复话术。", - "recv_timeout": 30, # 增加超时时间 - "input_mod": "audio_file", # 使用音频文件模式 - }, - }, - } - - request = bytearray(DoubaoProtocol.generate_header()) - request.extend(int(100).to_bytes(4, 'big')) - request.extend(len(self.session_id).to_bytes(4, 'big')) - request.extend(self.session_id.encode()) - - payload_bytes = json.dumps(session_config).encode() - payload_bytes = gzip.compress(payload_bytes) - request.extend(len(payload_bytes).to_bytes(4, 'big')) - request.extend(payload_bytes) - - await self.ws.send(request) - response = await self.ws.recv() - parsed_response = DoubaoProtocol.parse_response(response) - print(f"StartSession响应: {parsed_response}") - - # 检查是否有错误 - if 'error' in parsed_response: - raise Exception(f"StartSession解析错误: {parsed_response['error']}") - - # 等待一会确保会话完全建立 - await asyncio.sleep(1.0) - - async def send_audio_file(self, file_path: str) -> bytes: - """发送音频文件并返回响应音频""" - print(f"处理音频文件: {file_path}") - - # 读取音频文件 - audio_data, audio_info = AudioProcessor.read_wav_file(file_path) - print(f"音频参数: {audio_info}") - - # 计算分块大小(减小到50ms,避免数据块过大) - chunk_size = int(audio_info['framerate'] * audio_info['channels'] * - audio_info['sampwidth'] * 0.05) # 50ms - - # 分块发送音频数据 - total_chunks = (len(audio_data) + chunk_size - 1) // chunk_size - print(f"开始分块发送音频,共 {total_chunks} 块") - - received_audio = b"" - error_count = 0 - session_active = True - - for i in range(0, len(audio_data), chunk_size): - if not session_active: - print("会话已结束,停止发送") - break - - chunk = audio_data[i:i + chunk_size] - is_last = (i + chunk_size >= len(audio_data)) - - # 发送音频块 - await self._send_audio_chunk(chunk, is_last) - - # 接收响应 - try: - response = await asyncio.wait_for(self.ws.recv(), timeout=3.0) - parsed_response = DoubaoProtocol.parse_response(response) - - print(f"响应 {i//chunk_size + 1}/{total_chunks}: {parsed_response}") - - # 检查是否是错误响应 - if 'code' in parsed_response and parsed_response['code'] != 0: - print(f"服务器返回错误: {parsed_response}") - error_count += 1 - if error_count > 3: - raise Exception(f"服务器连续返回错误: {parsed_response}") - continue - - # 处理音频响应 - if (parsed_response.get('message_type') == 'SERVER_ACK' and - isinstance(parsed_response.get('payload_msg'), bytes)): - audio_chunk = parsed_response['payload_msg'] - received_audio += audio_chunk - print(f"接收到音频数据块,大小: {len(audio_chunk)} 字节") - - # 检查会话状态 - event = parsed_response.get('event') - if event in [359, 152, 153]: # 这些事件表示会话相关状态 - print(f"会话事件: {event}") - if event in [152, 153]: # 会话结束 - print("检测到会话结束事件") - session_active = False - break - - except asyncio.TimeoutError: - print("等待响应超时,继续发送") - - # 模拟实时发送的延迟 - await asyncio.sleep(0.1) - - print("音频文件发送完成") - return received_audio - - async def _send_audio_chunk(self, audio_data: bytes, is_last: bool = False) -> None: - """发送音频块""" - request = bytearray( - DoubaoProtocol.generate_header( - message_type=DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST, - message_type_specific_flags=DoubaoProtocol.NO_SEQUENCE, - serial_method=DoubaoProtocol.NO_SERIALIZATION, # 音频数据不需要序列化 - compression_type=DoubaoProtocol.GZIP - ) - ) - request.extend(int(200).to_bytes(4, 'big')) - request.extend(len(self.session_id).to_bytes(4, 'big')) - request.extend(self.session_id.encode()) - - # 压缩音频数据 - compressed_audio = gzip.compress(audio_data) - payload_size = len(compressed_audio) - request.extend(payload_size.to_bytes(4, 'big')) # payload size(4 bytes) - request.extend(compressed_audio) - - print(f"发送音频块 - 原始大小: {len(audio_data)}, 压缩后大小: {payload_size}, 总请求数据大小: {len(request)}") - - await self.ws.send(request) - - async def close(self) -> None: - """关闭连接""" - if self.ws: - try: - # 发送FinishSession - await self._send_finish_session() - - # 发送FinishConnection - await self._send_finish_connection() - - except Exception as e: - print(f"关闭会话时出错: {e}") - finally: - # 确保WebSocket连接关闭 - try: - await self.ws.close() - except: - pass - print("连接已关闭") - - async def _send_finish_session(self) -> None: - """发送FinishSession请求""" - print("发送FinishSession请求...") - request = bytearray(DoubaoProtocol.generate_header()) - request.extend(int(102).to_bytes(4, 'big')) - request.extend(len(self.session_id).to_bytes(4, 'big')) - request.extend(self.session_id.encode()) - - payload_bytes = b"{}" - payload_bytes = gzip.compress(payload_bytes) - request.extend(len(payload_bytes).to_bytes(4, 'big')) - request.extend(payload_bytes) - - await self.ws.send(request) - - async def _send_finish_connection(self) -> None: - """发送FinishConnection请求""" - print("发送FinishConnection请求...") - request = bytearray(DoubaoProtocol.generate_header()) - request.extend(int(2).to_bytes(4, 'big')) - - payload_bytes = b"{}" - payload_bytes = gzip.compress(payload_bytes) - request.extend(len(payload_bytes).to_bytes(4, 'big')) - request.extend(payload_bytes) - - await self.ws.send(request) - - try: - response = await asyncio.wait_for(self.ws.recv(), timeout=5.0) - parsed_response = DoubaoProtocol.parse_response(response) - print(f"FinishConnection响应: {parsed_response}") - except asyncio.TimeoutError: - print("FinishConnection响应超时") - - -class DoubaoProcessor: - """豆包音频处理器""" - - def __init__(self): - self.config = DoubaoConfig() - self.client = DoubaoClient(self.config) - - async def process_audio_file(self, input_file: str, output_file: str = None) -> str: - """处理音频文件 - - Args: - input_file: 输入音频文件路径 - output_file: 输出音频文件路径,如果为None则自动生成 - - Returns: - 输出音频文件路径 - """ - if not os.path.exists(input_file): - raise FileNotFoundError(f"音频文件不存在: {input_file}") - - # 生成输出文件名 - if output_file is None: - timestamp = time.strftime("%Y%m%d_%H%M%S") - output_file = f"doubao_response_{timestamp}.wav" - - try: - # 连接豆包服务器 - await self.client.connect() - - # 发送音频文件并获取响应 - received_audio = await self.client.send_audio_file(input_file) - - if received_audio: - print(f"总共接收到音频数据: {len(received_audio)} 字节") - # 转换为WAV格式保存(适配树莓派播放) - AudioProcessor.create_wav_file( - received_audio, - output_file, - sample_rate=24000, # 豆包返回的音频采样率 - channels=1, - sampwidth=2 # 16-bit - ) - print(f"响应音频已保存到: {output_file}") - - # 显示文件信息 - file_size = os.path.getsize(output_file) - print(f"输出文件大小: {file_size} 字节") - - else: - print("警告: 未接收到音频响应") - - return output_file - - except Exception as e: - print(f"处理音频文件时出错: {e}") - import traceback - traceback.print_exc() - raise - finally: - await self.client.close() - - -async def main(): - """测试函数""" - import argparse - - parser = argparse.ArgumentParser(description="豆包音频处理测试") - parser.add_argument("--input", type=str, required=True, help="输入音频文件路径") - parser.add_argument("--output", type=str, help="输出音频文件路径") - - args = parser.parse_args() - - processor = DoubaoProcessor() - try: - output_file = await processor.process_audio_file(args.input, args.output) - print(f"处理完成,输出文件: {output_file}") - except Exception as e: - print(f"处理失败: {e}") - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/doubao_final_test.py b/doubao_final_test.py deleted file mode 100644 index b0480f4..0000000 --- a/doubao_final_test.py +++ /dev/null @@ -1,425 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -豆包音频处理模块 - 简化测试版本 -专门测试完整的音频上传和TTS音频下载流程 -""" - -import asyncio -import gzip -import json -import uuid -import wave -import struct -import time -import os -from typing import Dict, Any, Optional -import websockets - - -# 直接复制原始豆包代码的协议常量 -PROTOCOL_VERSION = 0b0001 -CLIENT_FULL_REQUEST = 0b0001 -CLIENT_AUDIO_ONLY_REQUEST = 0b0010 -SERVER_FULL_RESPONSE = 0b1001 -SERVER_ACK = 0b1011 -SERVER_ERROR_RESPONSE = 0b1111 - -NO_SEQUENCE = 0b0000 -POS_SEQUENCE = 0b0001 -MSG_WITH_EVENT = 0b0100 - -NO_SERIALIZATION = 0b0000 -JSON = 0b0001 -GZIP = 0b0001 - - -def generate_header( - version=PROTOCOL_VERSION, - message_type=CLIENT_FULL_REQUEST, - message_type_specific_flags=MSG_WITH_EVENT, - serial_method=JSON, - compression_type=GZIP, - reserved_data=0x00, - extension_header=bytes() -): - """直接复制原始豆包代码的generate_header函数""" - header = bytearray() - header_size = int(len(extension_header) / 4) + 1 - header.append((version << 4) | header_size) - header.append((message_type << 4) | message_type_specific_flags) - header.append((serial_method << 4) | compression_type) - header.append(reserved_data) - header.extend(extension_header) - return header - - -class DoubaoConfig: - """豆包配置""" - - def __init__(self): - self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" - self.app_id = "8718217928" - self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc" - self.app_key = "PlgvMymc7f3tQnJ6" - self.resource_id = "volc.speech.dialog" - - def get_headers(self) -> Dict[str, str]: - """获取请求头""" - return { - "X-Api-App-ID": self.app_id, - "X-Api-Access-Key": self.access_key, - "X-Api-Resource-Id": self.resource_id, - "X-Api-App-Key": self.app_key, - "X-Api-Connect-Id": str(uuid.uuid4()), - } - - -class DoubaoClient: - """豆包客户端 - 基于原始代码""" - - def __init__(self, config: DoubaoConfig): - self.config = config - self.session_id = str(uuid.uuid4()) - self.ws = None - self.log_id = "" - - async def connect(self) -> None: - """建立WebSocket连接""" - print(f"连接豆包服务器: {self.config.base_url}") - - try: - self.ws = await websockets.connect( - self.config.base_url, - additional_headers=self.config.get_headers(), - ping_interval=None - ) - - # 获取log_id - if hasattr(self.ws, 'response_headers'): - self.log_id = self.ws.response_headers.get("X-Tt-Logid") - elif hasattr(self.ws, 'headers'): - self.log_id = self.ws.headers.get("X-Tt-Logid") - - print(f"连接成功, log_id: {self.log_id}") - - # 发送StartConnection请求 - await self._send_start_connection() - - # 发送StartSession请求 - await self._send_start_session() - - except Exception as e: - print(f"连接失败: {e}") - raise - - async def _send_start_connection(self) -> None: - """发送StartConnection请求""" - print("发送StartConnection请求...") - request = bytearray(generate_header()) - request.extend(int(1).to_bytes(4, 'big')) - - payload_bytes = b"{}" - payload_bytes = gzip.compress(payload_bytes) - request.extend(len(payload_bytes).to_bytes(4, 'big')) - request.extend(payload_bytes) - - await self.ws.send(request) - response = await self.ws.recv() - print(f"StartConnection响应长度: {len(response)}") - - async def _send_start_session(self) -> None: - """发送StartSession请求""" - print("发送StartSession请求...") - session_config = { - "asr": { - "extra": { - "end_smooth_window_ms": 1500, - }, - }, - "tts": { - "speaker": "zh_female_vv_jupiter_bigtts", - "audio_config": { - "channel": 1, - "format": "pcm", - "sample_rate": 24000 - }, - }, - "dialog": { - "bot_name": "豆包", - "system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。", - "speaking_style": "你的说话风格简洁明了,语速适中,语调自然。", - "location": {"city": "北京"}, - "extra": { - "strict_audit": False, - "audit_response": "支持客户自定义安全审核回复话术。", - "recv_timeout": 30, - "input_mod": "audio", - }, - }, - } - - request = bytearray(generate_header()) - request.extend(int(100).to_bytes(4, 'big')) - request.extend(len(self.session_id).to_bytes(4, 'big')) - request.extend(self.session_id.encode()) - - payload_bytes = json.dumps(session_config).encode() - payload_bytes = gzip.compress(payload_bytes) - request.extend(len(payload_bytes).to_bytes(4, 'big')) - request.extend(payload_bytes) - - await self.ws.send(request) - response = await self.ws.recv() - print(f"StartSession响应长度: {len(response)}") - - # 等待一会确保会话完全建立 - await asyncio.sleep(1.0) - - async def task_request(self, audio: bytes) -> None: - """直接复制原始豆包代码的task_request方法""" - task_request = bytearray( - generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST, - serial_method=NO_SERIALIZATION)) - task_request.extend(int(200).to_bytes(4, 'big')) - task_request.extend((len(self.session_id)).to_bytes(4, 'big')) - task_request.extend(str.encode(self.session_id)) - payload_bytes = gzip.compress(audio) - task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) - task_request.extend(payload_bytes) - await self.ws.send(task_request) - - async def test_full_dialog(self) -> None: - """测试完整对话流程""" - print("开始完整对话测试...") - - # 读取真实的录音文件 - try: - import wave - with wave.open("recording_20250920_135137.wav", 'rb') as wf: - # 读取前5秒的音频数据 - total_frames = wf.getnframes() - frames_to_read = min(total_frames, 80000) # 5秒 - audio_data = wf.readframes(frames_to_read) - print(f"读取真实音频数据: {len(audio_data)} 字节") - print(f"音频参数: 采样率={wf.getframerate()}, 通道数={wf.getnchannels()}, 采样宽度={wf.getsampwidth()}") - except Exception as e: - print(f"读取音频文件失败: {e}") - return - - print(f"音频数据大小: {len(audio_data)}") - - try: - # 发送音频数据 - print("发送音频数据...") - await self.task_request(audio_data) - print("音频数据发送成功") - - # 等待语音识别响应 - print("等待语音识别响应...") - response = await asyncio.wait_for(self.ws.recv(), timeout=15.0) - print(f"收到ASR响应,长度: {len(response)}") - - # 解析ASR响应 - if len(response) >= 4: - protocol_version = response[0] >> 4 - header_size = response[0] & 0x0f - message_type = response[1] >> 4 - flags = response[1] & 0x0f - print(f"ASR响应协议: version={protocol_version}, header_size={header_size}, message_type={message_type}, flags={flags}") - - if message_type == 9: # SERVER_FULL_RESPONSE - payload_start = header_size * 4 - payload = response[payload_start:] - - if len(payload) >= 4: - event = int.from_bytes(payload[:4], 'big') - print(f"ASR Event: {event}") - - if len(payload) >= 8: - session_id_len = int.from_bytes(payload[4:8], 'big') - if len(payload) >= 8 + session_id_len: - session_id = payload[8:8+session_id_len].decode() - print(f"Session ID: {session_id}") - - if len(payload) >= 12 + session_id_len: - payload_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big') - payload_data = payload[12+session_id_len:12+session_id_len+payload_size] - print(f"Payload size: {payload_size}") - - # 解析ASR结果 - try: - asr_result = json.loads(payload_data.decode('utf-8')) - print(f"ASR结果: {asr_result}") - - # 如果有识别结果,提取文本 - if 'results' in asr_result and asr_result['results']: - text = asr_result['results'][0].get('text', '') - print(f"识别文本: {text}") - - except Exception as e: - print(f"解析ASR结果失败: {e}") - - # 持续等待TTS音频响应 - print("开始持续等待TTS音频响应...") - response_count = 0 - max_responses = 10 - - while response_count < max_responses: - try: - print(f"等待第 {response_count + 1} 个响应...") - tts_response = await asyncio.wait_for(self.ws.recv(), timeout=30.0) - print(f"收到响应 {response_count + 1},长度: {len(tts_response)}") - - # 解析响应 - if len(tts_response) >= 4: - tts_version = tts_response[0] >> 4 - tts_header_size = tts_response[0] & 0x0f - tts_message_type = tts_response[1] >> 4 - tts_flags = tts_response[1] & 0x0f - print(f"响应协议: version={tts_version}, header_size={tts_header_size}, message_type={tts_message_type}, flags={tts_flags}") - - if tts_message_type == 11: # SERVER_ACK (包含TTS音频) - tts_payload_start = tts_header_size * 4 - tts_payload = tts_response[tts_payload_start:] - - if len(tts_payload) >= 12: - tts_event = int.from_bytes(tts_payload[:4], 'big') - tts_session_len = int.from_bytes(tts_payload[4:8], 'big') - tts_session = tts_payload[8:8+tts_session_len].decode() - tts_audio_size = int.from_bytes(tts_payload[8+tts_session_len:12+tts_session_len], 'big') - tts_audio_data = tts_payload[12+tts_session_len:12+tts_session_len+tts_audio_size] - - print(f"Event: {tts_event}") - print(f"音频数据大小: {tts_audio_size}") - - if tts_audio_size > 0: - print("找到TTS音频数据!") - # 尝试解压缩TTS音频 - try: - decompressed_tts = gzip.decompress(tts_audio_data) - print(f"解压缩后TTS音频大小: {len(decompressed_tts)}") - - # 创建WAV文件 - sample_rate = 24000 - channels = 1 - sampwidth = 2 - - with wave.open(f'tts_response_{response_count}.wav', 'wb') as wav_file: - wav_file.setnchannels(channels) - wav_file.setsampwidth(sampwidth) - wav_file.setframerate(sample_rate) - wav_file.writeframes(decompressed_tts) - - print(f"成功创建TTS WAV文件: tts_response_{response_count}.wav") - print(f"音频参数: {sample_rate}Hz, {channels}通道, {sampwidth*8}-bit") - - # 显示文件信息 - if os.path.exists(f'tts_response_{response_count}.wav'): - file_size = os.path.getsize(f'tts_response_{response_count}.wav') - duration = file_size / (sample_rate * channels * sampwidth) - print(f"WAV文件大小: {file_size} 字节") - print(f"音频时长: {duration:.2f} 秒") - - # 成功获取音频,退出循环 - break - - except Exception as tts_e: - print(f"TTS音频解压缩失败: {tts_e}") - # 保存原始数据 - with open(f'tts_response_audio_{response_count}.raw', 'wb') as f: - f.write(tts_audio_data) - print(f"原始TTS音频数据已保存到 tts_response_audio_{response_count}.raw") - - elif tts_message_type == 9: # SERVER_FULL_RESPONSE - tts_payload_start = tts_header_size * 4 - tts_payload = tts_response[tts_payload_start:] - - if len(tts_payload) >= 4: - event = int.from_bytes(tts_payload[:4], 'big') - print(f"Event: {event}") - - if event in [451, 359]: # ASR结果或TTS结束 - # 解析payload - if len(tts_payload) >= 8: - session_id_len = int.from_bytes(tts_payload[4:8], 'big') - if len(tts_payload) >= 8 + session_id_len: - session_id = tts_payload[8:8+session_id_len].decode() - if len(tts_payload) >= 12 + session_id_len: - payload_size = int.from_bytes(tts_payload[8+session_id_len:12+session_id_len], 'big') - payload_data = tts_payload[12+session_id_len:12+session_id_len+payload_size] - - try: - json_data = json.loads(payload_data.decode('utf-8')) - print(f"JSON数据: {json_data}") - - # 如果是ASR结果 - if 'results' in json_data: - text = json_data['results'][0].get('text', '') - print(f"识别文本: {text}") - - # 如果是TTS结束标记 - if event == 359: - print("TTS响应结束") - break - - except Exception as e: - print(f"解析JSON失败: {e}") - # 保存原始数据 - with open(f'tts_response_{response_count}.raw', 'wb') as f: - f.write(payload_data) - - # 保存完整响应用于调试 - with open(f'tts_response_full_{response_count}.raw', 'wb') as f: - f.write(tts_response) - print(f"完整响应已保存到 tts_response_full_{response_count}.raw") - - response_count += 1 - - except asyncio.TimeoutError: - print(f"等待第 {response_count + 1} 个响应超时") - break - except websockets.exceptions.ConnectionClosed: - print("连接已关闭") - break - - print(f"共收到 {response_count} 个响应") - - except asyncio.TimeoutError: - print("等待响应超时") - except websockets.exceptions.ConnectionClosed as e: - print(f"连接关闭: {e}") - except Exception as e: - print(f"测试失败: {e}") - import traceback - traceback.print_exc() - - async def close(self) -> None: - """关闭连接""" - if self.ws: - try: - await self.ws.close() - except: - pass - print("连接已关闭") - - -async def main(): - """测试函数""" - config = DoubaoConfig() - client = DoubaoClient(config) - - try: - await client.connect() - await client.test_full_dialog() - except Exception as e: - print(f"测试失败: {e}") - import traceback - traceback.print_exc() - finally: - await client.close() - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/doubao_original_test.py b/doubao_original_test.py deleted file mode 100644 index 0587ea4..0000000 --- a/doubao_original_test.py +++ /dev/null @@ -1,412 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -豆包音频处理模块 - 基于原始代码的测试版本 -直接使用原始豆包代码的核心逻辑 -""" - -import asyncio -import gzip -import json -import uuid -import wave -import struct -import time -import os -from typing import Dict, Any, Optional -import websockets - - -# 直接复制原始豆包代码的协议常量 -PROTOCOL_VERSION = 0b0001 -CLIENT_FULL_REQUEST = 0b0001 -CLIENT_AUDIO_ONLY_REQUEST = 0b0010 -SERVER_FULL_RESPONSE = 0b1001 -SERVER_ACK = 0b1011 -SERVER_ERROR_RESPONSE = 0b1111 - -NO_SEQUENCE = 0b0000 -POS_SEQUENCE = 0b0001 -MSG_WITH_EVENT = 0b0100 - -NO_SERIALIZATION = 0b0000 -JSON = 0b0001 -GZIP = 0b0001 - - -def generate_header( - version=PROTOCOL_VERSION, - message_type=CLIENT_FULL_REQUEST, - message_type_specific_flags=MSG_WITH_EVENT, - serial_method=JSON, - compression_type=GZIP, - reserved_data=0x00, - extension_header=bytes() -): - """直接复制原始豆包代码的generate_header函数""" - header = bytearray() - header_size = int(len(extension_header) / 4) + 1 - header.append((version << 4) | header_size) - header.append((message_type << 4) | message_type_specific_flags) - header.append((serial_method << 4) | compression_type) - header.append(reserved_data) - header.extend(extension_header) - return header - - -class DoubaoConfig: - """豆包配置""" - - def __init__(self): - self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" - self.app_id = "8718217928" - self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc" - self.app_key = "PlgvMymc7f3tQnJ6" - self.resource_id = "volc.speech.dialog" - - def get_headers(self) -> Dict[str, str]: - """获取请求头""" - return { - "X-Api-App-ID": self.app_id, - "X-Api-Access-Key": self.access_key, - "X-Api-Resource-Id": self.resource_id, - "X-Api-App-Key": self.app_key, - "X-Api-Connect-Id": str(uuid.uuid4()), - } - - -class DoubaoClient: - """豆包客户端 - 基于原始代码""" - - def __init__(self, config: DoubaoConfig): - self.config = config - self.session_id = str(uuid.uuid4()) - self.ws = None - self.log_id = "" - - async def connect(self) -> None: - """建立WebSocket连接""" - print(f"连接豆包服务器: {self.config.base_url}") - - try: - self.ws = await websockets.connect( - self.config.base_url, - additional_headers=self.config.get_headers(), - ping_interval=None - ) - - # 获取log_id - if hasattr(self.ws, 'response_headers'): - self.log_id = self.ws.response_headers.get("X-Tt-Logid") - elif hasattr(self.ws, 'headers'): - self.log_id = self.ws.headers.get("X-Tt-Logid") - - print(f"连接成功, log_id: {self.log_id}") - - # 发送StartConnection请求 - await self._send_start_connection() - - # 发送StartSession请求 - await self._send_start_session() - - except Exception as e: - print(f"连接失败: {e}") - raise - - async def _send_start_connection(self) -> None: - """发送StartConnection请求""" - print("发送StartConnection请求...") - request = bytearray(generate_header()) - request.extend(int(1).to_bytes(4, 'big')) - - payload_bytes = b"{}" - payload_bytes = gzip.compress(payload_bytes) - request.extend(len(payload_bytes).to_bytes(4, 'big')) - request.extend(payload_bytes) - - await self.ws.send(request) - response = await self.ws.recv() - print(f"StartConnection响应长度: {len(response)}") - - async def _send_start_session(self) -> None: - """发送StartSession请求""" - print("发送StartSession请求...") - session_config = { - "asr": { - "extra": { - "end_smooth_window_ms": 1500, - }, - }, - "tts": { - "speaker": "zh_female_vv_jupiter_bigtts", - "audio_config": { - "channel": 1, - "format": "pcm", - "sample_rate": 24000 - }, - }, - "dialog": { - "bot_name": "豆包", - "system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。", - "speaking_style": "你的说话风格简洁明了,语速适中,语调自然。", - "location": {"city": "北京"}, - "extra": { - "strict_audit": False, - "audit_response": "支持客户自定义安全审核回复话术。", - "recv_timeout": 30, - "input_mod": "audio", - }, - }, - } - - request = bytearray(generate_header()) - request.extend(int(100).to_bytes(4, 'big')) - request.extend(len(self.session_id).to_bytes(4, 'big')) - request.extend(self.session_id.encode()) - - payload_bytes = json.dumps(session_config).encode() - payload_bytes = gzip.compress(payload_bytes) - request.extend(len(payload_bytes).to_bytes(4, 'big')) - request.extend(payload_bytes) - - await self.ws.send(request) - response = await self.ws.recv() - print(f"StartSession响应长度: {len(response)}") - - # 等待一会确保会话完全建立 - await asyncio.sleep(1.0) - - async def task_request(self, audio: bytes) -> None: - """直接复制原始豆包代码的task_request方法""" - task_request = bytearray( - generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST, - serial_method=NO_SERIALIZATION)) - task_request.extend(int(200).to_bytes(4, 'big')) - task_request.extend((len(self.session_id)).to_bytes(4, 'big')) - task_request.extend(str.encode(self.session_id)) - payload_bytes = gzip.compress(audio) - task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) - task_request.extend(payload_bytes) - await self.ws.send(task_request) - - async def test_audio_request(self) -> None: - """测试音频请求""" - print("测试音频请求...") - - # 读取真实的录音文件 - try: - import wave - with wave.open("recording_20250920_135137.wav", 'rb') as wf: - # 读取前10秒的音频数据(16000采样率 * 10秒 = 160000帧) - total_frames = wf.getnframes() - frames_to_read = min(total_frames, 160000) # 最多10秒 - small_audio = wf.readframes(frames_to_read) - print(f"读取真实音频数据: {len(small_audio)} 字节") - print(f"音频参数: 采样率={wf.getframerate()}, 通道数={wf.getnchannels()}, 采样宽度={wf.getsampwidth()}") - print(f"总帧数: {total_frames}, 读取帧数: {frames_to_read}") - except Exception as e: - print(f"读取音频文件失败: {e}") - # 如果读取失败,使用静音数据 - small_audio = b'\x00' * 3200 - - print(f"音频数据大小: {len(small_audio)}") - - try: - # 发送完整的音频数据块 - print(f"发送完整的音频数据块...") - await self.task_request(small_audio) - print(f"音频数据块发送成功") - - print("等待语音识别响应...") - - # 等待更长时间的响应(语音识别可能需要更长时间) - response = await asyncio.wait_for(self.ws.recv(), timeout=15.0) - print(f"收到响应,长度: {len(response)}") - - # 解析响应 - try: - if len(response) >= 4: - protocol_version = response[0] >> 4 - header_size = response[0] & 0x0f - message_type = response[1] >> 4 - message_type_specific_flags = response[1] & 0x0f - print(f"响应协议信息: version={protocol_version}, header_size={header_size}, message_type={message_type}, flags={message_type_specific_flags}") - - # 解析payload - payload_start = header_size * 4 - payload = response[payload_start:] - - if message_type == 9: # SERVER_FULL_RESPONSE - print("收到SERVER_FULL_RESPONSE!") - if len(payload) >= 4: - # 解析event - event = int.from_bytes(payload[:4], 'big') - print(f"Event: {event}") - - # 解析session_id - if len(payload) >= 8: - session_id_len = int.from_bytes(payload[4:8], 'big') - if len(payload) >= 8 + session_id_len: - session_id = payload[8:8+session_id_len].decode() - print(f"Session ID: {session_id}") - - # 解析payload size和data - if len(payload) >= 12 + session_id_len: - payload_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big') - payload_data = payload[12+session_id_len:12+session_id_len+payload_size] - print(f"Payload size: {payload_size}") - - # 如果包含音频数据,保存到文件 - if len(payload_data) > 0: - print(f"收到数据: {len(payload_data)} 字节") - # 保存原始音频数据 - with open('response_audio.raw', 'wb') as f: - f.write(payload_data) - print("音频数据已保存到 response_audio.raw") - - # 尝试解析JSON数据 - try: - import json - json_data = json.loads(payload_data.decode('utf-8')) - print(f"JSON数据: {json_data}") - - # 如果是语音识别任务开始,继续等待音频响应 - if 'asr_task_id' in json_data: - print("语音识别任务开始,继续等待音频响应...") - try: - # 等待音频响应 - audio_response = await asyncio.wait_for(self.ws.recv(), timeout=20.0) - print(f"收到音频响应,长度: {len(audio_response)}") - - # 解析音频响应 - if len(audio_response) >= 4: - audio_version = audio_response[0] >> 4 - audio_header_size = audio_response[0] & 0x0f - audio_message_type = audio_response[1] >> 4 - audio_flags = audio_response[1] & 0x0f - print(f"音频响应协议信息: version={audio_version}, header_size={audio_header_size}, message_type={audio_message_type}, flags={audio_flags}") - - if audio_message_type == 9: # SERVER_FULL_RESPONSE (包含TTS音频) - audio_payload_start = audio_header_size * 4 - audio_payload = audio_response[audio_payload_start:] - - if len(audio_payload) >= 12: - # 解析event和session_id - audio_event = int.from_bytes(audio_payload[:4], 'big') - audio_session_len = int.from_bytes(audio_payload[4:8], 'big') - audio_session = audio_payload[8:8+audio_session_len].decode() - audio_data_size = int.from_bytes(audio_payload[8+audio_session_len:12+audio_session_len], 'big') - audio_data = audio_payload[12+audio_session_len:12+audio_session_len+audio_data_size] - - print(f"音频Event: {audio_event}") - print(f"音频数据大小: {audio_data_size}") - - if audio_data_size > 0: - # 保存原始音频数据 - with open('tts_response_audio.raw', 'wb') as f: - f.write(audio_data) - print(f"TTS音频数据已保存到 tts_response_audio.raw") - - # 尝试解析音频数据(可能是JSON或GZIP压缩的音频) - try: - # 首先尝试解压缩 - import gzip - decompressed_audio = gzip.decompress(audio_data) - print(f"解压缩后音频数据大小: {len(decompressed_audio)}") - with open('tts_response_audio_decompressed.raw', 'wb') as f: - f.write(decompressed_audio) - print("解压缩的音频数据已保存") - - # 创建WAV文件供树莓派播放 - import wave - import struct - - # 豆包返回的音频是24000Hz, 16-bit, 单声道 - sample_rate = 24000 - channels = 1 - sampwidth = 2 # 16-bit = 2 bytes - - with wave.open('tts_response.wav', 'wb') as wav_file: - wav_file.setnchannels(channels) - wav_file.setsampwidth(sampwidth) - wav_file.setframerate(sample_rate) - wav_file.writeframes(decompressed_audio) - - print("已创建WAV文件: tts_response.wav") - print(f"音频参数: {sample_rate}Hz, {channels}通道, {sampwidth*8}-bit") - - except Exception as audio_e: - print(f"音频数据处理失败: {audio_e}") - # 如果解压缩失败,直接保存原始数据 - with open('tts_response_audio_original.raw', 'wb') as f: - f.write(audio_data) - - elif audio_message_type == 11: # SERVER_ACK - print("收到SERVER_ACK音频响应") - # 处理SERVER_ACK格式的音频响应 - audio_payload_start = audio_header_size * 4 - audio_payload = audio_response[audio_payload_start:] - print(f"音频payload长度: {len(audio_payload)}") - with open('tts_response_ack.raw', 'wb') as f: - f.write(audio_payload) - - except asyncio.TimeoutError: - print("等待音频响应超时") - - except Exception as json_e: - print(f"解析JSON失败: {json_e}") - # 如果不是JSON,可能是音频数据,直接保存 - with open('response_audio.raw', 'wb') as f: - f.write(payload_data) - - elif message_type == 11: # SERVER_ACK - print("收到SERVER_ACK响应!") - - elif message_type == 15: # SERVER_ERROR_RESPONSE - print("收到错误响应") - if len(response) > 8: - error_code = int.from_bytes(response[4:8], 'big') - print(f"错误代码: {error_code}") - - except Exception as e: - print(f"解析响应失败: {e}") - import traceback - traceback.print_exc() - - except asyncio.TimeoutError: - print("等待响应超时") - except websockets.exceptions.ConnectionClosed as e: - print(f"连接关闭: {e}") - except Exception as e: - print(f"发送音频请求失败: {e}") - raise - - async def close(self) -> None: - """关闭连接""" - if self.ws: - try: - await self.ws.close() - except: - pass - print("连接已关闭") - - -async def main(): - """测试函数""" - config = DoubaoConfig() - client = DoubaoClient(config) - - try: - await client.connect() - await client.test_audio_request() - except Exception as e: - print(f"测试失败: {e}") - import traceback - traceback.print_exc() - finally: - await client.close() - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/doubao_test.py b/doubao_test.py deleted file mode 100644 index c5a94d8..0000000 --- a/doubao_test.py +++ /dev/null @@ -1,303 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -豆包音频处理模块 - 协议测试版本 -专门测试协议格式问题 -""" - -import asyncio -import gzip -import json -import uuid -import wave -import struct -import time -import os -from typing import Dict, Any, Optional -import websockets - - -class DoubaoConfig: - """豆包配置""" - - def __init__(self): - self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" - self.app_id = "8718217928" - self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc" - self.app_key = "PlgvMymc7f3tQnJ6" - self.resource_id = "volc.speech.dialog" - - def get_headers(self) -> Dict[str, str]: - """获取请求头""" - return { - "X-Api-App-ID": self.app_id, - "X-Api-Access-Key": self.access_key, - "X-Api-Resource-Id": self.resource_id, - "X-Api-App-Key": self.app_key, - "X-Api-Connect-Id": str(uuid.uuid4()), - } - - -class DoubaoProtocol: - """豆包协议处理""" - - # 协议常量 - PROTOCOL_VERSION = 0b0001 - CLIENT_FULL_REQUEST = 0b0001 - CLIENT_AUDIO_ONLY_REQUEST = 0b0010 - SERVER_FULL_RESPONSE = 0b1001 - SERVER_ACK = 0b1011 - SERVER_ERROR_RESPONSE = 0b1111 - - NO_SEQUENCE = 0b0000 - POS_SEQUENCE = 0b0001 - MSG_WITH_EVENT = 0b0100 - JSON = 0b0001 - NO_SERIALIZATION = 0b0000 - GZIP = 0b0001 - - @classmethod - def generate_header(cls, message_type=CLIENT_FULL_REQUEST, - message_type_specific_flags=MSG_WITH_EVENT, - serial_method=JSON, compression_type=GZIP) -> bytes: - """生成协议头""" - header = bytearray() - header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size - header.append((message_type << 4) | message_type_specific_flags) - header.append((serial_method << 4) | compression_type) - header.append(0x00) # reserved - return bytes(header) - - -class DoubaoClient: - """豆包客户端""" - - def __init__(self, config: DoubaoConfig): - self.config = config - self.session_id = str(uuid.uuid4()) - self.ws = None - self.log_id = "" - - async def connect(self) -> None: - """建立WebSocket连接""" - print(f"连接豆包服务器: {self.config.base_url}") - - try: - self.ws = await websockets.connect( - self.config.base_url, - additional_headers=self.config.get_headers(), - ping_interval=None - ) - - # 获取log_id - if hasattr(self.ws, 'response_headers'): - self.log_id = self.ws.response_headers.get("X-Tt-Logid") - elif hasattr(self.ws, 'headers'): - self.log_id = self.ws.headers.get("X-Tt-Logid") - - print(f"连接成功, log_id: {self.log_id}") - - # 发送StartConnection请求 - await self._send_start_connection() - - # 发送StartSession请求 - await self._send_start_session() - - except Exception as e: - print(f"连接失败: {e}") - raise - - async def _send_start_connection(self) -> None: - """发送StartConnection请求""" - print("发送StartConnection请求...") - request = bytearray(DoubaoProtocol.generate_header()) - request.extend(int(1).to_bytes(4, 'big')) - - payload_bytes = b"{}" - payload_bytes = gzip.compress(payload_bytes) - request.extend(len(payload_bytes).to_bytes(4, 'big')) - request.extend(payload_bytes) - - await self.ws.send(request) - response = await self.ws.recv() - print(f"StartConnection响应长度: {len(response)}") - - async def _send_start_session(self) -> None: - """发送StartSession请求""" - print("发送StartSession请求...") - session_config = { - "asr": { - "extra": { - "end_smooth_window_ms": 1500, - }, - }, - "tts": { - "speaker": "zh_female_vv_jupiter_bigtts", - "audio_config": { - "channel": 1, - "format": "pcm", - "sample_rate": 24000 - }, - }, - "dialog": { - "bot_name": "豆包", - "system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。", - "speaking_style": "你的说话风格简洁明了,语速适中,语调自然。", - "location": {"city": "北京"}, - "extra": { - "strict_audit": False, - "audit_response": "支持客户自定义安全审核回复话术。", - "recv_timeout": 30, - "input_mod": "audio", - }, - }, - } - - request = bytearray(DoubaoProtocol.generate_header()) - request.extend(int(100).to_bytes(4, 'big')) - request.extend(len(self.session_id).to_bytes(4, 'big')) - request.extend(self.session_id.encode()) - - payload_bytes = json.dumps(session_config).encode() - payload_bytes = gzip.compress(payload_bytes) - request.extend(len(payload_bytes).to_bytes(4, 'big')) - request.extend(payload_bytes) - - await self.ws.send(request) - response = await self.ws.recv() - print(f"StartSession响应长度: {len(response)}") - - # 等待一会确保会话完全建立 - await asyncio.sleep(1.0) - - async def test_audio_request(self) -> None: - """测试音频请求格式""" - print("测试音频请求格式...") - - # 创建音频数据(静音)- 使用原始豆包代码的chunk大小 - small_audio = b'\x00' * 3200 # 原始豆包代码中的chunk大小 - - # 完全按照原始豆包代码的格式构建请求,不进行任何填充 - header = bytearray() - header.append((DoubaoProtocol.PROTOCOL_VERSION << 4) | 1) # version + header_size - header.append((DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST << 4) | DoubaoProtocol.NO_SEQUENCE) - header.append((DoubaoProtocol.NO_SERIALIZATION << 4) | DoubaoProtocol.GZIP) - header.append(0x00) # reserved - - request = bytearray(header) - - # 添加消息类型 (200 = task request) - request.extend(int(200).to_bytes(4, 'big')) - - # 添加session_id - request.extend(len(self.session_id).to_bytes(4, 'big')) - request.extend(self.session_id.encode()) - - # 压缩音频数据 - compressed_audio = gzip.compress(small_audio) - - # 添加payload size - request.extend(len(compressed_audio).to_bytes(4, 'big')) - - # 添加压缩后的音频数据 - request.extend(compressed_audio) - - print(f"测试请求详细信息:") - print(f" - 音频原始大小: {len(small_audio)}") - print(f" - 音频压缩后大小: {len(compressed_audio)}") - print(f" - Session ID: {self.session_id} (长度: {len(self.session_id)})") - print(f" - 总请求大小: {len(request)}") - print(f" - 头部字节: {request[:4].hex()}") - print(f" - 消息类型: {int.from_bytes(request[4:8], 'big')}") - print(f" - Session ID长度: {int.from_bytes(request[8:12], 'big')}") - print(f" - Payload size: {int.from_bytes(request[12+len(self.session_id):16+len(self.session_id)], 'big')}") - - try: - await self.ws.send(request) - print("请求发送成功") - - # 等待响应 - response = await asyncio.wait_for(self.ws.recv(), timeout=3.0) - print(f"收到响应,长度: {len(response)}") - - # 尝试解析响应 - try: - protocol_version = response[0] >> 4 - header_size = response[0] & 0x0f - message_type = response[1] >> 4 - message_type_specific_flags = response[1] & 0x0f - serialization_method = response[2] >> 4 - message_compression = response[2] & 0x0f - - print(f"响应协议信息:") - print(f" - version={protocol_version}") - print(f" - header_size={header_size}") - print(f" - message_type={message_type} (15=SERVER_ERROR_RESPONSE)") - print(f" - message_type_specific_flags={message_type_specific_flags}") - print(f" - serialization_method={serialization_method}") - print(f" - message_compression={message_compression}") - - # 解析payload - payload = response[header_size * 4:] - if message_type == 15: # SERVER_ERROR_RESPONSE - if len(payload) >= 8: - code = int.from_bytes(payload[:4], "big", signed=False) - payload_size = int.from_bytes(payload[4:8], "big", signed=False) - print(f" - 错误代码: {code}") - print(f" - payload大小: {payload_size}") - - if len(payload) >= 8 + payload_size: - payload_msg = payload[8:8 + payload_size] - print(f" - payload长度: {len(payload_msg)}") - - if message_compression == 1: # GZIP - try: - payload_msg = gzip.decompress(payload_msg) - print(f" - 解压缩后长度: {len(payload_msg)}") - except: - pass - - try: - error_msg = json.loads(payload_msg.decode('utf-8')) - print(f" - 错误信息: {error_msg}") - except: - print(f" - 原始payload: {payload_msg}") - - except Exception as e: - print(f"解析响应失败: {e}") - import traceback - traceback.print_exc() - - except Exception as e: - print(f"发送测试请求失败: {e}") - raise - - async def close(self) -> None: - """关闭连接""" - if self.ws: - try: - await self.ws.close() - except: - pass - print("连接已关闭") - - -async def main(): - """测试函数""" - config = DoubaoConfig() - client = DoubaoClient(config) - - try: - await client.connect() - await client.test_audio_request() - except Exception as e: - print(f"测试失败: {e}") - import traceback - traceback.print_exc() - finally: - await client.close() - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/recognition_example.py b/recognition_example.py deleted file mode 100644 index c224b58..0000000 --- a/recognition_example.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -语音识别使用示例 -演示如何使用 speech_recognizer 模块 -""" - -import os -import asyncio -from speech_recognizer import SpeechRecognizer - -async def example_recognize_file(): - """示例:识别单个音频文件""" - print("=== 示例1:识别单个音频文件 ===") - - # 初始化识别器 - recognizer = SpeechRecognizer( - app_key="your_app_key", # 请替换为实际的app_key - access_key="your_access_key" # 请替换为实际的access_key - ) - - # 假设有一个录音文件 - audio_file = "recording_20240101_120000.wav" - - if not os.path.exists(audio_file): - print(f"音频文件不存在: {audio_file}") - print("请先运行 enhanced_wake_and_record.py 录制一个音频文件") - return - - try: - # 识别音频文件 - results = await recognizer.recognize_file(audio_file) - - print(f"识别结果(共{len(results)}个):") - for i, result in enumerate(results): - print(f"结果 {i+1}:") - print(f" 文本: {result.text}") - print(f" 置信度: {result.confidence}") - print(f" 最终结果: {result.is_final}") - print("-" * 40) - - except Exception as e: - print(f"识别失败: {e}") - -async def example_recognize_latest(): - """示例:识别最新的录音文件""" - print("\n=== 示例2:识别最新的录音文件 ===") - - # 初始化识别器 - recognizer = SpeechRecognizer( - app_key="your_app_key", # 请替换为实际的app_key - access_key="your_access_key" # 请替换为实际的access_key - ) - - try: - # 识别最新的录音文件 - result = await recognizer.recognize_latest_recording() - - if result: - print("识别结果:") - print(f" 文本: {result.text}") - print(f" 置信度: {result.confidence}") - print(f" 最终结果: {result.is_final}") - else: - print("未找到录音文件或识别失败") - - except Exception as e: - print(f"识别失败: {e}") - -async def example_batch_recognition(): - """示例:批量识别多个录音文件""" - print("\n=== 示例3:批量识别录音文件 ===") - - # 初始化识别器 - recognizer = SpeechRecognizer( - app_key="your_app_key", # 请替换为实际的app_key - access_key="your_access_key" # 请替换为实际的access_key - ) - - # 获取所有录音文件 - recording_files = [f for f in os.listdir(".") if f.startswith('recording_') and f.endswith('.wav')] - - if not recording_files: - print("未找到录音文件") - return - - print(f"找到 {len(recording_files)} 个录音文件") - - for filename in recording_files[:5]: # 只处理前5个文件 - print(f"\n处理文件: {filename}") - try: - results = await recognizer.recognize_file(filename) - - if results: - final_result = results[-1] # 取最后一个结果 - print(f"识别结果: {final_result.text}") - else: - print("识别失败") - - except Exception as e: - print(f"处理失败: {e}") - - # 添加延迟,避免请求过于频繁 - await asyncio.sleep(1) - -async def main(): - """主函数""" - print("🚀 语音识别使用示例") - print("=" * 50) - - # 请先设置环境变量或在代码中填入实际的API密钥 - if not os.getenv("SAUC_APP_KEY") and "your_app_key" in "your_app_key": - print("⚠️ 请先设置 SAUC_APP_KEY 和 SAUC_ACCESS_KEY 环境变量") - print("或者在代码中填入实际的 app_key 和 access_key") - print("示例:") - print("export SAUC_APP_KEY='your_app_key'") - print("export SAUC_ACCESS_KEY='your_access_key'") - return - - # 运行示例 - await example_recognize_file() - await example_recognize_latest() - await example_batch_recognition() - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/sauc_python/readme.md b/sauc_python/readme.md deleted file mode 100644 index 4dbcebd..0000000 --- a/sauc_python/readme.md +++ /dev/null @@ -1,15 +0,0 @@ -# README - -**asr tob 相关client demo** - -# Notice -python version: python 3.x - -替换代码中的key为真实数据: - "app_key": "xxxxxxx", - "access_key": "xxxxxxxxxxxxxxxx" -使用示例: - python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav - - - diff --git a/sauc_python/sauc_websocket_demo.py b/sauc_python/sauc_websocket_demo.py deleted file mode 100644 index 092d24b..0000000 --- a/sauc_python/sauc_websocket_demo.py +++ /dev/null @@ -1,523 +0,0 @@ -import asyncio -import aiohttp -import json -import struct -import gzip -import uuid -import logging -import os -import subprocess -from typing import Optional, List, Dict, Any, Tuple, AsyncGenerator - -# 配置日志 -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler('run.log'), - logging.StreamHandler() - ] -) -logger = logging.getLogger(__name__) - -# 常量定义 -DEFAULT_SAMPLE_RATE = 16000 - -class ProtocolVersion: - V1 = 0b0001 - -class MessageType: - CLIENT_FULL_REQUEST = 0b0001 - CLIENT_AUDIO_ONLY_REQUEST = 0b0010 - SERVER_FULL_RESPONSE = 0b1001 - SERVER_ERROR_RESPONSE = 0b1111 - -class MessageTypeSpecificFlags: - NO_SEQUENCE = 0b0000 - POS_SEQUENCE = 0b0001 - NEG_SEQUENCE = 0b0010 - NEG_WITH_SEQUENCE = 0b0011 - -class SerializationType: - NO_SERIALIZATION = 0b0000 - JSON = 0b0001 - -class CompressionType: - GZIP = 0b0001 - - -class Config: - def __init__(self): - # 填入控制台获取的app id和access token - self.auth = { - "app_key": "xxxxxxx", - "access_key": "xxxxxxxxxxxx" - } - - @property - def app_key(self) -> str: - return self.auth["app_key"] - - @property - def access_key(self) -> str: - return self.auth["access_key"] - -config = Config() - -class CommonUtils: - @staticmethod - def gzip_compress(data: bytes) -> bytes: - return gzip.compress(data) - - @staticmethod - def gzip_decompress(data: bytes) -> bytes: - return gzip.decompress(data) - - @staticmethod - def judge_wav(data: bytes) -> bool: - if len(data) < 44: - return False - return data[:4] == b'RIFF' and data[8:12] == b'WAVE' - - @staticmethod - def convert_wav_with_path(audio_path: str, sample_rate: int = DEFAULT_SAMPLE_RATE) -> bytes: - try: - cmd = [ - "ffmpeg", "-v", "quiet", "-y", "-i", audio_path, - "-acodec", "pcm_s16le", "-ac", "1", "-ar", str(sample_rate), - "-f", "wav", "-" - ] - result = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - - # 尝试删除原始文件 - try: - os.remove(audio_path) - except OSError as e: - logger.warning(f"Failed to remove original file: {e}") - - return result.stdout - except subprocess.CalledProcessError as e: - logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}") - raise RuntimeError(f"Audio conversion failed: {e.stderr.decode()}") - - @staticmethod - def read_wav_info(data: bytes) -> Tuple[int, int, int, int, bytes]: - if len(data) < 44: - raise ValueError("Invalid WAV file: too short") - - # 解析WAV头 - chunk_id = data[:4] - if chunk_id != b'RIFF': - raise ValueError("Invalid WAV file: not RIFF format") - - format_ = data[8:12] - if format_ != b'WAVE': - raise ValueError("Invalid WAV file: not WAVE format") - - # 解析fmt子块 - audio_format = struct.unpack(' 'AsrRequestHeader': - self.message_type = message_type - return self - - def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader': - self.message_type_specific_flags = flags - return self - - def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader': - self.serialization_type = serialization_type - return self - - def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader': - self.compression_type = compression_type - return self - - def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader': - self.reserved_data = reserved_data - return self - - def to_bytes(self) -> bytes: - header = bytearray() - header.append((ProtocolVersion.V1 << 4) | 1) - header.append((self.message_type << 4) | self.message_type_specific_flags) - header.append((self.serialization_type << 4) | self.compression_type) - header.extend(self.reserved_data) - return bytes(header) - - @staticmethod - def default_header() -> 'AsrRequestHeader': - return AsrRequestHeader() - -class RequestBuilder: - @staticmethod - def new_auth_headers() -> Dict[str, str]: - reqid = str(uuid.uuid4()) - return { - "X-Api-Resource-Id": "volc.bigasr.sauc.duration", - "X-Api-Request-Id": reqid, - "X-Api-Access-Key": config.access_key, - "X-Api-App-Key": config.app_key - } - - @staticmethod - def new_full_client_request(seq: int) -> bytes: # 添加seq参数 - header = AsrRequestHeader.default_header() \ - .with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE) - - payload = { - "user": { - "uid": "demo_uid" - }, - "audio": { - "format": "wav", - "codec": "raw", - "rate": 16000, - "bits": 16, - "channel": 1 - }, - "request": { - "model_name": "bigmodel", - "enable_itn": True, - "enable_punc": True, - "enable_ddc": True, - "show_utterances": True, - "enable_nonstream": False - } - } - - payload_bytes = json.dumps(payload).encode('utf-8') - compressed_payload = CommonUtils.gzip_compress(payload_bytes) - payload_size = len(compressed_payload) - - request = bytearray() - request.extend(header.to_bytes()) - request.extend(struct.pack('>i', seq)) # 使用传入的seq - request.extend(struct.pack('>I', payload_size)) - request.extend(compressed_payload) - - return bytes(request) - - @staticmethod - def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes: - header = AsrRequestHeader.default_header() - if is_last: # 最后一个包特殊处理 - header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE) - seq = -seq # 设为负值 - else: - header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE) - header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST) - - request = bytearray() - request.extend(header.to_bytes()) - request.extend(struct.pack('>i', seq)) - - compressed_segment = CommonUtils.gzip_compress(segment) - request.extend(struct.pack('>I', len(compressed_segment))) - request.extend(compressed_segment) - - return bytes(request) - -class AsrResponse: - def __init__(self): - self.code = 0 - self.event = 0 - self.is_last_package = False - self.payload_sequence = 0 - self.payload_size = 0 - self.payload_msg = None - - def to_dict(self) -> Dict[str, Any]: - return { - "code": self.code, - "event": self.event, - "is_last_package": self.is_last_package, - "payload_sequence": self.payload_sequence, - "payload_size": self.payload_size, - "payload_msg": self.payload_msg - } - -class ResponseParser: - @staticmethod - def parse_response(msg: bytes) -> AsrResponse: - response = AsrResponse() - - header_size = msg[0] & 0x0f - message_type = msg[1] >> 4 - message_type_specific_flags = msg[1] & 0x0f - serialization_method = msg[2] >> 4 - message_compression = msg[2] & 0x0f - - payload = msg[header_size*4:] - - # 解析message_type_specific_flags - if message_type_specific_flags & 0x01: - response.payload_sequence = struct.unpack('>i', payload[:4])[0] - payload = payload[4:] - if message_type_specific_flags & 0x02: - response.is_last_package = True - if message_type_specific_flags & 0x04: - response.event = struct.unpack('>i', payload[:4])[0] - payload = payload[4:] - - # 解析message_type - if message_type == MessageType.SERVER_FULL_RESPONSE: - response.payload_size = struct.unpack('>I', payload[:4])[0] - payload = payload[4:] - elif message_type == MessageType.SERVER_ERROR_RESPONSE: - response.code = struct.unpack('>i', payload[:4])[0] - response.payload_size = struct.unpack('>I', payload[4:8])[0] - payload = payload[8:] - - if not payload: - return response - - # 解压缩 - if message_compression == CompressionType.GZIP: - try: - payload = CommonUtils.gzip_decompress(payload) - except Exception as e: - logger.error(f"Failed to decompress payload: {e}") - return response - - # 解析payload - try: - if serialization_method == SerializationType.JSON: - response.payload_msg = json.loads(payload.decode('utf-8')) - except Exception as e: - logger.error(f"Failed to parse payload: {e}") - - return response - -class AsrWsClient: - def __init__(self, url: str, segment_duration: int = 200): - self.seq = 1 - self.url = url - self.segment_duration = segment_duration - self.conn = None - self.session = None # 添加session引用 - - async def __aenter__(self): - self.session = aiohttp.ClientSession() - return self - - async def __aexit__(self, exc_type, exc, tb): - if self.conn and not self.conn.closed: - await self.conn.close() - if self.session and not self.session.closed: - await self.session.close() - - async def read_audio_data(self, file_path: str) -> bytes: - try: - with open(file_path, 'rb') as f: - content = f.read() - - if not CommonUtils.judge_wav(content): - logger.info("Converting audio to WAV format...") - content = CommonUtils.convert_wav_with_path(file_path, DEFAULT_SAMPLE_RATE) - - return content - except Exception as e: - logger.error(f"Failed to read audio data: {e}") - raise - - def get_segment_size(self, content: bytes) -> int: - try: - channel_num, samp_width, frame_rate, _, _ = CommonUtils.read_wav_info(content)[:5] - size_per_sec = channel_num * samp_width * frame_rate - segment_size = size_per_sec * self.segment_duration // 1000 - return segment_size - except Exception as e: - logger.error(f"Failed to calculate segment size: {e}") - raise - - async def create_connection(self) -> None: - headers = RequestBuilder.new_auth_headers() - try: - self.conn = await self.session.ws_connect( # 使用self.session - self.url, - headers=headers - ) - logger.info(f"Connected to {self.url}") - except Exception as e: - logger.error(f"Failed to connect to WebSocket: {e}") - raise - - async def send_full_client_request(self) -> None: - request = RequestBuilder.new_full_client_request(self.seq) - self.seq += 1 # 发送后递增 - try: - await self.conn.send_bytes(request) - logger.info(f"Sent full client request with seq: {self.seq-1}") - - msg = await self.conn.receive() - if msg.type == aiohttp.WSMsgType.BINARY: - response = ResponseParser.parse_response(msg.data) - logger.info(f"Received response: {response.to_dict()}") - else: - logger.error(f"Unexpected message type: {msg.type}") - except Exception as e: - logger.error(f"Failed to send full client request: {e}") - raise - - async def send_messages(self, segment_size: int, content: bytes) -> AsyncGenerator[None, None]: - audio_segments = self.split_audio(content, segment_size) - total_segments = len(audio_segments) - - for i, segment in enumerate(audio_segments): - is_last = (i == total_segments - 1) - request = RequestBuilder.new_audio_only_request( - self.seq, - segment, - is_last=is_last - ) - await self.conn.send_bytes(request) - logger.info(f"Sent audio segment with seq: {self.seq} (last: {is_last})") - - if not is_last: - self.seq += 1 - - await asyncio.sleep(self.segment_duration / 1000) # 逐个发送,间隔时间模拟实时流 - # 让出控制权,允许接受消息 - yield - - async def recv_messages(self) -> AsyncGenerator[AsrResponse, None]: - try: - async for msg in self.conn: - if msg.type == aiohttp.WSMsgType.BINARY: - response = ResponseParser.parse_response(msg.data) - yield response - - if response.is_last_package or response.code != 0: - break - elif msg.type == aiohttp.WSMsgType.ERROR: - logger.error(f"WebSocket error: {msg.data}") - break - elif msg.type == aiohttp.WSMsgType.CLOSED: - logger.info("WebSocket connection closed") - break - except Exception as e: - logger.error(f"Error receiving messages: {e}") - raise - - async def start_audio_stream(self, segment_size: int, content: bytes) -> AsyncGenerator[AsrResponse, None]: - async def sender(): - async for _ in self.send_messages(segment_size, content): - pass - - # 启动发送和接收任务 - sender_task = asyncio.create_task(sender()) - - try: - async for response in self.recv_messages(): - yield response - finally: - sender_task.cancel() - try: - await sender_task - except asyncio.CancelledError: - pass - - @staticmethod - def split_audio(data: bytes, segment_size: int) -> List[bytes]: - if segment_size <= 0: - return [] - - segments = [] - for i in range(0, len(data), segment_size): - end = i + segment_size - if end > len(data): - end = len(data) - segments.append(data[i:end]) - return segments - - async def execute(self, file_path: str) -> AsyncGenerator[AsrResponse, None]: - if not file_path: - raise ValueError("File path is empty") - - if not self.url: - raise ValueError("URL is empty") - - self.seq = 1 - - try: - # 1. 读取音频文件 - content = await self.read_audio_data(file_path) - - # 2. 计算分段大小 - segment_size = self.get_segment_size(content) - - # 3. 创建WebSocket连接 - await self.create_connection() - - # 4. 发送完整客户端请求 - await self.send_full_client_request() - - # 5. 启动音频流处理 - async for response in self.start_audio_stream(segment_size, content): - yield response - - except Exception as e: - logger.error(f"Error in ASR execution: {e}") - raise - finally: - if self.conn: - await self.conn.close() - -async def main(): - import argparse - - parser = argparse.ArgumentParser(description="ASR WebSocket Client") - parser.add_argument("--file", type=str, required=True, help="Audio file path") - - #wss://openspeech.bytedance.com/api/v3/sauc/bigmodel - #wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async - #wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream - parser.add_argument("--url", type=str, default="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream", - help="WebSocket URL") - parser.add_argument("--seg-duration", type=int, default=200, - help="Audio duration(ms) per packet, default:200") - - args = parser.parse_args() - - async with AsrWsClient(args.url, args.seg_duration) as client: # 使用async with - try: - async for response in client.execute(args.file): - logger.info(f"Received response: {json.dumps(response.to_dict(), indent=2, ensure_ascii=False)}") - except Exception as e: - logger.error(f"ASR processing failed: {e}") - -if __name__ == "__main__": - asyncio.run(main()) - - # 用法: - # python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav \ No newline at end of file diff --git a/speech_recognizer.py b/speech_recognizer.py deleted file mode 100644 index ba232d4..0000000 --- a/speech_recognizer.py +++ /dev/null @@ -1,532 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -语音识别模块 -基于 SAUC API 为录音文件提供语音识别功能 -""" - -import os -import json -import time -import logging -import asyncio -import aiohttp -import struct -import gzip -import uuid -from typing import Optional, List, Dict, Any, AsyncGenerator -from dataclasses import dataclass - -# 配置日志 -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -# 常量定义 -DEFAULT_SAMPLE_RATE = 16000 - -class ProtocolVersion: - V1 = 0b0001 - -class MessageType: - CLIENT_FULL_REQUEST = 0b0001 - CLIENT_AUDIO_ONLY_REQUEST = 0b0010 - SERVER_FULL_RESPONSE = 0b1001 - SERVER_ERROR_RESPONSE = 0b1111 - -class MessageTypeSpecificFlags: - NO_SEQUENCE = 0b0000 - POS_SEQUENCE = 0b0001 - NEG_SEQUENCE = 0b0010 - NEG_WITH_SEQUENCE = 0b0011 - -class SerializationType: - NO_SERIALIZATION = 0b0000 - JSON = 0b0001 - -class CompressionType: - GZIP = 0b0001 - -@dataclass -class RecognitionResult: - """语音识别结果""" - text: str - confidence: float - is_final: bool - start_time: Optional[float] = None - end_time: Optional[float] = None - -class AudioUtils: - """音频处理工具类""" - - @staticmethod - def gzip_compress(data: bytes) -> bytes: - """GZIP压缩""" - return gzip.compress(data) - - @staticmethod - def gzip_decompress(data: bytes) -> bytes: - """GZIP解压缩""" - return gzip.decompress(data) - - @staticmethod - def is_wav_file(data: bytes) -> bool: - """检查是否为WAV文件""" - if len(data) < 44: - return False - return data[:4] == b'RIFF' and data[8:12] == b'WAVE' - - @staticmethod - def read_wav_info(data: bytes) -> tuple: - """读取WAV文件信息""" - if len(data) < 44: - raise ValueError("Invalid WAV file: too short") - - # 解析WAV头 - chunk_id = data[:4] - if chunk_id != b'RIFF': - raise ValueError("Invalid WAV file: not RIFF format") - - format_ = data[8:12] - if format_ != b'WAVE': - raise ValueError("Invalid WAV file: not WAVE format") - - # 解析fmt子块 - audio_format = struct.unpack(' str: - return self.auth["app_key"] - - @property - def access_key(self) -> str: - return self.auth["access_key"] - -class AsrRequestHeader: - """ASR请求头""" - - def __init__(self): - self.message_type = MessageType.CLIENT_FULL_REQUEST - self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE - self.serialization_type = SerializationType.JSON - self.compression_type = CompressionType.GZIP - self.reserved_data = bytes([0x00]) - - def with_message_type(self, message_type: int) -> 'AsrRequestHeader': - self.message_type = message_type - return self - - def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader': - self.message_type_specific_flags = flags - return self - - def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader': - self.serialization_type = serialization_type - return self - - def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader': - self.compression_type = compression_type - return self - - def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader': - self.reserved_data = reserved_data - return self - - def to_bytes(self) -> bytes: - header = bytearray() - header.append((ProtocolVersion.V1 << 4) | 1) - header.append((self.message_type << 4) | self.message_type_specific_flags) - header.append((self.serialization_type << 4) | self.compression_type) - header.extend(self.reserved_data) - return bytes(header) - - @staticmethod - def default_header() -> 'AsrRequestHeader': - return AsrRequestHeader() - -class RequestBuilder: - """请求构建器""" - - @staticmethod - def new_auth_headers(config: AsrConfig) -> Dict[str, str]: - """创建认证头""" - reqid = str(uuid.uuid4()) - return { - "X-Api-Resource-Id": "volc.bigasr.sauc.duration", - "X-Api-Request-Id": reqid, - "X-Api-Access-Key": config.access_key, - "X-Api-App-Key": config.app_key - } - - @staticmethod - def new_full_client_request(seq: int) -> bytes: - """创建完整客户端请求""" - header = AsrRequestHeader.default_header() \ - .with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE) - - payload = { - "user": { - "uid": "local_voice_user" - }, - "audio": { - "format": "wav", - "codec": "raw", - "rate": 16000, - "bits": 16, - "channel": 1 - }, - "request": { - "model_name": "bigmodel", - "enable_itn": True, - "enable_punc": True, - "enable_ddc": True, - "show_utterances": True, - "enable_nonstream": False - } - } - - payload_bytes = json.dumps(payload).encode('utf-8') - compressed_payload = AudioUtils.gzip_compress(payload_bytes) - payload_size = len(compressed_payload) - - request = bytearray() - request.extend(header.to_bytes()) - request.extend(struct.pack('>i', seq)) - request.extend(struct.pack('>U', payload_size)) - request.extend(compressed_payload) - - return bytes(request) - - @staticmethod - def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes: - """创建纯音频请求""" - header = AsrRequestHeader.default_header() - if is_last: - header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE) - seq = -seq - else: - header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE) - header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST) - - request = bytearray() - request.extend(header.to_bytes()) - request.extend(struct.pack('>i', seq)) - - compressed_segment = AudioUtils.gzip_compress(segment) - request.extend(struct.pack('>U', len(compressed_segment))) - request.extend(compressed_segment) - - return bytes(request) - -class AsrResponse: - """ASR响应""" - - def __init__(self): - self.code = 0 - self.event = 0 - self.is_last_package = False - self.payload_sequence = 0 - self.payload_size = 0 - self.payload_msg = None - - def to_dict(self) -> Dict[str, Any]: - return { - "code": self.code, - "event": self.event, - "is_last_package": self.is_last_package, - "payload_sequence": self.payload_sequence, - "payload_size": self.payload_size, - "payload_msg": self.payload_msg - } - -class ResponseParser: - """响应解析器""" - - @staticmethod - def parse_response(msg: bytes) -> AsrResponse: - """解析响应""" - response = AsrResponse() - - header_size = msg[0] & 0x0f - message_type = msg[1] >> 4 - message_type_specific_flags = msg[1] & 0x0f - serialization_method = msg[2] >> 4 - message_compression = msg[2] & 0x0f - - payload = msg[header_size*4:] - - # 解析message_type_specific_flags - if message_type_specific_flags & 0x01: - response.payload_sequence = struct.unpack('>i', payload[:4])[0] - payload = payload[4:] - if message_type_specific_flags & 0x02: - response.is_last_package = True - if message_type_specific_flags & 0x04: - response.event = struct.unpack('>i', payload[:4])[0] - payload = payload[4:] - - # 解析message_type - if message_type == MessageType.SERVER_FULL_RESPONSE: - response.payload_size = struct.unpack('>U', payload[:4])[0] - payload = payload[4:] - elif message_type == MessageType.SERVER_ERROR_RESPONSE: - response.code = struct.unpack('>i', payload[:4])[0] - response.payload_size = struct.unpack('>U', payload[4:8])[0] - payload = payload[8:] - - if not payload: - return response - - # 解压缩 - if message_compression == CompressionType.GZIP: - try: - payload = AudioUtils.gzip_decompress(payload) - except Exception as e: - logger.error(f"Failed to decompress payload: {e}") - return response - - # 解析payload - try: - if serialization_method == SerializationType.JSON: - response.payload_msg = json.loads(payload.decode('utf-8')) - except Exception as e: - logger.error(f"Failed to parse payload: {e}") - - return response - -class SpeechRecognizer: - """语音识别器""" - - def __init__(self, app_key: str = None, access_key: str = None, - url: str = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream"): - self.config = AsrConfig(app_key, access_key) - self.url = url - self.seq = 1 - - async def recognize_file(self, file_path: str) -> List[RecognitionResult]: - """识别音频文件""" - if not os.path.exists(file_path): - raise FileNotFoundError(f"Audio file not found: {file_path}") - - results = [] - - try: - async with aiohttp.ClientSession() as session: - # 读取音频文件 - with open(file_path, 'rb') as f: - content = f.read() - - if not AudioUtils.is_wav_file(content): - raise ValueError("Audio file must be in WAV format") - - # 获取音频信息 - try: - _, _, sample_rate, _, audio_data = AudioUtils.read_wav_info(content) - if sample_rate != DEFAULT_SAMPLE_RATE: - logger.warning(f"Sample rate {sample_rate} != {DEFAULT_SAMPLE_RATE}, may affect recognition accuracy") - except Exception as e: - logger.error(f"Failed to read audio info: {e}") - raise - - # 计算分段大小 (200ms per segment) - segment_size = 1 * 2 * DEFAULT_SAMPLE_RATE * 200 // 1000 # channel * bytes_per_sample * sample_rate * duration_ms / 1000 - - # 创建WebSocket连接 - headers = RequestBuilder.new_auth_headers(self.config) - async with session.ws_connect(self.url, headers=headers) as ws: - - # 发送完整客户端请求 - request = RequestBuilder.new_full_client_request(self.seq) - self.seq += 1 - await ws.send_bytes(request) - - # 接收初始响应 - msg = await ws.receive() - if msg.type == aiohttp.WSMsgType.BINARY: - response = ResponseParser.parse_response(msg.data) - logger.info(f"Initial response: {response.to_dict()}") - - # 分段发送音频数据 - audio_segments = self._split_audio(audio_data, segment_size) - total_segments = len(audio_segments) - - for i, segment in enumerate(audio_segments): - is_last = (i == total_segments - 1) - request = RequestBuilder.new_audio_only_request( - self.seq, - segment, - is_last=is_last - ) - await ws.send_bytes(request) - logger.info(f"Sent audio segment {i+1}/{total_segments}") - - if not is_last: - self.seq += 1 - - # 短暂延迟模拟实时流 - await asyncio.sleep(0.1) - - # 接收识别结果 - final_text = "" - while True: - msg = await ws.receive() - if msg.type == aiohttp.WSMsgType.BINARY: - response = ResponseParser.parse_response(msg.data) - - if response.payload_msg and 'text' in response.payload_msg: - text = response.payload_msg['text'] - if text: - final_text += text - - result = RecognitionResult( - text=text, - confidence=0.9, # 默认置信度 - is_final=response.is_last_package - ) - results.append(result) - - logger.info(f"Recognized: {text}") - - if response.is_last_package or response.code != 0: - break - - elif msg.type == aiohttp.WSMsgType.ERROR: - logger.error(f"WebSocket error: {msg.data}") - break - elif msg.type == aiohttp.WSMsgType.CLOSED: - logger.info("WebSocket connection closed") - break - - # 如果没有获得最终结果,创建一个包含所有文本的结果 - if final_text and not any(r.is_final for r in results): - final_result = RecognitionResult( - text=final_text, - confidence=0.9, - is_final=True - ) - results.append(final_result) - - return results - - except Exception as e: - logger.error(f"Speech recognition failed: {e}") - raise - - def _split_audio(self, data: bytes, segment_size: int) -> List[bytes]: - """分割音频数据""" - if segment_size <= 0: - return [] - - segments = [] - for i in range(0, len(data), segment_size): - end = i + segment_size - if end > len(data): - end = len(data) - segments.append(data[i:end]) - return segments - - async def recognize_latest_recording(self, directory: str = ".") -> Optional[RecognitionResult]: - """识别最新的录音文件""" - # 查找最新的录音文件 - recording_files = [f for f in os.listdir(directory) if f.startswith('recording_') and f.endswith('.wav')] - - if not recording_files: - logger.warning("No recording files found") - return None - - # 按文件名排序(包含时间戳) - recording_files.sort(reverse=True) - latest_file = recording_files[0] - latest_path = os.path.join(directory, latest_file) - - logger.info(f"Recognizing latest recording: {latest_file}") - - try: - results = await self.recognize_file(latest_path) - if results: - # 返回最终的识别结果 - final_results = [r for r in results if r.is_final] - if final_results: - return final_results[-1] - else: - # 如果没有标记为final的结果,返回最后一个 - return results[-1] - except Exception as e: - logger.error(f"Failed to recognize latest recording: {e}") - - return None - -async def main(): - """测试函数""" - import argparse - - parser = argparse.ArgumentParser(description="语音识别测试") - parser.add_argument("--file", type=str, help="音频文件路径") - parser.add_argument("--latest", action="store_true", help="识别最新的录音文件") - parser.add_argument("--app-key", type=str, help="SAUC App Key") - parser.add_argument("--access-key", type=str, help="SAUC Access Key") - - args = parser.parse_args() - - recognizer = SpeechRecognizer( - app_key=args.app_key, - access_key=args.access_key - ) - - try: - if args.latest: - result = await recognizer.recognize_latest_recording() - if result: - print(f"识别结果: {result.text}") - print(f"置信度: {result.confidence}") - print(f"最终结果: {result.is_final}") - else: - print("未能识别到语音内容") - elif args.file: - results = await recognizer.recognize_file(args.file) - for i, result in enumerate(results): - print(f"结果 {i+1}: {result.text}") - print(f"置信度: {result.confidence}") - print(f"最终结果: {result.is_final}") - print("-" * 40) - else: - print("请指定 --file 或 --latest 参数") - - except Exception as e: - print(f"识别失败: {e}") - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/test_doubao.py b/test_doubao.py deleted file mode 100644 index f0d09d0..0000000 --- a/test_doubao.py +++ /dev/null @@ -1,113 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -豆包音频处理模块 - 验证脚本 -验证完整的音频处理流程 -""" - -import asyncio -import subprocess -import os -from doubao_simple import DoubaoClient - -async def test_complete_workflow(): - """测试完整的工作流程""" - print("=== 豆包音频处理模块验证 ===") - - # 检查输入文件 - input_file = "recording_20250920_135137.wav" - if not os.path.exists(input_file): - print(f"❌ 输入文件不存在: {input_file}") - return False - - print(f"✅ 输入文件存在: {input_file}") - - # 检查文件信息 - try: - result = subprocess.run(['file', input_file], capture_output=True, text=True) - print(f"📁 输入文件格式: {result.stdout.strip()}") - except: - pass - - # 初始化客户端 - client = DoubaoClient() - - try: - # 连接服务器 - print("🔌 连接豆包服务器...") - await client.connect() - print("✅ 连接成功") - - # 处理音频文件 - output_file = "tts_output.wav" - print(f"🎵 处理音频文件: {input_file} -> {output_file}") - - success = await client.process_audio_file(input_file, output_file) - - if success: - print("✅ 音频处理成功!") - - # 检查输出文件 - if os.path.exists(output_file): - result = subprocess.run(['file', output_file], capture_output=True, text=True) - print(f"📁 输出文件格式: {result.stdout.strip()}") - - # 获取文件大小 - file_size = os.path.getsize(output_file) - print(f"📊 输出文件大小: {file_size:,} 字节") - - # 测试播放 - print("🔊 测试播放输出文件...") - try: - subprocess.run(['aplay', output_file], timeout=10, check=True) - print("✅ 播放成功") - except subprocess.TimeoutExpired: - print("✅ 播放完成(超时是正常的)") - except subprocess.CalledProcessError as e: - print(f"⚠️ 播放出现问题: {e}") - except FileNotFoundError: - print("⚠️ aplay命令不存在,跳过播放测试") - - return True - else: - print("❌ 输出文件未生成") - return False - else: - print("❌ 音频处理失败") - return False - - except Exception as e: - print(f"❌ 测试失败: {e}") - import traceback - traceback.print_exc() - return False - finally: - try: - await client.close() - except: - pass - -def main(): - """主函数""" - print("开始验证豆包音频处理模块...") - - success = asyncio.run(test_complete_workflow()) - - if success: - print("\n🎉 验证完成!豆包音频处理模块工作正常。") - print("\n📋 功能总结:") - print(" ✅ WebSocket连接建立") - print(" ✅ 音频文件上传") - print(" ✅ 语音识别") - print(" ✅ TTS音频生成") - print(" ✅ 音频格式转换(Float32 -> Int16)") - print(" ✅ WAV文件生成") - print(" ✅ 树莓派兼容播放") - else: - print("\n❌ 验证失败,请检查错误信息。") - - return success - -if __name__ == "__main__": - main() \ No newline at end of file