This commit is contained in:
朱潮 2025-09-20 10:53:56 +08:00
parent ef39e31a4b
commit eb099d827d
7 changed files with 1811 additions and 33 deletions

501
enhanced_wake_and_record.py Normal file
View File

@ -0,0 +1,501 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
集成语音识别的唤醒+录音系统
基于 simple_wake_and_record.py添加语音识别功能
"""
import sys
import os
import time
import threading
import pyaudio
import json
import asyncio
from typing import Optional, List
# 添加当前目录到路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
try:
from vosk import Model, KaldiRecognizer
VOSK_AVAILABLE = True
except ImportError:
VOSK_AVAILABLE = False
print("⚠️ Vosk 未安装,请运行: pip install vosk")
from speech_recognizer import SpeechRecognizer, RecognitionResult
class EnhancedWakeAndRecord:
"""增强的唤醒+录音系统,集成语音识别"""
def __init__(self, model_path="model", wake_words=["你好", "助手"],
enable_speech_recognition=True, app_key=None, access_key=None):
self.model_path = model_path
self.wake_words = wake_words
self.enable_speech_recognition = enable_speech_recognition
self.model = None
self.recognizer = None
self.audio = None
self.stream = None
self.running = False
# 音频参数
self.FORMAT = pyaudio.paInt16
self.CHANNELS = 1
self.RATE = 16000
self.CHUNK_SIZE = 1024
# 录音相关
self.recording = False
self.recorded_frames = []
self.last_text_time = None
self.recording_start_time = None
self.recording_recognizer = None
# 阈值
self.text_silence_threshold = 3.0
self.min_recording_time = 2.0
self.max_recording_time = 30.0
# 语音识别相关
self.speech_recognizer = None
self.last_recognition_result = None
self.recognition_thread = None
# 回调函数
self.on_recognition_result = None
self._setup_model()
self._setup_audio()
self._setup_speech_recognition(app_key, access_key)
def _setup_model(self):
"""设置 Vosk 模型"""
if not VOSK_AVAILABLE:
return
try:
if not os.path.exists(self.model_path):
print(f"模型路径不存在: {self.model_path}")
return
self.model = Model(self.model_path)
self.recognizer = KaldiRecognizer(self.model, self.RATE)
self.recognizer.SetWords(True)
print(f"✅ Vosk 模型加载成功")
except Exception as e:
print(f"模型初始化失败: {e}")
def _setup_audio(self):
"""设置音频设备"""
try:
if self.audio is None:
self.audio = pyaudio.PyAudio()
if self.stream is None:
self.stream = self.audio.open(
format=self.FORMAT,
channels=self.CHANNELS,
rate=self.RATE,
input=True,
frames_per_buffer=self.CHUNK_SIZE
)
print("✅ 音频设备初始化成功")
except Exception as e:
print(f"音频设备初始化失败: {e}")
def _setup_speech_recognition(self, app_key=None, access_key=None):
"""设置语音识别"""
if not self.enable_speech_recognition:
return
try:
self.speech_recognizer = SpeechRecognizer(
app_key=app_key,
access_key=access_key
)
print("✅ 语音识别器初始化成功")
except Exception as e:
print(f"语音识别器初始化失败: {e}")
self.enable_speech_recognition = False
def _calculate_energy(self, audio_data):
"""计算音频能量"""
if len(audio_data) == 0:
return 0
import numpy as np
audio_array = np.frombuffer(audio_data, dtype=np.int16)
rms = np.sqrt(np.mean(audio_array ** 2))
return rms
def _check_wake_word(self, text):
"""检查是否包含唤醒词"""
if not text or not self.wake_words:
return False, None
text_lower = text.lower()
for wake_word in self.wake_words:
if wake_word.lower() in text_lower:
return True, wake_word
return False, None
def _save_recording(self, audio_data):
"""保存录音"""
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"recording_{timestamp}.wav"
try:
import wave
with wave.open(filename, 'wb') as wf:
wf.setnchannels(self.CHANNELS)
wf.setsampwidth(self.audio.get_sample_size(self.FORMAT))
wf.setframerate(self.RATE)
wf.writeframes(audio_data)
print(f"✅ 录音已保存: {filename}")
return True, filename
except Exception as e:
print(f"保存录音失败: {e}")
return False, None
def _play_audio(self, filename):
"""播放音频文件"""
try:
import wave
# 打开音频文件
with wave.open(filename, 'rb') as wf:
# 获取音频参数
channels = wf.getnchannels()
width = wf.getsampwidth()
rate = wf.getframerate()
total_frames = wf.getnframes()
# 分块读取音频数据,避免内存问题
chunk_size = 1024
frames = []
for _ in range(0, total_frames, chunk_size):
chunk = wf.readframes(chunk_size)
if chunk:
frames.append(chunk)
else:
break
# 创建播放流
playback_stream = self.audio.open(
format=self.audio.get_format_from_width(width),
channels=channels,
rate=rate,
output=True
)
print(f"🔊 开始播放: {filename}")
# 分块播放音频
for chunk in frames:
playback_stream.write(chunk)
# 等待播放完成
playback_stream.stop_stream()
playback_stream.close()
print("✅ 播放完成")
except Exception as e:
print(f"❌ 播放失败: {e}")
self._play_with_system_player(filename)
def _play_with_system_player(self, filename):
"""使用系统播放器播放音频"""
try:
import platform
import subprocess
system = platform.system()
if system == 'Darwin': # macOS
cmd = ['afplay', filename]
elif system == 'Windows':
cmd = ['start', '/min', filename]
else: # Linux
cmd = ['aplay', filename]
print(f"🔊 使用系统播放器: {' '.join(cmd)}")
subprocess.run(cmd, check=True)
print("✅ 播放完成")
except Exception as e:
print(f"❌ 系统播放器也失败: {e}")
print(f"💡 文件已保存,请手动播放: {filename}")
def _start_recognition_thread(self, filename):
"""启动语音识别线程"""
if not self.enable_speech_recognition or not self.speech_recognizer:
return
def recognize_task():
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
print(f"🧠 开始识别录音文件: {filename}")
result = loop.run_until_complete(
self.speech_recognizer.recognize_file(filename)
)
if result:
# 合并所有识别结果
full_text = " ".join([r.text for r in result])
final_result = RecognitionResult(
text=full_text,
confidence=0.9,
is_final=True
)
self.last_recognition_result = final_result
print(f"\n🧠 语音识别结果: {full_text}")
# 调用回调函数
if self.on_recognition_result:
self.on_recognition_result(final_result)
else:
print(f"\n🧠 语音识别失败或未识别到内容")
loop.close()
except Exception as e:
print(f"❌ 语音识别线程异常: {e}")
self.recognition_thread = threading.Thread(target=recognize_task)
self.recognition_thread.daemon = True
self.recognition_thread.start()
def _start_recording(self):
"""开始录音"""
print("🎙️ 开始录音,请说话...")
self.recording = True
self.recorded_frames = []
self.last_text_time = None
self.recording_start_time = time.time()
# 为录音创建一个新的识别器
if self.model:
self.recording_recognizer = KaldiRecognizer(self.model, self.RATE)
self.recording_recognizer.SetWords(True)
def _stop_recording(self):
"""停止录音"""
if len(self.recorded_frames) > 0:
audio_data = b''.join(self.recorded_frames)
duration = len(audio_data) / (self.RATE * 2)
print(f"📝 录音完成,时长: {duration:.2f}")
# 保存录音
success, filename = self._save_recording(audio_data)
# 如果保存成功,播放录音并进行语音识别
if success and filename:
print("=" * 50)
print("🔊 播放刚才录制的音频...")
self._play_audio(filename)
print("=" * 50)
# 启动语音识别
if self.enable_speech_recognition:
print("🧠 准备进行语音识别...")
self._start_recognition_thread(filename)
self.recording = False
self.recorded_frames = []
self.last_text_time = None
self.recording_start_time = None
self.recording_recognizer = None
def set_recognition_callback(self, callback):
"""设置识别结果回调函数"""
self.on_recognition_result = callback
def get_last_recognition_result(self) -> Optional[RecognitionResult]:
"""获取最后一次识别结果"""
return self.last_recognition_result
def start(self):
"""开始唤醒词检测和录音"""
if not self.stream:
print("❌ 音频设备未初始化")
return
self.running = True
print("🎤 开始监听...")
print(f"唤醒词: {', '.join(self.wake_words)}")
if self.enable_speech_recognition:
print("🧠 语音识别: 已启用")
else:
print("🧠 语音识别: 已禁用")
try:
while self.running:
# 读取音频数据
data = self.stream.read(self.CHUNK_SIZE, exception_on_overflow=False)
if len(data) == 0:
continue
if self.recording:
# 录音模式
self.recorded_frames.append(data)
recording_duration = time.time() - self.recording_start_time
# 使用录音专用的识别器进行实时识别
if self.recording_recognizer:
if self.recording_recognizer.AcceptWaveform(data):
result = json.loads(self.recording_recognizer.Result())
text = result.get('text', '').strip()
if text:
self.last_text_time = time.time()
print(f"\n📝 实时识别: {text}")
else:
partial_result = json.loads(self.recording_recognizer.PartialResult())
partial_text = partial_result.get('partial', '').strip()
if partial_text:
self.last_text_time = time.time()
status = f"录音中... {recording_duration:.1f}s | {partial_text}"
print(f"\r{status}", end='', flush=True)
# 检查是否需要结束录音
current_time = time.time()
if self.last_text_time is not None:
text_silence_duration = current_time - self.last_text_time
if text_silence_duration > self.text_silence_threshold and recording_duration >= self.min_recording_time:
print(f"\n\n3秒没有识别到文字结束录音")
self._stop_recording()
else:
if recording_duration > 5.0:
print(f"\n\n5秒没有识别到文字结束录音")
self._stop_recording()
# 检查最大录音时间
if recording_duration > self.max_recording_time:
print(f"\n\n达到最大录音时间 {self.max_recording_time}s")
self._stop_recording()
# 显示录音状态
if self.last_text_time is None:
status = f"等待语音输入... {recording_duration:.1f}s"
print(f"\r{status}", end='', flush=True)
elif self.model and self.recognizer:
# 唤醒词检测模式
if self.recognizer.AcceptWaveform(data):
result = json.loads(self.recognizer.Result())
text = result.get('text', '').strip()
if text:
print(f"识别: {text}")
# 检查唤醒词
is_wake_word, detected_word = self._check_wake_word(text)
if is_wake_word:
print(f"🎯 检测到唤醒词: {detected_word}")
self._start_recording()
else:
# 显示实时音频级别
energy = self._calculate_energy(data)
if energy > 50:
partial_result = json.loads(self.recognizer.PartialResult())
partial_text = partial_result.get('partial', '')
if partial_text:
status = f"监听中... 能量: {energy:.0f} | {partial_text}"
else:
status = f"监听中... 能量: {energy:.0f}"
print(status, end='\r')
time.sleep(0.01)
except KeyboardInterrupt:
print("\n👋 退出")
except Exception as e:
print(f"错误: {e}")
finally:
self.stop()
def stop(self):
"""停止"""
self.running = False
if self.recording:
self._stop_recording()
if self.stream:
self.stream.stop_stream()
self.stream.close()
self.stream = None
if self.audio:
self.audio.terminate()
self.audio = None
# 等待识别线程结束
if self.recognition_thread and self.recognition_thread.is_alive():
self.recognition_thread.join(timeout=5.0)
def main():
"""主函数"""
print("🚀 增强版唤醒+录音+语音识别测试")
print("=" * 50)
# 检查模型
model_dir = "model"
if not os.path.exists(model_dir):
print("⚠️ 未找到模型目录")
print("请下载 Vosk 模型到 model 目录")
return
# 创建系统
system = EnhancedWakeAndRecord(
model_path=model_dir,
wake_words=["你好", "助手", "小爱"],
enable_speech_recognition=True,
# app_key="your_app_key", # 请填入实际的app_key
# access_key="your_access_key" # 请填入实际的access_key
)
if not system.model:
print("❌ 模型加载失败")
return
# 设置识别结果回调
def on_recognition_result(result):
print(f"\n🎯 识别完成!结果: {result.text}")
print(f" 置信度: {result.confidence}")
print(f" 是否最终结果: {result.is_final}")
system.set_recognition_callback(on_recognition_result)
print("✅ 系统初始化成功")
print("📖 使用说明:")
print("1. 说唤醒词开始录音")
print("2. 基于语音识别判断3秒没有识别到文字就结束")
print("3. 最少录音2秒最多30秒")
print("4. 录音时实时显示识别结果")
print("5. 录音文件自动保存")
print("6. 录音完成后自动播放刚才录制的内容")
print("7. 启动语音识别对录音文件进行识别")
print("8. 按 Ctrl+C 退出")
print("=" * 50)
# 开始运行
system.start()
if __name__ == "__main__":
main()

127
recognition_example.py Normal file
View File

@ -0,0 +1,127 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
语音识别使用示例
演示如何使用 speech_recognizer 模块
"""
import os
import asyncio
from speech_recognizer import SpeechRecognizer
async def example_recognize_file():
"""示例:识别单个音频文件"""
print("=== 示例1识别单个音频文件 ===")
# 初始化识别器
recognizer = SpeechRecognizer(
app_key="your_app_key", # 请替换为实际的app_key
access_key="your_access_key" # 请替换为实际的access_key
)
# 假设有一个录音文件
audio_file = "recording_20240101_120000.wav"
if not os.path.exists(audio_file):
print(f"音频文件不存在: {audio_file}")
print("请先运行 enhanced_wake_and_record.py 录制一个音频文件")
return
try:
# 识别音频文件
results = await recognizer.recognize_file(audio_file)
print(f"识别结果(共{len(results)}个):")
for i, result in enumerate(results):
print(f"结果 {i+1}:")
print(f" 文本: {result.text}")
print(f" 置信度: {result.confidence}")
print(f" 最终结果: {result.is_final}")
print("-" * 40)
except Exception as e:
print(f"识别失败: {e}")
async def example_recognize_latest():
"""示例:识别最新的录音文件"""
print("\n=== 示例2识别最新的录音文件 ===")
# 初始化识别器
recognizer = SpeechRecognizer(
app_key="your_app_key", # 请替换为实际的app_key
access_key="your_access_key" # 请替换为实际的access_key
)
try:
# 识别最新的录音文件
result = await recognizer.recognize_latest_recording()
if result:
print("识别结果:")
print(f" 文本: {result.text}")
print(f" 置信度: {result.confidence}")
print(f" 最终结果: {result.is_final}")
else:
print("未找到录音文件或识别失败")
except Exception as e:
print(f"识别失败: {e}")
async def example_batch_recognition():
"""示例:批量识别多个录音文件"""
print("\n=== 示例3批量识别录音文件 ===")
# 初始化识别器
recognizer = SpeechRecognizer(
app_key="your_app_key", # 请替换为实际的app_key
access_key="your_access_key" # 请替换为实际的access_key
)
# 获取所有录音文件
recording_files = [f for f in os.listdir(".") if f.startswith('recording_') and f.endswith('.wav')]
if not recording_files:
print("未找到录音文件")
return
print(f"找到 {len(recording_files)} 个录音文件")
for filename in recording_files[:5]: # 只处理前5个文件
print(f"\n处理文件: {filename}")
try:
results = await recognizer.recognize_file(filename)
if results:
final_result = results[-1] # 取最后一个结果
print(f"识别结果: {final_result.text}")
else:
print("识别失败")
except Exception as e:
print(f"处理失败: {e}")
# 添加延迟,避免请求过于频繁
await asyncio.sleep(1)
async def main():
"""主函数"""
print("🚀 语音识别使用示例")
print("=" * 50)
# 请先设置环境变量或在代码中填入实际的API密钥
if not os.getenv("SAUC_APP_KEY") and "your_app_key" in "your_app_key":
print("⚠️ 请先设置 SAUC_APP_KEY 和 SAUC_ACCESS_KEY 环境变量")
print("或者在代码中填入实际的 app_key 和 access_key")
print("示例:")
print("export SAUC_APP_KEY='your_app_key'")
print("export SAUC_ACCESS_KEY='your_access_key'")
return
# 运行示例
await example_recognize_file()
await example_recognize_latest()
await example_batch_recognition()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,3 +1,5 @@
vosk>=0.3.44
pyaudio>=0.2.11
numpy>=1.19.0
numpy>=1.19.0
aiohttp>=3.8.0
asyncio

15
sauc_python/readme.md Normal file
View File

@ -0,0 +1,15 @@
# README
**asr tob 相关client demo**
# Notice
python version: python 3.x
替换代码中的key为真实数据:
"app_key": "xxxxxxx",
"access_key": "xxxxxxxxxxxxxxxx"
使用示例:
python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav

View File

@ -0,0 +1,523 @@
import asyncio
import aiohttp
import json
import struct
import gzip
import uuid
import logging
import os
import subprocess
from typing import Optional, List, Dict, Any, Tuple, AsyncGenerator
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('run.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# 常量定义
DEFAULT_SAMPLE_RATE = 16000
class ProtocolVersion:
V1 = 0b0001
class MessageType:
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ERROR_RESPONSE = 0b1111
class MessageTypeSpecificFlags:
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
NEG_SEQUENCE = 0b0010
NEG_WITH_SEQUENCE = 0b0011
class SerializationType:
NO_SERIALIZATION = 0b0000
JSON = 0b0001
class CompressionType:
GZIP = 0b0001
class Config:
def __init__(self):
# 填入控制台获取的app id和access token
self.auth = {
"app_key": "xxxxxxx",
"access_key": "xxxxxxxxxxxx"
}
@property
def app_key(self) -> str:
return self.auth["app_key"]
@property
def access_key(self) -> str:
return self.auth["access_key"]
config = Config()
class CommonUtils:
@staticmethod
def gzip_compress(data: bytes) -> bytes:
return gzip.compress(data)
@staticmethod
def gzip_decompress(data: bytes) -> bytes:
return gzip.decompress(data)
@staticmethod
def judge_wav(data: bytes) -> bool:
if len(data) < 44:
return False
return data[:4] == b'RIFF' and data[8:12] == b'WAVE'
@staticmethod
def convert_wav_with_path(audio_path: str, sample_rate: int = DEFAULT_SAMPLE_RATE) -> bytes:
try:
cmd = [
"ffmpeg", "-v", "quiet", "-y", "-i", audio_path,
"-acodec", "pcm_s16le", "-ac", "1", "-ar", str(sample_rate),
"-f", "wav", "-"
]
result = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# 尝试删除原始文件
try:
os.remove(audio_path)
except OSError as e:
logger.warning(f"Failed to remove original file: {e}")
return result.stdout
except subprocess.CalledProcessError as e:
logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}")
raise RuntimeError(f"Audio conversion failed: {e.stderr.decode()}")
@staticmethod
def read_wav_info(data: bytes) -> Tuple[int, int, int, int, bytes]:
if len(data) < 44:
raise ValueError("Invalid WAV file: too short")
# 解析WAV头
chunk_id = data[:4]
if chunk_id != b'RIFF':
raise ValueError("Invalid WAV file: not RIFF format")
format_ = data[8:12]
if format_ != b'WAVE':
raise ValueError("Invalid WAV file: not WAVE format")
# 解析fmt子块
audio_format = struct.unpack('<H', data[20:22])[0]
num_channels = struct.unpack('<H', data[22:24])[0]
sample_rate = struct.unpack('<I', data[24:28])[0]
bits_per_sample = struct.unpack('<H', data[34:36])[0]
# 查找data子块
pos = 36
while pos < len(data) - 8:
subchunk_id = data[pos:pos+4]
subchunk_size = struct.unpack('<I', data[pos+4:pos+8])[0]
if subchunk_id == b'data':
wave_data = data[pos+8:pos+8+subchunk_size]
return (
num_channels,
bits_per_sample // 8,
sample_rate,
subchunk_size // (num_channels * (bits_per_sample // 8)),
wave_data
)
pos += 8 + subchunk_size
raise ValueError("Invalid WAV file: no data subchunk found")
class AsrRequestHeader:
def __init__(self):
self.message_type = MessageType.CLIENT_FULL_REQUEST
self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE
self.serialization_type = SerializationType.JSON
self.compression_type = CompressionType.GZIP
self.reserved_data = bytes([0x00])
def with_message_type(self, message_type: int) -> 'AsrRequestHeader':
self.message_type = message_type
return self
def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader':
self.message_type_specific_flags = flags
return self
def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader':
self.serialization_type = serialization_type
return self
def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader':
self.compression_type = compression_type
return self
def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader':
self.reserved_data = reserved_data
return self
def to_bytes(self) -> bytes:
header = bytearray()
header.append((ProtocolVersion.V1 << 4) | 1)
header.append((self.message_type << 4) | self.message_type_specific_flags)
header.append((self.serialization_type << 4) | self.compression_type)
header.extend(self.reserved_data)
return bytes(header)
@staticmethod
def default_header() -> 'AsrRequestHeader':
return AsrRequestHeader()
class RequestBuilder:
@staticmethod
def new_auth_headers() -> Dict[str, str]:
reqid = str(uuid.uuid4())
return {
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
"X-Api-Request-Id": reqid,
"X-Api-Access-Key": config.access_key,
"X-Api-App-Key": config.app_key
}
@staticmethod
def new_full_client_request(seq: int) -> bytes: # 添加seq参数
header = AsrRequestHeader.default_header() \
.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
payload = {
"user": {
"uid": "demo_uid"
},
"audio": {
"format": "wav",
"codec": "raw",
"rate": 16000,
"bits": 16,
"channel": 1
},
"request": {
"model_name": "bigmodel",
"enable_itn": True,
"enable_punc": True,
"enable_ddc": True,
"show_utterances": True,
"enable_nonstream": False
}
}
payload_bytes = json.dumps(payload).encode('utf-8')
compressed_payload = CommonUtils.gzip_compress(payload_bytes)
payload_size = len(compressed_payload)
request = bytearray()
request.extend(header.to_bytes())
request.extend(struct.pack('>i', seq)) # 使用传入的seq
request.extend(struct.pack('>I', payload_size))
request.extend(compressed_payload)
return bytes(request)
@staticmethod
def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes:
header = AsrRequestHeader.default_header()
if is_last: # 最后一个包特殊处理
header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE)
seq = -seq # 设为负值
else:
header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST)
request = bytearray()
request.extend(header.to_bytes())
request.extend(struct.pack('>i', seq))
compressed_segment = CommonUtils.gzip_compress(segment)
request.extend(struct.pack('>I', len(compressed_segment)))
request.extend(compressed_segment)
return bytes(request)
class AsrResponse:
def __init__(self):
self.code = 0
self.event = 0
self.is_last_package = False
self.payload_sequence = 0
self.payload_size = 0
self.payload_msg = None
def to_dict(self) -> Dict[str, Any]:
return {
"code": self.code,
"event": self.event,
"is_last_package": self.is_last_package,
"payload_sequence": self.payload_sequence,
"payload_size": self.payload_size,
"payload_msg": self.payload_msg
}
class ResponseParser:
@staticmethod
def parse_response(msg: bytes) -> AsrResponse:
response = AsrResponse()
header_size = msg[0] & 0x0f
message_type = msg[1] >> 4
message_type_specific_flags = msg[1] & 0x0f
serialization_method = msg[2] >> 4
message_compression = msg[2] & 0x0f
payload = msg[header_size*4:]
# 解析message_type_specific_flags
if message_type_specific_flags & 0x01:
response.payload_sequence = struct.unpack('>i', payload[:4])[0]
payload = payload[4:]
if message_type_specific_flags & 0x02:
response.is_last_package = True
if message_type_specific_flags & 0x04:
response.event = struct.unpack('>i', payload[:4])[0]
payload = payload[4:]
# 解析message_type
if message_type == MessageType.SERVER_FULL_RESPONSE:
response.payload_size = struct.unpack('>I', payload[:4])[0]
payload = payload[4:]
elif message_type == MessageType.SERVER_ERROR_RESPONSE:
response.code = struct.unpack('>i', payload[:4])[0]
response.payload_size = struct.unpack('>I', payload[4:8])[0]
payload = payload[8:]
if not payload:
return response
# 解压缩
if message_compression == CompressionType.GZIP:
try:
payload = CommonUtils.gzip_decompress(payload)
except Exception as e:
logger.error(f"Failed to decompress payload: {e}")
return response
# 解析payload
try:
if serialization_method == SerializationType.JSON:
response.payload_msg = json.loads(payload.decode('utf-8'))
except Exception as e:
logger.error(f"Failed to parse payload: {e}")
return response
class AsrWsClient:
def __init__(self, url: str, segment_duration: int = 200):
self.seq = 1
self.url = url
self.segment_duration = segment_duration
self.conn = None
self.session = None # 添加session引用
async def __aenter__(self):
self.session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc, tb):
if self.conn and not self.conn.closed:
await self.conn.close()
if self.session and not self.session.closed:
await self.session.close()
async def read_audio_data(self, file_path: str) -> bytes:
try:
with open(file_path, 'rb') as f:
content = f.read()
if not CommonUtils.judge_wav(content):
logger.info("Converting audio to WAV format...")
content = CommonUtils.convert_wav_with_path(file_path, DEFAULT_SAMPLE_RATE)
return content
except Exception as e:
logger.error(f"Failed to read audio data: {e}")
raise
def get_segment_size(self, content: bytes) -> int:
try:
channel_num, samp_width, frame_rate, _, _ = CommonUtils.read_wav_info(content)[:5]
size_per_sec = channel_num * samp_width * frame_rate
segment_size = size_per_sec * self.segment_duration // 1000
return segment_size
except Exception as e:
logger.error(f"Failed to calculate segment size: {e}")
raise
async def create_connection(self) -> None:
headers = RequestBuilder.new_auth_headers()
try:
self.conn = await self.session.ws_connect( # 使用self.session
self.url,
headers=headers
)
logger.info(f"Connected to {self.url}")
except Exception as e:
logger.error(f"Failed to connect to WebSocket: {e}")
raise
async def send_full_client_request(self) -> None:
request = RequestBuilder.new_full_client_request(self.seq)
self.seq += 1 # 发送后递增
try:
await self.conn.send_bytes(request)
logger.info(f"Sent full client request with seq: {self.seq-1}")
msg = await self.conn.receive()
if msg.type == aiohttp.WSMsgType.BINARY:
response = ResponseParser.parse_response(msg.data)
logger.info(f"Received response: {response.to_dict()}")
else:
logger.error(f"Unexpected message type: {msg.type}")
except Exception as e:
logger.error(f"Failed to send full client request: {e}")
raise
async def send_messages(self, segment_size: int, content: bytes) -> AsyncGenerator[None, None]:
audio_segments = self.split_audio(content, segment_size)
total_segments = len(audio_segments)
for i, segment in enumerate(audio_segments):
is_last = (i == total_segments - 1)
request = RequestBuilder.new_audio_only_request(
self.seq,
segment,
is_last=is_last
)
await self.conn.send_bytes(request)
logger.info(f"Sent audio segment with seq: {self.seq} (last: {is_last})")
if not is_last:
self.seq += 1
await asyncio.sleep(self.segment_duration / 1000) # 逐个发送,间隔时间模拟实时流
# 让出控制权,允许接受消息
yield
async def recv_messages(self) -> AsyncGenerator[AsrResponse, None]:
try:
async for msg in self.conn:
if msg.type == aiohttp.WSMsgType.BINARY:
response = ResponseParser.parse_response(msg.data)
yield response
if response.is_last_package or response.code != 0:
break
elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error(f"WebSocket error: {msg.data}")
break
elif msg.type == aiohttp.WSMsgType.CLOSED:
logger.info("WebSocket connection closed")
break
except Exception as e:
logger.error(f"Error receiving messages: {e}")
raise
async def start_audio_stream(self, segment_size: int, content: bytes) -> AsyncGenerator[AsrResponse, None]:
async def sender():
async for _ in self.send_messages(segment_size, content):
pass
# 启动发送和接收任务
sender_task = asyncio.create_task(sender())
try:
async for response in self.recv_messages():
yield response
finally:
sender_task.cancel()
try:
await sender_task
except asyncio.CancelledError:
pass
@staticmethod
def split_audio(data: bytes, segment_size: int) -> List[bytes]:
if segment_size <= 0:
return []
segments = []
for i in range(0, len(data), segment_size):
end = i + segment_size
if end > len(data):
end = len(data)
segments.append(data[i:end])
return segments
async def execute(self, file_path: str) -> AsyncGenerator[AsrResponse, None]:
if not file_path:
raise ValueError("File path is empty")
if not self.url:
raise ValueError("URL is empty")
self.seq = 1
try:
# 1. 读取音频文件
content = await self.read_audio_data(file_path)
# 2. 计算分段大小
segment_size = self.get_segment_size(content)
# 3. 创建WebSocket连接
await self.create_connection()
# 4. 发送完整客户端请求
await self.send_full_client_request()
# 5. 启动音频流处理
async for response in self.start_audio_stream(segment_size, content):
yield response
except Exception as e:
logger.error(f"Error in ASR execution: {e}")
raise
finally:
if self.conn:
await self.conn.close()
async def main():
import argparse
parser = argparse.ArgumentParser(description="ASR WebSocket Client")
parser.add_argument("--file", type=str, required=True, help="Audio file path")
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream
parser.add_argument("--url", type=str, default="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream",
help="WebSocket URL")
parser.add_argument("--seg-duration", type=int, default=200,
help="Audio duration(ms) per packet, default:200")
args = parser.parse_args()
async with AsrWsClient(args.url, args.seg_duration) as client: # 使用async with
try:
async for response in client.execute(args.file):
logger.info(f"Received response: {json.dumps(response.to_dict(), indent=2, ensure_ascii=False)}")
except Exception as e:
logger.error(f"ASR processing failed: {e}")
if __name__ == "__main__":
asyncio.run(main())
# 用法:
# python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav

