532 lines
19 KiB
Python
532 lines
19 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
语音识别模块
|
||
基于 SAUC API 为录音文件提供语音识别功能
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import time
|
||
import logging
|
||
import asyncio
|
||
import aiohttp
|
||
import struct
|
||
import gzip
|
||
import uuid
|
||
from typing import Optional, List, Dict, Any, AsyncGenerator
|
||
from dataclasses import dataclass
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 常量定义
|
||
DEFAULT_SAMPLE_RATE = 16000
|
||
|
||
class ProtocolVersion:
|
||
V1 = 0b0001
|
||
|
||
class MessageType:
|
||
CLIENT_FULL_REQUEST = 0b0001
|
||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||
SERVER_FULL_RESPONSE = 0b1001
|
||
SERVER_ERROR_RESPONSE = 0b1111
|
||
|
||
class MessageTypeSpecificFlags:
|
||
NO_SEQUENCE = 0b0000
|
||
POS_SEQUENCE = 0b0001
|
||
NEG_SEQUENCE = 0b0010
|
||
NEG_WITH_SEQUENCE = 0b0011
|
||
|
||
class SerializationType:
|
||
NO_SERIALIZATION = 0b0000
|
||
JSON = 0b0001
|
||
|
||
class CompressionType:
|
||
GZIP = 0b0001
|
||
|
||
@dataclass
|
||
class RecognitionResult:
|
||
"""语音识别结果"""
|
||
text: str
|
||
confidence: float
|
||
is_final: bool
|
||
start_time: Optional[float] = None
|
||
end_time: Optional[float] = None
|
||
|
||
class AudioUtils:
|
||
"""音频处理工具类"""
|
||
|
||
@staticmethod
|
||
def gzip_compress(data: bytes) -> bytes:
|
||
"""GZIP压缩"""
|
||
return gzip.compress(data)
|
||
|
||
@staticmethod
|
||
def gzip_decompress(data: bytes) -> bytes:
|
||
"""GZIP解压缩"""
|
||
return gzip.decompress(data)
|
||
|
||
@staticmethod
|
||
def is_wav_file(data: bytes) -> bool:
|
||
"""检查是否为WAV文件"""
|
||
if len(data) < 44:
|
||
return False
|
||
return data[:4] == b'RIFF' and data[8:12] == b'WAVE'
|
||
|
||
@staticmethod
|
||
def read_wav_info(data: bytes) -> tuple:
|
||
"""读取WAV文件信息"""
|
||
if len(data) < 44:
|
||
raise ValueError("Invalid WAV file: too short")
|
||
|
||
# 解析WAV头
|
||
chunk_id = data[:4]
|
||
if chunk_id != b'RIFF':
|
||
raise ValueError("Invalid WAV file: not RIFF format")
|
||
|
||
format_ = data[8:12]
|
||
if format_ != b'WAVE':
|
||
raise ValueError("Invalid WAV file: not WAVE format")
|
||
|
||
# 解析fmt子块
|
||
audio_format = struct.unpack('<H', data[20:22])[0]
|
||
num_channels = struct.unpack('<H', data[22:24])[0]
|
||
sample_rate = struct.unpack('<I', data[24:28])[0]
|
||
bits_per_sample = struct.unpack('<H', data[34:36])[0]
|
||
|
||
# 查找data子块
|
||
pos = 36
|
||
while pos < len(data) - 8:
|
||
subchunk_id = data[pos:pos+4]
|
||
subchunk_size = struct.unpack('<I', data[pos+4:pos+8])[0]
|
||
if subchunk_id == b'data':
|
||
wave_data = data[pos+8:pos+8+subchunk_size]
|
||
return (
|
||
num_channels,
|
||
bits_per_sample // 8,
|
||
sample_rate,
|
||
subchunk_size // (num_channels * (bits_per_sample // 8)),
|
||
wave_data
|
||
)
|
||
pos += 8 + subchunk_size
|
||
|
||
raise ValueError("Invalid WAV file: no data subchunk found")
|
||
|
||
class AsrConfig:
|
||
"""ASR配置"""
|
||
|
||
def __init__(self, app_key: str = None, access_key: str = None):
|
||
self.auth = {
|
||
"app_key": app_key or os.getenv("SAUC_APP_KEY", "your_app_key"),
|
||
"access_key": access_key or os.getenv("SAUC_ACCESS_KEY", "your_access_key")
|
||
}
|
||
|
||
@property
|
||
def app_key(self) -> str:
|
||
return self.auth["app_key"]
|
||
|
||
@property
|
||
def access_key(self) -> str:
|
||
return self.auth["access_key"]
|
||
|
||
class AsrRequestHeader:
|
||
"""ASR请求头"""
|
||
|
||
def __init__(self):
|
||
self.message_type = MessageType.CLIENT_FULL_REQUEST
|
||
self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE
|
||
self.serialization_type = SerializationType.JSON
|
||
self.compression_type = CompressionType.GZIP
|
||
self.reserved_data = bytes([0x00])
|
||
|
||
def with_message_type(self, message_type: int) -> 'AsrRequestHeader':
|
||
self.message_type = message_type
|
||
return self
|
||
|
||
def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader':
|
||
self.message_type_specific_flags = flags
|
||
return self
|
||
|
||
def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader':
|
||
self.serialization_type = serialization_type
|
||
return self
|
||
|
||
def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader':
|
||
self.compression_type = compression_type
|
||
return self
|
||
|
||
def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader':
|
||
self.reserved_data = reserved_data
|
||
return self
|
||
|
||
def to_bytes(self) -> bytes:
|
||
header = bytearray()
|
||
header.append((ProtocolVersion.V1 << 4) | 1)
|
||
header.append((self.message_type << 4) | self.message_type_specific_flags)
|
||
header.append((self.serialization_type << 4) | self.compression_type)
|
||
header.extend(self.reserved_data)
|
||
return bytes(header)
|
||
|
||
@staticmethod
|
||
def default_header() -> 'AsrRequestHeader':
|
||
return AsrRequestHeader()
|
||
|
||
class RequestBuilder:
|
||
"""请求构建器"""
|
||
|
||
@staticmethod
|
||
def new_auth_headers(config: AsrConfig) -> Dict[str, str]:
|
||
"""创建认证头"""
|
||
reqid = str(uuid.uuid4())
|
||
return {
|
||
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
|
||
"X-Api-Request-Id": reqid,
|
||
"X-Api-Access-Key": config.access_key,
|
||
"X-Api-App-Key": config.app_key
|
||
}
|
||
|
||
@staticmethod
|
||
def new_full_client_request(seq: int) -> bytes:
|
||
"""创建完整客户端请求"""
|
||
header = AsrRequestHeader.default_header() \
|
||
.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
|
||
|
||
payload = {
|
||
"user": {
|
||
"uid": "local_voice_user"
|
||
},
|
||
"audio": {
|
||
"format": "wav",
|
||
"codec": "raw",
|
||
"rate": 16000,
|
||
"bits": 16,
|
||
"channel": 1
|
||
},
|
||
"request": {
|
||
"model_name": "bigmodel",
|
||
"enable_itn": True,
|
||
"enable_punc": True,
|
||
"enable_ddc": True,
|
||
"show_utterances": True,
|
||
"enable_nonstream": False
|
||
}
|
||
}
|
||
|
||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||
compressed_payload = AudioUtils.gzip_compress(payload_bytes)
|
||
payload_size = len(compressed_payload)
|
||
|
||
request = bytearray()
|
||
request.extend(header.to_bytes())
|
||
request.extend(struct.pack('>i', seq))
|
||
request.extend(struct.pack('>U', payload_size))
|
||
request.extend(compressed_payload)
|
||
|
||
return bytes(request)
|
||
|
||
@staticmethod
|
||
def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes:
|
||
"""创建纯音频请求"""
|
||
header = AsrRequestHeader.default_header()
|
||
if is_last:
|
||
header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE)
|
||
seq = -seq
|
||
else:
|
||
header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
|
||
header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST)
|
||
|
||
request = bytearray()
|
||
request.extend(header.to_bytes())
|
||
request.extend(struct.pack('>i', seq))
|
||
|
||
compressed_segment = AudioUtils.gzip_compress(segment)
|
||
request.extend(struct.pack('>U', len(compressed_segment)))
|
||
request.extend(compressed_segment)
|
||
|
||
return bytes(request)
|
||
|
||
class AsrResponse:
|
||
"""ASR响应"""
|
||
|
||
def __init__(self):
|
||
self.code = 0
|
||
self.event = 0
|
||
self.is_last_package = False
|
||
self.payload_sequence = 0
|
||
self.payload_size = 0
|
||
self.payload_msg = None
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"code": self.code,
|
||
"event": self.event,
|
||
"is_last_package": self.is_last_package,
|
||
"payload_sequence": self.payload_sequence,
|
||
"payload_size": self.payload_size,
|
||
"payload_msg": self.payload_msg
|
||
}
|
||
|
||
class ResponseParser:
|
||
"""响应解析器"""
|
||
|
||
@staticmethod
|
||
def parse_response(msg: bytes) -> AsrResponse:
|
||
"""解析响应"""
|
||
response = AsrResponse()
|
||
|
||
header_size = msg[0] & 0x0f
|
||
message_type = msg[1] >> 4
|
||
message_type_specific_flags = msg[1] & 0x0f
|
||
serialization_method = msg[2] >> 4
|
||
message_compression = msg[2] & 0x0f
|
||
|
||
payload = msg[header_size*4:]
|
||
|
||
# 解析message_type_specific_flags
|
||
if message_type_specific_flags & 0x01:
|
||
response.payload_sequence = struct.unpack('>i', payload[:4])[0]
|
||
payload = payload[4:]
|
||
if message_type_specific_flags & 0x02:
|
||
response.is_last_package = True
|
||
if message_type_specific_flags & 0x04:
|
||
response.event = struct.unpack('>i', payload[:4])[0]
|
||
payload = payload[4:]
|
||
|
||
# 解析message_type
|
||
if message_type == MessageType.SERVER_FULL_RESPONSE:
|
||
response.payload_size = struct.unpack('>U', payload[:4])[0]
|
||
payload = payload[4:]
|
||
elif message_type == MessageType.SERVER_ERROR_RESPONSE:
|
||
response.code = struct.unpack('>i', payload[:4])[0]
|
||
response.payload_size = struct.unpack('>U', payload[4:8])[0]
|
||
payload = payload[8:]
|
||
|
||
if not payload:
|
||
return response
|
||
|
||
# 解压缩
|
||
if message_compression == CompressionType.GZIP:
|
||
try:
|
||
payload = AudioUtils.gzip_decompress(payload)
|
||
except Exception as e:
|
||
logger.error(f"Failed to decompress payload: {e}")
|
||
return response
|
||
|
||
# 解析payload
|
||
try:
|
||
if serialization_method == SerializationType.JSON:
|
||
response.payload_msg = json.loads(payload.decode('utf-8'))
|
||
except Exception as e:
|
||
logger.error(f"Failed to parse payload: {e}")
|
||
|
||
return response
|
||
|
||
class SpeechRecognizer:
|
||
"""语音识别器"""
|
||
|
||
def __init__(self, app_key: str = None, access_key: str = None,
|
||
url: str = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream"):
|
||
self.config = AsrConfig(app_key, access_key)
|
||
self.url = url
|
||
self.seq = 1
|
||
|
||
async def recognize_file(self, file_path: str) -> List[RecognitionResult]:
|
||
"""识别音频文件"""
|
||
if not os.path.exists(file_path):
|
||
raise FileNotFoundError(f"Audio file not found: {file_path}")
|
||
|
||
results = []
|
||
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
# 读取音频文件
|
||
with open(file_path, 'rb') as f:
|
||
content = f.read()
|
||
|
||
if not AudioUtils.is_wav_file(content):
|
||
raise ValueError("Audio file must be in WAV format")
|
||
|
||
# 获取音频信息
|
||
try:
|
||
_, _, sample_rate, _, audio_data = AudioUtils.read_wav_info(content)
|
||
if sample_rate != DEFAULT_SAMPLE_RATE:
|
||
logger.warning(f"Sample rate {sample_rate} != {DEFAULT_SAMPLE_RATE}, may affect recognition accuracy")
|
||
except Exception as e:
|
||
logger.error(f"Failed to read audio info: {e}")
|
||
raise
|
||
|
||
# 计算分段大小 (200ms per segment)
|
||
segment_size = 1 * 2 * DEFAULT_SAMPLE_RATE * 200 // 1000 # channel * bytes_per_sample * sample_rate * duration_ms / 1000
|
||
|
||
# 创建WebSocket连接
|
||
headers = RequestBuilder.new_auth_headers(self.config)
|
||
async with session.ws_connect(self.url, headers=headers) as ws:
|
||
|
||
# 发送完整客户端请求
|
||
request = RequestBuilder.new_full_client_request(self.seq)
|
||
self.seq += 1
|
||
await ws.send_bytes(request)
|
||
|
||
# 接收初始响应
|
||
msg = await ws.receive()
|
||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||
response = ResponseParser.parse_response(msg.data)
|
||
logger.info(f"Initial response: {response.to_dict()}")
|
||
|
||
# 分段发送音频数据
|
||
audio_segments = self._split_audio(audio_data, segment_size)
|
||
total_segments = len(audio_segments)
|
||
|
||
for i, segment in enumerate(audio_segments):
|
||
is_last = (i == total_segments - 1)
|
||
request = RequestBuilder.new_audio_only_request(
|
||
self.seq,
|
||
segment,
|
||
is_last=is_last
|
||
)
|
||
await ws.send_bytes(request)
|
||
logger.info(f"Sent audio segment {i+1}/{total_segments}")
|
||
|
||
if not is_last:
|
||
self.seq += 1
|
||
|
||
# 短暂延迟模拟实时流
|
||
await asyncio.sleep(0.1)
|
||
|
||
# 接收识别结果
|
||
final_text = ""
|
||
while True:
|
||
msg = await ws.receive()
|
||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||
response = ResponseParser.parse_response(msg.data)
|
||
|
||
if response.payload_msg and 'text' in response.payload_msg:
|
||
text = response.payload_msg['text']
|
||
if text:
|
||
final_text += text
|
||
|
||
result = RecognitionResult(
|
||
text=text,
|
||
confidence=0.9, # 默认置信度
|
||
is_final=response.is_last_package
|
||
)
|
||
results.append(result)
|
||
|
||
logger.info(f"Recognized: {text}")
|
||
|
||
if response.is_last_package or response.code != 0:
|
||
break
|
||
|
||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||
logger.error(f"WebSocket error: {msg.data}")
|
||
break
|
||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||
logger.info("WebSocket connection closed")
|
||
break
|
||
|
||
# 如果没有获得最终结果,创建一个包含所有文本的结果
|
||
if final_text and not any(r.is_final for r in results):
|
||
final_result = RecognitionResult(
|
||
text=final_text,
|
||
confidence=0.9,
|
||
is_final=True
|
||
)
|
||
results.append(final_result)
|
||
|
||
return results
|
||
|
||
except Exception as e:
|
||
logger.error(f"Speech recognition failed: {e}")
|
||
raise
|
||
|
||
def _split_audio(self, data: bytes, segment_size: int) -> List[bytes]:
|
||
"""分割音频数据"""
|
||
if segment_size <= 0:
|
||
return []
|
||
|
||
segments = []
|
||
for i in range(0, len(data), segment_size):
|
||
end = i + segment_size
|
||
if end > len(data):
|
||
end = len(data)
|
||
segments.append(data[i:end])
|
||
return segments
|
||
|
||
async def recognize_latest_recording(self, directory: str = ".") -> Optional[RecognitionResult]:
|
||
"""识别最新的录音文件"""
|
||
# 查找最新的录音文件
|
||
recording_files = [f for f in os.listdir(directory) if f.startswith('recording_') and f.endswith('.wav')]
|
||
|
||
if not recording_files:
|
||
logger.warning("No recording files found")
|
||
return None
|
||
|
||
# 按文件名排序(包含时间戳)
|
||
recording_files.sort(reverse=True)
|
||
latest_file = recording_files[0]
|
||
latest_path = os.path.join(directory, latest_file)
|
||
|
||
logger.info(f"Recognizing latest recording: {latest_file}")
|
||
|
||
try:
|
||
results = await self.recognize_file(latest_path)
|
||
if results:
|
||
# 返回最终的识别结果
|
||
final_results = [r for r in results if r.is_final]
|
||
if final_results:
|
||
return final_results[-1]
|
||
else:
|
||
# 如果没有标记为final的结果,返回最后一个
|
||
return results[-1]
|
||
except Exception as e:
|
||
logger.error(f"Failed to recognize latest recording: {e}")
|
||
|
||
return None
|
||
|
||
async def main():
|
||
"""测试函数"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="语音识别测试")
|
||
parser.add_argument("--file", type=str, help="音频文件路径")
|
||
parser.add_argument("--latest", action="store_true", help="识别最新的录音文件")
|
||
parser.add_argument("--app-key", type=str, help="SAUC App Key")
|
||
parser.add_argument("--access-key", type=str, help="SAUC Access Key")
|
||
|
||
args = parser.parse_args()
|
||
|
||
recognizer = SpeechRecognizer(
|
||
app_key=args.app_key,
|
||
access_key=args.access_key
|
||
)
|
||
|
||
try:
|
||
if args.latest:
|
||
result = await recognizer.recognize_latest_recording()
|
||
if result:
|
||
print(f"识别结果: {result.text}")
|
||
print(f"置信度: {result.confidence}")
|
||
print(f"最终结果: {result.is_final}")
|
||
else:
|
||
print("未能识别到语音内容")
|
||
elif args.file:
|
||
results = await recognizer.recognize_file(args.file)
|
||
for i, result in enumerate(results):
|
||
print(f"结果 {i+1}: {result.text}")
|
||
print(f"置信度: {result.confidence}")
|
||
print(f"最终结果: {result.is_final}")
|
||
print("-" * 40)
|
||
else:
|
||
print("请指定 --file 或 --latest 参数")
|
||
|
||
except Exception as e:
|
||
print(f"识别失败: {e}")
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main()) |