287 lines
10 KiB
Python
287 lines
10 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
语音识别诊断工具
|
||
用于测试和诊断语音识别功能的具体问题
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import gzip
|
||
import uuid
|
||
import numpy as np
|
||
import wave
|
||
import os
|
||
from typing import Optional
|
||
|
||
class ASRDiagnostic:
|
||
"""ASR诊断工具"""
|
||
|
||
def __init__(self):
|
||
self.api_config = {
|
||
'asr': {
|
||
'appid': "8718217928",
|
||
'cluster': "volcano_tts",
|
||
'token': "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc",
|
||
'ws_url': "wss://openspeech.bytedance.com/api/v2/asr"
|
||
}
|
||
}
|
||
|
||
def generate_asr_header(self, message_type=1, message_type_specific_flags=0):
|
||
"""生成ASR头部"""
|
||
PROTOCOL_VERSION = 0b0001
|
||
DEFAULT_HEADER_SIZE = 0b0001
|
||
JSON = 0b0001
|
||
GZIP = 0b0001
|
||
|
||
header = bytearray()
|
||
header.append((PROTOCOL_VERSION << 4) | DEFAULT_HEADER_SIZE)
|
||
header.append((message_type << 4) | message_type_specific_flags)
|
||
header.append((JSON << 4) | GZIP)
|
||
header.append(0x00) # reserved
|
||
return header
|
||
|
||
def parse_asr_response(self, res):
|
||
"""解析ASR响应"""
|
||
print(f"🔍 解析响应,原始大小: {len(res)} 字节")
|
||
|
||
if len(res) < 8:
|
||
print(f"❌ 响应太短,无法解析")
|
||
return {}
|
||
|
||
try:
|
||
message_type = res[1] >> 4
|
||
payload_size = int.from_bytes(res[4:8], "big", signed=False)
|
||
payload_msg = res[8:8+payload_size]
|
||
|
||
print(f"📋 消息类型: {message_type}, 载荷大小: {payload_size}")
|
||
|
||
if message_type == 0b1001: # SERVER_FULL_RESPONSE
|
||
try:
|
||
if payload_msg.startswith(b'{'):
|
||
result = json.loads(payload_msg.decode('utf-8'))
|
||
print(f"✅ 成功解析JSON响应")
|
||
return result
|
||
else:
|
||
print(f"❌ 响应不是JSON格式")
|
||
except Exception as e:
|
||
print(f"❌ JSON解析失败: {e}")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 响应解析异常: {e}")
|
||
|
||
return {}
|
||
|
||
async def test_asr_with_audio_file(self, audio_file_path: str):
|
||
"""使用音频文件测试ASR"""
|
||
print(f"🎵 测试ASR - 音频文件: {audio_file_path}")
|
||
|
||
if not os.path.exists(audio_file_path):
|
||
print(f"❌ 音频文件不存在: {audio_file_path}")
|
||
return
|
||
|
||
try:
|
||
# 读取音频文件
|
||
with wave.open(audio_file_path, 'rb') as wf:
|
||
channels = wf.getnchannels()
|
||
width = wf.getsampwidth()
|
||
rate = wf.getframerate()
|
||
frames = wf.readframes(wf.getnframes())
|
||
|
||
print(f"📊 音频信息: 采样率={rate}Hz, 声道={channels}, 位深={width*8}bits")
|
||
print(f"📊 音频大小: {len(frames)} 字节")
|
||
|
||
# 如果是立体声,转换为单声道
|
||
if channels > 1:
|
||
audio_array = np.frombuffer(frames, dtype=np.int16)
|
||
audio_array = audio_array.reshape(-1, channels)
|
||
audio_array = np.mean(audio_array, axis=1).astype(np.int16)
|
||
frames = audio_array.tobytes()
|
||
print(f"🔄 已转换为单声道")
|
||
|
||
return await self._test_asr_connection(frames)
|
||
|
||
except Exception as e:
|
||
print(f"❌ 音频文件处理失败: {e}")
|
||
return None
|
||
|
||
async def test_asr_with_silence(self):
|
||
"""测试静音音频"""
|
||
print(f"🔇 测试ASR - 静音音频")
|
||
|
||
# 生成3秒的静音音频 (16kHz, 16bit, 单声道)
|
||
duration = 3 # 秒
|
||
sample_rate = 16000
|
||
silence_data = bytes(duration * sample_rate * 2) # 2 bytes per sample
|
||
|
||
return await self._test_asr_connection(silence_data)
|
||
|
||
async def test_asr_with_noise(self):
|
||
"""测试噪音音频"""
|
||
print(f"📢 测试ASR - 噪音音频")
|
||
|
||
# 生成3秒的随机噪音
|
||
duration = 3 # 秒
|
||
sample_rate = 16000
|
||
noise_data = np.random.randint(-32768, 32767, duration * sample_rate, dtype=np.int16)
|
||
noise_data = noise_data.tobytes()
|
||
|
||
return await self._test_asr_connection(noise_data)
|
||
|
||
async def _test_asr_connection(self, audio_data: bytes):
|
||
"""测试ASR连接"""
|
||
try:
|
||
import websockets
|
||
|
||
# 构建请求参数
|
||
reqid = str(uuid.uuid4())
|
||
request_params = {
|
||
'app': {
|
||
'appid': self.api_config['asr']['appid'],
|
||
'cluster': self.api_config['asr']['cluster'],
|
||
'token': self.api_config['asr']['token'],
|
||
},
|
||
'user': {
|
||
'uid': 'asr_diagnostic'
|
||
},
|
||
'request': {
|
||
'reqid': reqid,
|
||
'nbest': 1,
|
||
'workflow': 'audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate',
|
||
'show_language': False,
|
||
'show_utterances': False,
|
||
'result_type': 'full',
|
||
"sequence": 1
|
||
},
|
||
'audio': {
|
||
'format': 'wav',
|
||
'rate': 16000,
|
||
'language': 'zh-CN',
|
||
'bits': 16,
|
||
'channel': 1,
|
||
'codec': 'raw'
|
||
}
|
||
}
|
||
|
||
print(f"📋 ASR请求参数:")
|
||
print(f" - AppID: {request_params['app']['appid']}")
|
||
print(f" - Cluster: {request_params['app']['cluster']}")
|
||
print(f" - Token: {request_params['app']['token'][:20]}...")
|
||
print(f" - RequestID: {reqid}")
|
||
|
||
# 构建请求
|
||
payload_bytes = str.encode(json.dumps(request_params))
|
||
payload_bytes = gzip.compress(payload_bytes)
|
||
full_client_request = bytearray(self.generate_asr_header())
|
||
full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
||
full_client_request.extend(payload_bytes)
|
||
|
||
# 设置认证头
|
||
additional_headers = {'Authorization': 'Bearer; {}'.format(self.api_config['asr']['token'])}
|
||
|
||
print(f"📡 连接WebSocket...")
|
||
|
||
# 连接WebSocket
|
||
async with websockets.connect(
|
||
self.api_config['asr']['ws_url'],
|
||
additional_headers=additional_headers,
|
||
max_size=1000000000
|
||
) as ws:
|
||
print(f"✅ WebSocket连接成功")
|
||
|
||
# 发送请求
|
||
print(f"📤 发送ASR配置请求...")
|
||
await ws.send(full_client_request)
|
||
res = await ws.recv()
|
||
result = self.parse_asr_response(res)
|
||
print(f"📥 配置响应: {result}")
|
||
|
||
# 发送音频数据
|
||
chunk_size = int(1 * 2 * 16000 * 15000 / 1000) # 1秒 chunks
|
||
total_chunks = 0
|
||
|
||
for offset in range(0, len(audio_data), chunk_size):
|
||
chunk = audio_data[offset:offset + chunk_size]
|
||
last = (offset + chunk_size) >= len(audio_data)
|
||
|
||
payload_bytes = gzip.compress(chunk)
|
||
audio_only_request = bytearray(
|
||
self.generate_asr_header(
|
||
message_type=0b0010,
|
||
message_type_specific_flags=0b0010 if last else 0
|
||
)
|
||
)
|
||
audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
||
audio_only_request.extend(payload_bytes)
|
||
|
||
await ws.send(audio_only_request)
|
||
res = await ws.recv()
|
||
result = self.parse_asr_response(res)
|
||
total_chunks += 1
|
||
|
||
if last:
|
||
print(f"📨 发送最后一块音频数据 (总计{total_chunks}块)")
|
||
|
||
# 获取最终结果
|
||
print(f"🎯 等待最终识别结果...")
|
||
if 'payload_msg' in result and 'result' in result['payload_msg']:
|
||
results = result['payload_msg']['result']
|
||
print(f"📝 ASR返回结果数量: {len(results)}")
|
||
if results:
|
||
text = results[0].get('text', '识别失败')
|
||
print(f"✅ 识别结果: {text}")
|
||
return text
|
||
else:
|
||
print(f"❌ ASR结果为空")
|
||
else:
|
||
print(f"❌ ASR响应格式异常: {result.keys()}")
|
||
print(f"📋 完整响应: {result}")
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"❌ ASR连接异常: {e}")
|
||
import traceback
|
||
print(f"❌ 详细错误:\n{traceback.format_exc()}")
|
||
return None
|
||
|
||
async def run_diagnostic(self):
|
||
"""运行完整诊断"""
|
||
print("🔧 ASR诊断工具")
|
||
print("=" * 50)
|
||
|
||
# 1. 测试静音
|
||
print("\n1️⃣ 测试静音识别...")
|
||
await self.test_asr_with_silence()
|
||
|
||
# 2. 测试噪音
|
||
print("\n2️⃣ 测试噪音识别...")
|
||
await self.test_asr_with_noise()
|
||
|
||
# 3. 测试录音文件(如果存在)
|
||
recording_files = [f for f in os.listdir('.') if f.startswith('recording_') and f.endswith('.wav')]
|
||
if recording_files:
|
||
print(f"\n3️⃣ 测试录音文件...")
|
||
for file in recording_files[:3]: # 最多测试3个文件
|
||
await self.test_asr_with_audio_file(file)
|
||
else:
|
||
print(f"\n3️⃣ 跳过录音文件测试 (无录音文件)")
|
||
|
||
print(f"\n✅ 诊断完成")
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
diagnostic = ASRDiagnostic()
|
||
|
||
try:
|
||
asyncio.run(diagnostic.run_diagnostic())
|
||
except KeyboardInterrupt:
|
||
print(f"\n🛑 诊断被用户中断")
|
||
except Exception as e:
|
||
print(f"❌ 诊断工具异常: {e}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |