config
This commit is contained in:
parent
ef39e31a4b
commit
eb099d827d
501
enhanced_wake_and_record.py
Normal file
501
enhanced_wake_and_record.py
Normal file
@ -0,0 +1,501 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
集成语音识别的唤醒+录音系统
|
||||
基于 simple_wake_and_record.py,添加语音识别功能
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import pyaudio
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Optional, List
|
||||
|
||||
# 添加当前目录到路径
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
try:
|
||||
from vosk import Model, KaldiRecognizer
|
||||
VOSK_AVAILABLE = True
|
||||
except ImportError:
|
||||
VOSK_AVAILABLE = False
|
||||
print("⚠️ Vosk 未安装,请运行: pip install vosk")
|
||||
|
||||
from speech_recognizer import SpeechRecognizer, RecognitionResult
|
||||
|
||||
class EnhancedWakeAndRecord:
|
||||
"""增强的唤醒+录音系统,集成语音识别"""
|
||||
|
||||
def __init__(self, model_path="model", wake_words=["你好", "助手"],
|
||||
enable_speech_recognition=True, app_key=None, access_key=None):
|
||||
self.model_path = model_path
|
||||
self.wake_words = wake_words
|
||||
self.enable_speech_recognition = enable_speech_recognition
|
||||
self.model = None
|
||||
self.recognizer = None
|
||||
self.audio = None
|
||||
self.stream = None
|
||||
self.running = False
|
||||
|
||||
# 音频参数
|
||||
self.FORMAT = pyaudio.paInt16
|
||||
self.CHANNELS = 1
|
||||
self.RATE = 16000
|
||||
self.CHUNK_SIZE = 1024
|
||||
|
||||
# 录音相关
|
||||
self.recording = False
|
||||
self.recorded_frames = []
|
||||
self.last_text_time = None
|
||||
self.recording_start_time = None
|
||||
self.recording_recognizer = None
|
||||
|
||||
# 阈值
|
||||
self.text_silence_threshold = 3.0
|
||||
self.min_recording_time = 2.0
|
||||
self.max_recording_time = 30.0
|
||||
|
||||
# 语音识别相关
|
||||
self.speech_recognizer = None
|
||||
self.last_recognition_result = None
|
||||
self.recognition_thread = None
|
||||
|
||||
# 回调函数
|
||||
self.on_recognition_result = None
|
||||
|
||||
self._setup_model()
|
||||
self._setup_audio()
|
||||
self._setup_speech_recognition(app_key, access_key)
|
||||
|
||||
def _setup_model(self):
|
||||
"""设置 Vosk 模型"""
|
||||
if not VOSK_AVAILABLE:
|
||||
return
|
||||
|
||||
try:
|
||||
if not os.path.exists(self.model_path):
|
||||
print(f"模型路径不存在: {self.model_path}")
|
||||
return
|
||||
|
||||
self.model = Model(self.model_path)
|
||||
self.recognizer = KaldiRecognizer(self.model, self.RATE)
|
||||
self.recognizer.SetWords(True)
|
||||
|
||||
print(f"✅ Vosk 模型加载成功")
|
||||
|
||||
except Exception as e:
|
||||
print(f"模型初始化失败: {e}")
|
||||
|
||||
def _setup_audio(self):
|
||||
"""设置音频设备"""
|
||||
try:
|
||||
if self.audio is None:
|
||||
self.audio = pyaudio.PyAudio()
|
||||
|
||||
if self.stream is None:
|
||||
self.stream = self.audio.open(
|
||||
format=self.FORMAT,
|
||||
channels=self.CHANNELS,
|
||||
rate=self.RATE,
|
||||
input=True,
|
||||
frames_per_buffer=self.CHUNK_SIZE
|
||||
)
|
||||
|
||||
print("✅ 音频设备初始化成功")
|
||||
|
||||
except Exception as e:
|
||||
print(f"音频设备初始化失败: {e}")
|
||||
|
||||
def _setup_speech_recognition(self, app_key=None, access_key=None):
|
||||
"""设置语音识别"""
|
||||
if not self.enable_speech_recognition:
|
||||
return
|
||||
|
||||
try:
|
||||
self.speech_recognizer = SpeechRecognizer(
|
||||
app_key=app_key,
|
||||
access_key=access_key
|
||||
)
|
||||
print("✅ 语音识别器初始化成功")
|
||||
except Exception as e:
|
||||
print(f"语音识别器初始化失败: {e}")
|
||||
self.enable_speech_recognition = False
|
||||
|
||||
def _calculate_energy(self, audio_data):
|
||||
"""计算音频能量"""
|
||||
if len(audio_data) == 0:
|
||||
return 0
|
||||
|
||||
import numpy as np
|
||||
audio_array = np.frombuffer(audio_data, dtype=np.int16)
|
||||
rms = np.sqrt(np.mean(audio_array ** 2))
|
||||
return rms
|
||||
|
||||
def _check_wake_word(self, text):
|
||||
"""检查是否包含唤醒词"""
|
||||
if not text or not self.wake_words:
|
||||
return False, None
|
||||
|
||||
text_lower = text.lower()
|
||||
for wake_word in self.wake_words:
|
||||
if wake_word.lower() in text_lower:
|
||||
return True, wake_word
|
||||
return False, None
|
||||
|
||||
def _save_recording(self, audio_data):
|
||||
"""保存录音"""
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"recording_{timestamp}.wav"
|
||||
|
||||
try:
|
||||
import wave
|
||||
with wave.open(filename, 'wb') as wf:
|
||||
wf.setnchannels(self.CHANNELS)
|
||||
wf.setsampwidth(self.audio.get_sample_size(self.FORMAT))
|
||||
wf.setframerate(self.RATE)
|
||||
wf.writeframes(audio_data)
|
||||
|
||||
print(f"✅ 录音已保存: {filename}")
|
||||
return True, filename
|
||||
except Exception as e:
|
||||
print(f"保存录音失败: {e}")
|
||||
return False, None
|
||||
|
||||
def _play_audio(self, filename):
|
||||
"""播放音频文件"""
|
||||
try:
|
||||
import wave
|
||||
|
||||
# 打开音频文件
|
||||
with wave.open(filename, 'rb') as wf:
|
||||
# 获取音频参数
|
||||
channels = wf.getnchannels()
|
||||
width = wf.getsampwidth()
|
||||
rate = wf.getframerate()
|
||||
total_frames = wf.getnframes()
|
||||
|
||||
# 分块读取音频数据,避免内存问题
|
||||
chunk_size = 1024
|
||||
frames = []
|
||||
|
||||
for _ in range(0, total_frames, chunk_size):
|
||||
chunk = wf.readframes(chunk_size)
|
||||
if chunk:
|
||||
frames.append(chunk)
|
||||
else:
|
||||
break
|
||||
|
||||
# 创建播放流
|
||||
playback_stream = self.audio.open(
|
||||
format=self.audio.get_format_from_width(width),
|
||||
channels=channels,
|
||||
rate=rate,
|
||||
output=True
|
||||
)
|
||||
|
||||
print(f"🔊 开始播放: {filename}")
|
||||
|
||||
# 分块播放音频
|
||||
for chunk in frames:
|
||||
playback_stream.write(chunk)
|
||||
|
||||
# 等待播放完成
|
||||
playback_stream.stop_stream()
|
||||
playback_stream.close()
|
||||
|
||||
print("✅ 播放完成")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 播放失败: {e}")
|
||||
self._play_with_system_player(filename)
|
||||
|
||||
def _play_with_system_player(self, filename):
|
||||
"""使用系统播放器播放音频"""
|
||||
try:
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
system = platform.system()
|
||||
|
||||
if system == 'Darwin': # macOS
|
||||
cmd = ['afplay', filename]
|
||||
elif system == 'Windows':
|
||||
cmd = ['start', '/min', filename]
|
||||
else: # Linux
|
||||
cmd = ['aplay', filename]
|
||||
|
||||
print(f"🔊 使用系统播放器: {' '.join(cmd)}")
|
||||
subprocess.run(cmd, check=True)
|
||||
print("✅ 播放完成")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 系统播放器也失败: {e}")
|
||||
print(f"💡 文件已保存,请手动播放: {filename}")
|
||||
|
||||
def _start_recognition_thread(self, filename):
|
||||
"""启动语音识别线程"""
|
||||
if not self.enable_speech_recognition or not self.speech_recognizer:
|
||||
return
|
||||
|
||||
def recognize_task():
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
print(f"🧠 开始识别录音文件: {filename}")
|
||||
result = loop.run_until_complete(
|
||||
self.speech_recognizer.recognize_file(filename)
|
||||
)
|
||||
|
||||
if result:
|
||||
# 合并所有识别结果
|
||||
full_text = " ".join([r.text for r in result])
|
||||
final_result = RecognitionResult(
|
||||
text=full_text,
|
||||
confidence=0.9,
|
||||
is_final=True
|
||||
)
|
||||
|
||||
self.last_recognition_result = final_result
|
||||
print(f"\n🧠 语音识别结果: {full_text}")
|
||||
|
||||
# 调用回调函数
|
||||
if self.on_recognition_result:
|
||||
self.on_recognition_result(final_result)
|
||||
else:
|
||||
print(f"\n🧠 语音识别失败或未识别到内容")
|
||||
|
||||
loop.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 语音识别线程异常: {e}")
|
||||
|
||||
self.recognition_thread = threading.Thread(target=recognize_task)
|
||||
self.recognition_thread.daemon = True
|
||||
self.recognition_thread.start()
|
||||
|
||||
def _start_recording(self):
|
||||
"""开始录音"""
|
||||
print("🎙️ 开始录音,请说话...")
|
||||
self.recording = True
|
||||
self.recorded_frames = []
|
||||
self.last_text_time = None
|
||||
self.recording_start_time = time.time()
|
||||
|
||||
# 为录音创建一个新的识别器
|
||||
if self.model:
|
||||
self.recording_recognizer = KaldiRecognizer(self.model, self.RATE)
|
||||
self.recording_recognizer.SetWords(True)
|
||||
|
||||
def _stop_recording(self):
|
||||
"""停止录音"""
|
||||
if len(self.recorded_frames) > 0:
|
||||
audio_data = b''.join(self.recorded_frames)
|
||||
duration = len(audio_data) / (self.RATE * 2)
|
||||
print(f"📝 录音完成,时长: {duration:.2f}秒")
|
||||
|
||||
# 保存录音
|
||||
success, filename = self._save_recording(audio_data)
|
||||
|
||||
# 如果保存成功,播放录音并进行语音识别
|
||||
if success and filename:
|
||||
print("=" * 50)
|
||||
print("🔊 播放刚才录制的音频...")
|
||||
self._play_audio(filename)
|
||||
print("=" * 50)
|
||||
|
||||
# 启动语音识别
|
||||
if self.enable_speech_recognition:
|
||||
print("🧠 准备进行语音识别...")
|
||||
self._start_recognition_thread(filename)
|
||||
|
||||
self.recording = False
|
||||
self.recorded_frames = []
|
||||
self.last_text_time = None
|
||||
self.recording_start_time = None
|
||||
self.recording_recognizer = None
|
||||
|
||||
def set_recognition_callback(self, callback):
|
||||
"""设置识别结果回调函数"""
|
||||
self.on_recognition_result = callback
|
||||
|
||||
def get_last_recognition_result(self) -> Optional[RecognitionResult]:
|
||||
"""获取最后一次识别结果"""
|
||||
return self.last_recognition_result
|
||||
|
||||
def start(self):
|
||||
"""开始唤醒词检测和录音"""
|
||||
if not self.stream:
|
||||
print("❌ 音频设备未初始化")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
print("🎤 开始监听...")
|
||||
print(f"唤醒词: {', '.join(self.wake_words)}")
|
||||
if self.enable_speech_recognition:
|
||||
print("🧠 语音识别: 已启用")
|
||||
else:
|
||||
print("🧠 语音识别: 已禁用")
|
||||
|
||||
try:
|
||||
while self.running:
|
||||
# 读取音频数据
|
||||
data = self.stream.read(self.CHUNK_SIZE, exception_on_overflow=False)
|
||||
|
||||
if len(data) == 0:
|
||||
continue
|
||||
|
||||
if self.recording:
|
||||
# 录音模式
|
||||
self.recorded_frames.append(data)
|
||||
recording_duration = time.time() - self.recording_start_time
|
||||
|
||||
# 使用录音专用的识别器进行实时识别
|
||||
if self.recording_recognizer:
|
||||
if self.recording_recognizer.AcceptWaveform(data):
|
||||
result = json.loads(self.recording_recognizer.Result())
|
||||
text = result.get('text', '').strip()
|
||||
|
||||
if text:
|
||||
self.last_text_time = time.time()
|
||||
print(f"\n📝 实时识别: {text}")
|
||||
else:
|
||||
partial_result = json.loads(self.recording_recognizer.PartialResult())
|
||||
partial_text = partial_result.get('partial', '').strip()
|
||||
|
||||
if partial_text:
|
||||
self.last_text_time = time.time()
|
||||
status = f"录音中... {recording_duration:.1f}s | {partial_text}"
|
||||
print(f"\r{status}", end='', flush=True)
|
||||
|
||||
# 检查是否需要结束录音
|
||||
current_time = time.time()
|
||||
|
||||
if self.last_text_time is not None:
|
||||
text_silence_duration = current_time - self.last_text_time
|
||||
if text_silence_duration > self.text_silence_threshold and recording_duration >= self.min_recording_time:
|
||||
print(f"\n\n3秒没有识别到文字,结束录音")
|
||||
self._stop_recording()
|
||||
else:
|
||||
if recording_duration > 5.0:
|
||||
print(f"\n\n5秒没有识别到文字,结束录音")
|
||||
self._stop_recording()
|
||||
|
||||
# 检查最大录音时间
|
||||
if recording_duration > self.max_recording_time:
|
||||
print(f"\n\n达到最大录音时间 {self.max_recording_time}s")
|
||||
self._stop_recording()
|
||||
|
||||
# 显示录音状态
|
||||
if self.last_text_time is None:
|
||||
status = f"等待语音输入... {recording_duration:.1f}s"
|
||||
print(f"\r{status}", end='', flush=True)
|
||||
|
||||
elif self.model and self.recognizer:
|
||||
# 唤醒词检测模式
|
||||
if self.recognizer.AcceptWaveform(data):
|
||||
result = json.loads(self.recognizer.Result())
|
||||
text = result.get('text', '').strip()
|
||||
|
||||
if text:
|
||||
print(f"识别: {text}")
|
||||
|
||||
# 检查唤醒词
|
||||
is_wake_word, detected_word = self._check_wake_word(text)
|
||||
if is_wake_word:
|
||||
print(f"🎯 检测到唤醒词: {detected_word}")
|
||||
self._start_recording()
|
||||
else:
|
||||
# 显示实时音频级别
|
||||
energy = self._calculate_energy(data)
|
||||
if energy > 50:
|
||||
partial_result = json.loads(self.recognizer.PartialResult())
|
||||
partial_text = partial_result.get('partial', '')
|
||||
if partial_text:
|
||||
status = f"监听中... 能量: {energy:.0f} | {partial_text}"
|
||||
else:
|
||||
status = f"监听中... 能量: {energy:.0f}"
|
||||
print(status, end='\r')
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 退出")
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
finally:
|
||||
self.stop()
|
||||
|
||||
def stop(self):
|
||||
"""停止"""
|
||||
self.running = False
|
||||
if self.recording:
|
||||
self._stop_recording()
|
||||
|
||||
if self.stream:
|
||||
self.stream.stop_stream()
|
||||
self.stream.close()
|
||||
self.stream = None
|
||||
|
||||
if self.audio:
|
||||
self.audio.terminate()
|
||||
self.audio = None
|
||||
|
||||
# 等待识别线程结束
|
||||
if self.recognition_thread and self.recognition_thread.is_alive():
|
||||
self.recognition_thread.join(timeout=5.0)
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("🚀 增强版唤醒+录音+语音识别测试")
|
||||
print("=" * 50)
|
||||
|
||||
# 检查模型
|
||||
model_dir = "model"
|
||||
if not os.path.exists(model_dir):
|
||||
print("⚠️ 未找到模型目录")
|
||||
print("请下载 Vosk 模型到 model 目录")
|
||||
return
|
||||
|
||||
# 创建系统
|
||||
system = EnhancedWakeAndRecord(
|
||||
model_path=model_dir,
|
||||
wake_words=["你好", "助手", "小爱"],
|
||||
enable_speech_recognition=True,
|
||||
# app_key="your_app_key", # 请填入实际的app_key
|
||||
# access_key="your_access_key" # 请填入实际的access_key
|
||||
)
|
||||
|
||||
if not system.model:
|
||||
print("❌ 模型加载失败")
|
||||
return
|
||||
|
||||
# 设置识别结果回调
|
||||
def on_recognition_result(result):
|
||||
print(f"\n🎯 识别完成!结果: {result.text}")
|
||||
print(f" 置信度: {result.confidence}")
|
||||
print(f" 是否最终结果: {result.is_final}")
|
||||
|
||||
system.set_recognition_callback(on_recognition_result)
|
||||
|
||||
print("✅ 系统初始化成功")
|
||||
print("📖 使用说明:")
|
||||
print("1. 说唤醒词开始录音")
|
||||
print("2. 基于语音识别判断,3秒没有识别到文字就结束")
|
||||
print("3. 最少录音2秒,最多30秒")
|
||||
print("4. 录音时实时显示识别结果")
|
||||
print("5. 录音文件自动保存")
|
||||
print("6. 录音完成后自动播放刚才录制的内容")
|
||||
print("7. 启动语音识别对录音文件进行识别")
|
||||
print("8. 按 Ctrl+C 退出")
|
||||
print("=" * 50)
|
||||
|
||||
# 开始运行
|
||||
system.start()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
127
recognition_example.py
Normal file
127
recognition_example.py
Normal file
@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
语音识别使用示例
|
||||
演示如何使用 speech_recognizer 模块
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from speech_recognizer import SpeechRecognizer
|
||||
|
||||
async def example_recognize_file():
|
||||
"""示例:识别单个音频文件"""
|
||||
print("=== 示例1:识别单个音频文件 ===")
|
||||
|
||||
# 初始化识别器
|
||||
recognizer = SpeechRecognizer(
|
||||
app_key="your_app_key", # 请替换为实际的app_key
|
||||
access_key="your_access_key" # 请替换为实际的access_key
|
||||
)
|
||||
|
||||
# 假设有一个录音文件
|
||||
audio_file = "recording_20240101_120000.wav"
|
||||
|
||||
if not os.path.exists(audio_file):
|
||||
print(f"音频文件不存在: {audio_file}")
|
||||
print("请先运行 enhanced_wake_and_record.py 录制一个音频文件")
|
||||
return
|
||||
|
||||
try:
|
||||
# 识别音频文件
|
||||
results = await recognizer.recognize_file(audio_file)
|
||||
|
||||
print(f"识别结果(共{len(results)}个):")
|
||||
for i, result in enumerate(results):
|
||||
print(f"结果 {i+1}:")
|
||||
print(f" 文本: {result.text}")
|
||||
print(f" 置信度: {result.confidence}")
|
||||
print(f" 最终结果: {result.is_final}")
|
||||
print("-" * 40)
|
||||
|
||||
except Exception as e:
|
||||
print(f"识别失败: {e}")
|
||||
|
||||
async def example_recognize_latest():
|
||||
"""示例:识别最新的录音文件"""
|
||||
print("\n=== 示例2:识别最新的录音文件 ===")
|
||||
|
||||
# 初始化识别器
|
||||
recognizer = SpeechRecognizer(
|
||||
app_key="your_app_key", # 请替换为实际的app_key
|
||||
access_key="your_access_key" # 请替换为实际的access_key
|
||||
)
|
||||
|
||||
try:
|
||||
# 识别最新的录音文件
|
||||
result = await recognizer.recognize_latest_recording()
|
||||
|
||||
if result:
|
||||
print("识别结果:")
|
||||
print(f" 文本: {result.text}")
|
||||
print(f" 置信度: {result.confidence}")
|
||||
print(f" 最终结果: {result.is_final}")
|
||||
else:
|
||||
print("未找到录音文件或识别失败")
|
||||
|
||||
except Exception as e:
|
||||
print(f"识别失败: {e}")
|
||||
|
||||
async def example_batch_recognition():
|
||||
"""示例:批量识别多个录音文件"""
|
||||
print("\n=== 示例3:批量识别录音文件 ===")
|
||||
|
||||
# 初始化识别器
|
||||
recognizer = SpeechRecognizer(
|
||||
app_key="your_app_key", # 请替换为实际的app_key
|
||||
access_key="your_access_key" # 请替换为实际的access_key
|
||||
)
|
||||
|
||||
# 获取所有录音文件
|
||||
recording_files = [f for f in os.listdir(".") if f.startswith('recording_') and f.endswith('.wav')]
|
||||
|
||||
if not recording_files:
|
||||
print("未找到录音文件")
|
||||
return
|
||||
|
||||
print(f"找到 {len(recording_files)} 个录音文件")
|
||||
|
||||
for filename in recording_files[:5]: # 只处理前5个文件
|
||||
print(f"\n处理文件: {filename}")
|
||||
try:
|
||||
results = await recognizer.recognize_file(filename)
|
||||
|
||||
if results:
|
||||
final_result = results[-1] # 取最后一个结果
|
||||
print(f"识别结果: {final_result.text}")
|
||||
else:
|
||||
print("识别失败")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理失败: {e}")
|
||||
|
||||
# 添加延迟,避免请求过于频繁
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
print("🚀 语音识别使用示例")
|
||||
print("=" * 50)
|
||||
|
||||
# 请先设置环境变量或在代码中填入实际的API密钥
|
||||
if not os.getenv("SAUC_APP_KEY") and "your_app_key" in "your_app_key":
|
||||
print("⚠️ 请先设置 SAUC_APP_KEY 和 SAUC_ACCESS_KEY 环境变量")
|
||||
print("或者在代码中填入实际的 app_key 和 access_key")
|
||||
print("示例:")
|
||||
print("export SAUC_APP_KEY='your_app_key'")
|
||||
print("export SAUC_ACCESS_KEY='your_access_key'")
|
||||
return
|
||||
|
||||
# 运行示例
|
||||
await example_recognize_file()
|
||||
await example_recognize_latest()
|
||||
await example_batch_recognition()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,3 +1,5 @@
|
||||
vosk>=0.3.44
|
||||
pyaudio>=0.2.11
|
||||
numpy>=1.19.0
|
||||
numpy>=1.19.0
|
||||
aiohttp>=3.8.0
|
||||
asyncio
|
||||
15
sauc_python/readme.md
Normal file
15
sauc_python/readme.md
Normal file
@ -0,0 +1,15 @@
|
||||
# README
|
||||
|
||||
**asr tob 相关client demo**
|
||||
|
||||
# Notice
|
||||
python version: python 3.x
|
||||
|
||||
替换代码中的key为真实数据:
|
||||
"app_key": "xxxxxxx",
|
||||
"access_key": "xxxxxxxxxxxxxxxx"
|
||||
使用示例:
|
||||
python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav
|
||||
|
||||
|
||||
|
||||
523
sauc_python/sauc_websocket_demo.py
Normal file
523
sauc_python/sauc_websocket_demo.py
Normal file
@ -0,0 +1,523 @@
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import struct
|
||||
import gzip
|
||||
import uuid
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Optional, List, Dict, Any, Tuple, AsyncGenerator
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('run.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 常量定义
|
||||
DEFAULT_SAMPLE_RATE = 16000
|
||||
|
||||
class ProtocolVersion:
|
||||
V1 = 0b0001
|
||||
|
||||
class MessageType:
|
||||
CLIENT_FULL_REQUEST = 0b0001
|
||||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||
SERVER_FULL_RESPONSE = 0b1001
|
||||
SERVER_ERROR_RESPONSE = 0b1111
|
||||
|
||||
class MessageTypeSpecificFlags:
|
||||
NO_SEQUENCE = 0b0000
|
||||
POS_SEQUENCE = 0b0001
|
||||
NEG_SEQUENCE = 0b0010
|
||||
NEG_WITH_SEQUENCE = 0b0011
|
||||
|
||||
class SerializationType:
|
||||
NO_SERIALIZATION = 0b0000
|
||||
JSON = 0b0001
|
||||
|
||||
class CompressionType:
|
||||
GZIP = 0b0001
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
# 填入控制台获取的app id和access token
|
||||
self.auth = {
|
||||
"app_key": "xxxxxxx",
|
||||
"access_key": "xxxxxxxxxxxx"
|
||||
}
|
||||
|
||||
@property
|
||||
def app_key(self) -> str:
|
||||
return self.auth["app_key"]
|
||||
|
||||
@property
|
||||
def access_key(self) -> str:
|
||||
return self.auth["access_key"]
|
||||
|
||||
config = Config()
|
||||
|
||||
class CommonUtils:
|
||||
@staticmethod
|
||||
def gzip_compress(data: bytes) -> bytes:
|
||||
return gzip.compress(data)
|
||||
|
||||
@staticmethod
|
||||
def gzip_decompress(data: bytes) -> bytes:
|
||||
return gzip.decompress(data)
|
||||
|
||||
@staticmethod
|
||||
def judge_wav(data: bytes) -> bool:
|
||||
if len(data) < 44:
|
||||
return False
|
||||
return data[:4] == b'RIFF' and data[8:12] == b'WAVE'
|
||||
|
||||
@staticmethod
|
||||
def convert_wav_with_path(audio_path: str, sample_rate: int = DEFAULT_SAMPLE_RATE) -> bytes:
|
||||
try:
|
||||
cmd = [
|
||||
"ffmpeg", "-v", "quiet", "-y", "-i", audio_path,
|
||||
"-acodec", "pcm_s16le", "-ac", "1", "-ar", str(sample_rate),
|
||||
"-f", "wav", "-"
|
||||
]
|
||||
result = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
# 尝试删除原始文件
|
||||
try:
|
||||
os.remove(audio_path)
|
||||
except OSError as e:
|
||||
logger.warning(f"Failed to remove original file: {e}")
|
||||
|
||||
return result.stdout
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}")
|
||||
raise RuntimeError(f"Audio conversion failed: {e.stderr.decode()}")
|
||||
|
||||
@staticmethod
|
||||
def read_wav_info(data: bytes) -> Tuple[int, int, int, int, bytes]:
|
||||
if len(data) < 44:
|
||||
raise ValueError("Invalid WAV file: too short")
|
||||
|
||||
# 解析WAV头
|
||||
chunk_id = data[:4]
|
||||
if chunk_id != b'RIFF':
|
||||
raise ValueError("Invalid WAV file: not RIFF format")
|
||||
|
||||
format_ = data[8:12]
|
||||
if format_ != b'WAVE':
|
||||
raise ValueError("Invalid WAV file: not WAVE format")
|
||||
|
||||
# 解析fmt子块
|
||||
audio_format = struct.unpack('<H', data[20:22])[0]
|
||||
num_channels = struct.unpack('<H', data[22:24])[0]
|
||||
sample_rate = struct.unpack('<I', data[24:28])[0]
|
||||
bits_per_sample = struct.unpack('<H', data[34:36])[0]
|
||||
|
||||
# 查找data子块
|
||||
pos = 36
|
||||
while pos < len(data) - 8:
|
||||
subchunk_id = data[pos:pos+4]
|
||||
subchunk_size = struct.unpack('<I', data[pos+4:pos+8])[0]
|
||||
if subchunk_id == b'data':
|
||||
wave_data = data[pos+8:pos+8+subchunk_size]
|
||||
return (
|
||||
num_channels,
|
||||
bits_per_sample // 8,
|
||||
sample_rate,
|
||||
subchunk_size // (num_channels * (bits_per_sample // 8)),
|
||||
wave_data
|
||||
)
|
||||
pos += 8 + subchunk_size
|
||||
|
||||
raise ValueError("Invalid WAV file: no data subchunk found")
|
||||
|
||||
class AsrRequestHeader:
|
||||
def __init__(self):
|
||||
self.message_type = MessageType.CLIENT_FULL_REQUEST
|
||||
self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE
|
||||
self.serialization_type = SerializationType.JSON
|
||||
self.compression_type = CompressionType.GZIP
|
||||
self.reserved_data = bytes([0x00])
|
||||
|
||||
def with_message_type(self, message_type: int) -> 'AsrRequestHeader':
|
||||
self.message_type = message_type
|
||||
return self
|
||||
|
||||
def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader':
|
||||
self.message_type_specific_flags = flags
|
||||
return self
|
||||
|
||||
def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader':
|
||||
self.serialization_type = serialization_type
|
||||
return self
|
||||
|
||||
def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader':
|
||||
self.compression_type = compression_type
|
||||
return self
|
||||
|
||||
def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader':
|
||||
self.reserved_data = reserved_data
|
||||
return self
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
header = bytearray()
|
||||
header.append((ProtocolVersion.V1 << 4) | 1)
|
||||
header.append((self.message_type << 4) | self.message_type_specific_flags)
|
||||
header.append((self.serialization_type << 4) | self.compression_type)
|
||||
header.extend(self.reserved_data)
|
||||
return bytes(header)
|
||||
|
||||
@staticmethod
|
||||
def default_header() -> 'AsrRequestHeader':
|
||||
return AsrRequestHeader()
|
||||
|
||||
class RequestBuilder:
|
||||
@staticmethod
|
||||
def new_auth_headers() -> Dict[str, str]:
|
||||
reqid = str(uuid.uuid4())
|
||||
return {
|
||||
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
|
||||
"X-Api-Request-Id": reqid,
|
||||
"X-Api-Access-Key": config.access_key,
|
||||
"X-Api-App-Key": config.app_key
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def new_full_client_request(seq: int) -> bytes: # 添加seq参数
|
||||
header = AsrRequestHeader.default_header() \
|
||||
.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
|
||||
|
||||
payload = {
|
||||
"user": {
|
||||
"uid": "demo_uid"
|
||||
},
|
||||
"audio": {
|
||||
"format": "wav",
|
||||
"codec": "raw",
|
||||
"rate": 16000,
|
||||
"bits": 16,
|
||||
"channel": 1
|
||||
},
|
||||
"request": {
|
||||
"model_name": "bigmodel",
|
||||
"enable_itn": True,
|
||||
"enable_punc": True,
|
||||
"enable_ddc": True,
|
||||
"show_utterances": True,
|
||||
"enable_nonstream": False
|
||||
}
|
||||
}
|
||||
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
compressed_payload = CommonUtils.gzip_compress(payload_bytes)
|
||||
payload_size = len(compressed_payload)
|
||||
|
||||
request = bytearray()
|
||||
request.extend(header.to_bytes())
|
||||
request.extend(struct.pack('>i', seq)) # 使用传入的seq
|
||||
request.extend(struct.pack('>I', payload_size))
|
||||
request.extend(compressed_payload)
|
||||
|
||||
return bytes(request)
|
||||
|
||||
@staticmethod
|
||||
def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes:
|
||||
header = AsrRequestHeader.default_header()
|
||||
if is_last: # 最后一个包特殊处理
|
||||
header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE)
|
||||
seq = -seq # 设为负值
|
||||
else:
|
||||
header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
|
||||
header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST)
|
||||
|
||||
request = bytearray()
|
||||
request.extend(header.to_bytes())
|
||||
request.extend(struct.pack('>i', seq))
|
||||
|
||||
compressed_segment = CommonUtils.gzip_compress(segment)
|
||||
request.extend(struct.pack('>I', len(compressed_segment)))
|
||||
request.extend(compressed_segment)
|
||||
|
||||
return bytes(request)
|
||||
|
||||
class AsrResponse:
|
||||
def __init__(self):
|
||||
self.code = 0
|
||||
self.event = 0
|
||||
self.is_last_package = False
|
||||
self.payload_sequence = 0
|
||||
self.payload_size = 0
|
||||
self.payload_msg = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"code": self.code,
|
||||
"event": self.event,
|
||||
"is_last_package": self.is_last_package,
|
||||
"payload_sequence": self.payload_sequence,
|
||||
"payload_size": self.payload_size,
|
||||
"payload_msg": self.payload_msg
|
||||
}
|
||||
|
||||
class ResponseParser:
|
||||
@staticmethod
|
||||
def parse_response(msg: bytes) -> AsrResponse:
|
||||
response = AsrResponse()
|
||||
|
||||
header_size = msg[0] & 0x0f
|
||||
message_type = msg[1] >> 4
|
||||
message_type_specific_flags = msg[1] & 0x0f
|
||||
serialization_method = msg[2] >> 4
|
||||
message_compression = msg[2] & 0x0f
|
||||
|
||||
payload = msg[header_size*4:]
|
||||
|
||||
# 解析message_type_specific_flags
|
||||
if message_type_specific_flags & 0x01:
|
||||
response.payload_sequence = struct.unpack('>i', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
if message_type_specific_flags & 0x02:
|
||||
response.is_last_package = True
|
||||
if message_type_specific_flags & 0x04:
|
||||
response.event = struct.unpack('>i', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
|
||||
# 解析message_type
|
||||
if message_type == MessageType.SERVER_FULL_RESPONSE:
|
||||
response.payload_size = struct.unpack('>I', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
elif message_type == MessageType.SERVER_ERROR_RESPONSE:
|
||||
response.code = struct.unpack('>i', payload[:4])[0]
|
||||
response.payload_size = struct.unpack('>I', payload[4:8])[0]
|
||||
payload = payload[8:]
|
||||
|
||||
if not payload:
|
||||
return response
|
||||
|
||||
# 解压缩
|
||||
if message_compression == CompressionType.GZIP:
|
||||
try:
|
||||
payload = CommonUtils.gzip_decompress(payload)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decompress payload: {e}")
|
||||
return response
|
||||
|
||||
# 解析payload
|
||||
try:
|
||||
if serialization_method == SerializationType.JSON:
|
||||
response.payload_msg = json.loads(payload.decode('utf-8'))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse payload: {e}")
|
||||
|
||||
return response
|
||||
|
||||
class AsrWsClient:
|
||||
def __init__(self, url: str, segment_duration: int = 200):
|
||||
self.seq = 1
|
||||
self.url = url
|
||||
self.segment_duration = segment_duration
|
||||
self.conn = None
|
||||
self.session = None # 添加session引用
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = aiohttp.ClientSession()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
if self.conn and not self.conn.closed:
|
||||
await self.conn.close()
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
|
||||
async def read_audio_data(self, file_path: str) -> bytes:
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
|
||||
if not CommonUtils.judge_wav(content):
|
||||
logger.info("Converting audio to WAV format...")
|
||||
content = CommonUtils.convert_wav_with_path(file_path, DEFAULT_SAMPLE_RATE)
|
||||
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read audio data: {e}")
|
||||
raise
|
||||
|
||||
def get_segment_size(self, content: bytes) -> int:
|
||||
try:
|
||||
channel_num, samp_width, frame_rate, _, _ = CommonUtils.read_wav_info(content)[:5]
|
||||
size_per_sec = channel_num * samp_width * frame_rate
|
||||
segment_size = size_per_sec * self.segment_duration // 1000
|
||||
return segment_size
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to calculate segment size: {e}")
|
||||
raise
|
||||
|
||||
async def create_connection(self) -> None:
|
||||
headers = RequestBuilder.new_auth_headers()
|
||||
try:
|
||||
self.conn = await self.session.ws_connect( # 使用self.session
|
||||
self.url,
|
||||
headers=headers
|
||||
)
|
||||
logger.info(f"Connected to {self.url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to WebSocket: {e}")
|
||||
raise
|
||||
|
||||
async def send_full_client_request(self) -> None:
|
||||
request = RequestBuilder.new_full_client_request(self.seq)
|
||||
self.seq += 1 # 发送后递增
|
||||
try:
|
||||
await self.conn.send_bytes(request)
|
||||
logger.info(f"Sent full client request with seq: {self.seq-1}")
|
||||
|
||||
msg = await self.conn.receive()
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
response = ResponseParser.parse_response(msg.data)
|
||||
logger.info(f"Received response: {response.to_dict()}")
|
||||
else:
|
||||
logger.error(f"Unexpected message type: {msg.type}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send full client request: {e}")
|
||||
raise
|
||||
|
||||
async def send_messages(self, segment_size: int, content: bytes) -> AsyncGenerator[None, None]:
|
||||
audio_segments = self.split_audio(content, segment_size)
|
||||
total_segments = len(audio_segments)
|
||||
|
||||
for i, segment in enumerate(audio_segments):
|
||||
is_last = (i == total_segments - 1)
|
||||
request = RequestBuilder.new_audio_only_request(
|
||||
self.seq,
|
||||
segment,
|
||||
is_last=is_last
|
||||
)
|
||||
await self.conn.send_bytes(request)
|
||||
logger.info(f"Sent audio segment with seq: {self.seq} (last: {is_last})")
|
||||
|
||||
if not is_last:
|
||||
self.seq += 1
|
||||
|
||||
await asyncio.sleep(self.segment_duration / 1000) # 逐个发送,间隔时间模拟实时流
|
||||
# 让出控制权,允许接受消息
|
||||
yield
|
||||
|
||||
async def recv_messages(self) -> AsyncGenerator[AsrResponse, None]:
|
||||
try:
|
||||
async for msg in self.conn:
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
response = ResponseParser.parse_response(msg.data)
|
||||
yield response
|
||||
|
||||
if response.is_last_package or response.code != 0:
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error: {msg.data}")
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
logger.info("WebSocket connection closed")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving messages: {e}")
|
||||
raise
|
||||
|
||||
async def start_audio_stream(self, segment_size: int, content: bytes) -> AsyncGenerator[AsrResponse, None]:
|
||||
async def sender():
|
||||
async for _ in self.send_messages(segment_size, content):
|
||||
pass
|
||||
|
||||
# 启动发送和接收任务
|
||||
sender_task = asyncio.create_task(sender())
|
||||
|
||||
try:
|
||||
async for response in self.recv_messages():
|
||||
yield response
|
||||
finally:
|
||||
sender_task.cancel()
|
||||
try:
|
||||
await sender_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def split_audio(data: bytes, segment_size: int) -> List[bytes]:
|
||||
if segment_size <= 0:
|
||||
return []
|
||||
|
||||
segments = []
|
||||
for i in range(0, len(data), segment_size):
|
||||
end = i + segment_size
|
||||
if end > len(data):
|
||||
end = len(data)
|
||||
segments.append(data[i:end])
|
||||
return segments
|
||||
|
||||
async def execute(self, file_path: str) -> AsyncGenerator[AsrResponse, None]:
|
||||
if not file_path:
|
||||
raise ValueError("File path is empty")
|
||||
|
||||
if not self.url:
|
||||
raise ValueError("URL is empty")
|
||||
|
||||
self.seq = 1
|
||||
|
||||
try:
|
||||
# 1. 读取音频文件
|
||||
content = await self.read_audio_data(file_path)
|
||||
|
||||
# 2. 计算分段大小
|
||||
segment_size = self.get_segment_size(content)
|
||||
|
||||
# 3. 创建WebSocket连接
|
||||
await self.create_connection()
|
||||
|
||||
# 4. 发送完整客户端请求
|
||||
await self.send_full_client_request()
|
||||
|
||||
# 5. 启动音频流处理
|
||||
async for response in self.start_audio_stream(segment_size, content):
|
||||
yield response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ASR execution: {e}")
|
||||
raise
|
||||
finally:
|
||||
if self.conn:
|
||||
await self.conn.close()
|
||||
|
||||
async def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="ASR WebSocket Client")
|
||||
parser.add_argument("--file", type=str, required=True, help="Audio file path")
|
||||
|
||||
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel
|
||||
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async
|
||||
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream
|
||||
parser.add_argument("--url", type=str, default="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream",
|
||||
help="WebSocket URL")
|
||||
parser.add_argument("--seg-duration", type=int, default=200,
|
||||
help="Audio duration(ms) per packet, default:200")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
async with AsrWsClient(args.url, args.seg_duration) as client: # 使用async with
|
||||
try:
|
||||
async for response in client.execute(args.file):
|
||||
logger.info(f"Received response: {json.dumps(response.to_dict(), indent=2, ensure_ascii=False)}")
|
||||
except Exception as e:
|
||||
logger.error(f"ASR processing failed: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
# 用法:
|
||||
# python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav
|
||||
@ -35,11 +35,11 @@ class SimpleWakeAndRecord:
|
||||
self.stream = None
|
||||
self.running = False
|
||||
|
||||
# 音频参数
|
||||
# 音频参数 - 优化为树莓派3B
|
||||
self.FORMAT = pyaudio.paInt16
|
||||
self.CHANNELS = 1
|
||||
self.RATE = 16000
|
||||
self.CHUNK_SIZE = 1024
|
||||
self.RATE = 8000 # 从16kHz降至8kHz,减少50%数据处理量
|
||||
self.CHUNK_SIZE = 2048 # 增大块大小,减少处理次数
|
||||
|
||||
# 录音相关
|
||||
self.recording = False
|
||||
@ -48,6 +48,19 @@ class SimpleWakeAndRecord:
|
||||
self.recording_start_time = None
|
||||
self.recording_recognizer = None # 录音时专用的识别器
|
||||
|
||||
# 性能优化相关
|
||||
self.audio_buffer = [] # 音频缓冲区
|
||||
self.buffer_size = 10 # 缓冲区大小(块数)
|
||||
self.last_process_time = time.time() # 上次处理时间
|
||||
self.process_interval = 0.5 # 处理间隔(秒)
|
||||
self.batch_process_size = 5 # 批处理大小
|
||||
|
||||
# 性能监控
|
||||
self.process_count = 0
|
||||
self.avg_process_time = 0
|
||||
self.last_monitor_time = time.time()
|
||||
self.monitor_interval = 5.0 # 监控间隔(秒)
|
||||
|
||||
# 阈值
|
||||
self.text_silence_threshold = 3.0 # 3秒没有识别到文字就结束
|
||||
self.min_recording_time = 2.0 # 最小录音时间
|
||||
@ -116,6 +129,48 @@ class SimpleWakeAndRecord:
|
||||
return True, wake_word
|
||||
return False, None
|
||||
|
||||
def _should_process_audio(self):
|
||||
"""判断是否应该处理音频"""
|
||||
current_time = time.time()
|
||||
return (current_time - self.last_process_time >= self.process_interval and
|
||||
len(self.audio_buffer) >= self.batch_process_size)
|
||||
|
||||
def _process_audio_batch(self):
|
||||
"""批量处理音频数据"""
|
||||
if len(self.audio_buffer) < self.batch_process_size:
|
||||
return
|
||||
|
||||
# 记录处理开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 取出批处理数据
|
||||
batch_data = self.audio_buffer[:self.batch_process_size]
|
||||
self.audio_buffer = self.audio_buffer[self.batch_process_size:]
|
||||
|
||||
# 合并音频数据
|
||||
combined_data = b''.join(batch_data)
|
||||
|
||||
# 更新处理时间
|
||||
self.last_process_time = time.time()
|
||||
|
||||
# 更新性能统计
|
||||
process_time = time.time() - start_time
|
||||
self.process_count += 1
|
||||
self.avg_process_time = (self.avg_process_time * (self.process_count - 1) + process_time) / self.process_count
|
||||
|
||||
# 性能监控
|
||||
self._monitor_performance()
|
||||
|
||||
return combined_data
|
||||
|
||||
def _monitor_performance(self):
|
||||
"""性能监控"""
|
||||
current_time = time.time()
|
||||
if current_time - self.last_monitor_time >= self.monitor_interval:
|
||||
buffer_usage = len(self.audio_buffer) / self.buffer_size * 100
|
||||
print(f"\n📊 性能监控 | 处理次数: {self.process_count} | 平均处理时间: {self.avg_process_time:.3f}s | 缓冲区使用: {buffer_usage:.1f}%")
|
||||
self.last_monitor_time = current_time
|
||||
|
||||
def _save_recording(self, audio_data):
|
||||
"""保存录音"""
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
@ -262,13 +317,21 @@ class SimpleWakeAndRecord:
|
||||
continue
|
||||
|
||||
if self.recording:
|
||||
# 录音模式
|
||||
# 录音模式 - 直接处理
|
||||
self.recorded_frames.append(data)
|
||||
recording_duration = time.time() - self.recording_start_time
|
||||
|
||||
# 使用录音专用的识别器进行实时识别
|
||||
if self.recording_recognizer:
|
||||
if self.recording_recognizer.AcceptWaveform(data):
|
||||
# 录音时使用批处理进行识别
|
||||
self.audio_buffer.append(data)
|
||||
|
||||
# 限制缓冲区大小
|
||||
if len(self.audio_buffer) > self.buffer_size:
|
||||
self.audio_buffer.pop(0)
|
||||
|
||||
# 批处理识别
|
||||
if self._should_process_audio() and self.recording_recognizer:
|
||||
combined_data = self._process_audio_batch()
|
||||
if combined_data and self.recording_recognizer.AcceptWaveform(combined_data):
|
||||
# 获取最终识别结果
|
||||
result = json.loads(self.recording_recognizer.Result())
|
||||
text = result.get('text', '').strip()
|
||||
@ -277,7 +340,7 @@ class SimpleWakeAndRecord:
|
||||
# 识别到文字,更新时间戳
|
||||
self.last_text_time = time.time()
|
||||
print(f"\n📝 识别: {text}")
|
||||
else:
|
||||
elif combined_data:
|
||||
# 获取部分识别结果
|
||||
partial_result = json.loads(self.recording_recognizer.PartialResult())
|
||||
partial_text = partial_result.get('partial', '').strip()
|
||||
@ -314,32 +377,41 @@ class SimpleWakeAndRecord:
|
||||
print(f"\r{status}", end='', flush=True)
|
||||
|
||||
elif self.model and self.recognizer:
|
||||
# 唤醒词检测模式
|
||||
if self.recognizer.AcceptWaveform(data):
|
||||
result = json.loads(self.recognizer.Result())
|
||||
text = result.get('text', '').strip()
|
||||
|
||||
if text:
|
||||
print(f"识别: {text}")
|
||||
# 唤醒词检测模式 - 使用批处理
|
||||
self.audio_buffer.append(data)
|
||||
|
||||
# 限制缓冲区大小
|
||||
if len(self.audio_buffer) > self.buffer_size:
|
||||
self.audio_buffer.pop(0)
|
||||
|
||||
# 批处理识别
|
||||
if self._should_process_audio():
|
||||
combined_data = self._process_audio_batch()
|
||||
if combined_data and self.recognizer.AcceptWaveform(combined_data):
|
||||
result = json.loads(self.recognizer.Result())
|
||||
text = result.get('text', '').strip()
|
||||
|
||||
# 检查唤醒词
|
||||
is_wake_word, detected_word = self._check_wake_word(text)
|
||||
if is_wake_word:
|
||||
print(f"🎯 检测到唤醒词: {detected_word}")
|
||||
self._start_recording()
|
||||
else:
|
||||
# 显示实时音频级别
|
||||
energy = self._calculate_energy(data)
|
||||
if energy > 50: # 只显示有意义的音频级别
|
||||
partial_result = json.loads(self.recognizer.PartialResult())
|
||||
partial_text = partial_result.get('partial', '')
|
||||
if partial_text:
|
||||
status = f"监听中... 能量: {energy:.0f} | {partial_text}"
|
||||
else:
|
||||
status = f"监听中... 能量: {energy:.0f}"
|
||||
print(status, end='\r')
|
||||
if text:
|
||||
print(f"识别: {text}")
|
||||
|
||||
# 检查唤醒词
|
||||
is_wake_word, detected_word = self._check_wake_word(text)
|
||||
if is_wake_word:
|
||||
print(f"🎯 检测到唤醒词: {detected_word}")
|
||||
self._start_recording()
|
||||
else:
|
||||
# 显示实时音频级别
|
||||
energy = self._calculate_energy(data)
|
||||
if energy > 50: # 只显示有意义的音频级别
|
||||
partial_result = json.loads(self.recognizer.PartialResult())
|
||||
partial_text = partial_result.get('partial', '')
|
||||
if partial_text:
|
||||
status = f"监听中... 能量: {energy:.0f} | {partial_text}"
|
||||
else:
|
||||
status = f"监听中... 能量: {energy:.0f}"
|
||||
print(status, end='\r')
|
||||
|
||||
time.sleep(0.01)
|
||||
time.sleep(0.05) # 增加延迟,减少CPU使用
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 退出")
|
||||
@ -394,6 +466,12 @@ def main():
|
||||
print("5. 录音文件自动保存")
|
||||
print("6. 录音完成后自动播放刚才录制的内容")
|
||||
print("7. 按 Ctrl+C 退出")
|
||||
print("🚀 性能优化已启用:")
|
||||
print(" - 采样率: 8kHz (降低50%数据量)")
|
||||
print(" - 批处理: 5个音频块/次")
|
||||
print(" - 处理间隔: 0.5秒")
|
||||
print(" - 缓冲区: 10个音频块")
|
||||
print(" - 性能监控: 每5秒显示")
|
||||
print("=" * 50)
|
||||
|
||||
# 开始运行
|
||||
|
||||
532
speech_recognizer.py
Normal file
532
speech_recognizer.py
Normal file
@ -0,0 +1,532 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
语音识别模块
|
||||
基于 SAUC API 为录音文件提供语音识别功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import struct
|
||||
import gzip
|
||||
import uuid
|
||||
from typing import Optional, List, Dict, Any, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 常量定义
|
||||
DEFAULT_SAMPLE_RATE = 16000
|
||||
|
||||
class ProtocolVersion:
|
||||
V1 = 0b0001
|
||||
|
||||
class MessageType:
|
||||
CLIENT_FULL_REQUEST = 0b0001
|
||||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||
SERVER_FULL_RESPONSE = 0b1001
|
||||
SERVER_ERROR_RESPONSE = 0b1111
|
||||
|
||||
class MessageTypeSpecificFlags:
|
||||
NO_SEQUENCE = 0b0000
|
||||
POS_SEQUENCE = 0b0001
|
||||
NEG_SEQUENCE = 0b0010
|
||||
NEG_WITH_SEQUENCE = 0b0011
|
||||
|
||||
class SerializationType:
|
||||
NO_SERIALIZATION = 0b0000
|
||||
JSON = 0b0001
|
||||
|
||||
class CompressionType:
|
||||
GZIP = 0b0001
|
||||
|
||||
@dataclass
|
||||
class RecognitionResult:
|
||||
"""语音识别结果"""
|
||||
text: str
|
||||
confidence: float
|
||||
is_final: bool
|
||||
start_time: Optional[float] = None
|
||||
end_time: Optional[float] = None
|
||||
|
||||
class AudioUtils:
|
||||
"""音频处理工具类"""
|
||||
|
||||
@staticmethod
|
||||
def gzip_compress(data: bytes) -> bytes:
|
||||
"""GZIP压缩"""
|
||||
return gzip.compress(data)
|
||||
|
||||
@staticmethod
|
||||
def gzip_decompress(data: bytes) -> bytes:
|
||||
"""GZIP解压缩"""
|
||||
return gzip.decompress(data)
|
||||
|
||||
@staticmethod
|
||||
def is_wav_file(data: bytes) -> bool:
|
||||
"""检查是否为WAV文件"""
|
||||
if len(data) < 44:
|
||||
return False
|
||||
return data[:4] == b'RIFF' and data[8:12] == b'WAVE'
|
||||
|
||||
@staticmethod
|
||||
def read_wav_info(data: bytes) -> tuple:
|
||||
"""读取WAV文件信息"""
|
||||
if len(data) < 44:
|
||||
raise ValueError("Invalid WAV file: too short")
|
||||
|
||||
# 解析WAV头
|
||||
chunk_id = data[:4]
|
||||
if chunk_id != b'RIFF':
|
||||
raise ValueError("Invalid WAV file: not RIFF format")
|
||||
|
||||
format_ = data[8:12]
|
||||
if format_ != b'WAVE':
|
||||
raise ValueError("Invalid WAV file: not WAVE format")
|
||||
|
||||
# 解析fmt子块
|
||||
audio_format = struct.unpack('<H', data[20:22])[0]
|
||||
num_channels = struct.unpack('<H', data[22:24])[0]
|
||||
sample_rate = struct.unpack('<I', data[24:28])[0]
|
||||
bits_per_sample = struct.unpack('<H', data[34:36])[0]
|
||||
|
||||
# 查找data子块
|
||||
pos = 36
|
||||
while pos < len(data) - 8:
|
||||
subchunk_id = data[pos:pos+4]
|
||||
subchunk_size = struct.unpack('<I', data[pos+4:pos+8])[0]
|
||||
if subchunk_id == b'data':
|
||||
wave_data = data[pos+8:pos+8+subchunk_size]
|
||||
return (
|
||||
num_channels,
|
||||
bits_per_sample // 8,
|
||||
sample_rate,
|
||||
subchunk_size // (num_channels * (bits_per_sample // 8)),
|
||||
wave_data
|
||||
)
|
||||
pos += 8 + subchunk_size
|
||||
|
||||
raise ValueError("Invalid WAV file: no data subchunk found")
|
||||
|
||||
class AsrConfig:
|
||||
"""ASR配置"""
|
||||
|
||||
def __init__(self, app_key: str = None, access_key: str = None):
|
||||
self.auth = {
|
||||
"app_key": app_key or os.getenv("SAUC_APP_KEY", "your_app_key"),
|
||||
"access_key": access_key or os.getenv("SAUC_ACCESS_KEY", "your_access_key")
|
||||
}
|
||||
|
||||
@property
|
||||
def app_key(self) -> str:
|
||||
return self.auth["app_key"]
|
||||
|
||||
@property
|
||||
def access_key(self) -> str:
|
||||
return self.auth["access_key"]
|
||||
|
||||
class AsrRequestHeader:
|
||||
"""ASR请求头"""
|
||||
|
||||
def __init__(self):
|
||||
self.message_type = MessageType.CLIENT_FULL_REQUEST
|
||||
self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE
|
||||
self.serialization_type = SerializationType.JSON
|
||||
self.compression_type = CompressionType.GZIP
|
||||
self.reserved_data = bytes([0x00])
|
||||
|
||||
def with_message_type(self, message_type: int) -> 'AsrRequestHeader':
|
||||
self.message_type = message_type
|
||||
return self
|
||||
|
||||
def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader':
|
||||
self.message_type_specific_flags = flags
|
||||
return self
|
||||
|
||||
def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader':
|
||||
self.serialization_type = serialization_type
|
||||
return self
|
||||
|
||||
def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader':
|
||||
self.compression_type = compression_type
|
||||
return self
|
||||
|
||||
def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader':
|
||||
self.reserved_data = reserved_data
|
||||
return self
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
header = bytearray()
|
||||
header.append((ProtocolVersion.V1 << 4) | 1)
|
||||
header.append((self.message_type << 4) | self.message_type_specific_flags)
|
||||
header.append((self.serialization_type << 4) | self.compression_type)
|
||||
header.extend(self.reserved_data)
|
||||
return bytes(header)
|
||||
|
||||
@staticmethod
|
||||
def default_header() -> 'AsrRequestHeader':
|
||||
return AsrRequestHeader()
|
||||
|
||||
class RequestBuilder:
|
||||
"""请求构建器"""
|
||||
|
||||
@staticmethod
|
||||
def new_auth_headers(config: AsrConfig) -> Dict[str, str]:
|
||||
"""创建认证头"""
|
||||
reqid = str(uuid.uuid4())
|
||||
return {
|
||||
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
|
||||
"X-Api-Request-Id": reqid,
|
||||
"X-Api-Access-Key": config.access_key,
|
||||
"X-Api-App-Key": config.app_key
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def new_full_client_request(seq: int) -> bytes:
|
||||
"""创建完整客户端请求"""
|
||||
header = AsrRequestHeader.default_header() \
|
||||
.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
|
||||
|
||||
payload = {
|
||||
"user": {
|
||||
"uid": "local_voice_user"
|
||||
},
|
||||
"audio": {
|
||||
"format": "wav",
|
||||
"codec": "raw",
|
||||
"rate": 16000,
|
||||
"bits": 16,
|
||||
"channel": 1
|
||||
},
|
||||
"request": {
|
||||
"model_name": "bigmodel",
|
||||
"enable_itn": True,
|
||||
"enable_punc": True,
|
||||
"enable_ddc": True,
|
||||
"show_utterances": True,
|
||||
"enable_nonstream": False
|
||||
}
|
||||
}
|
||||
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
compressed_payload = AudioUtils.gzip_compress(payload_bytes)
|
||||
payload_size = len(compressed_payload)
|
||||
|
||||
request = bytearray()
|
||||
request.extend(header.to_bytes())
|
||||
request.extend(struct.pack('>i', seq))
|
||||
request.extend(struct.pack('>U', payload_size))
|
||||
request.extend(compressed_payload)
|
||||
|
||||
return bytes(request)
|
||||
|
||||
@staticmethod
|
||||
def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes:
|
||||
"""创建纯音频请求"""
|
||||
header = AsrRequestHeader.default_header()
|
||||
if is_last:
|
||||
header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE)
|
||||
seq = -seq
|
||||
else:
|
||||
header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
|
||||
header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST)
|
||||
|
||||
request = bytearray()
|
||||
request.extend(header.to_bytes())
|
||||
request.extend(struct.pack('>i', seq))
|
||||
|
||||
compressed_segment = AudioUtils.gzip_compress(segment)
|
||||
request.extend(struct.pack('>U', len(compressed_segment)))
|
||||
request.extend(compressed_segment)
|
||||
|
||||
return bytes(request)
|
||||
|
||||
class AsrResponse:
|
||||
"""ASR响应"""
|
||||
|
||||
def __init__(self):
|
||||
self.code = 0
|
||||
self.event = 0
|
||||
self.is_last_package = False
|
||||
self.payload_sequence = 0
|
||||
self.payload_size = 0
|
||||
self.payload_msg = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"code": self.code,
|
||||
"event": self.event,
|
||||
"is_last_package": self.is_last_package,
|
||||
"payload_sequence": self.payload_sequence,
|
||||
"payload_size": self.payload_size,
|
||||
"payload_msg": self.payload_msg
|
||||
}
|
||||
|
||||
class ResponseParser:
|
||||
"""响应解析器"""
|
||||
|
||||
@staticmethod
|
||||
def parse_response(msg: bytes) -> AsrResponse:
|
||||
"""解析响应"""
|
||||
response = AsrResponse()
|
||||
|
||||
header_size = msg[0] & 0x0f
|
||||
message_type = msg[1] >> 4
|
||||
message_type_specific_flags = msg[1] & 0x0f
|
||||
serialization_method = msg[2] >> 4
|
||||
message_compression = msg[2] & 0x0f
|
||||
|
||||
payload = msg[header_size*4:]
|
||||
|
||||
# 解析message_type_specific_flags
|
||||
if message_type_specific_flags & 0x01:
|
||||
response.payload_sequence = struct.unpack('>i', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
if message_type_specific_flags & 0x02:
|
||||
response.is_last_package = True
|
||||
if message_type_specific_flags & 0x04:
|
||||
response.event = struct.unpack('>i', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
|
||||
# 解析message_type
|
||||
if message_type == MessageType.SERVER_FULL_RESPONSE:
|
||||
response.payload_size = struct.unpack('>U', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
elif message_type == MessageType.SERVER_ERROR_RESPONSE:
|
||||
response.code = struct.unpack('>i', payload[:4])[0]
|
||||
response.payload_size = struct.unpack('>U', payload[4:8])[0]
|
||||
payload = payload[8:]
|
||||
|
||||
if not payload:
|
||||
return response
|
||||
|
||||
# 解压缩
|
||||
if message_compression == CompressionType.GZIP:
|
||||
try:
|
||||
payload = AudioUtils.gzip_decompress(payload)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decompress payload: {e}")
|
||||
return response
|
||||
|
||||
# 解析payload
|
||||
try:
|
||||
if serialization_method == SerializationType.JSON:
|
||||
response.payload_msg = json.loads(payload.decode('utf-8'))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse payload: {e}")
|
||||
|
||||
return response
|
||||
|
||||
class SpeechRecognizer:
|
||||
"""语音识别器"""
|
||||
|
||||
def __init__(self, app_key: str = None, access_key: str = None,
|
||||
url: str = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream"):
|
||||
self.config = AsrConfig(app_key, access_key)
|
||||
self.url = url
|
||||
self.seq = 1
|
||||
|
||||
async def recognize_file(self, file_path: str) -> List[RecognitionResult]:
|
||||
"""识别音频文件"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Audio file not found: {file_path}")
|
||||
|
||||
results = []
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# 读取音频文件
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
|
||||
if not AudioUtils.is_wav_file(content):
|
||||
raise ValueError("Audio file must be in WAV format")
|
||||
|
||||
# 获取音频信息
|
||||
try:
|
||||
_, _, sample_rate, _, audio_data = AudioUtils.read_wav_info(content)
|
||||
if sample_rate != DEFAULT_SAMPLE_RATE:
|
||||
logger.warning(f"Sample rate {sample_rate} != {DEFAULT_SAMPLE_RATE}, may affect recognition accuracy")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read audio info: {e}")
|
||||
raise
|
||||
|
||||
# 计算分段大小 (200ms per segment)
|
||||
segment_size = 1 * 2 * DEFAULT_SAMPLE_RATE * 200 // 1000 # channel * bytes_per_sample * sample_rate * duration_ms / 1000
|
||||
|
||||
# 创建WebSocket连接
|
||||
headers = RequestBuilder.new_auth_headers(self.config)
|
||||
async with session.ws_connect(self.url, headers=headers) as ws:
|
||||
|
||||
# 发送完整客户端请求
|
||||
request = RequestBuilder.new_full_client_request(self.seq)
|
||||
self.seq += 1
|
||||
await ws.send_bytes(request)
|
||||
|
||||
# 接收初始响应
|
||||
msg = await ws.receive()
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
response = ResponseParser.parse_response(msg.data)
|
||||
logger.info(f"Initial response: {response.to_dict()}")
|
||||
|
||||
# 分段发送音频数据
|
||||
audio_segments = self._split_audio(audio_data, segment_size)
|
||||
total_segments = len(audio_segments)
|
||||
|
||||
for i, segment in enumerate(audio_segments):
|
||||
is_last = (i == total_segments - 1)
|
||||
request = RequestBuilder.new_audio_only_request(
|
||||
self.seq,
|
||||
segment,
|
||||
is_last=is_last
|
||||
)
|
||||
await ws.send_bytes(request)
|
||||
logger.info(f"Sent audio segment {i+1}/{total_segments}")
|
||||
|
||||
if not is_last:
|
||||
self.seq += 1
|
||||
|
||||
# 短暂延迟模拟实时流
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 接收识别结果
|
||||
final_text = ""
|
||||
while True:
|
||||
msg = await ws.receive()
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
response = ResponseParser.parse_response(msg.data)
|
||||
|
||||
if response.payload_msg and 'text' in response.payload_msg:
|
||||
text = response.payload_msg['text']
|
||||
if text:
|
||||
final_text += text
|
||||
|
||||
result = RecognitionResult(
|
||||
text=text,
|
||||
confidence=0.9, # 默认置信度
|
||||
is_final=response.is_last_package
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
logger.info(f"Recognized: {text}")
|
||||
|
||||
if response.is_last_package or response.code != 0:
|
||||
break
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error: {msg.data}")
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
logger.info("WebSocket connection closed")
|
||||
break
|
||||
|
||||
# 如果没有获得最终结果,创建一个包含所有文本的结果
|
||||
if final_text and not any(r.is_final for r in results):
|
||||
final_result = RecognitionResult(
|
||||
text=final_text,
|
||||
confidence=0.9,
|
||||
is_final=True
|
||||
)
|
||||
results.append(final_result)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Speech recognition failed: {e}")
|
||||
raise
|
||||
|
||||
def _split_audio(self, data: bytes, segment_size: int) -> List[bytes]:
|
||||
"""分割音频数据"""
|
||||
if segment_size <= 0:
|
||||
return []
|
||||
|
||||
segments = []
|
||||
for i in range(0, len(data), segment_size):
|
||||
end = i + segment_size
|
||||
if end > len(data):
|
||||
end = len(data)
|
||||
segments.append(data[i:end])
|
||||
return segments
|
||||
|
||||
async def recognize_latest_recording(self, directory: str = ".") -> Optional[RecognitionResult]:
|
||||
"""识别最新的录音文件"""
|
||||
# 查找最新的录音文件
|
||||
recording_files = [f for f in os.listdir(directory) if f.startswith('recording_') and f.endswith('.wav')]
|
||||
|
||||
if not recording_files:
|
||||
logger.warning("No recording files found")
|
||||
return None
|
||||
|
||||
# 按文件名排序(包含时间戳)
|
||||
recording_files.sort(reverse=True)
|
||||
latest_file = recording_files[0]
|
||||
latest_path = os.path.join(directory, latest_file)
|
||||
|
||||
logger.info(f"Recognizing latest recording: {latest_file}")
|
||||
|
||||
try:
|
||||
results = await self.recognize_file(latest_path)
|
||||
if results:
|
||||
# 返回最终的识别结果
|
||||
final_results = [r for r in results if r.is_final]
|
||||
if final_results:
|
||||
return final_results[-1]
|
||||
else:
|
||||
# 如果没有标记为final的结果,返回最后一个
|
||||
return results[-1]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to recognize latest recording: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def main():
|
||||
"""测试函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="语音识别测试")
|
||||
parser.add_argument("--file", type=str, help="音频文件路径")
|
||||
parser.add_argument("--latest", action="store_true", help="识别最新的录音文件")
|
||||
parser.add_argument("--app-key", type=str, help="SAUC App Key")
|
||||
parser.add_argument("--access-key", type=str, help="SAUC Access Key")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
recognizer = SpeechRecognizer(
|
||||
app_key=args.app_key,
|
||||
access_key=args.access_key
|
||||
)
|
||||
|
||||
try:
|
||||
if args.latest:
|
||||
result = await recognizer.recognize_latest_recording()
|
||||
if result:
|
||||
print(f"识别结果: {result.text}")
|
||||
print(f"置信度: {result.confidence}")
|
||||
print(f"最终结果: {result.is_final}")
|
||||
else:
|
||||
print("未能识别到语音内容")
|
||||
elif args.file:
|
||||
results = await recognizer.recognize_file(args.file)
|
||||
for i, result in enumerate(results):
|
||||
print(f"结果 {i+1}: {result.text}")
|
||||
print(f"置信度: {result.confidence}")
|
||||
print(f"最终结果: {result.is_final}")
|
||||
print("-" * 40)
|
||||
else:
|
||||
print("请指定 --file 或 --latest 参数")
|
||||
|
||||
except Exception as e:
|
||||
print(f"识别失败: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Loading…
Reference in New Issue
Block a user