View File

@ -35,11 +35,11 @@ class SimpleWakeAndRecord:
self.stream = None
self.running = False
# 音频参数
# 音频参数 - 优化为树莓派3B
self.FORMAT = pyaudio.paInt16
self.CHANNELS = 1
self.RATE = 16000
self.CHUNK_SIZE = 1024
self.RATE = 8000 # 从16kHz降至8kHz减少50%数据处理量
self.CHUNK_SIZE = 2048 # 增大块大小,减少处理次数
# 录音相关
self.recording = False
@ -48,6 +48,19 @@ class SimpleWakeAndRecord:
self.recording_start_time = None
self.recording_recognizer = None # 录音时专用的识别器
# 性能优化相关
self.audio_buffer = [] # 音频缓冲区
self.buffer_size = 10 # 缓冲区大小(块数)
self.last_process_time = time.time() # 上次处理时间
self.process_interval = 0.5 # 处理间隔(秒)
self.batch_process_size = 5 # 批处理大小
# 性能监控
self.process_count = 0
self.avg_process_time = 0
self.last_monitor_time = time.time()
self.monitor_interval = 5.0 # 监控间隔(秒)
# 阈值
self.text_silence_threshold = 3.0 # 3秒没有识别到文字就结束
self.min_recording_time = 2.0 # 最小录音时间
@ -116,6 +129,48 @@ class SimpleWakeAndRecord:
return True, wake_word
return False, None
def _should_process_audio(self):
"""判断是否应该处理音频"""
current_time = time.time()
return (current_time - self.last_process_time >= self.process_interval and
len(self.audio_buffer) >= self.batch_process_size)
def _process_audio_batch(self):
"""批量处理音频数据"""
if len(self.audio_buffer) < self.batch_process_size:
return
# 记录处理开始时间
start_time = time.time()
# 取出批处理数据
batch_data = self.audio_buffer[:self.batch_process_size]
self.audio_buffer = self.audio_buffer[self.batch_process_size:]
# 合并音频数据
combined_data = b''.join(batch_data)
# 更新处理时间
self.last_process_time = time.time()
# 更新性能统计
process_time = time.time() - start_time
self.process_count += 1
self.avg_process_time = (self.avg_process_time * (self.process_count - 1) + process_time) / self.process_count
# 性能监控
self._monitor_performance()
return combined_data
def _monitor_performance(self):
"""性能监控"""
current_time = time.time()
if current_time - self.last_monitor_time >= self.monitor_interval:
buffer_usage = len(self.audio_buffer) / self.buffer_size * 100
print(f"\n📊 性能监控 | 处理次数: {self.process_count} | 平均处理时间: {self.avg_process_time:.3f}s | 缓冲区使用: {buffer_usage:.1f}%")
self.last_monitor_time = current_time
def _save_recording(self, audio_data):
"""保存录音"""
timestamp = time.strftime("%Y%m%d_%H%M%S")
@ -262,13 +317,21 @@ class SimpleWakeAndRecord:
continue
if self.recording:
# 录音模式
# 录音模式 - 直接处理
self.recorded_frames.append(data)
recording_duration = time.time() - self.recording_start_time
# 使用录音专用的识别器进行实时识别
if self.recording_recognizer:
if self.recording_recognizer.AcceptWaveform(data):
# 录音时使用批处理进行识别
self.audio_buffer.append(data)
# 限制缓冲区大小
if len(self.audio_buffer) > self.buffer_size:
self.audio_buffer.pop(0)
# 批处理识别
if self._should_process_audio() and self.recording_recognizer:
combined_data = self._process_audio_batch()
if combined_data and self.recording_recognizer.AcceptWaveform(combined_data):
# 获取最终识别结果
result = json.loads(self.recording_recognizer.Result())
text = result.get('text', '').strip()
@ -277,7 +340,7 @@ class SimpleWakeAndRecord:
# 识别到文字,更新时间戳
self.last_text_time = time.time()
print(f"\n📝 识别: {text}")
else:
elif combined_data:
# 获取部分识别结果
partial_result = json.loads(self.recording_recognizer.PartialResult())
partial_text = partial_result.get('partial', '').strip()
@ -314,32 +377,41 @@ class SimpleWakeAndRecord:
print(f"\r{status}", end='', flush=True)
elif self.model and self.recognizer:
# 唤醒词检测模式
if self.recognizer.AcceptWaveform(data):
result = json.loads(self.recognizer.Result())
text = result.get('text', '').strip()
if text:
print(f"识别: {text}")
# 唤醒词检测模式 - 使用批处理
self.audio_buffer.append(data)
# 限制缓冲区大小
if len(self.audio_buffer) > self.buffer_size:
self.audio_buffer.pop(0)
# 批处理识别
if self._should_process_audio():
combined_data = self._process_audio_batch()
if combined_data and self.recognizer.AcceptWaveform(combined_data):
result = json.loads(self.recognizer.Result())
text = result.get('text', '').strip()
# 检查唤醒词
is_wake_word, detected_word = self._check_wake_word(text)
if is_wake_word:
print(f"🎯 检测到唤醒词: {detected_word}")
self._start_recording()
else:
# 显示实时音频级别
energy = self._calculate_energy(data)
if energy > 50: # 只显示有意义的音频级别
partial_result = json.loads(self.recognizer.PartialResult())
partial_text = partial_result.get('partial', '')
if partial_text:
status = f"监听中... 能量: {energy:.0f} | {partial_text}"
else:
status = f"监听中... 能量: {energy:.0f}"
print(status, end='\r')
if text:
print(f"识别: {text}")
# 检查唤醒词
is_wake_word, detected_word = self._check_wake_word(text)
if is_wake_word:
print(f"🎯 检测到唤醒词: {detected_word}")
self._start_recording()
else:
# 显示实时音频级别
energy = self._calculate_energy(data)
if energy > 50: # 只显示有意义的音频级别
partial_result = json.loads(self.recognizer.PartialResult())
partial_text = partial_result.get('partial', '')
if partial_text:
status = f"监听中... 能量: {energy:.0f} | {partial_text}"
else:
status = f"监听中... 能量: {energy:.0f}"
print(status, end='\r')
time.sleep(0.01)
time.sleep(0.05) # 增加延迟减少CPU使用
except KeyboardInterrupt:
print("\n👋 退出")
@ -394,6 +466,12 @@ def main():
print("5. 录音文件自动保存")
print("6. 录音完成后自动播放刚才录制的内容")
print("7. 按 Ctrl+C 退出")
print("🚀 性能优化已启用:")
print(" - 采样率: 8kHz (降低50%数据量)")
print(" - 批处理: 5个音频块/次")
print(" - 处理间隔: 0.5秒")
print(" - 缓冲区: 10个音频块")
print(" - 性能监控: 每5秒显示")
print("=" * 50)
# 开始运行

532
speech_recognizer.py Normal file
View File

@ -0,0 +1,532 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
语音识别模块
基于 SAUC API 为录音文件提供语音识别功能
"""
import os
import json
import time
import logging
import asyncio
import aiohttp
import struct
import gzip
import uuid
from typing import Optional, List, Dict, Any, AsyncGenerator
from dataclasses import dataclass
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 常量定义
DEFAULT_SAMPLE_RATE = 16000
class ProtocolVersion:
V1 = 0b0001
class MessageType:
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ERROR_RESPONSE = 0b1111
class MessageTypeSpecificFlags:
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
NEG_SEQUENCE = 0b0010
NEG_WITH_SEQUENCE = 0b0011
class SerializationType:
NO_SERIALIZATION = 0b0000
JSON = 0b0001
class CompressionType:
GZIP = 0b0001
@dataclass
class RecognitionResult:
"""语音识别结果"""
text: str
confidence: float
is_final: bool
start_time: Optional[float] = None
end_time: Optional[float] = None
class AudioUtils:
"""音频处理工具类"""
@staticmethod
def gzip_compress(data: bytes) -> bytes:
"""GZIP压缩"""
return gzip.compress(data)
@staticmethod
def gzip_decompress(data: bytes) -> bytes:
"""GZIP解压缩"""
return gzip.decompress(data)
@staticmethod
def is_wav_file(data: bytes) -> bool:
"""检查是否为WAV文件"""
if len(data) < 44:
return False
return data[:4] == b'RIFF' and data[8:12] == b'WAVE'
@staticmethod
def read_wav_info(data: bytes) -> tuple:
"""读取WAV文件信息"""
if len(data) < 44:
raise ValueError("Invalid WAV file: too short")
# 解析WAV头
chunk_id = data[:4]
if chunk_id != b'RIFF':
raise ValueError("Invalid WAV file: not RIFF format")
format_ = data[8:12]
if format_ != b'WAVE':
raise ValueError("Invalid WAV file: not WAVE format")
# 解析fmt子块
audio_format = struct.unpack('<H', data[20:22])[0]
num_channels = struct.unpack('<H', data[22:24])[0]
sample_rate = struct.unpack('<I', data[24:28])[0]
bits_per_sample = struct.unpack('<H', data[34:36])[0]
# 查找data子块
pos = 36
while pos < len(data) - 8:
subchunk_id = data[pos:pos+4]
subchunk_size = struct.unpack('<I', data[pos+4:pos+8])[0]
if subchunk_id == b'data':
wave_data = data[pos+8:pos+8+subchunk_size]
return (
num_channels,
bits_per_sample // 8,
sample_rate,
subchunk_size // (num_channels * (bits_per_sample // 8)),
wave_data
)
pos += 8 + subchunk_size
raise ValueError("Invalid WAV file: no data subchunk found")
class AsrConfig:
"""ASR配置"""
def __init__(self, app_key: str = None, access_key: str = None):
self.auth = {
"app_key": app_key or os.getenv("SAUC_APP_KEY", "your_app_key"),
"access_key": access_key or os.getenv("SAUC_ACCESS_KEY", "your_access_key")
}
@property
def app_key(self) -> str:
return self.auth["app_key"]
@property
def access_key(self) -> str:
return self.auth["access_key"]
class AsrRequestHeader:
"""ASR请求头"""
def __init__(self):
self.message_type = MessageType.CLIENT_FULL_REQUEST
self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE
self.serialization_type = SerializationType.JSON
self.compression_type = CompressionType.GZIP
self.reserved_data = bytes([0x00])
def with_message_type(self, message_type: int) -> 'AsrRequestHeader':
self.message_type = message_type
return self
def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader':
self.message_type_specific_flags = flags
return self
def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader':
self.serialization_type = serialization_type
return self
def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader':
self.compression_type = compression_type
return self
def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader':
self.reserved_data = reserved_data
return self
def to_bytes(self) -> bytes:
header = bytearray()
header.append((ProtocolVersion.V1 << 4) | 1)
header.append((self.message_type << 4) | self.message_type_specific_flags)
header.append((self.serialization_type << 4) | self.compression_type)
header.extend(self.reserved_data)
return bytes(header)
@staticmethod
def default_header() -> 'AsrRequestHeader':
return AsrRequestHeader()
class RequestBuilder:
"""请求构建器"""
@staticmethod
def new_auth_headers(config: AsrConfig) -> Dict[str, str]:
"""创建认证头"""
reqid = str(uuid.uuid4())
return {
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
"X-Api-Request-Id": reqid,
"X-Api-Access-Key": config.access_key,
"X-Api-App-Key": config.app_key
}
@staticmethod
def new_full_client_request(seq: int) -> bytes:
"""创建完整客户端请求"""
header = AsrRequestHeader.default_header() \
.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
payload = {
"user": {
"uid": "local_voice_user"
},
"audio": {
"format": "wav",
"codec": "raw",
"rate": 16000,
"bits": 16,
"channel": 1
},
"request": {
"model_name": "bigmodel",
"enable_itn": True,
"enable_punc": True,
"enable_ddc": True,
"show_utterances": True,
"enable_nonstream": False
}
}
payload_bytes = json.dumps(payload).encode('utf-8')
compressed_payload = AudioUtils.gzip_compress(payload_bytes)
payload_size = len(compressed_payload)
request = bytearray()
request.extend(header.to_bytes())
request.extend(struct.pack('>i', seq))
request.extend(struct.pack('>U', payload_size))
request.extend(compressed_payload)
return bytes(request)
@staticmethod
def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes:
"""创建纯音频请求"""
header = AsrRequestHeader.default_header()
if is_last:
header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE)
seq = -seq
else:
header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST)
request = bytearray()
request.extend(header.to_bytes())
request.extend(struct.pack('>i', seq))
compressed_segment = AudioUtils.gzip_compress(segment)
request.extend(struct.pack('>U', len(compressed_segment)))
request.extend(compressed_segment)
return bytes(request)
class AsrResponse:
"""ASR响应"""
def __init__(self):
self.code = 0
self.event = 0
self.is_last_package = False
self.payload_sequence = 0
self.payload_size = 0
self.payload_msg = None
def to_dict(self) -> Dict[str, Any]:
return {
"code": self.code,
"event": self.event,
"is_last_package": self.is_last_package,
"payload_sequence": self.payload_sequence,
"payload_size": self.payload_size,
"payload_msg": self.payload_msg
}
class ResponseParser:
"""响应解析器"""
@staticmethod
def parse_response(msg: bytes) -> AsrResponse:
"""解析响应"""
response = AsrResponse()
header_size = msg[0] & 0x0f
message_type = msg[1] >> 4
message_type_specific_flags = msg[1] & 0x0f
serialization_method = msg[2] >> 4
message_compression = msg[2] & 0x0f
payload = msg[header_size*4:]
# 解析message_type_specific_flags
if message_type_specific_flags & 0x01:
response.payload_sequence = struct.unpack('>i', payload[:4])[0]
payload = payload[4:]
if message_type_specific_flags & 0x02:
response.is_last_package = True
if message_type_specific_flags & 0x04:
response.event = struct.unpack('>i', payload[:4])[0]
payload = payload[4:]
# 解析message_type
if message_type == MessageType.SERVER_FULL_RESPONSE:
response.payload_size = struct.unpack('>U', payload[:4])[0]
payload = payload[4:]
elif message_type == MessageType.SERVER_ERROR_RESPONSE:
response.code = struct.unpack('>i', payload[:4])[0]
response.payload_size = struct.unpack('>U', payload[4:8])[0]
payload = payload[8:]
if not payload:
return response
# 解压缩
if message_compression == CompressionType.GZIP:
try:
payload = AudioUtils.gzip_decompress(payload)
except Exception as e:
logger.error(f"Failed to decompress payload: {e}")
return response
# 解析payload
try:
if serialization_method == SerializationType.JSON:
response.payload_msg = json.loads(payload.decode('utf-8'))
except Exception as e:
logger.error(f"Failed to parse payload: {e}")
return response
class SpeechRecognizer:
"""语音识别器"""
def __init__(self, app_key: str = None, access_key: str = None,
url: str = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream"):
self.config = AsrConfig(app_key, access_key)
self.url = url
self.seq = 1
async def recognize_file(self, file_path: str) -> List[RecognitionResult]:
"""识别音频文件"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"Audio file not found: {file_path}")
results = []
try:
async with aiohttp.ClientSession() as session:
# 读取音频文件
with open(file_path, 'rb') as f:
content = f.read()
if not AudioUtils.is_wav_file(content):
raise ValueError("Audio file must be in WAV format")
# 获取音频信息
try:
_, _, sample_rate, _, audio_data = AudioUtils.read_wav_info(content)
if sample_rate != DEFAULT_SAMPLE_RATE:
logger.warning(f"Sample rate {sample_rate} != {DEFAULT_SAMPLE_RATE}, may affect recognition accuracy")
except Exception as e:
logger.error(f"Failed to read audio info: {e}")
raise
# 计算分段大小 (200ms per segment)
segment_size = 1 * 2 * DEFAULT_SAMPLE_RATE * 200 // 1000 # channel * bytes_per_sample * sample_rate * duration_ms / 1000
# 创建WebSocket连接
headers = RequestBuilder.new_auth_headers(self.config)
async with session.ws_connect(self.url, headers=headers) as ws:
# 发送完整客户端请求
request = RequestBuilder.new_full_client_request(self.seq)
self.seq += 1
await ws.send_bytes(request)
# 接收初始响应
msg = await ws.receive()
if msg.type == aiohttp.WSMsgType.BINARY:
response = ResponseParser.parse_response(msg.data)
logger.info(f"Initial response: {response.to_dict()}")
# 分段发送音频数据
audio_segments = self._split_audio(audio_data, segment_size)
total_segments = len(audio_segments)
for i, segment in enumerate(audio_segments):
is_last = (i == total_segments - 1)
request = RequestBuilder.new_audio_only_request(
self.seq,
segment,
is_last=is_last
)
await ws.send_bytes(request)
logger.info(f"Sent audio segment {i+1}/{total_segments}")
if not is_last:
self.seq += 1
# 短暂延迟模拟实时流
await asyncio.sleep(0.1)
# 接收识别结果
final_text = ""
while True:
msg = await ws.receive()
if msg.type == aiohttp.WSMsgType.BINARY:
response = ResponseParser.parse_response(msg.data)
if response.payload_msg and 'text' in response.payload_msg:
text = response.payload_msg['text']
if text:
final_text += text
result = RecognitionResult(
text=text,
confidence=0.9, # 默认置信度
is_final=response.is_last_package
)
results.append(result)
logger.info(f"Recognized: {text}")
if response.is_last_package or response.code != 0:
break
elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error(f"WebSocket error: {msg.data}")
break
elif msg.type == aiohttp.WSMsgType.CLOSED:
logger.info("WebSocket connection closed")
break
# 如果没有获得最终结果,创建一个包含所有文本的结果
if final_text and not any(r.is_final for r in results):
final_result = RecognitionResult(
text=final_text,
confidence=0.9,
is_final=True
)
results.append(final_result)
return results
except Exception as e:
logger.error(f"Speech recognition failed: {e}")
raise
def _split_audio(self, data: bytes, segment_size: int) -> List[bytes]:
"""分割音频数据"""
if segment_size <= 0:
return []
segments = []
for i in range(0, len(data), segment_size):
end = i + segment_size
if end > len(data):
end = len(data)
segments.append(data[i:end])
return segments
async def recognize_latest_recording(self, directory: str = ".") -> Optional[RecognitionResult]:
"""识别最新的录音文件"""
# 查找最新的录音文件
recording_files = [f for f in os.listdir(directory) if f.startswith('recording_') and f.endswith('.wav')]
if not recording_files:
logger.warning("No recording files found")
return None
# 按文件名排序(包含时间戳)
recording_files.sort(reverse=True)
latest_file = recording_files[0]
latest_path = os.path.join(directory, latest_file)
logger.info(f"Recognizing latest recording: {latest_file}")
try:
results = await self.recognize_file(latest_path)
if results:
# 返回最终的识别结果
final_results = [r for r in results if r.is_final]
if final_results:
return final_results[-1]
else:
# 如果没有标记为final的结果返回最后一个
return results[-1]
except Exception as e:
logger.error(f"Failed to recognize latest recording: {e}")
return None
async def main():
"""测试函数"""
import argparse
parser = argparse.ArgumentParser(description="语音识别测试")
parser.add_argument("--file", type=str, help="音频文件路径")
parser.add_argument("--latest", action="store_true", help="识别最新的录音文件")
parser.add_argument("--app-key", type=str, help="SAUC App Key")
parser.add_argument("--access-key", type=str, help="SAUC Access Key")
args = parser.parse_args()
recognizer = SpeechRecognizer(
app_key=args.app_key,
access_key=args.access_key
)
try:
if args.latest:
result = await recognizer.recognize_latest_recording()
if result:
print(f"识别结果: {result.text}")
print(f"置信度: {result.confidence}")
print(f"最终结果: {result.is_final}")
else:
print("未能识别到语音内容")
elif args.file:
results = await recognizer.recognize_file(args.file)
for i, result in enumerate(results):
print(f"结果 {i+1}: {result.text}")
print(f"置信度: {result.confidence}")
print(f"最终结果: {result.is_final}")
print("-" * 40)
else:
print("请指定 --file 或 --latest 参数")
except Exception as e:
print(f"识别失败: {e}")
if __name__ == "__main__":
asyncio.run(main())