Local-Voice/asr_diagnostic.py
2025-09-21 03:00:11 +08:00

287 lines
10 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
语音识别诊断工具
用于测试和诊断语音识别功能的具体问题
"""
import asyncio
import json
import gzip
import uuid
import numpy as np
import wave
import os
from typing import Optional
class ASRDiagnostic:
"""ASR诊断工具"""
def __init__(self):
self.api_config = {
'asr': {
'appid': "8718217928",
'cluster': "volcano_tts",
'token': "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc",
'ws_url': "wss://openspeech.bytedance.com/api/v2/asr"
}
}
def generate_asr_header(self, message_type=1, message_type_specific_flags=0):
"""生成ASR头部"""
PROTOCOL_VERSION = 0b0001
DEFAULT_HEADER_SIZE = 0b0001
JSON = 0b0001
GZIP = 0b0001
header = bytearray()
header.append((PROTOCOL_VERSION << 4) | DEFAULT_HEADER_SIZE)
header.append((message_type << 4) | message_type_specific_flags)
header.append((JSON << 4) | GZIP)
header.append(0x00) # reserved
return header
def parse_asr_response(self, res):
"""解析ASR响应"""
print(f"🔍 解析响应,原始大小: {len(res)} 字节")
if len(res) < 8:
print(f"❌ 响应太短,无法解析")
return {}
try:
message_type = res[1] >> 4
payload_size = int.from_bytes(res[4:8], "big", signed=False)
payload_msg = res[8:8+payload_size]
print(f"📋 消息类型: {message_type}, 载荷大小: {payload_size}")
if message_type == 0b1001: # SERVER_FULL_RESPONSE
try:
if payload_msg.startswith(b'{'):
result = json.loads(payload_msg.decode('utf-8'))
print(f"✅ 成功解析JSON响应")
return result
else:
print(f"❌ 响应不是JSON格式")
except Exception as e:
print(f"❌ JSON解析失败: {e}")
except Exception as e:
print(f"❌ 响应解析异常: {e}")
return {}
async def test_asr_with_audio_file(self, audio_file_path: str):
"""使用音频文件测试ASR"""
print(f"🎵 测试ASR - 音频文件: {audio_file_path}")
if not os.path.exists(audio_file_path):
print(f"❌ 音频文件不存在: {audio_file_path}")
return
try:
# 读取音频文件
with wave.open(audio_file_path, 'rb') as wf:
channels = wf.getnchannels()
width = wf.getsampwidth()
rate = wf.getframerate()
frames = wf.readframes(wf.getnframes())
print(f"📊 音频信息: 采样率={rate}Hz, 声道={channels}, 位深={width*8}bits")
print(f"📊 音频大小: {len(frames)} 字节")
# 如果是立体声,转换为单声道
if channels > 1:
audio_array = np.frombuffer(frames, dtype=np.int16)
audio_array = audio_array.reshape(-1, channels)
audio_array = np.mean(audio_array, axis=1).astype(np.int16)
frames = audio_array.tobytes()
print(f"🔄 已转换为单声道")
return await self._test_asr_connection(frames)
except Exception as e:
print(f"❌ 音频文件处理失败: {e}")
return None
async def test_asr_with_silence(self):
"""测试静音音频"""
print(f"🔇 测试ASR - 静音音频")
# 生成3秒的静音音频 (16kHz, 16bit, 单声道)
duration = 3 # 秒
sample_rate = 16000
silence_data = bytes(duration * sample_rate * 2) # 2 bytes per sample
return await self._test_asr_connection(silence_data)
async def test_asr_with_noise(self):
"""测试噪音音频"""
print(f"📢 测试ASR - 噪音音频")
# 生成3秒的随机噪音
duration = 3 # 秒
sample_rate = 16000
noise_data = np.random.randint(-32768, 32767, duration * sample_rate, dtype=np.int16)
noise_data = noise_data.tobytes()
return await self._test_asr_connection(noise_data)
async def _test_asr_connection(self, audio_data: bytes):
"""测试ASR连接"""
try:
import websockets
# 构建请求参数
reqid = str(uuid.uuid4())
request_params = {
'app': {
'appid': self.api_config['asr']['appid'],
'cluster': self.api_config['asr']['cluster'],
'token': self.api_config['asr']['token'],
},
'user': {
'uid': 'asr_diagnostic'
},
'request': {
'reqid': reqid,
'nbest': 1,
'workflow': 'audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate',
'show_language': False,
'show_utterances': False,
'result_type': 'full',
"sequence": 1
},
'audio': {
'format': 'wav',
'rate': 16000,
'language': 'zh-CN',
'bits': 16,
'channel': 1,
'codec': 'raw'
}
}
print(f"📋 ASR请求参数:")
print(f" - AppID: {request_params['app']['appid']}")
print(f" - Cluster: {request_params['app']['cluster']}")
print(f" - Token: {request_params['app']['token'][:20]}...")
print(f" - RequestID: {reqid}")
# 构建请求
payload_bytes = str.encode(json.dumps(request_params))
payload_bytes = gzip.compress(payload_bytes)
full_client_request = bytearray(self.generate_asr_header())
full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
full_client_request.extend(payload_bytes)
# 设置认证头
additional_headers = {'Authorization': 'Bearer; {}'.format(self.api_config['asr']['token'])}
print(f"📡 连接WebSocket...")
# 连接WebSocket
async with websockets.connect(
self.api_config['asr']['ws_url'],
additional_headers=additional_headers,
max_size=1000000000
) as ws:
print(f"✅ WebSocket连接成功")
# 发送请求
print(f"📤 发送ASR配置请求...")
await ws.send(full_client_request)
res = await ws.recv()
result = self.parse_asr_response(res)
print(f"📥 配置响应: {result}")
# 发送音频数据
chunk_size = int(1 * 2 * 16000 * 15000 / 1000) # 1秒 chunks
total_chunks = 0
for offset in range(0, len(audio_data), chunk_size):
chunk = audio_data[offset:offset + chunk_size]
last = (offset + chunk_size) >= len(audio_data)
payload_bytes = gzip.compress(chunk)
audio_only_request = bytearray(
self.generate_asr_header(
message_type=0b0010,
message_type_specific_flags=0b0010 if last else 0
)
)
audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
audio_only_request.extend(payload_bytes)
await ws.send(audio_only_request)
res = await ws.recv()
result = self.parse_asr_response(res)
total_chunks += 1
if last:
print(f"📨 发送最后一块音频数据 (总计{total_chunks}块)")
# 获取最终结果
print(f"🎯 等待最终识别结果...")
if 'payload_msg' in result and 'result' in result['payload_msg']:
results = result['payload_msg']['result']
print(f"📝 ASR返回结果数量: {len(results)}")
if results:
text = results[0].get('text', '识别失败')
print(f"✅ 识别结果: {text}")
return text
else:
print(f"❌ ASR结果为空")
else:
print(f"❌ ASR响应格式异常: {result.keys()}")
print(f"📋 完整响应: {result}")
return None
except Exception as e:
print(f"❌ ASR连接异常: {e}")
import traceback
print(f"❌ 详细错误:\n{traceback.format_exc()}")
return None
async def run_diagnostic(self):
"""运行完整诊断"""
print("🔧 ASR诊断工具")
print("=" * 50)
# 1. 测试静音
print("\n1⃣ 测试静音识别...")
await self.test_asr_with_silence()
# 2. 测试噪音
print("\n2⃣ 测试噪音识别...")
await self.test_asr_with_noise()
# 3. 测试录音文件(如果存在)
recording_files = [f for f in os.listdir('.') if f.startswith('recording_') and f.endswith('.wav')]
if recording_files:
print(f"\n3⃣ 测试录音文件...")
for file in recording_files[:3]: # 最多测试3个文件
await self.test_asr_with_audio_file(file)
else:
print(f"\n3⃣ 跳过录音文件测试 (无录音文件)")
print(f"\n✅ 诊断完成")
def main():
"""主函数"""
diagnostic = ASRDiagnostic()
try:
asyncio.run(diagnostic.run_diagnostic())
except KeyboardInterrupt:
print(f"\n🛑 诊断被用户中断")
except Exception as e:
print(f"❌ 诊断工具异常: {e}")
if __name__ == "__main__":
main()