config
This commit is contained in:
parent
ef39e31a4b
commit
eb099d827d
501
enhanced_wake_and_record.py
Normal file
501
enhanced_wake_and_record.py
Normal file
@ -0,0 +1,501 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
集成语音识别的唤醒+录音系统
|
||||||
|
基于 simple_wake_and_record.py,添加语音识别功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import pyaudio
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
# 添加当前目录到路径
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vosk import Model, KaldiRecognizer
|
||||||
|
VOSK_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
VOSK_AVAILABLE = False
|
||||||
|
print("⚠️ Vosk 未安装,请运行: pip install vosk")
|
||||||
|
|
||||||
|
from speech_recognizer import SpeechRecognizer, RecognitionResult
|
||||||
|
|
||||||
|
class EnhancedWakeAndRecord:
|
||||||
|
"""增强的唤醒+录音系统,集成语音识别"""
|
||||||
|
|
||||||
|
def __init__(self, model_path="model", wake_words=["你好", "助手"],
|
||||||
|
enable_speech_recognition=True, app_key=None, access_key=None):
|
||||||
|
self.model_path = model_path
|
||||||
|
self.wake_words = wake_words
|
||||||
|
self.enable_speech_recognition = enable_speech_recognition
|
||||||
|
self.model = None
|
||||||
|
self.recognizer = None
|
||||||
|
self.audio = None
|
||||||
|
self.stream = None
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
# 音频参数
|
||||||
|
self.FORMAT = pyaudio.paInt16
|
||||||
|
self.CHANNELS = 1
|
||||||
|
self.RATE = 16000
|
||||||
|
self.CHUNK_SIZE = 1024
|
||||||
|
|
||||||
|
# 录音相关
|
||||||
|
self.recording = False
|
||||||
|
self.recorded_frames = []
|
||||||
|
self.last_text_time = None
|
||||||
|
self.recording_start_time = None
|
||||||
|
self.recording_recognizer = None
|
||||||
|
|
||||||
|
# 阈值
|
||||||
|
self.text_silence_threshold = 3.0
|
||||||
|
self.min_recording_time = 2.0
|
||||||
|
self.max_recording_time = 30.0
|
||||||
|
|
||||||
|
# 语音识别相关
|
||||||
|
self.speech_recognizer = None
|
||||||
|
self.last_recognition_result = None
|
||||||
|
self.recognition_thread = None
|
||||||
|
|
||||||
|
# 回调函数
|
||||||
|
self.on_recognition_result = None
|
||||||
|
|
||||||
|
self._setup_model()
|
||||||
|
self._setup_audio()
|
||||||
|
self._setup_speech_recognition(app_key, access_key)
|
||||||
|
|
||||||
|
def _setup_model(self):
|
||||||
|
"""设置 Vosk 模型"""
|
||||||
|
if not VOSK_AVAILABLE:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not os.path.exists(self.model_path):
|
||||||
|
print(f"模型路径不存在: {self.model_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.model = Model(self.model_path)
|
||||||
|
self.recognizer = KaldiRecognizer(self.model, self.RATE)
|
||||||
|
self.recognizer.SetWords(True)
|
||||||
|
|
||||||
|
print(f"✅ Vosk 模型加载成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"模型初始化失败: {e}")
|
||||||
|
|
||||||
|
def _setup_audio(self):
|
||||||
|
"""设置音频设备"""
|
||||||
|
try:
|
||||||
|
if self.audio is None:
|
||||||
|
self.audio = pyaudio.PyAudio()
|
||||||
|
|
||||||
|
if self.stream is None:
|
||||||
|
self.stream = self.audio.open(
|
||||||
|
format=self.FORMAT,
|
||||||
|
channels=self.CHANNELS,
|
||||||
|
rate=self.RATE,
|
||||||
|
input=True,
|
||||||
|
frames_per_buffer=self.CHUNK_SIZE
|
||||||
|
)
|
||||||
|
|
||||||
|
print("✅ 音频设备初始化成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"音频设备初始化失败: {e}")
|
||||||
|
|
||||||
|
def _setup_speech_recognition(self, app_key=None, access_key=None):
|
||||||
|
"""设置语音识别"""
|
||||||
|
if not self.enable_speech_recognition:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.speech_recognizer = SpeechRecognizer(
|
||||||
|
app_key=app_key,
|
||||||
|
access_key=access_key
|
||||||
|
)
|
||||||
|
print("✅ 语音识别器初始化成功")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"语音识别器初始化失败: {e}")
|
||||||
|
self.enable_speech_recognition = False
|
||||||
|
|
||||||
|
def _calculate_energy(self, audio_data):
|
||||||
|
"""计算音频能量"""
|
||||||
|
if len(audio_data) == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
audio_array = np.frombuffer(audio_data, dtype=np.int16)
|
||||||
|
rms = np.sqrt(np.mean(audio_array ** 2))
|
||||||
|
return rms
|
||||||
|
|
||||||
|
def _check_wake_word(self, text):
|
||||||
|
"""检查是否包含唤醒词"""
|
||||||
|
if not text or not self.wake_words:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
text_lower = text.lower()
|
||||||
|
for wake_word in self.wake_words:
|
||||||
|
if wake_word.lower() in text_lower:
|
||||||
|
return True, wake_word
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def _save_recording(self, audio_data):
|
||||||
|
"""保存录音"""
|
||||||
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"recording_{timestamp}.wav"
|
||||||
|
|
||||||
|
try:
|
||||||
|
import wave
|
||||||
|
with wave.open(filename, 'wb') as wf:
|
||||||
|
wf.setnchannels(self.CHANNELS)
|
||||||
|
wf.setsampwidth(self.audio.get_sample_size(self.FORMAT))
|
||||||
|
wf.setframerate(self.RATE)
|
||||||
|
wf.writeframes(audio_data)
|
||||||
|
|
||||||
|
print(f"✅ 录音已保存: {filename}")
|
||||||
|
return True, filename
|
||||||
|
except Exception as e:
|
||||||
|
print(f"保存录音失败: {e}")
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def _play_audio(self, filename):
|
||||||
|
"""播放音频文件"""
|
||||||
|
try:
|
||||||
|
import wave
|
||||||
|
|
||||||
|
# 打开音频文件
|
||||||
|
with wave.open(filename, 'rb') as wf:
|
||||||
|
# 获取音频参数
|
||||||
|
channels = wf.getnchannels()
|
||||||
|
width = wf.getsampwidth()
|
||||||
|
rate = wf.getframerate()
|
||||||
|
total_frames = wf.getnframes()
|
||||||
|
|
||||||
|
# 分块读取音频数据,避免内存问题
|
||||||
|
chunk_size = 1024
|
||||||
|
frames = []
|
||||||
|
|
||||||
|
for _ in range(0, total_frames, chunk_size):
|
||||||
|
chunk = wf.readframes(chunk_size)
|
||||||
|
if chunk:
|
||||||
|
frames.append(chunk)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 创建播放流
|
||||||
|
playback_stream = self.audio.open(
|
||||||
|
format=self.audio.get_format_from_width(width),
|
||||||
|
channels=channels,
|
||||||
|
rate=rate,
|
||||||
|
output=True
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"🔊 开始播放: {filename}")
|
||||||
|
|
||||||
|
# 分块播放音频
|
||||||
|
for chunk in frames:
|
||||||
|
playback_stream.write(chunk)
|
||||||
|
|
||||||
|
# 等待播放完成
|
||||||
|
playback_stream.stop_stream()
|
||||||
|
playback_stream.close()
|
||||||
|
|
||||||
|
print("✅ 播放完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 播放失败: {e}")
|
||||||
|
self._play_with_system_player(filename)
|
||||||
|
|
||||||
|
def _play_with_system_player(self, filename):
|
||||||
|
"""使用系统播放器播放音频"""
|
||||||
|
try:
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
system = platform.system()
|
||||||
|
|
||||||
|
if system == 'Darwin': # macOS
|
||||||
|
cmd = ['afplay', filename]
|
||||||
|
elif system == 'Windows':
|
||||||
|
cmd = ['start', '/min', filename]
|
||||||
|
else: # Linux
|
||||||
|
cmd = ['aplay', filename]
|
||||||
|
|
||||||
|
print(f"🔊 使用系统播放器: {' '.join(cmd)}")
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
|
print("✅ 播放完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 系统播放器也失败: {e}")
|
||||||
|
print(f"💡 文件已保存,请手动播放: {filename}")
|
||||||
|
|
||||||
|
def _start_recognition_thread(self, filename):
|
||||||
|
"""启动语音识别线程"""
|
||||||
|
if not self.enable_speech_recognition or not self.speech_recognizer:
|
||||||
|
return
|
||||||
|
|
||||||
|
def recognize_task():
|
||||||
|
try:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
print(f"🧠 开始识别录音文件: {filename}")
|
||||||
|
result = loop.run_until_complete(
|
||||||
|
self.speech_recognizer.recognize_file(filename)
|
||||||
|
)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
# 合并所有识别结果
|
||||||
|
full_text = " ".join([r.text for r in result])
|
||||||
|
final_result = RecognitionResult(
|
||||||
|
text=full_text,
|
||||||
|
confidence=0.9,
|
||||||
|
is_final=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.last_recognition_result = final_result
|
||||||
|
print(f"\n🧠 语音识别结果: {full_text}")
|
||||||
|
|
||||||
|
# 调用回调函数
|
||||||
|
if self.on_recognition_result:
|
||||||
|
self.on_recognition_result(final_result)
|
||||||
|
else:
|
||||||
|
print(f"\n🧠 语音识别失败或未识别到内容")
|
||||||
|
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 语音识别线程异常: {e}")
|
||||||
|
|
||||||
|
self.recognition_thread = threading.Thread(target=recognize_task)
|
||||||
|
self.recognition_thread.daemon = True
|
||||||
|
self.recognition_thread.start()
|
||||||
|
|
||||||
|
def _start_recording(self):
|
||||||
|
"""开始录音"""
|
||||||
|
print("🎙️ 开始录音,请说话...")
|
||||||
|
self.recording = True
|
||||||
|
self.recorded_frames = []
|
||||||
|
self.last_text_time = None
|
||||||
|
self.recording_start_time = time.time()
|
||||||
|
|
||||||
|
# 为录音创建一个新的识别器
|
||||||
|
if self.model:
|
||||||
|
self.recording_recognizer = KaldiRecognizer(self.model, self.RATE)
|
||||||
|
self.recording_recognizer.SetWords(True)
|
||||||
|
|
||||||
|
def _stop_recording(self):
|
||||||
|
"""停止录音"""
|
||||||
|
if len(self.recorded_frames) > 0:
|
||||||
|
audio_data = b''.join(self.recorded_frames)
|
||||||
|
duration = len(audio_data) / (self.RATE * 2)
|
||||||
|
print(f"📝 录音完成,时长: {duration:.2f}秒")
|
||||||
|
|
||||||
|
# 保存录音
|
||||||
|
success, filename = self._save_recording(audio_data)
|
||||||
|
|
||||||
|
# 如果保存成功,播放录音并进行语音识别
|
||||||
|
if success and filename:
|
||||||
|
print("=" * 50)
|
||||||
|
print("🔊 播放刚才录制的音频...")
|
||||||
|
self._play_audio(filename)
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# 启动语音识别
|
||||||
|
if self.enable_speech_recognition:
|
||||||
|
print("🧠 准备进行语音识别...")
|
||||||
|
self._start_recognition_thread(filename)
|
||||||
|
|
||||||
|
self.recording = False
|
||||||
|
self.recorded_frames = []
|
||||||
|
self.last_text_time = None
|
||||||
|
self.recording_start_time = None
|
||||||
|
self.recording_recognizer = None
|
||||||
|
|
||||||
|
def set_recognition_callback(self, callback):
|
||||||
|
"""设置识别结果回调函数"""
|
||||||
|
self.on_recognition_result = callback
|
||||||
|
|
||||||
|
def get_last_recognition_result(self) -> Optional[RecognitionResult]:
|
||||||
|
"""获取最后一次识别结果"""
|
||||||
|
return self.last_recognition_result
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""开始唤醒词检测和录音"""
|
||||||
|
if not self.stream:
|
||||||
|
print("❌ 音频设备未初始化")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
print("🎤 开始监听...")
|
||||||
|
print(f"唤醒词: {', '.join(self.wake_words)}")
|
||||||
|
if self.enable_speech_recognition:
|
||||||
|
print("🧠 语音识别: 已启用")
|
||||||
|
else:
|
||||||
|
print("🧠 语音识别: 已禁用")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while self.running:
|
||||||
|
# 读取音频数据
|
||||||
|
data = self.stream.read(self.CHUNK_SIZE, exception_on_overflow=False)
|
||||||
|
|
||||||
|
if len(data) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self.recording:
|
||||||
|
# 录音模式
|
||||||
|
self.recorded_frames.append(data)
|
||||||
|
recording_duration = time.time() - self.recording_start_time
|
||||||
|
|
||||||
|
# 使用录音专用的识别器进行实时识别
|
||||||
|
if self.recording_recognizer:
|
||||||
|
if self.recording_recognizer.AcceptWaveform(data):
|
||||||
|
result = json.loads(self.recording_recognizer.Result())
|
||||||
|
text = result.get('text', '').strip()
|
||||||
|
|
||||||
|
if text:
|
||||||
|
self.last_text_time = time.time()
|
||||||
|
print(f"\n📝 实时识别: {text}")
|
||||||
|
else:
|
||||||
|
partial_result = json.loads(self.recording_recognizer.PartialResult())
|
||||||
|
partial_text = partial_result.get('partial', '').strip()
|
||||||
|
|
||||||
|
if partial_text:
|
||||||
|
self.last_text_time = time.time()
|
||||||
|
status = f"录音中... {recording_duration:.1f}s | {partial_text}"
|
||||||
|
print(f"\r{status}", end='', flush=True)
|
||||||
|
|
||||||
|
# 检查是否需要结束录音
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
if self.last_text_time is not None:
|
||||||
|
text_silence_duration = current_time - self.last_text_time
|
||||||
|
if text_silence_duration > self.text_silence_threshold and recording_duration >= self.min_recording_time:
|
||||||
|
print(f"\n\n3秒没有识别到文字,结束录音")
|
||||||
|
self._stop_recording()
|
||||||
|
else:
|
||||||
|
if recording_duration > 5.0:
|
||||||
|
print(f"\n\n5秒没有识别到文字,结束录音")
|
||||||
|
self._stop_recording()
|
||||||
|
|
||||||
|
# 检查最大录音时间
|
||||||
|
if recording_duration > self.max_recording_time:
|
||||||
|
print(f"\n\n达到最大录音时间 {self.max_recording_time}s")
|
||||||
|
self._stop_recording()
|
||||||
|
|
||||||
|
# 显示录音状态
|
||||||
|
if self.last_text_time is None:
|
||||||
|
status = f"等待语音输入... {recording_duration:.1f}s"
|
||||||
|
print(f"\r{status}", end='', flush=True)
|
||||||
|
|
||||||
|
elif self.model and self.recognizer:
|
||||||
|
# 唤醒词检测模式
|
||||||
|
if self.recognizer.AcceptWaveform(data):
|
||||||
|
result = json.loads(self.recognizer.Result())
|
||||||
|
text = result.get('text', '').strip()
|
||||||
|
|
||||||
|
if text:
|
||||||
|
print(f"识别: {text}")
|
||||||
|
|
||||||
|
# 检查唤醒词
|
||||||
|
is_wake_word, detected_word = self._check_wake_word(text)
|
||||||
|
if is_wake_word:
|
||||||
|
print(f"🎯 检测到唤醒词: {detected_word}")
|
||||||
|
self._start_recording()
|
||||||
|
else:
|
||||||
|
# 显示实时音频级别
|
||||||
|
energy = self._calculate_energy(data)
|
||||||
|
if energy > 50:
|
||||||
|
partial_result = json.loads(self.recognizer.PartialResult())
|
||||||
|
partial_text = partial_result.get('partial', '')
|
||||||
|
if partial_text:
|
||||||
|
status = f"监听中... 能量: {energy:.0f} | {partial_text}"
|
||||||
|
else:
|
||||||
|
status = f"监听中... 能量: {energy:.0f}"
|
||||||
|
print(status, end='\r')
|
||||||
|
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n👋 退出")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"错误: {e}")
|
||||||
|
finally:
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""停止"""
|
||||||
|
self.running = False
|
||||||
|
if self.recording:
|
||||||
|
self._stop_recording()
|
||||||
|
|
||||||
|
if self.stream:
|
||||||
|
self.stream.stop_stream()
|
||||||
|
self.stream.close()
|
||||||
|
self.stream = None
|
||||||
|
|
||||||
|
if self.audio:
|
||||||
|
self.audio.terminate()
|
||||||
|
self.audio = None
|
||||||
|
|
||||||
|
# 等待识别线程结束
|
||||||
|
if self.recognition_thread and self.recognition_thread.is_alive():
|
||||||
|
self.recognition_thread.join(timeout=5.0)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
print("🚀 增强版唤醒+录音+语音识别测试")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# 检查模型
|
||||||
|
model_dir = "model"
|
||||||
|
if not os.path.exists(model_dir):
|
||||||
|
print("⚠️ 未找到模型目录")
|
||||||
|
print("请下载 Vosk 模型到 model 目录")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建系统
|
||||||
|
system = EnhancedWakeAndRecord(
|
||||||
|
model_path=model_dir,
|
||||||
|
wake_words=["你好", "助手", "小爱"],
|
||||||
|
enable_speech_recognition=True,
|
||||||
|
# app_key="your_app_key", # 请填入实际的app_key
|
||||||
|
# access_key="your_access_key" # 请填入实际的access_key
|
||||||
|
)
|
||||||
|
|
||||||
|
if not system.model:
|
||||||
|
print("❌ 模型加载失败")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 设置识别结果回调
|
||||||
|
def on_recognition_result(result):
|
||||||
|
print(f"\n🎯 识别完成!结果: {result.text}")
|
||||||
|
print(f" 置信度: {result.confidence}")
|
||||||
|
print(f" 是否最终结果: {result.is_final}")
|
||||||
|
|
||||||
|
system.set_recognition_callback(on_recognition_result)
|
||||||
|
|
||||||
|
print("✅ 系统初始化成功")
|
||||||
|
print("📖 使用说明:")
|
||||||
|
print("1. 说唤醒词开始录音")
|
||||||
|
print("2. 基于语音识别判断,3秒没有识别到文字就结束")
|
||||||
|
print("3. 最少录音2秒,最多30秒")
|
||||||
|
print("4. 录音时实时显示识别结果")
|
||||||
|
print("5. 录音文件自动保存")
|
||||||
|
print("6. 录音完成后自动播放刚才录制的内容")
|
||||||
|
print("7. 启动语音识别对录音文件进行识别")
|
||||||
|
print("8. 按 Ctrl+C 退出")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# 开始运行
|
||||||
|
system.start()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
127
recognition_example.py
Normal file
127
recognition_example.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
语音识别使用示例
|
||||||
|
演示如何使用 speech_recognizer 模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from speech_recognizer import SpeechRecognizer
|
||||||
|
|
||||||
|
async def example_recognize_file():
|
||||||
|
"""示例:识别单个音频文件"""
|
||||||
|
print("=== 示例1:识别单个音频文件 ===")
|
||||||
|
|
||||||
|
# 初始化识别器
|
||||||
|
recognizer = SpeechRecognizer(
|
||||||
|
app_key="your_app_key", # 请替换为实际的app_key
|
||||||
|
access_key="your_access_key" # 请替换为实际的access_key
|
||||||
|
)
|
||||||
|
|
||||||
|
# 假设有一个录音文件
|
||||||
|
audio_file = "recording_20240101_120000.wav"
|
||||||
|
|
||||||
|
if not os.path.exists(audio_file):
|
||||||
|
print(f"音频文件不存在: {audio_file}")
|
||||||
|
print("请先运行 enhanced_wake_and_record.py 录制一个音频文件")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 识别音频文件
|
||||||
|
results = await recognizer.recognize_file(audio_file)
|
||||||
|
|
||||||
|
print(f"识别结果(共{len(results)}个):")
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
print(f"结果 {i+1}:")
|
||||||
|
print(f" 文本: {result.text}")
|
||||||
|
print(f" 置信度: {result.confidence}")
|
||||||
|
print(f" 最终结果: {result.is_final}")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"识别失败: {e}")
|
||||||
|
|
||||||
|
async def example_recognize_latest():
|
||||||
|
"""示例:识别最新的录音文件"""
|
||||||
|
print("\n=== 示例2:识别最新的录音文件 ===")
|
||||||
|
|
||||||
|
# 初始化识别器
|
||||||
|
recognizer = SpeechRecognizer(
|
||||||
|
app_key="your_app_key", # 请替换为实际的app_key
|
||||||
|
access_key="your_access_key" # 请替换为实际的access_key
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 识别最新的录音文件
|
||||||
|
result = await recognizer.recognize_latest_recording()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
print("识别结果:")
|
||||||
|
print(f" 文本: {result.text}")
|
||||||
|
print(f" 置信度: {result.confidence}")
|
||||||
|
print(f" 最终结果: {result.is_final}")
|
||||||
|
else:
|
||||||
|
print("未找到录音文件或识别失败")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"识别失败: {e}")
|
||||||
|
|
||||||
|
async def example_batch_recognition():
|
||||||
|
"""示例:批量识别多个录音文件"""
|
||||||
|
print("\n=== 示例3:批量识别录音文件 ===")
|
||||||
|
|
||||||
|
# 初始化识别器
|
||||||
|
recognizer = SpeechRecognizer(
|
||||||
|
app_key="your_app_key", # 请替换为实际的app_key
|
||||||
|
access_key="your_access_key" # 请替换为实际的access_key
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取所有录音文件
|
||||||
|
recording_files = [f for f in os.listdir(".") if f.startswith('recording_') and f.endswith('.wav')]
|
||||||
|
|
||||||
|
if not recording_files:
|
||||||
|
print("未找到录音文件")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"找到 {len(recording_files)} 个录音文件")
|
||||||
|
|
||||||
|
for filename in recording_files[:5]: # 只处理前5个文件
|
||||||
|
print(f"\n处理文件: {filename}")
|
||||||
|
try:
|
||||||
|
results = await recognizer.recognize_file(filename)
|
||||||
|
|
||||||
|
if results:
|
||||||
|
final_result = results[-1] # 取最后一个结果
|
||||||
|
print(f"识别结果: {final_result.text}")
|
||||||
|
else:
|
||||||
|
print("识别失败")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理失败: {e}")
|
||||||
|
|
||||||
|
# 添加延迟,避免请求过于频繁
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""主函数"""
|
||||||
|
print("🚀 语音识别使用示例")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# 请先设置环境变量或在代码中填入实际的API密钥
|
||||||
|
if not os.getenv("SAUC_APP_KEY") and "your_app_key" in "your_app_key":
|
||||||
|
print("⚠️ 请先设置 SAUC_APP_KEY 和 SAUC_ACCESS_KEY 环境变量")
|
||||||
|
print("或者在代码中填入实际的 app_key 和 access_key")
|
||||||
|
print("示例:")
|
||||||
|
print("export SAUC_APP_KEY='your_app_key'")
|
||||||
|
print("export SAUC_ACCESS_KEY='your_access_key'")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 运行示例
|
||||||
|
await example_recognize_file()
|
||||||
|
await example_recognize_latest()
|
||||||
|
await example_batch_recognition()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@ -1,3 +1,5 @@
|
|||||||
vosk>=0.3.44
|
vosk>=0.3.44
|
||||||
pyaudio>=0.2.11
|
pyaudio>=0.2.11
|
||||||
numpy>=1.19.0
|
numpy>=1.19.0
|
||||||
|
aiohttp>=3.8.0
|
||||||
|
asyncio
|
||||||
15
sauc_python/readme.md
Normal file
15
sauc_python/readme.md
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# README
|
||||||
|
|
||||||
|
**asr tob 相关client demo**
|
||||||
|
|
||||||
|
# Notice
|
||||||
|
python version: python 3.x
|
||||||
|
|
||||||
|
替换代码中的key为真实数据:
|
||||||
|
"app_key": "xxxxxxx",
|
||||||
|
"access_key": "xxxxxxxxxxxxxxxx"
|
||||||
|
使用示例:
|
||||||
|
python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
523
sauc_python/sauc_websocket_demo.py
Normal file
523
sauc_python/sauc_websocket_demo.py
Normal file
@ -0,0 +1,523 @@
|
|||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
import json
|
||||||
|
import struct
|
||||||
|
import gzip
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from typing import Optional, List, Dict, Any, Tuple, AsyncGenerator
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler('run.log'),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 常量定义
|
||||||
|
DEFAULT_SAMPLE_RATE = 16000
|
||||||
|
|
||||||
|
class ProtocolVersion:
|
||||||
|
V1 = 0b0001
|
||||||
|
|
||||||
|
class MessageType:
|
||||||
|
CLIENT_FULL_REQUEST = 0b0001
|
||||||
|
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||||
|
SERVER_FULL_RESPONSE = 0b1001
|
||||||
|
SERVER_ERROR_RESPONSE = 0b1111
|
||||||
|
|
||||||
|
class MessageTypeSpecificFlags:
|
||||||
|
NO_SEQUENCE = 0b0000
|
||||||
|
POS_SEQUENCE = 0b0001
|
||||||
|
NEG_SEQUENCE = 0b0010
|
||||||
|
NEG_WITH_SEQUENCE = 0b0011
|
||||||
|
|
||||||
|
class SerializationType:
|
||||||
|
NO_SERIALIZATION = 0b0000
|
||||||
|
JSON = 0b0001
|
||||||
|
|
||||||
|
class CompressionType:
|
||||||
|
GZIP = 0b0001
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
def __init__(self):
|
||||||
|
# 填入控制台获取的app id和access token
|
||||||
|
self.auth = {
|
||||||
|
"app_key": "xxxxxxx",
|
||||||
|
"access_key": "xxxxxxxxxxxx"
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def app_key(self) -> str:
|
||||||
|
return self.auth["app_key"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def access_key(self) -> str:
|
||||||
|
return self.auth["access_key"]
|
||||||
|
|
||||||
|
config = Config()
|
||||||
|
|
||||||
|
class CommonUtils:
|
||||||
|
@staticmethod
|
||||||
|
def gzip_compress(data: bytes) -> bytes:
|
||||||
|
return gzip.compress(data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gzip_decompress(data: bytes) -> bytes:
|
||||||
|
return gzip.decompress(data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def judge_wav(data: bytes) -> bool:
|
||||||
|
if len(data) < 44:
|
||||||
|
return False
|
||||||
|
return data[:4] == b'RIFF' and data[8:12] == b'WAVE'
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_wav_with_path(audio_path: str, sample_rate: int = DEFAULT_SAMPLE_RATE) -> bytes:
|
||||||
|
try:
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg", "-v", "quiet", "-y", "-i", audio_path,
|
||||||
|
"-acodec", "pcm_s16le", "-ac", "1", "-ar", str(sample_rate),
|
||||||
|
"-f", "wav", "-"
|
||||||
|
]
|
||||||
|
result = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||||
|
|
||||||
|
# 尝试删除原始文件
|
||||||
|
try:
|
||||||
|
os.remove(audio_path)
|
||||||
|
except OSError as e:
|
||||||
|
logger.warning(f"Failed to remove original file: {e}")
|
||||||
|
|
||||||
|
return result.stdout
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}")
|
||||||
|
raise RuntimeError(f"Audio conversion failed: {e.stderr.decode()}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def read_wav_info(data: bytes) -> Tuple[int, int, int, int, bytes]:
|
||||||
|
if len(data) < 44:
|
||||||
|
raise ValueError("Invalid WAV file: too short")
|
||||||
|
|
||||||
|
# 解析WAV头
|
||||||
|
chunk_id = data[:4]
|
||||||
|
if chunk_id != b'RIFF':
|
||||||
|
raise ValueError("Invalid WAV file: not RIFF format")
|
||||||
|
|
||||||
|
format_ = data[8:12]
|
||||||
|
if format_ != b'WAVE':
|
||||||
|
raise ValueError("Invalid WAV file: not WAVE format")
|
||||||
|
|
||||||
|
# 解析fmt子块
|
||||||
|
audio_format = struct.unpack('<H', data[20:22])[0]
|
||||||
|
num_channels = struct.unpack('<H', data[22:24])[0]
|
||||||
|
sample_rate = struct.unpack('<I', data[24:28])[0]
|
||||||
|
bits_per_sample = struct.unpack('<H', data[34:36])[0]
|
||||||
|
|
||||||
|
# 查找data子块
|
||||||
|
pos = 36
|
||||||
|
while pos < len(data) - 8:
|
||||||
|
subchunk_id = data[pos:pos+4]
|
||||||
|
subchunk_size = struct.unpack('<I', data[pos+4:pos+8])[0]
|
||||||
|
if subchunk_id == b'data':
|
||||||
|
wave_data = data[pos+8:pos+8+subchunk_size]
|
||||||
|
return (
|
||||||
|
num_channels,
|
||||||
|
bits_per_sample // 8,
|
||||||
|
sample_rate,
|
||||||
|
subchunk_size // (num_channels * (bits_per_sample // 8)),
|
||||||
|
wave_data
|
||||||
|
)
|
||||||
|
pos += 8 + subchunk_size
|
||||||
|
|
||||||
|
raise ValueError("Invalid WAV file: no data subchunk found")
|
||||||
|
|
||||||
|
class AsrRequestHeader:
|
||||||
|
def __init__(self):
|
||||||
|
self.message_type = MessageType.CLIENT_FULL_REQUEST
|
||||||
|
self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE
|
||||||
|
self.serialization_type = SerializationType.JSON
|
||||||
|
self.compression_type = CompressionType.GZIP
|
||||||
|
self.reserved_data = bytes([0x00])
|
||||||
|
|
||||||
|
def with_message_type(self, message_type: int) -> 'AsrRequestHeader':
|
||||||
|
self.message_type = message_type
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader':
|
||||||
|
self.message_type_specific_flags = flags
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader':
|
||||||
|
self.serialization_type = serialization_type
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader':
|
||||||
|
self.compression_type = compression_type
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader':
|
||||||
|
self.reserved_data = reserved_data
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to_bytes(self) -> bytes:
|
||||||
|
header = bytearray()
|
||||||
|
header.append((ProtocolVersion.V1 << 4) | 1)
|
||||||
|
header.append((self.message_type << 4) | self.message_type_specific_flags)
|
||||||
|
header.append((self.serialization_type << 4) | self.compression_type)
|
||||||
|
header.extend(self.reserved_data)
|
||||||
|
return bytes(header)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_header() -> 'AsrRequestHeader':
|
||||||
|
return AsrRequestHeader()
|
||||||
|
|
||||||
|
class RequestBuilder:
|
||||||
|
@staticmethod
|
||||||
|
def new_auth_headers() -> Dict[str, str]:
|
||||||
|
reqid = str(uuid.uuid4())
|
||||||
|
return {
|
||||||
|
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
|
||||||
|
"X-Api-Request-Id": reqid,
|
||||||
|
"X-Api-Access-Key": config.access_key,
|
||||||
|
"X-Api-App-Key": config.app_key
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_full_client_request(seq: int) -> bytes: # 添加seq参数
|
||||||
|
header = AsrRequestHeader.default_header() \
|
||||||
|
.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"user": {
|
||||||
|
"uid": "demo_uid"
|
||||||
|
},
|
||||||
|
"audio": {
|
||||||
|
"format": "wav",
|
||||||
|
"codec": "raw",
|
||||||
|
"rate": 16000,
|
||||||
|
"bits": 16,
|
||||||
|
"channel": 1
|
||||||
|
},
|
||||||
|
"request": {
|
||||||
|
"model_name": "bigmodel",
|
||||||
|
"enable_itn": True,
|
||||||
|
"enable_punc": True,
|
||||||
|
"enable_ddc": True,
|
||||||
|
"show_utterances": True,
|
||||||
|
"enable_nonstream": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||||
|
compressed_payload = CommonUtils.gzip_compress(payload_bytes)
|
||||||
|
payload_size = len(compressed_payload)
|
||||||
|
|
||||||
|
request = bytearray()
|
||||||
|
request.extend(header.to_bytes())
|
||||||
|
request.extend(struct.pack('>i', seq)) # 使用传入的seq
|
||||||
|
request.extend(struct.pack('>I', payload_size))
|
||||||
|
request.extend(compressed_payload)
|
||||||
|
|
||||||
|
return bytes(request)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes:
|
||||||
|
header = AsrRequestHeader.default_header()
|
||||||
|
if is_last: # 最后一个包特殊处理
|
||||||
|
header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE)
|
||||||
|
seq = -seq # 设为负值
|
||||||
|
else:
|
||||||
|
header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
|
||||||
|
header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST)
|
||||||
|
|
||||||
|
request = bytearray()
|
||||||
|
request.extend(header.to_bytes())
|
||||||
|
request.extend(struct.pack('>i', seq))
|
||||||
|
|
||||||
|
compressed_segment = CommonUtils.gzip_compress(segment)
|
||||||
|
request.extend(struct.pack('>I', len(compressed_segment)))
|
||||||
|
request.extend(compressed_segment)
|
||||||
|
|
||||||
|
return bytes(request)
|
||||||
|
|
||||||
|
class AsrResponse:
|
||||||
|
def __init__(self):
|
||||||
|
self.code = 0
|
||||||
|
self.event = 0
|
||||||
|
self.is_last_package = False
|
||||||
|
self.payload_sequence = 0
|
||||||
|
self.payload_size = 0
|
||||||
|
self.payload_msg = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"code": self.code,
|
||||||
|
"event": self.event,
|
||||||
|
"is_last_package": self.is_last_package,
|
||||||
|
"payload_sequence": self.payload_sequence,
|
||||||
|
"payload_size": self.payload_size,
|
||||||
|
"payload_msg": self.payload_msg
|
||||||
|
}
|
||||||
|
|
||||||
|
class ResponseParser:
|
||||||
|
@staticmethod
|
||||||
|
def parse_response(msg: bytes) -> AsrResponse:
|
||||||
|
response = AsrResponse()
|
||||||
|
|
||||||
|
header_size = msg[0] & 0x0f
|
||||||
|
message_type = msg[1] >> 4
|
||||||
|
message_type_specific_flags = msg[1] & 0x0f
|
||||||
|
serialization_method = msg[2] >> 4
|
||||||
|
message_compression = msg[2] & 0x0f
|
||||||
|
|
||||||
|
payload = msg[header_size*4:]
|
||||||
|
|
||||||
|
# 解析message_type_specific_flags
|
||||||
|
if message_type_specific_flags & 0x01:
|
||||||
|
response.payload_sequence = struct.unpack('>i', payload[:4])[0]
|
||||||
|
payload = payload[4:]
|
||||||
|
if message_type_specific_flags & 0x02:
|
||||||
|
response.is_last_package = True
|
||||||
|
if message_type_specific_flags & 0x04:
|
||||||
|
response.event = struct.unpack('>i', payload[:4])[0]
|
||||||
|
payload = payload[4:]
|
||||||
|
|
||||||
|
# 解析message_type
|
||||||
|
if message_type == MessageType.SERVER_FULL_RESPONSE:
|
||||||
|
response.payload_size = struct.unpack('>I', payload[:4])[0]
|
||||||
|
payload = payload[4:]
|
||||||
|
elif message_type == MessageType.SERVER_ERROR_RESPONSE:
|
||||||
|
response.code = struct.unpack('>i', payload[:4])[0]
|
||||||
|
response.payload_size = struct.unpack('>I', payload[4:8])[0]
|
||||||
|
payload = payload[8:]
|
||||||
|
|
||||||
|
if not payload:
|
||||||
|
return response
|
||||||
|
|
||||||
|
# 解压缩
|
||||||
|
if message_compression == CompressionType.GZIP:
|
||||||
|
try:
|
||||||
|
payload = CommonUtils.gzip_decompress(payload)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to decompress payload: {e}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
# 解析payload
|
||||||
|
try:
|
||||||
|
if serialization_method == SerializationType.JSON:
|
||||||
|
response.payload_msg = json.loads(payload.decode('utf-8'))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse payload: {e}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
class AsrWsClient:
|
||||||
|
def __init__(self, url: str, segment_duration: int = 200):
|
||||||
|
self.seq = 1
|
||||||
|
self.url = url
|
||||||
|
self.segment_duration = segment_duration
|
||||||
|
self.conn = None
|
||||||
|
self.session = None # 添加session引用
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self.session = aiohttp.ClientSession()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
if self.conn and not self.conn.closed:
|
||||||
|
await self.conn.close()
|
||||||
|
if self.session and not self.session.closed:
|
||||||
|
await self.session.close()
|
||||||
|
|
||||||
|
async def read_audio_data(self, file_path: str) -> bytes:
|
||||||
|
try:
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
if not CommonUtils.judge_wav(content):
|
||||||
|
logger.info("Converting audio to WAV format...")
|
||||||
|
content = CommonUtils.convert_wav_with_path(file_path, DEFAULT_SAMPLE_RATE)
|
||||||
|
|
||||||
|
return content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to read audio data: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_segment_size(self, content: bytes) -> int:
|
||||||
|
try:
|
||||||
|
channel_num, samp_width, frame_rate, _, _ = CommonUtils.read_wav_info(content)[:5]
|
||||||
|
size_per_sec = channel_num * samp_width * frame_rate
|
||||||
|
segment_size = size_per_sec * self.segment_duration // 1000
|
||||||
|
return segment_size
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to calculate segment size: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def create_connection(self) -> None:
|
||||||
|
headers = RequestBuilder.new_auth_headers()
|
||||||
|
try:
|
||||||
|
self.conn = await self.session.ws_connect( # 使用self.session
|
||||||
|
self.url,
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
logger.info(f"Connected to {self.url}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to WebSocket: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def send_full_client_request(self) -> None:
|
||||||
|
request = RequestBuilder.new_full_client_request(self.seq)
|
||||||
|
self.seq += 1 # 发送后递增
|
||||||
|
try:
|
||||||
|
await self.conn.send_bytes(request)
|
||||||
|
logger.info(f"Sent full client request with seq: {self.seq-1}")
|
||||||
|
|
||||||
|
msg = await self.conn.receive()
|
||||||
|
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||||
|
response = ResponseParser.parse_response(msg.data)
|
||||||
|
logger.info(f"Received response: {response.to_dict()}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Unexpected message type: {msg.type}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send full client request: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def send_messages(self, segment_size: int, content: bytes) -> AsyncGenerator[None, None]:
|
||||||
|
audio_segments = self.split_audio(content, segment_size)
|
||||||
|
total_segments = len(audio_segments)
|
||||||
|
|
||||||
|
for i, segment in enumerate(audio_segments):
|
||||||
|
is_last = (i == total_segments - 1)
|
||||||
|
request = RequestBuilder.new_audio_only_request(
|
||||||
|
self.seq,
|
||||||
|
segment,
|
||||||
|
is_last=is_last
|
||||||
|
)
|
||||||
|
await self.conn.send_bytes(request)
|
||||||
|
logger.info(f"Sent audio segment with seq: {self.seq} (last: {is_last})")
|
||||||
|
|
||||||
|
if not is_last:
|
||||||
|
self.seq += 1
|
||||||
|
|
||||||
|
await asyncio.sleep(self.segment_duration / 1000) # 逐个发送,间隔时间模拟实时流
|
||||||
|
# 让出控制权,允许接受消息
|
||||||
|
yield
|
||||||
|
|
||||||
|
async def recv_messages(self) -> AsyncGenerator[AsrResponse, None]:
|
||||||
|
try:
|
||||||
|
async for msg in self.conn:
|
||||||
|
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||||
|
response = ResponseParser.parse_response(msg.data)
|
||||||
|
yield response
|
||||||
|
|
||||||
|
if response.is_last_package or response.code != 0:
|
||||||
|
break
|
||||||
|
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||||
|
logger.error(f"WebSocket error: {msg.data}")
|
||||||
|
break
|
||||||
|
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||||
|
logger.info("WebSocket connection closed")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error receiving messages: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def start_audio_stream(self, segment_size: int, content: bytes) -> AsyncGenerator[AsrResponse, None]:
|
||||||
|
async def sender():
|
||||||
|
async for _ in self.send_messages(segment_size, content):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 启动发送和接收任务
|
||||||
|
sender_task = asyncio.create_task(sender())
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for response in self.recv_messages():
|
||||||
|
yield response
|
||||||
|
finally:
|
||||||
|
sender_task.cancel()
|
||||||
|
try:
|
||||||
|
await sender_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_audio(data: bytes, segment_size: int) -> List[bytes]:
|
||||||
|
if segment_size <= 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
segments = []
|
||||||
|
for i in range(0, len(data), segment_size):
|
||||||
|
end = i + segment_size
|
||||||
|
if end > len(data):
|
||||||
|
end = len(data)
|
||||||
|
segments.append(data[i:end])
|
||||||
|
return segments
|
||||||
|
|
||||||
|
async def execute(self, file_path: str) -> AsyncGenerator[AsrResponse, None]:
|
||||||
|
if not file_path:
|
||||||
|
raise ValueError("File path is empty")
|
||||||
|
|
||||||
|
if not self.url:
|
||||||
|
raise ValueError("URL is empty")
|
||||||
|
|
||||||
|
self.seq = 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 读取音频文件
|
||||||
|
content = await self.read_audio_data(file_path)
|
||||||
|
|
||||||
|
# 2. 计算分段大小
|
||||||
|
segment_size = self.get_segment_size(content)
|
||||||
|
|
||||||
|
# 3. 创建WebSocket连接
|
||||||
|
await self.create_connection()
|
||||||
|
|
||||||
|
# 4. 发送完整客户端请求
|
||||||
|
await self.send_full_client_request()
|
||||||
|
|
||||||
|
# 5. 启动音频流处理
|
||||||
|
async for response in self.start_audio_stream(segment_size, content):
|
||||||
|
yield response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in ASR execution: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if self.conn:
|
||||||
|
await self.conn.close()
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="ASR WebSocket Client")
|
||||||
|
parser.add_argument("--file", type=str, required=True, help="Audio file path")
|
||||||
|
|
||||||
|
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel
|
||||||
|
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async
|
||||||
|
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream
|
||||||
|
parser.add_argument("--url", type=str, default="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream",
|
||||||
|
help="WebSocket URL")
|
||||||
|
parser.add_argument("--seg-duration", type=int, default=200,
|
||||||
|
help="Audio duration(ms) per packet, default:200")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
async with AsrWsClient(args.url, args.seg_duration) as client: # 使用async with
|
||||||
|
try:
|
||||||
|
async for response in client.execute(args.file):
|
||||||
|
logger.info(f"Received response: {json.dumps(response.to_dict(), indent=2, ensure_ascii=False)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ASR processing failed: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
# 用法:
|
||||||
|
# python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav
|
||||||
@ -35,11 +35,11 @@ class SimpleWakeAndRecord:
|
|||||||
self.stream = None
|
self.stream = None
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
# 音频参数
|
# 音频参数 - 优化为树莓派3B
|
||||||
self.FORMAT = pyaudio.paInt16
|
self.FORMAT = pyaudio.paInt16
|
||||||
self.CHANNELS = 1
|
self.CHANNELS = 1
|
||||||
self.RATE = 16000
|
self.RATE = 8000 # 从16kHz降至8kHz,减少50%数据处理量
|
||||||
self.CHUNK_SIZE = 1024
|
self.CHUNK_SIZE = 2048 # 增大块大小,减少处理次数
|
||||||
|
|
||||||
# 录音相关
|
# 录音相关
|
||||||
self.recording = False
|
self.recording = False
|
||||||
@ -48,6 +48,19 @@ class SimpleWakeAndRecord:
|
|||||||
self.recording_start_time = None
|
self.recording_start_time = None
|
||||||
self.recording_recognizer = None # 录音时专用的识别器
|
self.recording_recognizer = None # 录音时专用的识别器
|
||||||
|
|
||||||
|
# 性能优化相关
|
||||||
|
self.audio_buffer = [] # 音频缓冲区
|
||||||
|
self.buffer_size = 10 # 缓冲区大小(块数)
|
||||||
|
self.last_process_time = time.time() # 上次处理时间
|
||||||
|
self.process_interval = 0.5 # 处理间隔(秒)
|
||||||
|
self.batch_process_size = 5 # 批处理大小
|
||||||
|
|
||||||
|
# 性能监控
|
||||||
|
self.process_count = 0
|
||||||
|
self.avg_process_time = 0
|
||||||
|
self.last_monitor_time = time.time()
|
||||||
|
self.monitor_interval = 5.0 # 监控间隔(秒)
|
||||||
|
|
||||||
# 阈值
|
# 阈值
|
||||||
self.text_silence_threshold = 3.0 # 3秒没有识别到文字就结束
|
self.text_silence_threshold = 3.0 # 3秒没有识别到文字就结束
|
||||||
self.min_recording_time = 2.0 # 最小录音时间
|
self.min_recording_time = 2.0 # 最小录音时间
|
||||||
@ -116,6 +129,48 @@ class SimpleWakeAndRecord:
|
|||||||
return True, wake_word
|
return True, wake_word
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
|
def _should_process_audio(self):
|
||||||
|
"""判断是否应该处理音频"""
|
||||||
|
current_time = time.time()
|
||||||
|
return (current_time - self.last_process_time >= self.process_interval and
|
||||||
|
len(self.audio_buffer) >= self.batch_process_size)
|
||||||
|
|
||||||
|
def _process_audio_batch(self):
|
||||||
|
"""批量处理音频数据"""
|
||||||
|
if len(self.audio_buffer) < self.batch_process_size:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 记录处理开始时间
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 取出批处理数据
|
||||||
|
batch_data = self.audio_buffer[:self.batch_process_size]
|
||||||
|
self.audio_buffer = self.audio_buffer[self.batch_process_size:]
|
||||||
|
|
||||||
|
# 合并音频数据
|
||||||
|
combined_data = b''.join(batch_data)
|
||||||
|
|
||||||
|
# 更新处理时间
|
||||||
|
self.last_process_time = time.time()
|
||||||
|
|
||||||
|
# 更新性能统计
|
||||||
|
process_time = time.time() - start_time
|
||||||
|
self.process_count += 1
|
||||||
|
self.avg_process_time = (self.avg_process_time * (self.process_count - 1) + process_time) / self.process_count
|
||||||
|
|
||||||
|
# 性能监控
|
||||||
|
self._monitor_performance()
|
||||||
|
|
||||||
|
return combined_data
|
||||||
|
|
||||||
|
def _monitor_performance(self):
|
||||||
|
"""性能监控"""
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time - self.last_monitor_time >= self.monitor_interval:
|
||||||
|
buffer_usage = len(self.audio_buffer) / self.buffer_size * 100
|
||||||
|
print(f"\n📊 性能监控 | 处理次数: {self.process_count} | 平均处理时间: {self.avg_process_time:.3f}s | 缓冲区使用: {buffer_usage:.1f}%")
|
||||||
|
self.last_monitor_time = current_time
|
||||||
|
|
||||||
def _save_recording(self, audio_data):
|
def _save_recording(self, audio_data):
|
||||||
"""保存录音"""
|
"""保存录音"""
|
||||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||||
@ -262,13 +317,21 @@ class SimpleWakeAndRecord:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if self.recording:
|
if self.recording:
|
||||||
# 录音模式
|
# 录音模式 - 直接处理
|
||||||
self.recorded_frames.append(data)
|
self.recorded_frames.append(data)
|
||||||
recording_duration = time.time() - self.recording_start_time
|
recording_duration = time.time() - self.recording_start_time
|
||||||
|
|
||||||
# 使用录音专用的识别器进行实时识别
|
# 录音时使用批处理进行识别
|
||||||
if self.recording_recognizer:
|
self.audio_buffer.append(data)
|
||||||
if self.recording_recognizer.AcceptWaveform(data):
|
|
||||||
|
# 限制缓冲区大小
|
||||||
|
if len(self.audio_buffer) > self.buffer_size:
|
||||||
|
self.audio_buffer.pop(0)
|
||||||
|
|
||||||
|
# 批处理识别
|
||||||
|
if self._should_process_audio() and self.recording_recognizer:
|
||||||
|
combined_data = self._process_audio_batch()
|
||||||
|
if combined_data and self.recording_recognizer.AcceptWaveform(combined_data):
|
||||||
# 获取最终识别结果
|
# 获取最终识别结果
|
||||||
result = json.loads(self.recording_recognizer.Result())
|
result = json.loads(self.recording_recognizer.Result())
|
||||||
text = result.get('text', '').strip()
|
text = result.get('text', '').strip()
|
||||||
@ -277,7 +340,7 @@ class SimpleWakeAndRecord:
|
|||||||
# 识别到文字,更新时间戳
|
# 识别到文字,更新时间戳
|
||||||
self.last_text_time = time.time()
|
self.last_text_time = time.time()
|
||||||
print(f"\n📝 识别: {text}")
|
print(f"\n📝 识别: {text}")
|
||||||
else:
|
elif combined_data:
|
||||||
# 获取部分识别结果
|
# 获取部分识别结果
|
||||||
partial_result = json.loads(self.recording_recognizer.PartialResult())
|
partial_result = json.loads(self.recording_recognizer.PartialResult())
|
||||||
partial_text = partial_result.get('partial', '').strip()
|
partial_text = partial_result.get('partial', '').strip()
|
||||||
@ -314,32 +377,41 @@ class SimpleWakeAndRecord:
|
|||||||
print(f"\r{status}", end='', flush=True)
|
print(f"\r{status}", end='', flush=True)
|
||||||
|
|
||||||
elif self.model and self.recognizer:
|
elif self.model and self.recognizer:
|
||||||
# 唤醒词检测模式
|
# 唤醒词检测模式 - 使用批处理
|
||||||
if self.recognizer.AcceptWaveform(data):
|
self.audio_buffer.append(data)
|
||||||
result = json.loads(self.recognizer.Result())
|
|
||||||
text = result.get('text', '').strip()
|
|
||||||
|
|
||||||
if text:
|
# 限制缓冲区大小
|
||||||
print(f"识别: {text}")
|
if len(self.audio_buffer) > self.buffer_size:
|
||||||
|
self.audio_buffer.pop(0)
|
||||||
|
|
||||||
# 检查唤醒词
|
# 批处理识别
|
||||||
is_wake_word, detected_word = self._check_wake_word(text)
|
if self._should_process_audio():
|
||||||
if is_wake_word:
|
combined_data = self._process_audio_batch()
|
||||||
print(f"🎯 检测到唤醒词: {detected_word}")
|
if combined_data and self.recognizer.AcceptWaveform(combined_data):
|
||||||
self._start_recording()
|
result = json.loads(self.recognizer.Result())
|
||||||
else:
|
text = result.get('text', '').strip()
|
||||||
# 显示实时音频级别
|
|
||||||
energy = self._calculate_energy(data)
|
|
||||||
if energy > 50: # 只显示有意义的音频级别
|
|
||||||
partial_result = json.loads(self.recognizer.PartialResult())
|
|
||||||
partial_text = partial_result.get('partial', '')
|
|
||||||
if partial_text:
|
|
||||||
status = f"监听中... 能量: {energy:.0f} | {partial_text}"
|
|
||||||
else:
|
|
||||||
status = f"监听中... 能量: {energy:.0f}"
|
|
||||||
print(status, end='\r')
|
|
||||||
|
|
||||||
time.sleep(0.01)
|
if text:
|
||||||
|
print(f"识别: {text}")
|
||||||
|
|
||||||
|
# 检查唤醒词
|
||||||
|
is_wake_word, detected_word = self._check_wake_word(text)
|
||||||
|
if is_wake_word:
|
||||||
|
print(f"🎯 检测到唤醒词: {detected_word}")
|
||||||
|
self._start_recording()
|
||||||
|
else:
|
||||||
|
# 显示实时音频级别
|
||||||
|
energy = self._calculate_energy(data)
|
||||||
|
if energy > 50: # 只显示有意义的音频级别
|
||||||
|
partial_result = json.loads(self.recognizer.PartialResult())
|
||||||
|
partial_text = partial_result.get('partial', '')
|
||||||
|
if partial_text:
|
||||||
|
status = f"监听中... 能量: {energy:.0f} | {partial_text}"
|
||||||
|
else:
|
||||||
|
status = f"监听中... 能量: {energy:.0f}"
|
||||||
|
print(status, end='\r')
|
||||||
|
|
||||||
|
time.sleep(0.05) # 增加延迟,减少CPU使用
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\n👋 退出")
|
print("\n👋 退出")
|
||||||
@ -394,6 +466,12 @@ def main():
|
|||||||
print("5. 录音文件自动保存")
|
print("5. 录音文件自动保存")
|
||||||
print("6. 录音完成后自动播放刚才录制的内容")
|
print("6. 录音完成后自动播放刚才录制的内容")
|
||||||
print("7. 按 Ctrl+C 退出")
|
print("7. 按 Ctrl+C 退出")
|
||||||
|
print("🚀 性能优化已启用:")
|
||||||
|
print(" - 采样率: 8kHz (降低50%数据量)")
|
||||||
|
print(" - 批处理: 5个音频块/次")
|
||||||
|
print(" - 处理间隔: 0.5秒")
|
||||||
|
print(" - 缓冲区: 10个音频块")
|
||||||
|
print(" - 性能监控: 每5秒显示")
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
# 开始运行
|
# 开始运行
|
||||||
|
|||||||
532
speech_recognizer.py
Normal file
532
speech_recognizer.py
Normal file
@ -0,0 +1,532 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
语音识别模块
|
||||||
|
基于 SAUC API 为录音文件提供语音识别功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
import struct
|
||||||
|
import gzip
|
||||||
|
import uuid
|
||||||
|
from typing import Optional, List, Dict, Any, AsyncGenerator
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 常量定义
|
||||||
|
DEFAULT_SAMPLE_RATE = 16000
|
||||||
|
|
||||||
|
class ProtocolVersion:
|
||||||
|
V1 = 0b0001
|
||||||
|
|
||||||
|
class MessageType:
|
||||||
|
CLIENT_FULL_REQUEST = 0b0001
|
||||||
|
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||||
|
SERVER_FULL_RESPONSE = 0b1001
|
||||||
|
SERVER_ERROR_RESPONSE = 0b1111
|
||||||
|
|
||||||
|
class MessageTypeSpecificFlags:
|
||||||
|
NO_SEQUENCE = 0b0000
|
||||||
|
POS_SEQUENCE = 0b0001
|
||||||
|
NEG_SEQUENCE = 0b0010
|
||||||
|
NEG_WITH_SEQUENCE = 0b0011
|
||||||
|
|
||||||
|
class SerializationType:
|
||||||
|
NO_SERIALIZATION = 0b0000
|
||||||
|
JSON = 0b0001
|
||||||
|
|
||||||
|
class CompressionType:
|
||||||
|
GZIP = 0b0001
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RecognitionResult:
|
||||||
|
"""语音识别结果"""
|
||||||
|
text: str
|
||||||
|
confidence: float
|
||||||
|
is_final: bool
|
||||||
|
start_time: Optional[float] = None
|
||||||
|
end_time: Optional[float] = None
|
||||||
|
|
||||||
|
class AudioUtils:
|
||||||
|
"""音频处理工具类"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gzip_compress(data: bytes) -> bytes:
|
||||||
|
"""GZIP压缩"""
|
||||||
|
return gzip.compress(data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gzip_decompress(data: bytes) -> bytes:
|
||||||
|
"""GZIP解压缩"""
|
||||||
|
return gzip.decompress(data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_wav_file(data: bytes) -> bool:
|
||||||
|
"""检查是否为WAV文件"""
|
||||||
|
if len(data) < 44:
|
||||||
|
return False
|
||||||
|
return data[:4] == b'RIFF' and data[8:12] == b'WAVE'
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def read_wav_info(data: bytes) -> tuple:
|
||||||
|
"""读取WAV文件信息"""
|
||||||
|
if len(data) < 44:
|
||||||
|
raise ValueError("Invalid WAV file: too short")
|
||||||
|
|
||||||
|
# 解析WAV头
|
||||||
|
chunk_id = data[:4]
|
||||||
|
if chunk_id != b'RIFF':
|
||||||
|
raise ValueError("Invalid WAV file: not RIFF format")
|
||||||
|
|
||||||
|
format_ = data[8:12]
|
||||||
|
if format_ != b'WAVE':
|
||||||
|
raise ValueError("Invalid WAV file: not WAVE format")
|
||||||
|
|
||||||
|
# 解析fmt子块
|
||||||
|
audio_format = struct.unpack('<H', data[20:22])[0]
|
||||||
|
num_channels = struct.unpack('<H', data[22:24])[0]
|
||||||
|
sample_rate = struct.unpack('<I', data[24:28])[0]
|
||||||
|
bits_per_sample = struct.unpack('<H', data[34:36])[0]
|
||||||
|
|
||||||
|
# 查找data子块
|
||||||
|
pos = 36
|
||||||
|
while pos < len(data) - 8:
|
||||||
|
subchunk_id = data[pos:pos+4]
|
||||||
|
subchunk_size = struct.unpack('<I', data[pos+4:pos+8])[0]
|
||||||
|
if subchunk_id == b'data':
|
||||||
|
wave_data = data[pos+8:pos+8+subchunk_size]
|
||||||
|
return (
|
||||||
|
num_channels,
|
||||||
|
bits_per_sample // 8,
|
||||||
|
sample_rate,
|
||||||
|
subchunk_size // (num_channels * (bits_per_sample // 8)),
|
||||||
|
wave_data
|
||||||
|
)
|
||||||
|
pos += 8 + subchunk_size
|
||||||
|
|
||||||
|
raise ValueError("Invalid WAV file: no data subchunk found")
|
||||||
|
|
||||||
|
class AsrConfig:
|
||||||
|
"""ASR配置"""
|
||||||
|
|
||||||
|
def __init__(self, app_key: str = None, access_key: str = None):
|
||||||
|
self.auth = {
|
||||||
|
"app_key": app_key or os.getenv("SAUC_APP_KEY", "your_app_key"),
|
||||||
|
"access_key": access_key or os.getenv("SAUC_ACCESS_KEY", "your_access_key")
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def app_key(self) -> str:
|
||||||
|
return self.auth["app_key"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def access_key(self) -> str:
|
||||||
|
return self.auth["access_key"]
|
||||||
|
|
||||||
|
class AsrRequestHeader:
|
||||||
|
"""ASR请求头"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.message_type = MessageType.CLIENT_FULL_REQUEST
|
||||||
|
self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE
|
||||||
|
self.serialization_type = SerializationType.JSON
|
||||||
|
self.compression_type = CompressionType.GZIP
|
||||||
|
self.reserved_data = bytes([0x00])
|
||||||
|
|
||||||
|
def with_message_type(self, message_type: int) -> 'AsrRequestHeader':
|
||||||
|
self.message_type = message_type
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader':
|
||||||
|
self.message_type_specific_flags = flags
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader':
|
||||||
|
self.serialization_type = serialization_type
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader':
|
||||||
|
self.compression_type = compression_type
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader':
|
||||||
|
self.reserved_data = reserved_data
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to_bytes(self) -> bytes:
|
||||||
|
header = bytearray()
|
||||||
|
header.append((ProtocolVersion.V1 << 4) | 1)
|
||||||
|
header.append((self.message_type << 4) | self.message_type_specific_flags)
|
||||||
|
header.append((self.serialization_type << 4) | self.compression_type)
|
||||||
|
header.extend(self.reserved_data)
|
||||||
|
return bytes(header)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_header() -> 'AsrRequestHeader':
|
||||||
|
return AsrRequestHeader()
|
||||||
|
|
||||||
|
class RequestBuilder:
|
||||||
|
"""请求构建器"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_auth_headers(config: AsrConfig) -> Dict[str, str]:
|
||||||
|
"""创建认证头"""
|
||||||
|
reqid = str(uuid.uuid4())
|
||||||
|
return {
|
||||||
|
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
|
||||||
|
"X-Api-Request-Id": reqid,
|
||||||
|
"X-Api-Access-Key": config.access_key,
|
||||||
|
"X-Api-App-Key": config.app_key
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_full_client_request(seq: int) -> bytes:
|
||||||
|
"""创建完整客户端请求"""
|
||||||
|
header = AsrRequestHeader.default_header() \
|
||||||
|
.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"user": {
|
||||||
|
"uid": "local_voice_user"
|
||||||
|
},
|
||||||
|
"audio": {
|
||||||
|
"format": "wav",
|
||||||
|
"codec": "raw",
|
||||||
|
"rate": 16000,
|
||||||
|
"bits": 16,
|
||||||
|
"channel": 1
|
||||||
|
},
|
||||||
|
"request": {
|
||||||
|
"model_name": "bigmodel",
|
||||||
|
"enable_itn": True,
|
||||||
|
"enable_punc": True,
|
||||||
|
"enable_ddc": True,
|
||||||
|
"show_utterances": True,
|
||||||
|
"enable_nonstream": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||||
|
compressed_payload = AudioUtils.gzip_compress(payload_bytes)
|
||||||
|
payload_size = len(compressed_payload)
|
||||||
|
|
||||||
|
request = bytearray()
|
||||||
|
request.extend(header.to_bytes())
|
||||||
|
request.extend(struct.pack('>i', seq))
|
||||||
|
request.extend(struct.pack('>U', payload_size))
|
||||||
|
request.extend(compressed_payload)
|
||||||
|
|
||||||
|
return bytes(request)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes:
|
||||||
|
"""创建纯音频请求"""
|
||||||
|
header = AsrRequestHeader.default_header()
|
||||||
|
if is_last:
|
||||||
|
header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE)
|
||||||
|
seq = -seq
|
||||||
|
else:
|
||||||
|
header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
|
||||||
|
header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST)
|
||||||
|
|
||||||
|
request = bytearray()
|
||||||
|
request.extend(header.to_bytes())
|
||||||
|
request.extend(struct.pack('>i', seq))
|
||||||
|
|
||||||
|
compressed_segment = AudioUtils.gzip_compress(segment)
|
||||||
|
request.extend(struct.pack('>U', len(compressed_segment)))
|
||||||
|
request.extend(compressed_segment)
|
||||||
|
|
||||||
|
return bytes(request)
|
||||||
|
|
||||||
|
class AsrResponse:
|
||||||
|
"""ASR响应"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.code = 0
|
||||||
|
self.event = 0
|
||||||
|
self.is_last_package = False
|
||||||
|
self.payload_sequence = 0
|
||||||
|
self.payload_size = 0
|
||||||
|
self.payload_msg = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"code": self.code,
|
||||||
|
"event": self.event,
|
||||||
|
"is_last_package": self.is_last_package,
|
||||||
|
"payload_sequence": self.payload_sequence,
|
||||||
|
"payload_size": self.payload_size,
|
||||||
|
"payload_msg": self.payload_msg
|
||||||
|
}
|
||||||
|
|
||||||
|
class ResponseParser:
|
||||||
|
"""响应解析器"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_response(msg: bytes) -> AsrResponse:
|
||||||
|
"""解析响应"""
|
||||||
|
response = AsrResponse()
|
||||||
|
|
||||||
|
header_size = msg[0] & 0x0f
|
||||||
|
message_type = msg[1] >> 4
|
||||||
|
message_type_specific_flags = msg[1] & 0x0f
|
||||||
|
serialization_method = msg[2] >> 4
|
||||||
|
message_compression = msg[2] & 0x0f
|
||||||
|
|
||||||
|
payload = msg[header_size*4:]
|
||||||
|
|
||||||
|
# 解析message_type_specific_flags
|
||||||
|
if message_type_specific_flags & 0x01:
|
||||||
|
response.payload_sequence = struct.unpack('>i', payload[:4])[0]
|
||||||
|
payload = payload[4:]
|
||||||
|
if message_type_specific_flags & 0x02:
|
||||||
|
response.is_last_package = True
|
||||||
|
if message_type_specific_flags & 0x04:
|
||||||
|
response.event = struct.unpack('>i', payload[:4])[0]
|
||||||
|
payload = payload[4:]
|
||||||
|
|
||||||
|
# 解析message_type
|
||||||
|
if message_type == MessageType.SERVER_FULL_RESPONSE:
|
||||||
|
response.payload_size = struct.unpack('>U', payload[:4])[0]
|
||||||
|
payload = payload[4:]
|
||||||
|
elif message_type == MessageType.SERVER_ERROR_RESPONSE:
|
||||||
|
response.code = struct.unpack('>i', payload[:4])[0]
|
||||||
|
response.payload_size = struct.unpack('>U', payload[4:8])[0]
|
||||||
|
payload = payload[8:]
|
||||||
|
|
||||||
|
if not payload:
|
||||||
|
return response
|
||||||
|
|
||||||
|
# 解压缩
|
||||||
|
if message_compression == CompressionType.GZIP:
|
||||||
|
try:
|
||||||
|
payload = AudioUtils.gzip_decompress(payload)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to decompress payload: {e}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
# 解析payload
|
||||||
|
try:
|
||||||
|
if serialization_method == SerializationType.JSON:
|
||||||
|
response.payload_msg = json.loads(payload.decode('utf-8'))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse payload: {e}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
class SpeechRecognizer:
|
||||||
|
"""语音识别器"""
|
||||||
|
|
||||||
|
def __init__(self, app_key: str = None, access_key: str = None,
|
||||||
|
url: str = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream"):
|
||||||
|
self.config = AsrConfig(app_key, access_key)
|
||||||
|
self.url = url
|
||||||
|
self.seq = 1
|
||||||
|
|
||||||
|
async def recognize_file(self, file_path: str) -> List[RecognitionResult]:
|
||||||
|
"""识别音频文件"""
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"Audio file not found: {file_path}")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# 读取音频文件
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
if not AudioUtils.is_wav_file(content):
|
||||||
|
raise ValueError("Audio file must be in WAV format")
|
||||||
|
|
||||||
|
# 获取音频信息
|
||||||
|
try:
|
||||||
|
_, _, sample_rate, _, audio_data = AudioUtils.read_wav_info(content)
|
||||||
|
if sample_rate != DEFAULT_SAMPLE_RATE:
|
||||||
|
logger.warning(f"Sample rate {sample_rate} != {DEFAULT_SAMPLE_RATE}, may affect recognition accuracy")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to read audio info: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 计算分段大小 (200ms per segment)
|
||||||
|
segment_size = 1 * 2 * DEFAULT_SAMPLE_RATE * 200 // 1000 # channel * bytes_per_sample * sample_rate * duration_ms / 1000
|
||||||
|
|
||||||
|
# 创建WebSocket连接
|
||||||
|
headers = RequestBuilder.new_auth_headers(self.config)
|
||||||
|
async with session.ws_connect(self.url, headers=headers) as ws:
|
||||||
|
|
||||||
|
# 发送完整客户端请求
|
||||||
|
request = RequestBuilder.new_full_client_request(self.seq)
|
||||||
|
self.seq += 1
|
||||||
|
await ws.send_bytes(request)
|
||||||
|
|
||||||
|
# 接收初始响应
|
||||||
|
msg = await ws.receive()
|
||||||
|
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||||
|
response = ResponseParser.parse_response(msg.data)
|
||||||
|
logger.info(f"Initial response: {response.to_dict()}")
|
||||||
|
|
||||||
|
# 分段发送音频数据
|
||||||
|
audio_segments = self._split_audio(audio_data, segment_size)
|
||||||
|
total_segments = len(audio_segments)
|
||||||
|
|
||||||
|
for i, segment in enumerate(audio_segments):
|
||||||
|
is_last = (i == total_segments - 1)
|
||||||
|
request = RequestBuilder.new_audio_only_request(
|
||||||
|
self.seq,
|
||||||
|
segment,
|
||||||
|
is_last=is_last
|
||||||
|
)
|
||||||
|
await ws.send_bytes(request)
|
||||||
|
logger.info(f"Sent audio segment {i+1}/{total_segments}")
|
||||||
|
|
||||||
|
if not is_last:
|
||||||
|
self.seq += 1
|
||||||
|
|
||||||
|
# 短暂延迟模拟实时流
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# 接收识别结果
|
||||||
|
final_text = ""
|
||||||
|
while True:
|
||||||
|
msg = await ws.receive()
|
||||||
|
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||||
|
response = ResponseParser.parse_response(msg.data)
|
||||||
|
|
||||||
|
if response.payload_msg and 'text' in response.payload_msg:
|
||||||
|
text = response.payload_msg['text']
|
||||||
|
if text:
|
||||||
|
final_text += text
|
||||||
|
|
||||||
|
result = RecognitionResult(
|
||||||
|
text=text,
|
||||||
|
confidence=0.9, # 默认置信度
|
||||||
|
is_final=response.is_last_package
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
logger.info(f"Recognized: {text}")
|
||||||
|
|
||||||
|
if response.is_last_package or response.code != 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||||
|
logger.error(f"WebSocket error: {msg.data}")
|
||||||
|
break
|
||||||
|
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||||
|
logger.info("WebSocket connection closed")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果没有获得最终结果,创建一个包含所有文本的结果
|
||||||
|
if final_text and not any(r.is_final for r in results):
|
||||||
|
final_result = RecognitionResult(
|
||||||
|
text=final_text,
|
||||||
|
confidence=0.9,
|
||||||
|
is_final=True
|
||||||
|
)
|
||||||
|
results.append(final_result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Speech recognition failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _split_audio(self, data: bytes, segment_size: int) -> List[bytes]:
|
||||||
|
"""分割音频数据"""
|
||||||
|
if segment_size <= 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
segments = []
|
||||||
|
for i in range(0, len(data), segment_size):
|
||||||
|
end = i + segment_size
|
||||||
|
if end > len(data):
|
||||||
|
end = len(data)
|
||||||
|
segments.append(data[i:end])
|
||||||
|
return segments
|
||||||
|
|
||||||
|
async def recognize_latest_recording(self, directory: str = ".") -> Optional[RecognitionResult]:
|
||||||
|
"""识别最新的录音文件"""
|
||||||
|
# 查找最新的录音文件
|
||||||
|
recording_files = [f for f in os.listdir(directory) if f.startswith('recording_') and f.endswith('.wav')]
|
||||||
|
|
||||||
|
if not recording_files:
|
||||||
|
logger.warning("No recording files found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 按文件名排序(包含时间戳)
|
||||||
|
recording_files.sort(reverse=True)
|
||||||
|
latest_file = recording_files[0]
|
||||||
|
latest_path = os.path.join(directory, latest_file)
|
||||||
|
|
||||||
|
logger.info(f"Recognizing latest recording: {latest_file}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = await self.recognize_file(latest_path)
|
||||||
|
if results:
|
||||||
|
# 返回最终的识别结果
|
||||||
|
final_results = [r for r in results if r.is_final]
|
||||||
|
if final_results:
|
||||||
|
return final_results[-1]
|
||||||
|
else:
|
||||||
|
# 如果没有标记为final的结果,返回最后一个
|
||||||
|
return results[-1]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to recognize latest recording: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""测试函数"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="语音识别测试")
|
||||||
|
parser.add_argument("--file", type=str, help="音频文件路径")
|
||||||
|
parser.add_argument("--latest", action="store_true", help="识别最新的录音文件")
|
||||||
|
parser.add_argument("--app-key", type=str, help="SAUC App Key")
|
||||||
|
parser.add_argument("--access-key", type=str, help="SAUC Access Key")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
recognizer = SpeechRecognizer(
|
||||||
|
app_key=args.app_key,
|
||||||
|
access_key=args.access_key
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if args.latest:
|
||||||
|
result = await recognizer.recognize_latest_recording()
|
||||||
|
if result:
|
||||||
|
print(f"识别结果: {result.text}")
|
||||||
|
print(f"置信度: {result.confidence}")
|
||||||
|
print(f"最终结果: {result.is_final}")
|
||||||
|
else:
|
||||||
|
print("未能识别到语音内容")
|
||||||
|
elif args.file:
|
||||||
|
results = await recognizer.recognize_file(args.file)
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
print(f"结果 {i+1}: {result.text}")
|
||||||
|
print(f"置信度: {result.confidence}")
|
||||||
|
print(f"最终结果: {result.is_final}")
|
||||||
|
print("-" * 40)
|
||||||
|
else:
|
||||||
|
print("请指定 --file 或 --latest 参数")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"识别失败: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
Loading…
Reference in New Issue
Block a user