import asyncio import aiohttp import json import struct import gzip import uuid import logging import os import subprocess from typing import Optional, List, Dict, Any, Tuple, AsyncGenerator # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('run.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # 常量定义 DEFAULT_SAMPLE_RATE = 16000 class ProtocolVersion: V1 = 0b0001 class MessageType: CLIENT_FULL_REQUEST = 0b0001 CLIENT_AUDIO_ONLY_REQUEST = 0b0010 SERVER_FULL_RESPONSE = 0b1001 SERVER_ERROR_RESPONSE = 0b1111 class MessageTypeSpecificFlags: NO_SEQUENCE = 0b0000 POS_SEQUENCE = 0b0001 NEG_SEQUENCE = 0b0010 NEG_WITH_SEQUENCE = 0b0011 class SerializationType: NO_SERIALIZATION = 0b0000 JSON = 0b0001 class CompressionType: GZIP = 0b0001 class Config: def __init__(self): # 填入控制台获取的app id和access token self.auth = { "app_key": "xxxxxxx", "access_key": "xxxxxxxxxxxx" } @property def app_key(self) -> str: return self.auth["app_key"] @property def access_key(self) -> str: return self.auth["access_key"] config = Config() class CommonUtils: @staticmethod def gzip_compress(data: bytes) -> bytes: return gzip.compress(data) @staticmethod def gzip_decompress(data: bytes) -> bytes: return gzip.decompress(data) @staticmethod def judge_wav(data: bytes) -> bool: if len(data) < 44: return False return data[:4] == b'RIFF' and data[8:12] == b'WAVE' @staticmethod def convert_wav_with_path(audio_path: str, sample_rate: int = DEFAULT_SAMPLE_RATE) -> bytes: try: cmd = [ "ffmpeg", "-v", "quiet", "-y", "-i", audio_path, "-acodec", "pcm_s16le", "-ac", "1", "-ar", str(sample_rate), "-f", "wav", "-" ] result = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) # 尝试删除原始文件 try: os.remove(audio_path) except OSError as e: logger.warning(f"Failed to remove original file: {e}") return result.stdout except subprocess.CalledProcessError as e: logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}") raise RuntimeError(f"Audio conversion failed: {e.stderr.decode()}") @staticmethod def read_wav_info(data: bytes) -> Tuple[int, int, int, int, bytes]: if len(data) < 44: raise ValueError("Invalid WAV file: too short") # 解析WAV头 chunk_id = data[:4] if chunk_id != b'RIFF': raise ValueError("Invalid WAV file: not RIFF format") format_ = data[8:12] if format_ != b'WAVE': raise ValueError("Invalid WAV file: not WAVE format") # 解析fmt子块 audio_format = struct.unpack(' 'AsrRequestHeader': self.message_type = message_type return self def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader': self.message_type_specific_flags = flags return self def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader': self.serialization_type = serialization_type return self def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader': self.compression_type = compression_type return self def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader': self.reserved_data = reserved_data return self def to_bytes(self) -> bytes: header = bytearray() header.append((ProtocolVersion.V1 << 4) | 1) header.append((self.message_type << 4) | self.message_type_specific_flags) header.append((self.serialization_type << 4) | self.compression_type) header.extend(self.reserved_data) return bytes(header) @staticmethod def default_header() -> 'AsrRequestHeader': return AsrRequestHeader() class RequestBuilder: @staticmethod def new_auth_headers() -> Dict[str, str]: reqid = str(uuid.uuid4()) return { "X-Api-Resource-Id": "volc.bigasr.sauc.duration", "X-Api-Request-Id": reqid, "X-Api-Access-Key": config.access_key, "X-Api-App-Key": config.app_key } @staticmethod def new_full_client_request(seq: int) -> bytes: # 添加seq参数 header = AsrRequestHeader.default_header() \ .with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE) payload = { "user": { "uid": "demo_uid" }, "audio": { "format": "wav", "codec": "raw", "rate": 16000, "bits": 16, "channel": 1 }, "request": { "model_name": "bigmodel", "enable_itn": True, "enable_punc": True, "enable_ddc": True, "show_utterances": True, "enable_nonstream": False } } payload_bytes = json.dumps(payload).encode('utf-8') compressed_payload = CommonUtils.gzip_compress(payload_bytes) payload_size = len(compressed_payload) request = bytearray() request.extend(header.to_bytes()) request.extend(struct.pack('>i', seq)) # 使用传入的seq request.extend(struct.pack('>I', payload_size)) request.extend(compressed_payload) return bytes(request) @staticmethod def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes: header = AsrRequestHeader.default_header() if is_last: # 最后一个包特殊处理 header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE) seq = -seq # 设为负值 else: header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE) header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST) request = bytearray() request.extend(header.to_bytes()) request.extend(struct.pack('>i', seq)) compressed_segment = CommonUtils.gzip_compress(segment) request.extend(struct.pack('>I', len(compressed_segment))) request.extend(compressed_segment) return bytes(request) class AsrResponse: def __init__(self): self.code = 0 self.event = 0 self.is_last_package = False self.payload_sequence = 0 self.payload_size = 0 self.payload_msg = None def to_dict(self) -> Dict[str, Any]: return { "code": self.code, "event": self.event, "is_last_package": self.is_last_package, "payload_sequence": self.payload_sequence, "payload_size": self.payload_size, "payload_msg": self.payload_msg } class ResponseParser: @staticmethod def parse_response(msg: bytes) -> AsrResponse: response = AsrResponse() header_size = msg[0] & 0x0f message_type = msg[1] >> 4 message_type_specific_flags = msg[1] & 0x0f serialization_method = msg[2] >> 4 message_compression = msg[2] & 0x0f payload = msg[header_size*4:] # 解析message_type_specific_flags if message_type_specific_flags & 0x01: response.payload_sequence = struct.unpack('>i', payload[:4])[0] payload = payload[4:] if message_type_specific_flags & 0x02: response.is_last_package = True if message_type_specific_flags & 0x04: response.event = struct.unpack('>i', payload[:4])[0] payload = payload[4:] # 解析message_type if message_type == MessageType.SERVER_FULL_RESPONSE: response.payload_size = struct.unpack('>I', payload[:4])[0] payload = payload[4:] elif message_type == MessageType.SERVER_ERROR_RESPONSE: response.code = struct.unpack('>i', payload[:4])[0] response.payload_size = struct.unpack('>I', payload[4:8])[0] payload = payload[8:] if not payload: return response # 解压缩 if message_compression == CompressionType.GZIP: try: payload = CommonUtils.gzip_decompress(payload) except Exception as e: logger.error(f"Failed to decompress payload: {e}") return response # 解析payload try: if serialization_method == SerializationType.JSON: response.payload_msg = json.loads(payload.decode('utf-8')) except Exception as e: logger.error(f"Failed to parse payload: {e}") return response class AsrWsClient: def __init__(self, url: str, segment_duration: int = 200): self.seq = 1 self.url = url self.segment_duration = segment_duration self.conn = None self.session = None # 添加session引用 async def __aenter__(self): self.session = aiohttp.ClientSession() return self async def __aexit__(self, exc_type, exc, tb): if self.conn and not self.conn.closed: await self.conn.close() if self.session and not self.session.closed: await self.session.close() async def read_audio_data(self, file_path: str) -> bytes: try: with open(file_path, 'rb') as f: content = f.read() if not CommonUtils.judge_wav(content): logger.info("Converting audio to WAV format...") content = CommonUtils.convert_wav_with_path(file_path, DEFAULT_SAMPLE_RATE) return content except Exception as e: logger.error(f"Failed to read audio data: {e}") raise def get_segment_size(self, content: bytes) -> int: try: channel_num, samp_width, frame_rate, _, _ = CommonUtils.read_wav_info(content)[:5] size_per_sec = channel_num * samp_width * frame_rate segment_size = size_per_sec * self.segment_duration // 1000 return segment_size except Exception as e: logger.error(f"Failed to calculate segment size: {e}") raise async def create_connection(self) -> None: headers = RequestBuilder.new_auth_headers() try: self.conn = await self.session.ws_connect( # 使用self.session self.url, headers=headers ) logger.info(f"Connected to {self.url}") except Exception as e: logger.error(f"Failed to connect to WebSocket: {e}") raise async def send_full_client_request(self) -> None: request = RequestBuilder.new_full_client_request(self.seq) self.seq += 1 # 发送后递增 try: await self.conn.send_bytes(request) logger.info(f"Sent full client request with seq: {self.seq-1}") msg = await self.conn.receive() if msg.type == aiohttp.WSMsgType.BINARY: response = ResponseParser.parse_response(msg.data) logger.info(f"Received response: {response.to_dict()}") else: logger.error(f"Unexpected message type: {msg.type}") except Exception as e: logger.error(f"Failed to send full client request: {e}") raise async def send_messages(self, segment_size: int, content: bytes) -> AsyncGenerator[None, None]: audio_segments = self.split_audio(content, segment_size) total_segments = len(audio_segments) for i, segment in enumerate(audio_segments): is_last = (i == total_segments - 1) request = RequestBuilder.new_audio_only_request( self.seq, segment, is_last=is_last ) await self.conn.send_bytes(request) logger.info(f"Sent audio segment with seq: {self.seq} (last: {is_last})") if not is_last: self.seq += 1 await asyncio.sleep(self.segment_duration / 1000) # 逐个发送,间隔时间模拟实时流 # 让出控制权,允许接受消息 yield async def recv_messages(self) -> AsyncGenerator[AsrResponse, None]: try: async for msg in self.conn: if msg.type == aiohttp.WSMsgType.BINARY: response = ResponseParser.parse_response(msg.data) yield response if response.is_last_package or response.code != 0: break elif msg.type == aiohttp.WSMsgType.ERROR: logger.error(f"WebSocket error: {msg.data}") break elif msg.type == aiohttp.WSMsgType.CLOSED: logger.info("WebSocket connection closed") break except Exception as e: logger.error(f"Error receiving messages: {e}") raise async def start_audio_stream(self, segment_size: int, content: bytes) -> AsyncGenerator[AsrResponse, None]: async def sender(): async for _ in self.send_messages(segment_size, content): pass # 启动发送和接收任务 sender_task = asyncio.create_task(sender()) try: async for response in self.recv_messages(): yield response finally: sender_task.cancel() try: await sender_task except asyncio.CancelledError: pass @staticmethod def split_audio(data: bytes, segment_size: int) -> List[bytes]: if segment_size <= 0: return [] segments = [] for i in range(0, len(data), segment_size): end = i + segment_size if end > len(data): end = len(data) segments.append(data[i:end]) return segments async def execute(self, file_path: str) -> AsyncGenerator[AsrResponse, None]: if not file_path: raise ValueError("File path is empty") if not self.url: raise ValueError("URL is empty") self.seq = 1 try: # 1. 读取音频文件 content = await self.read_audio_data(file_path) # 2. 计算分段大小 segment_size = self.get_segment_size(content) # 3. 创建WebSocket连接 await self.create_connection() # 4. 发送完整客户端请求 await self.send_full_client_request() # 5. 启动音频流处理 async for response in self.start_audio_stream(segment_size, content): yield response except Exception as e: logger.error(f"Error in ASR execution: {e}") raise finally: if self.conn: await self.conn.close() async def main(): import argparse parser = argparse.ArgumentParser(description="ASR WebSocket Client") parser.add_argument("--file", type=str, required=True, help="Audio file path") #wss://openspeech.bytedance.com/api/v3/sauc/bigmodel #wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async #wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream parser.add_argument("--url", type=str, default="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream", help="WebSocket URL") parser.add_argument("--seg-duration", type=int, default=200, help="Audio duration(ms) per packet, default:200") args = parser.parse_args() async with AsrWsClient(args.url, args.seg_duration) as client: # 使用async with try: async for response in client.execute(args.file): logger.info(f"Received response: {json.dumps(response.to_dict(), indent=2, ensure_ascii=False)}") except Exception as e: logger.error(f"ASR processing failed: {e}") if __name__ == "__main__": asyncio.run(main()) # 用法: # python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav