#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 语音识别模块 基于 SAUC API 为录音文件提供语音识别功能 """ import os import json import time import logging import asyncio import aiohttp import struct import gzip import uuid from typing import Optional, List, Dict, Any, AsyncGenerator from dataclasses import dataclass # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # 常量定义 DEFAULT_SAMPLE_RATE = 16000 class ProtocolVersion: V1 = 0b0001 class MessageType: CLIENT_FULL_REQUEST = 0b0001 CLIENT_AUDIO_ONLY_REQUEST = 0b0010 SERVER_FULL_RESPONSE = 0b1001 SERVER_ERROR_RESPONSE = 0b1111 class MessageTypeSpecificFlags: NO_SEQUENCE = 0b0000 POS_SEQUENCE = 0b0001 NEG_SEQUENCE = 0b0010 NEG_WITH_SEQUENCE = 0b0011 class SerializationType: NO_SERIALIZATION = 0b0000 JSON = 0b0001 class CompressionType: GZIP = 0b0001 @dataclass class RecognitionResult: """语音识别结果""" text: str confidence: float is_final: bool start_time: Optional[float] = None end_time: Optional[float] = None class AudioUtils: """音频处理工具类""" @staticmethod def gzip_compress(data: bytes) -> bytes: """GZIP压缩""" return gzip.compress(data) @staticmethod def gzip_decompress(data: bytes) -> bytes: """GZIP解压缩""" return gzip.decompress(data) @staticmethod def is_wav_file(data: bytes) -> bool: """检查是否为WAV文件""" if len(data) < 44: return False return data[:4] == b'RIFF' and data[8:12] == b'WAVE' @staticmethod def read_wav_info(data: bytes) -> tuple: """读取WAV文件信息""" if len(data) < 44: raise ValueError("Invalid WAV file: too short") # 解析WAV头 chunk_id = data[:4] if chunk_id != b'RIFF': raise ValueError("Invalid WAV file: not RIFF format") format_ = data[8:12] if format_ != b'WAVE': raise ValueError("Invalid WAV file: not WAVE format") # 解析fmt子块 audio_format = struct.unpack(' str: return self.auth["app_key"] @property def access_key(self) -> str: return self.auth["access_key"] class AsrRequestHeader: """ASR请求头""" def __init__(self): self.message_type = MessageType.CLIENT_FULL_REQUEST self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE self.serialization_type = SerializationType.JSON self.compression_type = CompressionType.GZIP self.reserved_data = bytes([0x00]) def with_message_type(self, message_type: int) -> 'AsrRequestHeader': self.message_type = message_type return self def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader': self.message_type_specific_flags = flags return self def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader': self.serialization_type = serialization_type return self def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader': self.compression_type = compression_type return self def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader': self.reserved_data = reserved_data return self def to_bytes(self) -> bytes: header = bytearray() header.append((ProtocolVersion.V1 << 4) | 1) header.append((self.message_type << 4) | self.message_type_specific_flags) header.append((self.serialization_type << 4) | self.compression_type) header.extend(self.reserved_data) return bytes(header) @staticmethod def default_header() -> 'AsrRequestHeader': return AsrRequestHeader() class RequestBuilder: """请求构建器""" @staticmethod def new_auth_headers(config: AsrConfig) -> Dict[str, str]: """创建认证头""" reqid = str(uuid.uuid4()) return { "X-Api-Resource-Id": "volc.bigasr.sauc.duration", "X-Api-Request-Id": reqid, "X-Api-Access-Key": config.access_key, "X-Api-App-Key": config.app_key } @staticmethod def new_full_client_request(seq: int) -> bytes: """创建完整客户端请求""" header = AsrRequestHeader.default_header() \ .with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE) payload = { "user": { "uid": "local_voice_user" }, "audio": { "format": "wav", "codec": "raw", "rate": 16000, "bits": 16, "channel": 1 }, "request": { "model_name": "bigmodel", "enable_itn": True, "enable_punc": True, "enable_ddc": True, "show_utterances": True, "enable_nonstream": False } } payload_bytes = json.dumps(payload).encode('utf-8') compressed_payload = AudioUtils.gzip_compress(payload_bytes) payload_size = len(compressed_payload) request = bytearray() request.extend(header.to_bytes()) request.extend(struct.pack('>i', seq)) request.extend(struct.pack('>U', payload_size)) request.extend(compressed_payload) return bytes(request) @staticmethod def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes: """创建纯音频请求""" header = AsrRequestHeader.default_header() if is_last: header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE) seq = -seq else: header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE) header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST) request = bytearray() request.extend(header.to_bytes()) request.extend(struct.pack('>i', seq)) compressed_segment = AudioUtils.gzip_compress(segment) request.extend(struct.pack('>U', len(compressed_segment))) request.extend(compressed_segment) return bytes(request) class AsrResponse: """ASR响应""" def __init__(self): self.code = 0 self.event = 0 self.is_last_package = False self.payload_sequence = 0 self.payload_size = 0 self.payload_msg = None def to_dict(self) -> Dict[str, Any]: return { "code": self.code, "event": self.event, "is_last_package": self.is_last_package, "payload_sequence": self.payload_sequence, "payload_size": self.payload_size, "payload_msg": self.payload_msg } class ResponseParser: """响应解析器""" @staticmethod def parse_response(msg: bytes) -> AsrResponse: """解析响应""" response = AsrResponse() header_size = msg[0] & 0x0f message_type = msg[1] >> 4 message_type_specific_flags = msg[1] & 0x0f serialization_method = msg[2] >> 4 message_compression = msg[2] & 0x0f payload = msg[header_size*4:] # 解析message_type_specific_flags if message_type_specific_flags & 0x01: response.payload_sequence = struct.unpack('>i', payload[:4])[0] payload = payload[4:] if message_type_specific_flags & 0x02: response.is_last_package = True if message_type_specific_flags & 0x04: response.event = struct.unpack('>i', payload[:4])[0] payload = payload[4:] # 解析message_type if message_type == MessageType.SERVER_FULL_RESPONSE: response.payload_size = struct.unpack('>U', payload[:4])[0] payload = payload[4:] elif message_type == MessageType.SERVER_ERROR_RESPONSE: response.code = struct.unpack('>i', payload[:4])[0] response.payload_size = struct.unpack('>U', payload[4:8])[0] payload = payload[8:] if not payload: return response # 解压缩 if message_compression == CompressionType.GZIP: try: payload = AudioUtils.gzip_decompress(payload) except Exception as e: logger.error(f"Failed to decompress payload: {e}") return response # 解析payload try: if serialization_method == SerializationType.JSON: response.payload_msg = json.loads(payload.decode('utf-8')) except Exception as e: logger.error(f"Failed to parse payload: {e}") return response class SpeechRecognizer: """语音识别器""" def __init__(self, app_key: str = None, access_key: str = None, url: str = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream"): self.config = AsrConfig(app_key, access_key) self.url = url self.seq = 1 async def recognize_file(self, file_path: str) -> List[RecognitionResult]: """识别音频文件""" if not os.path.exists(file_path): raise FileNotFoundError(f"Audio file not found: {file_path}") results = [] try: async with aiohttp.ClientSession() as session: # 读取音频文件 with open(file_path, 'rb') as f: content = f.read() if not AudioUtils.is_wav_file(content): raise ValueError("Audio file must be in WAV format") # 获取音频信息 try: _, _, sample_rate, _, audio_data = AudioUtils.read_wav_info(content) if sample_rate != DEFAULT_SAMPLE_RATE: logger.warning(f"Sample rate {sample_rate} != {DEFAULT_SAMPLE_RATE}, may affect recognition accuracy") except Exception as e: logger.error(f"Failed to read audio info: {e}") raise # 计算分段大小 (200ms per segment) segment_size = 1 * 2 * DEFAULT_SAMPLE_RATE * 200 // 1000 # channel * bytes_per_sample * sample_rate * duration_ms / 1000 # 创建WebSocket连接 headers = RequestBuilder.new_auth_headers(self.config) async with session.ws_connect(self.url, headers=headers) as ws: # 发送完整客户端请求 request = RequestBuilder.new_full_client_request(self.seq) self.seq += 1 await ws.send_bytes(request) # 接收初始响应 msg = await ws.receive() if msg.type == aiohttp.WSMsgType.BINARY: response = ResponseParser.parse_response(msg.data) logger.info(f"Initial response: {response.to_dict()}") # 分段发送音频数据 audio_segments = self._split_audio(audio_data, segment_size) total_segments = len(audio_segments) for i, segment in enumerate(audio_segments): is_last = (i == total_segments - 1) request = RequestBuilder.new_audio_only_request( self.seq, segment, is_last=is_last ) await ws.send_bytes(request) logger.info(f"Sent audio segment {i+1}/{total_segments}") if not is_last: self.seq += 1 # 短暂延迟模拟实时流 await asyncio.sleep(0.1) # 接收识别结果 final_text = "" while True: msg = await ws.receive() if msg.type == aiohttp.WSMsgType.BINARY: response = ResponseParser.parse_response(msg.data) if response.payload_msg and 'text' in response.payload_msg: text = response.payload_msg['text'] if text: final_text += text result = RecognitionResult( text=text, confidence=0.9, # 默认置信度 is_final=response.is_last_package ) results.append(result) logger.info(f"Recognized: {text}") if response.is_last_package or response.code != 0: break elif msg.type == aiohttp.WSMsgType.ERROR: logger.error(f"WebSocket error: {msg.data}") break elif msg.type == aiohttp.WSMsgType.CLOSED: logger.info("WebSocket connection closed") break # 如果没有获得最终结果,创建一个包含所有文本的结果 if final_text and not any(r.is_final for r in results): final_result = RecognitionResult( text=final_text, confidence=0.9, is_final=True ) results.append(final_result) return results except Exception as e: logger.error(f"Speech recognition failed: {e}") raise def _split_audio(self, data: bytes, segment_size: int) -> List[bytes]: """分割音频数据""" if segment_size <= 0: return [] segments = [] for i in range(0, len(data), segment_size): end = i + segment_size if end > len(data): end = len(data) segments.append(data[i:end]) return segments async def recognize_latest_recording(self, directory: str = ".") -> Optional[RecognitionResult]: """识别最新的录音文件""" # 查找最新的录音文件 recording_files = [f for f in os.listdir(directory) if f.startswith('recording_') and f.endswith('.wav')] if not recording_files: logger.warning("No recording files found") return None # 按文件名排序(包含时间戳) recording_files.sort(reverse=True) latest_file = recording_files[0] latest_path = os.path.join(directory, latest_file) logger.info(f"Recognizing latest recording: {latest_file}") try: results = await self.recognize_file(latest_path) if results: # 返回最终的识别结果 final_results = [r for r in results if r.is_final] if final_results: return final_results[-1] else: # 如果没有标记为final的结果,返回最后一个 return results[-1] except Exception as e: logger.error(f"Failed to recognize latest recording: {e}") return None async def main(): """测试函数""" import argparse parser = argparse.ArgumentParser(description="语音识别测试") parser.add_argument("--file", type=str, help="音频文件路径") parser.add_argument("--latest", action="store_true", help="识别最新的录音文件") parser.add_argument("--app-key", type=str, help="SAUC App Key") parser.add_argument("--access-key", type=str, help="SAUC Access Key") args = parser.parse_args() recognizer = SpeechRecognizer( app_key=args.app_key, access_key=args.access_key ) try: if args.latest: result = await recognizer.recognize_latest_recording() if result: print(f"识别结果: {result.text}") print(f"置信度: {result.confidence}") print(f"最终结果: {result.is_final}") else: print("未能识别到语音内容") elif args.file: results = await recognizer.recognize_file(args.file) for i, result in enumerate(results): print(f"结果 {i+1}: {result.text}") print(f"置信度: {result.confidence}") print(f"最终结果: {result.is_final}") print("-" * 40) else: print("请指定 --file 或 --latest 参数") except Exception as e: print(f"识别失败: {e}") if __name__ == "__main__": asyncio.run(main())