Local-Voice/speech_recognizer.py
2025-09-20 10:53:56 +08:00

532 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
语音识别模块
基于 SAUC API 为录音文件提供语音识别功能
"""
import os
import json
import time
import logging
import asyncio
import aiohttp
import struct
import gzip
import uuid
from typing import Optional, List, Dict, Any, AsyncGenerator
from dataclasses import dataclass
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 常量定义
DEFAULT_SAMPLE_RATE = 16000
class ProtocolVersion:
V1 = 0b0001
class MessageType:
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ERROR_RESPONSE = 0b1111
class MessageTypeSpecificFlags:
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
NEG_SEQUENCE = 0b0010
NEG_WITH_SEQUENCE = 0b0011
class SerializationType:
NO_SERIALIZATION = 0b0000
JSON = 0b0001
class CompressionType:
GZIP = 0b0001
@dataclass
class RecognitionResult:
"""语音识别结果"""
text: str
confidence: float
is_final: bool
start_time: Optional[float] = None
end_time: Optional[float] = None
class AudioUtils:
"""音频处理工具类"""
@staticmethod
def gzip_compress(data: bytes) -> bytes:
"""GZIP压缩"""
return gzip.compress(data)
@staticmethod
def gzip_decompress(data: bytes) -> bytes:
"""GZIP解压缩"""
return gzip.decompress(data)
@staticmethod
def is_wav_file(data: bytes) -> bool:
"""检查是否为WAV文件"""
if len(data) < 44:
return False
return data[:4] == b'RIFF' and data[8:12] == b'WAVE'
@staticmethod
def read_wav_info(data: bytes) -> tuple:
"""读取WAV文件信息"""
if len(data) < 44:
raise ValueError("Invalid WAV file: too short")
# 解析WAV头
chunk_id = data[:4]
if chunk_id != b'RIFF':
raise ValueError("Invalid WAV file: not RIFF format")
format_ = data[8:12]
if format_ != b'WAVE':
raise ValueError("Invalid WAV file: not WAVE format")
# 解析fmt子块
audio_format = struct.unpack('<H', data[20:22])[0]
num_channels = struct.unpack('<H', data[22:24])[0]
sample_rate = struct.unpack('<I', data[24:28])[0]
bits_per_sample = struct.unpack('<H', data[34:36])[0]
# 查找data子块
pos = 36
while pos < len(data) - 8:
subchunk_id = data[pos:pos+4]
subchunk_size = struct.unpack('<I', data[pos+4:pos+8])[0]
if subchunk_id == b'data':
wave_data = data[pos+8:pos+8+subchunk_size]
return (
num_channels,
bits_per_sample // 8,
sample_rate,
subchunk_size // (num_channels * (bits_per_sample // 8)),
wave_data
)
pos += 8 + subchunk_size
raise ValueError("Invalid WAV file: no data subchunk found")
class AsrConfig:
"""ASR配置"""
def __init__(self, app_key: str = None, access_key: str = None):
self.auth = {
"app_key": app_key or os.getenv("SAUC_APP_KEY", "your_app_key"),
"access_key": access_key or os.getenv("SAUC_ACCESS_KEY", "your_access_key")
}
@property
def app_key(self) -> str:
return self.auth["app_key"]
@property
def access_key(self) -> str:
return self.auth["access_key"]
class AsrRequestHeader:
"""ASR请求头"""
def __init__(self):
self.message_type = MessageType.CLIENT_FULL_REQUEST
self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE
self.serialization_type = SerializationType.JSON
self.compression_type = CompressionType.GZIP
self.reserved_data = bytes([0x00])
def with_message_type(self, message_type: int) -> 'AsrRequestHeader':
self.message_type = message_type
return self
def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader':
self.message_type_specific_flags = flags
return self
def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader':
self.serialization_type = serialization_type
return self
def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader':
self.compression_type = compression_type
return self
def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader':
self.reserved_data = reserved_data
return self
def to_bytes(self) -> bytes:
header = bytearray()
header.append((ProtocolVersion.V1 << 4) | 1)
header.append((self.message_type << 4) | self.message_type_specific_flags)
header.append((self.serialization_type << 4) | self.compression_type)
header.extend(self.reserved_data)
return bytes(header)
@staticmethod
def default_header() -> 'AsrRequestHeader':
return AsrRequestHeader()
class RequestBuilder:
"""请求构建器"""
@staticmethod
def new_auth_headers(config: AsrConfig) -> Dict[str, str]:
"""创建认证头"""
reqid = str(uuid.uuid4())
return {
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
"X-Api-Request-Id": reqid,
"X-Api-Access-Key": config.access_key,
"X-Api-App-Key": config.app_key
}
@staticmethod
def new_full_client_request(seq: int) -> bytes:
"""创建完整客户端请求"""
header = AsrRequestHeader.default_header() \
.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
payload = {
"user": {
"uid": "local_voice_user"
},
"audio": {
"format": "wav",
"codec": "raw",
"rate": 16000,
"bits": 16,
"channel": 1
},
"request": {
"model_name": "bigmodel",
"enable_itn": True,
"enable_punc": True,
"enable_ddc": True,
"show_utterances": True,
"enable_nonstream": False
}
}
payload_bytes = json.dumps(payload).encode('utf-8')
compressed_payload = AudioUtils.gzip_compress(payload_bytes)
payload_size = len(compressed_payload)
request = bytearray()
request.extend(header.to_bytes())
request.extend(struct.pack('>i', seq))
request.extend(struct.pack('>U', payload_size))
request.extend(compressed_payload)
return bytes(request)
@staticmethod
def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes:
"""创建纯音频请求"""
header = AsrRequestHeader.default_header()
if is_last:
header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE)
seq = -seq
else:
header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST)
request = bytearray()
request.extend(header.to_bytes())
request.extend(struct.pack('>i', seq))
compressed_segment = AudioUtils.gzip_compress(segment)
request.extend(struct.pack('>U', len(compressed_segment)))
request.extend(compressed_segment)
return bytes(request)
class AsrResponse:
"""ASR响应"""
def __init__(self):
self.code = 0
self.event = 0
self.is_last_package = False
self.payload_sequence = 0
self.payload_size = 0
self.payload_msg = None
def to_dict(self) -> Dict[str, Any]:
return {
"code": self.code,
"event": self.event,
"is_last_package": self.is_last_package,
"payload_sequence": self.payload_sequence,
"payload_size": self.payload_size,
"payload_msg": self.payload_msg
}
class ResponseParser:
"""响应解析器"""
@staticmethod
def parse_response(msg: bytes) -> AsrResponse:
"""解析响应"""
response = AsrResponse()
header_size = msg[0] & 0x0f
message_type = msg[1] >> 4
message_type_specific_flags = msg[1] & 0x0f
serialization_method = msg[2] >> 4
message_compression = msg[2] & 0x0f
payload = msg[header_size*4:]
# 解析message_type_specific_flags
if message_type_specific_flags & 0x01:
response.payload_sequence = struct.unpack('>i', payload[:4])[0]
payload = payload[4:]
if message_type_specific_flags & 0x02:
response.is_last_package = True
if message_type_specific_flags & 0x04:
response.event = struct.unpack('>i', payload[:4])[0]
payload = payload[4:]
# 解析message_type
if message_type == MessageType.SERVER_FULL_RESPONSE:
response.payload_size = struct.unpack('>U', payload[:4])[0]
payload = payload[4:]
elif message_type == MessageType.SERVER_ERROR_RESPONSE:
response.code = struct.unpack('>i', payload[:4])[0]
response.payload_size = struct.unpack('>U', payload[4:8])[0]
payload = payload[8:]
if not payload:
return response
# 解压缩
if message_compression == CompressionType.GZIP:
try:
payload = AudioUtils.gzip_decompress(payload)
except Exception as e:
logger.error(f"Failed to decompress payload: {e}")
return response
# 解析payload
try:
if serialization_method == SerializationType.JSON:
response.payload_msg = json.loads(payload.decode('utf-8'))
except Exception as e:
logger.error(f"Failed to parse payload: {e}")
return response
class SpeechRecognizer:
"""语音识别器"""
def __init__(self, app_key: str = None, access_key: str = None,
url: str = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream"):
self.config = AsrConfig(app_key, access_key)
self.url = url
self.seq = 1
async def recognize_file(self, file_path: str) -> List[RecognitionResult]:
"""识别音频文件"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"Audio file not found: {file_path}")
results = []
try:
async with aiohttp.ClientSession() as session:
# 读取音频文件
with open(file_path, 'rb') as f:
content = f.read()
if not AudioUtils.is_wav_file(content):
raise ValueError("Audio file must be in WAV format")
# 获取音频信息
try:
_, _, sample_rate, _, audio_data = AudioUtils.read_wav_info(content)
if sample_rate != DEFAULT_SAMPLE_RATE:
logger.warning(f"Sample rate {sample_rate} != {DEFAULT_SAMPLE_RATE}, may affect recognition accuracy")
except Exception as e:
logger.error(f"Failed to read audio info: {e}")
raise
# 计算分段大小 (200ms per segment)
segment_size = 1 * 2 * DEFAULT_SAMPLE_RATE * 200 // 1000 # channel * bytes_per_sample * sample_rate * duration_ms / 1000
# 创建WebSocket连接
headers = RequestBuilder.new_auth_headers(self.config)
async with session.ws_connect(self.url, headers=headers) as ws:
# 发送完整客户端请求
request = RequestBuilder.new_full_client_request(self.seq)
self.seq += 1
await ws.send_bytes(request)
# 接收初始响应
msg = await ws.receive()
if msg.type == aiohttp.WSMsgType.BINARY:
response = ResponseParser.parse_response(msg.data)
logger.info(f"Initial response: {response.to_dict()}")
# 分段发送音频数据
audio_segments = self._split_audio(audio_data, segment_size)
total_segments = len(audio_segments)
for i, segment in enumerate(audio_segments):
is_last = (i == total_segments - 1)
request = RequestBuilder.new_audio_only_request(
self.seq,
segment,
is_last=is_last
)
await ws.send_bytes(request)
logger.info(f"Sent audio segment {i+1}/{total_segments}")
if not is_last:
self.seq += 1
# 短暂延迟模拟实时流
await asyncio.sleep(0.1)
# 接收识别结果
final_text = ""
while True:
msg = await ws.receive()
if msg.type == aiohttp.WSMsgType.BINARY:
response = ResponseParser.parse_response(msg.data)
if response.payload_msg and 'text' in response.payload_msg:
text = response.payload_msg['text']
if text:
final_text += text
result = RecognitionResult(
text=text,
confidence=0.9, # 默认置信度
is_final=response.is_last_package
)
results.append(result)
logger.info(f"Recognized: {text}")
if response.is_last_package or response.code != 0:
break
elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error(f"WebSocket error: {msg.data}")
break
elif msg.type == aiohttp.WSMsgType.CLOSED:
logger.info("WebSocket connection closed")
break
# 如果没有获得最终结果,创建一个包含所有文本的结果
if final_text and not any(r.is_final for r in results):
final_result = RecognitionResult(
text=final_text,
confidence=0.9,
is_final=True
)
results.append(final_result)
return results
except Exception as e:
logger.error(f"Speech recognition failed: {e}")
raise
def _split_audio(self, data: bytes, segment_size: int) -> List[bytes]:
"""分割音频数据"""
if segment_size <= 0:
return []
segments = []
for i in range(0, len(data), segment_size):
end = i + segment_size
if end > len(data):
end = len(data)
segments.append(data[i:end])
return segments
async def recognize_latest_recording(self, directory: str = ".") -> Optional[RecognitionResult]:
"""识别最新的录音文件"""
# 查找最新的录音文件
recording_files = [f for f in os.listdir(directory) if f.startswith('recording_') and f.endswith('.wav')]
if not recording_files:
logger.warning("No recording files found")
return None
# 按文件名排序(包含时间戳)
recording_files.sort(reverse=True)
latest_file = recording_files[0]
latest_path = os.path.join(directory, latest_file)
logger.info(f"Recognizing latest recording: {latest_file}")
try:
results = await self.recognize_file(latest_path)
if results:
# 返回最终的识别结果
final_results = [r for r in results if r.is_final]
if final_results:
return final_results[-1]
else:
# 如果没有标记为final的结果返回最后一个
return results[-1]
except Exception as e:
logger.error(f"Failed to recognize latest recording: {e}")
return None
async def main():
"""测试函数"""
import argparse
parser = argparse.ArgumentParser(description="语音识别测试")
parser.add_argument("--file", type=str, help="音频文件路径")
parser.add_argument("--latest", action="store_true", help="识别最新的录音文件")
parser.add_argument("--app-key", type=str, help="SAUC App Key")
parser.add_argument("--access-key", type=str, help="SAUC Access Key")
args = parser.parse_args()
recognizer = SpeechRecognizer(
app_key=args.app_key,
access_key=args.access_key
)
try:
if args.latest:
result = await recognizer.recognize_latest_recording()
if result:
print(f"识别结果: {result.text}")
print(f"置信度: {result.confidence}")
print(f"最终结果: {result.is_final}")
else:
print("未能识别到语音内容")
elif args.file:
results = await recognizer.recognize_file(args.file)
for i, result in enumerate(results):
print(f"结果 {i+1}: {result.text}")
print(f"置信度: {result.confidence}")
print(f"最终结果: {result.is_final}")
print("-" * 40)
else:
print("请指定 --file 或 --latest 参数")
except Exception as e:
print(f"识别失败: {e}")
if __name__ == "__main__":
asyncio.run(main())