From eb099d827d7f4f679bafbd8068a4c46a663457de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Sat, 20 Sep 2025 10:53:56 +0800 Subject: [PATCH] config --- enhanced_wake_and_record.py | 501 +++++++++++++++++++++++++++ recognition_example.py | 127 +++++++ requirements.txt | 4 +- sauc_python/readme.md | 15 + sauc_python/sauc_websocket_demo.py | 523 ++++++++++++++++++++++++++++ simple_wake_and_record.py | 142 ++++++-- speech_recognizer.py | 532 +++++++++++++++++++++++++++++ 7 files changed, 1811 insertions(+), 33 deletions(-) create mode 100644 enhanced_wake_and_record.py create mode 100644 recognition_example.py create mode 100644 sauc_python/readme.md create mode 100644 sauc_python/sauc_websocket_demo.py create mode 100644 speech_recognizer.py diff --git a/enhanced_wake_and_record.py b/enhanced_wake_and_record.py new file mode 100644 index 0000000..e55fc08 --- /dev/null +++ b/enhanced_wake_and_record.py @@ -0,0 +1,501 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +集成语音识别的唤醒+录音系统 +基于 simple_wake_and_record.py,添加语音识别功能 +""" + +import sys +import os +import time +import threading +import pyaudio +import json +import asyncio +from typing import Optional, List + +# 添加当前目录到路径 +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +try: + from vosk import Model, KaldiRecognizer + VOSK_AVAILABLE = True +except ImportError: + VOSK_AVAILABLE = False + print("⚠️ Vosk 未安装,请运行: pip install vosk") + +from speech_recognizer import SpeechRecognizer, RecognitionResult + +class EnhancedWakeAndRecord: + """增强的唤醒+录音系统,集成语音识别""" + + def __init__(self, model_path="model", wake_words=["你好", "助手"], + enable_speech_recognition=True, app_key=None, access_key=None): + self.model_path = model_path + self.wake_words = wake_words + self.enable_speech_recognition = enable_speech_recognition + self.model = None + self.recognizer = None + self.audio = None + self.stream = None + self.running = False + + # 音频参数 + self.FORMAT = pyaudio.paInt16 + self.CHANNELS = 1 + self.RATE = 16000 + self.CHUNK_SIZE = 1024 + + # 录音相关 + self.recording = False + self.recorded_frames = [] + self.last_text_time = None + self.recording_start_time = None + self.recording_recognizer = None + + # 阈值 + self.text_silence_threshold = 3.0 + self.min_recording_time = 2.0 + self.max_recording_time = 30.0 + + # 语音识别相关 + self.speech_recognizer = None + self.last_recognition_result = None + self.recognition_thread = None + + # 回调函数 + self.on_recognition_result = None + + self._setup_model() + self._setup_audio() + self._setup_speech_recognition(app_key, access_key) + + def _setup_model(self): + """设置 Vosk 模型""" + if not VOSK_AVAILABLE: + return + + try: + if not os.path.exists(self.model_path): + print(f"模型路径不存在: {self.model_path}") + return + + self.model = Model(self.model_path) + self.recognizer = KaldiRecognizer(self.model, self.RATE) + self.recognizer.SetWords(True) + + print(f"✅ Vosk 模型加载成功") + + except Exception as e: + print(f"模型初始化失败: {e}") + + def _setup_audio(self): + """设置音频设备""" + try: + if self.audio is None: + self.audio = pyaudio.PyAudio() + + if self.stream is None: + self.stream = self.audio.open( + format=self.FORMAT, + channels=self.CHANNELS, + rate=self.RATE, + input=True, + frames_per_buffer=self.CHUNK_SIZE + ) + + print("✅ 音频设备初始化成功") + + except Exception as e: + print(f"音频设备初始化失败: {e}") + + def _setup_speech_recognition(self, app_key=None, access_key=None): + """设置语音识别""" + if not self.enable_speech_recognition: + return + + try: + self.speech_recognizer = SpeechRecognizer( + app_key=app_key, + access_key=access_key + ) + print("✅ 语音识别器初始化成功") + except Exception as e: + print(f"语音识别器初始化失败: {e}") + self.enable_speech_recognition = False + + def _calculate_energy(self, audio_data): + """计算音频能量""" + if len(audio_data) == 0: + return 0 + + import numpy as np + audio_array = np.frombuffer(audio_data, dtype=np.int16) + rms = np.sqrt(np.mean(audio_array ** 2)) + return rms + + def _check_wake_word(self, text): + """检查是否包含唤醒词""" + if not text or not self.wake_words: + return False, None + + text_lower = text.lower() + for wake_word in self.wake_words: + if wake_word.lower() in text_lower: + return True, wake_word + return False, None + + def _save_recording(self, audio_data): + """保存录音""" + timestamp = time.strftime("%Y%m%d_%H%M%S") + filename = f"recording_{timestamp}.wav" + + try: + import wave + with wave.open(filename, 'wb') as wf: + wf.setnchannels(self.CHANNELS) + wf.setsampwidth(self.audio.get_sample_size(self.FORMAT)) + wf.setframerate(self.RATE) + wf.writeframes(audio_data) + + print(f"✅ 录音已保存: {filename}") + return True, filename + except Exception as e: + print(f"保存录音失败: {e}") + return False, None + + def _play_audio(self, filename): + """播放音频文件""" + try: + import wave + + # 打开音频文件 + with wave.open(filename, 'rb') as wf: + # 获取音频参数 + channels = wf.getnchannels() + width = wf.getsampwidth() + rate = wf.getframerate() + total_frames = wf.getnframes() + + # 分块读取音频数据,避免内存问题 + chunk_size = 1024 + frames = [] + + for _ in range(0, total_frames, chunk_size): + chunk = wf.readframes(chunk_size) + if chunk: + frames.append(chunk) + else: + break + + # 创建播放流 + playback_stream = self.audio.open( + format=self.audio.get_format_from_width(width), + channels=channels, + rate=rate, + output=True + ) + + print(f"🔊 开始播放: {filename}") + + # 分块播放音频 + for chunk in frames: + playback_stream.write(chunk) + + # 等待播放完成 + playback_stream.stop_stream() + playback_stream.close() + + print("✅ 播放完成") + + except Exception as e: + print(f"❌ 播放失败: {e}") + self._play_with_system_player(filename) + + def _play_with_system_player(self, filename): + """使用系统播放器播放音频""" + try: + import platform + import subprocess + + system = platform.system() + + if system == 'Darwin': # macOS + cmd = ['afplay', filename] + elif system == 'Windows': + cmd = ['start', '/min', filename] + else: # Linux + cmd = ['aplay', filename] + + print(f"🔊 使用系统播放器: {' '.join(cmd)}") + subprocess.run(cmd, check=True) + print("✅ 播放完成") + + except Exception as e: + print(f"❌ 系统播放器也失败: {e}") + print(f"💡 文件已保存,请手动播放: {filename}") + + def _start_recognition_thread(self, filename): + """启动语音识别线程""" + if not self.enable_speech_recognition or not self.speech_recognizer: + return + + def recognize_task(): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + print(f"🧠 开始识别录音文件: {filename}") + result = loop.run_until_complete( + self.speech_recognizer.recognize_file(filename) + ) + + if result: + # 合并所有识别结果 + full_text = " ".join([r.text for r in result]) + final_result = RecognitionResult( + text=full_text, + confidence=0.9, + is_final=True + ) + + self.last_recognition_result = final_result + print(f"\n🧠 语音识别结果: {full_text}") + + # 调用回调函数 + if self.on_recognition_result: + self.on_recognition_result(final_result) + else: + print(f"\n🧠 语音识别失败或未识别到内容") + + loop.close() + + except Exception as e: + print(f"❌ 语音识别线程异常: {e}") + + self.recognition_thread = threading.Thread(target=recognize_task) + self.recognition_thread.daemon = True + self.recognition_thread.start() + + def _start_recording(self): + """开始录音""" + print("🎙️ 开始录音,请说话...") + self.recording = True + self.recorded_frames = [] + self.last_text_time = None + self.recording_start_time = time.time() + + # 为录音创建一个新的识别器 + if self.model: + self.recording_recognizer = KaldiRecognizer(self.model, self.RATE) + self.recording_recognizer.SetWords(True) + + def _stop_recording(self): + """停止录音""" + if len(self.recorded_frames) > 0: + audio_data = b''.join(self.recorded_frames) + duration = len(audio_data) / (self.RATE * 2) + print(f"📝 录音完成,时长: {duration:.2f}秒") + + # 保存录音 + success, filename = self._save_recording(audio_data) + + # 如果保存成功,播放录音并进行语音识别 + if success and filename: + print("=" * 50) + print("🔊 播放刚才录制的音频...") + self._play_audio(filename) + print("=" * 50) + + # 启动语音识别 + if self.enable_speech_recognition: + print("🧠 准备进行语音识别...") + self._start_recognition_thread(filename) + + self.recording = False + self.recorded_frames = [] + self.last_text_time = None + self.recording_start_time = None + self.recording_recognizer = None + + def set_recognition_callback(self, callback): + """设置识别结果回调函数""" + self.on_recognition_result = callback + + def get_last_recognition_result(self) -> Optional[RecognitionResult]: + """获取最后一次识别结果""" + return self.last_recognition_result + + def start(self): + """开始唤醒词检测和录音""" + if not self.stream: + print("❌ 音频设备未初始化") + return + + self.running = True + print("🎤 开始监听...") + print(f"唤醒词: {', '.join(self.wake_words)}") + if self.enable_speech_recognition: + print("🧠 语音识别: 已启用") + else: + print("🧠 语音识别: 已禁用") + + try: + while self.running: + # 读取音频数据 + data = self.stream.read(self.CHUNK_SIZE, exception_on_overflow=False) + + if len(data) == 0: + continue + + if self.recording: + # 录音模式 + self.recorded_frames.append(data) + recording_duration = time.time() - self.recording_start_time + + # 使用录音专用的识别器进行实时识别 + if self.recording_recognizer: + if self.recording_recognizer.AcceptWaveform(data): + result = json.loads(self.recording_recognizer.Result()) + text = result.get('text', '').strip() + + if text: + self.last_text_time = time.time() + print(f"\n📝 实时识别: {text}") + else: + partial_result = json.loads(self.recording_recognizer.PartialResult()) + partial_text = partial_result.get('partial', '').strip() + + if partial_text: + self.last_text_time = time.time() + status = f"录音中... {recording_duration:.1f}s | {partial_text}" + print(f"\r{status}", end='', flush=True) + + # 检查是否需要结束录音 + current_time = time.time() + + if self.last_text_time is not None: + text_silence_duration = current_time - self.last_text_time + if text_silence_duration > self.text_silence_threshold and recording_duration >= self.min_recording_time: + print(f"\n\n3秒没有识别到文字,结束录音") + self._stop_recording() + else: + if recording_duration > 5.0: + print(f"\n\n5秒没有识别到文字,结束录音") + self._stop_recording() + + # 检查最大录音时间 + if recording_duration > self.max_recording_time: + print(f"\n\n达到最大录音时间 {self.max_recording_time}s") + self._stop_recording() + + # 显示录音状态 + if self.last_text_time is None: + status = f"等待语音输入... {recording_duration:.1f}s" + print(f"\r{status}", end='', flush=True) + + elif self.model and self.recognizer: + # 唤醒词检测模式 + if self.recognizer.AcceptWaveform(data): + result = json.loads(self.recognizer.Result()) + text = result.get('text', '').strip() + + if text: + print(f"识别: {text}") + + # 检查唤醒词 + is_wake_word, detected_word = self._check_wake_word(text) + if is_wake_word: + print(f"🎯 检测到唤醒词: {detected_word}") + self._start_recording() + else: + # 显示实时音频级别 + energy = self._calculate_energy(data) + if energy > 50: + partial_result = json.loads(self.recognizer.PartialResult()) + partial_text = partial_result.get('partial', '') + if partial_text: + status = f"监听中... 能量: {energy:.0f} | {partial_text}" + else: + status = f"监听中... 能量: {energy:.0f}" + print(status, end='\r') + + time.sleep(0.01) + + except KeyboardInterrupt: + print("\n👋 退出") + except Exception as e: + print(f"错误: {e}") + finally: + self.stop() + + def stop(self): + """停止""" + self.running = False + if self.recording: + self._stop_recording() + + if self.stream: + self.stream.stop_stream() + self.stream.close() + self.stream = None + + if self.audio: + self.audio.terminate() + self.audio = None + + # 等待识别线程结束 + if self.recognition_thread and self.recognition_thread.is_alive(): + self.recognition_thread.join(timeout=5.0) + +def main(): + """主函数""" + print("🚀 增强版唤醒+录音+语音识别测试") + print("=" * 50) + + # 检查模型 + model_dir = "model" + if not os.path.exists(model_dir): + print("⚠️ 未找到模型目录") + print("请下载 Vosk 模型到 model 目录") + return + + # 创建系统 + system = EnhancedWakeAndRecord( + model_path=model_dir, + wake_words=["你好", "助手", "小爱"], + enable_speech_recognition=True, + # app_key="your_app_key", # 请填入实际的app_key + # access_key="your_access_key" # 请填入实际的access_key + ) + + if not system.model: + print("❌ 模型加载失败") + return + + # 设置识别结果回调 + def on_recognition_result(result): + print(f"\n🎯 识别完成!结果: {result.text}") + print(f" 置信度: {result.confidence}") + print(f" 是否最终结果: {result.is_final}") + + system.set_recognition_callback(on_recognition_result) + + print("✅ 系统初始化成功") + print("📖 使用说明:") + print("1. 说唤醒词开始录音") + print("2. 基于语音识别判断,3秒没有识别到文字就结束") + print("3. 最少录音2秒,最多30秒") + print("4. 录音时实时显示识别结果") + print("5. 录音文件自动保存") + print("6. 录音完成后自动播放刚才录制的内容") + print("7. 启动语音识别对录音文件进行识别") + print("8. 按 Ctrl+C 退出") + print("=" * 50) + + # 开始运行 + system.start() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/recognition_example.py b/recognition_example.py new file mode 100644 index 0000000..c224b58 --- /dev/null +++ b/recognition_example.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +语音识别使用示例 +演示如何使用 speech_recognizer 模块 +""" + +import os +import asyncio +from speech_recognizer import SpeechRecognizer + +async def example_recognize_file(): + """示例:识别单个音频文件""" + print("=== 示例1:识别单个音频文件 ===") + + # 初始化识别器 + recognizer = SpeechRecognizer( + app_key="your_app_key", # 请替换为实际的app_key + access_key="your_access_key" # 请替换为实际的access_key + ) + + # 假设有一个录音文件 + audio_file = "recording_20240101_120000.wav" + + if not os.path.exists(audio_file): + print(f"音频文件不存在: {audio_file}") + print("请先运行 enhanced_wake_and_record.py 录制一个音频文件") + return + + try: + # 识别音频文件 + results = await recognizer.recognize_file(audio_file) + + print(f"识别结果(共{len(results)}个):") + for i, result in enumerate(results): + print(f"结果 {i+1}:") + print(f" 文本: {result.text}") + print(f" 置信度: {result.confidence}") + print(f" 最终结果: {result.is_final}") + print("-" * 40) + + except Exception as e: + print(f"识别失败: {e}") + +async def example_recognize_latest(): + """示例:识别最新的录音文件""" + print("\n=== 示例2:识别最新的录音文件 ===") + + # 初始化识别器 + recognizer = SpeechRecognizer( + app_key="your_app_key", # 请替换为实际的app_key + access_key="your_access_key" # 请替换为实际的access_key + ) + + try: + # 识别最新的录音文件 + result = await recognizer.recognize_latest_recording() + + if result: + print("识别结果:") + print(f" 文本: {result.text}") + print(f" 置信度: {result.confidence}") + print(f" 最终结果: {result.is_final}") + else: + print("未找到录音文件或识别失败") + + except Exception as e: + print(f"识别失败: {e}") + +async def example_batch_recognition(): + """示例:批量识别多个录音文件""" + print("\n=== 示例3:批量识别录音文件 ===") + + # 初始化识别器 + recognizer = SpeechRecognizer( + app_key="your_app_key", # 请替换为实际的app_key + access_key="your_access_key" # 请替换为实际的access_key + ) + + # 获取所有录音文件 + recording_files = [f for f in os.listdir(".") if f.startswith('recording_') and f.endswith('.wav')] + + if not recording_files: + print("未找到录音文件") + return + + print(f"找到 {len(recording_files)} 个录音文件") + + for filename in recording_files[:5]: # 只处理前5个文件 + print(f"\n处理文件: {filename}") + try: + results = await recognizer.recognize_file(filename) + + if results: + final_result = results[-1] # 取最后一个结果 + print(f"识别结果: {final_result.text}") + else: + print("识别失败") + + except Exception as e: + print(f"处理失败: {e}") + + # 添加延迟,避免请求过于频繁 + await asyncio.sleep(1) + +async def main(): + """主函数""" + print("🚀 语音识别使用示例") + print("=" * 50) + + # 请先设置环境变量或在代码中填入实际的API密钥 + if not os.getenv("SAUC_APP_KEY") and "your_app_key" in "your_app_key": + print("⚠️ 请先设置 SAUC_APP_KEY 和 SAUC_ACCESS_KEY 环境变量") + print("或者在代码中填入实际的 app_key 和 access_key") + print("示例:") + print("export SAUC_APP_KEY='your_app_key'") + print("export SAUC_ACCESS_KEY='your_access_key'") + return + + # 运行示例 + await example_recognize_file() + await example_recognize_latest() + await example_batch_recognition() + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9b6cb52..dbf157b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ vosk>=0.3.44 pyaudio>=0.2.11 -numpy>=1.19.0 \ No newline at end of file +numpy>=1.19.0 +aiohttp>=3.8.0 +asyncio \ No newline at end of file diff --git a/sauc_python/readme.md b/sauc_python/readme.md new file mode 100644 index 0000000..4dbcebd --- /dev/null +++ b/sauc_python/readme.md @@ -0,0 +1,15 @@ +# README + +**asr tob 相关client demo** + +# Notice +python version: python 3.x + +替换代码中的key为真实数据: + "app_key": "xxxxxxx", + "access_key": "xxxxxxxxxxxxxxxx" +使用示例: + python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav + + + diff --git a/sauc_python/sauc_websocket_demo.py b/sauc_python/sauc_websocket_demo.py new file mode 100644 index 0000000..092d24b --- /dev/null +++ b/sauc_python/sauc_websocket_demo.py @@ -0,0 +1,523 @@ +import asyncio +import aiohttp +import json +import struct +import gzip +import uuid +import logging +import os +import subprocess +from typing import Optional, List, Dict, Any, Tuple, AsyncGenerator + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler('run.log'), + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +# 常量定义 +DEFAULT_SAMPLE_RATE = 16000 + +class ProtocolVersion: + V1 = 0b0001 + +class MessageType: + CLIENT_FULL_REQUEST = 0b0001 + CLIENT_AUDIO_ONLY_REQUEST = 0b0010 + SERVER_FULL_RESPONSE = 0b1001 + SERVER_ERROR_RESPONSE = 0b1111 + +class MessageTypeSpecificFlags: + NO_SEQUENCE = 0b0000 + POS_SEQUENCE = 0b0001 + NEG_SEQUENCE = 0b0010 + NEG_WITH_SEQUENCE = 0b0011 + +class SerializationType: + NO_SERIALIZATION = 0b0000 + JSON = 0b0001 + +class CompressionType: + GZIP = 0b0001 + + +class Config: + def __init__(self): + # 填入控制台获取的app id和access token + self.auth = { + "app_key": "xxxxxxx", + "access_key": "xxxxxxxxxxxx" + } + + @property + def app_key(self) -> str: + return self.auth["app_key"] + + @property + def access_key(self) -> str: + return self.auth["access_key"] + +config = Config() + +class CommonUtils: + @staticmethod + def gzip_compress(data: bytes) -> bytes: + return gzip.compress(data) + + @staticmethod + def gzip_decompress(data: bytes) -> bytes: + return gzip.decompress(data) + + @staticmethod + def judge_wav(data: bytes) -> bool: + if len(data) < 44: + return False + return data[:4] == b'RIFF' and data[8:12] == b'WAVE' + + @staticmethod + def convert_wav_with_path(audio_path: str, sample_rate: int = DEFAULT_SAMPLE_RATE) -> bytes: + try: + cmd = [ + "ffmpeg", "-v", "quiet", "-y", "-i", audio_path, + "-acodec", "pcm_s16le", "-ac", "1", "-ar", str(sample_rate), + "-f", "wav", "-" + ] + result = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + # 尝试删除原始文件 + try: + os.remove(audio_path) + except OSError as e: + logger.warning(f"Failed to remove original file: {e}") + + return result.stdout + except subprocess.CalledProcessError as e: + logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}") + raise RuntimeError(f"Audio conversion failed: {e.stderr.decode()}") + + @staticmethod + def read_wav_info(data: bytes) -> Tuple[int, int, int, int, bytes]: + if len(data) < 44: + raise ValueError("Invalid WAV file: too short") + + # 解析WAV头 + chunk_id = data[:4] + if chunk_id != b'RIFF': + raise ValueError("Invalid WAV file: not RIFF format") + + format_ = data[8:12] + if format_ != b'WAVE': + raise ValueError("Invalid WAV file: not WAVE format") + + # 解析fmt子块 + audio_format = struct.unpack(' 'AsrRequestHeader': + self.message_type = message_type + return self + + def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader': + self.message_type_specific_flags = flags + return self + + def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader': + self.serialization_type = serialization_type + return self + + def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader': + self.compression_type = compression_type + return self + + def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader': + self.reserved_data = reserved_data + return self + + def to_bytes(self) -> bytes: + header = bytearray() + header.append((ProtocolVersion.V1 << 4) | 1) + header.append((self.message_type << 4) | self.message_type_specific_flags) + header.append((self.serialization_type << 4) | self.compression_type) + header.extend(self.reserved_data) + return bytes(header) + + @staticmethod + def default_header() -> 'AsrRequestHeader': + return AsrRequestHeader() + +class RequestBuilder: + @staticmethod + def new_auth_headers() -> Dict[str, str]: + reqid = str(uuid.uuid4()) + return { + "X-Api-Resource-Id": "volc.bigasr.sauc.duration", + "X-Api-Request-Id": reqid, + "X-Api-Access-Key": config.access_key, + "X-Api-App-Key": config.app_key + } + + @staticmethod + def new_full_client_request(seq: int) -> bytes: # 添加seq参数 + header = AsrRequestHeader.default_header() \ + .with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE) + + payload = { + "user": { + "uid": "demo_uid" + }, + "audio": { + "format": "wav", + "codec": "raw", + "rate": 16000, + "bits": 16, + "channel": 1 + }, + "request": { + "model_name": "bigmodel", + "enable_itn": True, + "enable_punc": True, + "enable_ddc": True, + "show_utterances": True, + "enable_nonstream": False + } + } + + payload_bytes = json.dumps(payload).encode('utf-8') + compressed_payload = CommonUtils.gzip_compress(payload_bytes) + payload_size = len(compressed_payload) + + request = bytearray() + request.extend(header.to_bytes()) + request.extend(struct.pack('>i', seq)) # 使用传入的seq + request.extend(struct.pack('>I', payload_size)) + request.extend(compressed_payload) + + return bytes(request) + + @staticmethod + def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes: + header = AsrRequestHeader.default_header() + if is_last: # 最后一个包特殊处理 + header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE) + seq = -seq # 设为负值 + else: + header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE) + header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST) + + request = bytearray() + request.extend(header.to_bytes()) + request.extend(struct.pack('>i', seq)) + + compressed_segment = CommonUtils.gzip_compress(segment) + request.extend(struct.pack('>I', len(compressed_segment))) + request.extend(compressed_segment) + + return bytes(request) + +class AsrResponse: + def __init__(self): + self.code = 0 + self.event = 0 + self.is_last_package = False + self.payload_sequence = 0 + self.payload_size = 0 + self.payload_msg = None + + def to_dict(self) -> Dict[str, Any]: + return { + "code": self.code, + "event": self.event, + "is_last_package": self.is_last_package, + "payload_sequence": self.payload_sequence, + "payload_size": self.payload_size, + "payload_msg": self.payload_msg + } + +class ResponseParser: + @staticmethod + def parse_response(msg: bytes) -> AsrResponse: + response = AsrResponse() + + header_size = msg[0] & 0x0f + message_type = msg[1] >> 4 + message_type_specific_flags = msg[1] & 0x0f + serialization_method = msg[2] >> 4 + message_compression = msg[2] & 0x0f + + payload = msg[header_size*4:] + + # 解析message_type_specific_flags + if message_type_specific_flags & 0x01: + response.payload_sequence = struct.unpack('>i', payload[:4])[0] + payload = payload[4:] + if message_type_specific_flags & 0x02: + response.is_last_package = True + if message_type_specific_flags & 0x04: + response.event = struct.unpack('>i', payload[:4])[0] + payload = payload[4:] + + # 解析message_type + if message_type == MessageType.SERVER_FULL_RESPONSE: + response.payload_size = struct.unpack('>I', payload[:4])[0] + payload = payload[4:] + elif message_type == MessageType.SERVER_ERROR_RESPONSE: + response.code = struct.unpack('>i', payload[:4])[0] + response.payload_size = struct.unpack('>I', payload[4:8])[0] + payload = payload[8:] + + if not payload: + return response + + # 解压缩 + if message_compression == CompressionType.GZIP: + try: + payload = CommonUtils.gzip_decompress(payload) + except Exception as e: + logger.error(f"Failed to decompress payload: {e}") + return response + + # 解析payload + try: + if serialization_method == SerializationType.JSON: + response.payload_msg = json.loads(payload.decode('utf-8')) + except Exception as e: + logger.error(f"Failed to parse payload: {e}") + + return response + +class AsrWsClient: + def __init__(self, url: str, segment_duration: int = 200): + self.seq = 1 + self.url = url + self.segment_duration = segment_duration + self.conn = None + self.session = None # 添加session引用 + + async def __aenter__(self): + self.session = aiohttp.ClientSession() + return self + + async def __aexit__(self, exc_type, exc, tb): + if self.conn and not self.conn.closed: + await self.conn.close() + if self.session and not self.session.closed: + await self.session.close() + + async def read_audio_data(self, file_path: str) -> bytes: + try: + with open(file_path, 'rb') as f: + content = f.read() + + if not CommonUtils.judge_wav(content): + logger.info("Converting audio to WAV format...") + content = CommonUtils.convert_wav_with_path(file_path, DEFAULT_SAMPLE_RATE) + + return content + except Exception as e: + logger.error(f"Failed to read audio data: {e}") + raise + + def get_segment_size(self, content: bytes) -> int: + try: + channel_num, samp_width, frame_rate, _, _ = CommonUtils.read_wav_info(content)[:5] + size_per_sec = channel_num * samp_width * frame_rate + segment_size = size_per_sec * self.segment_duration // 1000 + return segment_size + except Exception as e: + logger.error(f"Failed to calculate segment size: {e}") + raise + + async def create_connection(self) -> None: + headers = RequestBuilder.new_auth_headers() + try: + self.conn = await self.session.ws_connect( # 使用self.session + self.url, + headers=headers + ) + logger.info(f"Connected to {self.url}") + except Exception as e: + logger.error(f"Failed to connect to WebSocket: {e}") + raise + + async def send_full_client_request(self) -> None: + request = RequestBuilder.new_full_client_request(self.seq) + self.seq += 1 # 发送后递增 + try: + await self.conn.send_bytes(request) + logger.info(f"Sent full client request with seq: {self.seq-1}") + + msg = await self.conn.receive() + if msg.type == aiohttp.WSMsgType.BINARY: + response = ResponseParser.parse_response(msg.data) + logger.info(f"Received response: {response.to_dict()}") + else: + logger.error(f"Unexpected message type: {msg.type}") + except Exception as e: + logger.error(f"Failed to send full client request: {e}") + raise + + async def send_messages(self, segment_size: int, content: bytes) -> AsyncGenerator[None, None]: + audio_segments = self.split_audio(content, segment_size) + total_segments = len(audio_segments) + + for i, segment in enumerate(audio_segments): + is_last = (i == total_segments - 1) + request = RequestBuilder.new_audio_only_request( + self.seq, + segment, + is_last=is_last + ) + await self.conn.send_bytes(request) + logger.info(f"Sent audio segment with seq: {self.seq} (last: {is_last})") + + if not is_last: + self.seq += 1 + + await asyncio.sleep(self.segment_duration / 1000) # 逐个发送,间隔时间模拟实时流 + # 让出控制权,允许接受消息 + yield + + async def recv_messages(self) -> AsyncGenerator[AsrResponse, None]: + try: + async for msg in self.conn: + if msg.type == aiohttp.WSMsgType.BINARY: + response = ResponseParser.parse_response(msg.data) + yield response + + if response.is_last_package or response.code != 0: + break + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error(f"WebSocket error: {msg.data}") + break + elif msg.type == aiohttp.WSMsgType.CLOSED: + logger.info("WebSocket connection closed") + break + except Exception as e: + logger.error(f"Error receiving messages: {e}") + raise + + async def start_audio_stream(self, segment_size: int, content: bytes) -> AsyncGenerator[AsrResponse, None]: + async def sender(): + async for _ in self.send_messages(segment_size, content): + pass + + # 启动发送和接收任务 + sender_task = asyncio.create_task(sender()) + + try: + async for response in self.recv_messages(): + yield response + finally: + sender_task.cancel() + try: + await sender_task + except asyncio.CancelledError: + pass + + @staticmethod + def split_audio(data: bytes, segment_size: int) -> List[bytes]: + if segment_size <= 0: + return [] + + segments = [] + for i in range(0, len(data), segment_size): + end = i + segment_size + if end > len(data): + end = len(data) + segments.append(data[i:end]) + return segments + + async def execute(self, file_path: str) -> AsyncGenerator[AsrResponse, None]: + if not file_path: + raise ValueError("File path is empty") + + if not self.url: + raise ValueError("URL is empty") + + self.seq = 1 + + try: + # 1. 读取音频文件 + content = await self.read_audio_data(file_path) + + # 2. 计算分段大小 + segment_size = self.get_segment_size(content) + + # 3. 创建WebSocket连接 + await self.create_connection() + + # 4. 发送完整客户端请求 + await self.send_full_client_request() + + # 5. 启动音频流处理 + async for response in self.start_audio_stream(segment_size, content): + yield response + + except Exception as e: + logger.error(f"Error in ASR execution: {e}") + raise + finally: + if self.conn: + await self.conn.close() + +async def main(): + import argparse + + parser = argparse.ArgumentParser(description="ASR WebSocket Client") + parser.add_argument("--file", type=str, required=True, help="Audio file path") + + #wss://openspeech.bytedance.com/api/v3/sauc/bigmodel + #wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async + #wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream + parser.add_argument("--url", type=str, default="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream", + help="WebSocket URL") + parser.add_argument("--seg-duration", type=int, default=200, + help="Audio duration(ms) per packet, default:200") + + args = parser.parse_args() + + async with AsrWsClient(args.url, args.seg_duration) as client: # 使用async with + try: + async for response in client.execute(args.file): + logger.info(f"Received response: {json.dumps(response.to_dict(), indent=2, ensure_ascii=False)}") + except Exception as e: + logger.error(f"ASR processing failed: {e}") + +if __name__ == "__main__": + asyncio.run(main()) + + # 用法: + # python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav \ No newline at end of file diff --git a/simple_wake_and_record.py b/simple_wake_and_record.py index 6430f8b..23d9c3e 100644 --- a/simple_wake_and_record.py +++ b/simple_wake_and_record.py @@ -35,11 +35,11 @@ class SimpleWakeAndRecord: self.stream = None self.running = False - # 音频参数 + # 音频参数 - 优化为树莓派3B self.FORMAT = pyaudio.paInt16 self.CHANNELS = 1 - self.RATE = 16000 - self.CHUNK_SIZE = 1024 + self.RATE = 8000 # 从16kHz降至8kHz,减少50%数据处理量 + self.CHUNK_SIZE = 2048 # 增大块大小,减少处理次数 # 录音相关 self.recording = False @@ -48,6 +48,19 @@ class SimpleWakeAndRecord: self.recording_start_time = None self.recording_recognizer = None # 录音时专用的识别器 + # 性能优化相关 + self.audio_buffer = [] # 音频缓冲区 + self.buffer_size = 10 # 缓冲区大小(块数) + self.last_process_time = time.time() # 上次处理时间 + self.process_interval = 0.5 # 处理间隔(秒) + self.batch_process_size = 5 # 批处理大小 + + # 性能监控 + self.process_count = 0 + self.avg_process_time = 0 + self.last_monitor_time = time.time() + self.monitor_interval = 5.0 # 监控间隔(秒) + # 阈值 self.text_silence_threshold = 3.0 # 3秒没有识别到文字就结束 self.min_recording_time = 2.0 # 最小录音时间 @@ -116,6 +129,48 @@ class SimpleWakeAndRecord: return True, wake_word return False, None + def _should_process_audio(self): + """判断是否应该处理音频""" + current_time = time.time() + return (current_time - self.last_process_time >= self.process_interval and + len(self.audio_buffer) >= self.batch_process_size) + + def _process_audio_batch(self): + """批量处理音频数据""" + if len(self.audio_buffer) < self.batch_process_size: + return + + # 记录处理开始时间 + start_time = time.time() + + # 取出批处理数据 + batch_data = self.audio_buffer[:self.batch_process_size] + self.audio_buffer = self.audio_buffer[self.batch_process_size:] + + # 合并音频数据 + combined_data = b''.join(batch_data) + + # 更新处理时间 + self.last_process_time = time.time() + + # 更新性能统计 + process_time = time.time() - start_time + self.process_count += 1 + self.avg_process_time = (self.avg_process_time * (self.process_count - 1) + process_time) / self.process_count + + # 性能监控 + self._monitor_performance() + + return combined_data + + def _monitor_performance(self): + """性能监控""" + current_time = time.time() + if current_time - self.last_monitor_time >= self.monitor_interval: + buffer_usage = len(self.audio_buffer) / self.buffer_size * 100 + print(f"\n📊 性能监控 | 处理次数: {self.process_count} | 平均处理时间: {self.avg_process_time:.3f}s | 缓冲区使用: {buffer_usage:.1f}%") + self.last_monitor_time = current_time + def _save_recording(self, audio_data): """保存录音""" timestamp = time.strftime("%Y%m%d_%H%M%S") @@ -262,13 +317,21 @@ class SimpleWakeAndRecord: continue if self.recording: - # 录音模式 + # 录音模式 - 直接处理 self.recorded_frames.append(data) recording_duration = time.time() - self.recording_start_time - # 使用录音专用的识别器进行实时识别 - if self.recording_recognizer: - if self.recording_recognizer.AcceptWaveform(data): + # 录音时使用批处理进行识别 + self.audio_buffer.append(data) + + # 限制缓冲区大小 + if len(self.audio_buffer) > self.buffer_size: + self.audio_buffer.pop(0) + + # 批处理识别 + if self._should_process_audio() and self.recording_recognizer: + combined_data = self._process_audio_batch() + if combined_data and self.recording_recognizer.AcceptWaveform(combined_data): # 获取最终识别结果 result = json.loads(self.recording_recognizer.Result()) text = result.get('text', '').strip() @@ -277,7 +340,7 @@ class SimpleWakeAndRecord: # 识别到文字,更新时间戳 self.last_text_time = time.time() print(f"\n📝 识别: {text}") - else: + elif combined_data: # 获取部分识别结果 partial_result = json.loads(self.recording_recognizer.PartialResult()) partial_text = partial_result.get('partial', '').strip() @@ -314,32 +377,41 @@ class SimpleWakeAndRecord: print(f"\r{status}", end='', flush=True) elif self.model and self.recognizer: - # 唤醒词检测模式 - if self.recognizer.AcceptWaveform(data): - result = json.loads(self.recognizer.Result()) - text = result.get('text', '').strip() - - if text: - print(f"识别: {text}") + # 唤醒词检测模式 - 使用批处理 + self.audio_buffer.append(data) + + # 限制缓冲区大小 + if len(self.audio_buffer) > self.buffer_size: + self.audio_buffer.pop(0) + + # 批处理识别 + if self._should_process_audio(): + combined_data = self._process_audio_batch() + if combined_data and self.recognizer.AcceptWaveform(combined_data): + result = json.loads(self.recognizer.Result()) + text = result.get('text', '').strip() - # 检查唤醒词 - is_wake_word, detected_word = self._check_wake_word(text) - if is_wake_word: - print(f"🎯 检测到唤醒词: {detected_word}") - self._start_recording() - else: - # 显示实时音频级别 - energy = self._calculate_energy(data) - if energy > 50: # 只显示有意义的音频级别 - partial_result = json.loads(self.recognizer.PartialResult()) - partial_text = partial_result.get('partial', '') - if partial_text: - status = f"监听中... 能量: {energy:.0f} | {partial_text}" - else: - status = f"监听中... 能量: {energy:.0f}" - print(status, end='\r') + if text: + print(f"识别: {text}") + + # 检查唤醒词 + is_wake_word, detected_word = self._check_wake_word(text) + if is_wake_word: + print(f"🎯 检测到唤醒词: {detected_word}") + self._start_recording() + else: + # 显示实时音频级别 + energy = self._calculate_energy(data) + if energy > 50: # 只显示有意义的音频级别 + partial_result = json.loads(self.recognizer.PartialResult()) + partial_text = partial_result.get('partial', '') + if partial_text: + status = f"监听中... 能量: {energy:.0f} | {partial_text}" + else: + status = f"监听中... 能量: {energy:.0f}" + print(status, end='\r') - time.sleep(0.01) + time.sleep(0.05) # 增加延迟,减少CPU使用 except KeyboardInterrupt: print("\n👋 退出") @@ -394,6 +466,12 @@ def main(): print("5. 录音文件自动保存") print("6. 录音完成后自动播放刚才录制的内容") print("7. 按 Ctrl+C 退出") + print("🚀 性能优化已启用:") + print(" - 采样率: 8kHz (降低50%数据量)") + print(" - 批处理: 5个音频块/次") + print(" - 处理间隔: 0.5秒") + print(" - 缓冲区: 10个音频块") + print(" - 性能监控: 每5秒显示") print("=" * 50) # 开始运行 diff --git a/speech_recognizer.py b/speech_recognizer.py new file mode 100644 index 0000000..ba232d4 --- /dev/null +++ b/speech_recognizer.py @@ -0,0 +1,532 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +语音识别模块 +基于 SAUC API 为录音文件提供语音识别功能 +""" + +import os +import json +import time +import logging +import asyncio +import aiohttp +import struct +import gzip +import uuid +from typing import Optional, List, Dict, Any, AsyncGenerator +from dataclasses import dataclass + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# 常量定义 +DEFAULT_SAMPLE_RATE = 16000 + +class ProtocolVersion: + V1 = 0b0001 + +class MessageType: + CLIENT_FULL_REQUEST = 0b0001 + CLIENT_AUDIO_ONLY_REQUEST = 0b0010 + SERVER_FULL_RESPONSE = 0b1001 + SERVER_ERROR_RESPONSE = 0b1111 + +class MessageTypeSpecificFlags: + NO_SEQUENCE = 0b0000 + POS_SEQUENCE = 0b0001 + NEG_SEQUENCE = 0b0010 + NEG_WITH_SEQUENCE = 0b0011 + +class SerializationType: + NO_SERIALIZATION = 0b0000 + JSON = 0b0001 + +class CompressionType: + GZIP = 0b0001 + +@dataclass +class RecognitionResult: + """语音识别结果""" + text: str + confidence: float + is_final: bool + start_time: Optional[float] = None + end_time: Optional[float] = None + +class AudioUtils: + """音频处理工具类""" + + @staticmethod + def gzip_compress(data: bytes) -> bytes: + """GZIP压缩""" + return gzip.compress(data) + + @staticmethod + def gzip_decompress(data: bytes) -> bytes: + """GZIP解压缩""" + return gzip.decompress(data) + + @staticmethod + def is_wav_file(data: bytes) -> bool: + """检查是否为WAV文件""" + if len(data) < 44: + return False + return data[:4] == b'RIFF' and data[8:12] == b'WAVE' + + @staticmethod + def read_wav_info(data: bytes) -> tuple: + """读取WAV文件信息""" + if len(data) < 44: + raise ValueError("Invalid WAV file: too short") + + # 解析WAV头 + chunk_id = data[:4] + if chunk_id != b'RIFF': + raise ValueError("Invalid WAV file: not RIFF format") + + format_ = data[8:12] + if format_ != b'WAVE': + raise ValueError("Invalid WAV file: not WAVE format") + + # 解析fmt子块 + audio_format = struct.unpack(' str: + return self.auth["app_key"] + + @property + def access_key(self) -> str: + return self.auth["access_key"] + +class AsrRequestHeader: + """ASR请求头""" + + def __init__(self): + self.message_type = MessageType.CLIENT_FULL_REQUEST + self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE + self.serialization_type = SerializationType.JSON + self.compression_type = CompressionType.GZIP + self.reserved_data = bytes([0x00]) + + def with_message_type(self, message_type: int) -> 'AsrRequestHeader': + self.message_type = message_type + return self + + def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader': + self.message_type_specific_flags = flags + return self + + def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader': + self.serialization_type = serialization_type + return self + + def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader': + self.compression_type = compression_type + return self + + def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader': + self.reserved_data = reserved_data + return self + + def to_bytes(self) -> bytes: + header = bytearray() + header.append((ProtocolVersion.V1 << 4) | 1) + header.append((self.message_type << 4) | self.message_type_specific_flags) + header.append((self.serialization_type << 4) | self.compression_type) + header.extend(self.reserved_data) + return bytes(header) + + @staticmethod + def default_header() -> 'AsrRequestHeader': + return AsrRequestHeader() + +class RequestBuilder: + """请求构建器""" + + @staticmethod + def new_auth_headers(config: AsrConfig) -> Dict[str, str]: + """创建认证头""" + reqid = str(uuid.uuid4()) + return { + "X-Api-Resource-Id": "volc.bigasr.sauc.duration", + "X-Api-Request-Id": reqid, + "X-Api-Access-Key": config.access_key, + "X-Api-App-Key": config.app_key + } + + @staticmethod + def new_full_client_request(seq: int) -> bytes: + """创建完整客户端请求""" + header = AsrRequestHeader.default_header() \ + .with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE) + + payload = { + "user": { + "uid": "local_voice_user" + }, + "audio": { + "format": "wav", + "codec": "raw", + "rate": 16000, + "bits": 16, + "channel": 1 + }, + "request": { + "model_name": "bigmodel", + "enable_itn": True, + "enable_punc": True, + "enable_ddc": True, + "show_utterances": True, + "enable_nonstream": False + } + } + + payload_bytes = json.dumps(payload).encode('utf-8') + compressed_payload = AudioUtils.gzip_compress(payload_bytes) + payload_size = len(compressed_payload) + + request = bytearray() + request.extend(header.to_bytes()) + request.extend(struct.pack('>i', seq)) + request.extend(struct.pack('>U', payload_size)) + request.extend(compressed_payload) + + return bytes(request) + + @staticmethod + def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes: + """创建纯音频请求""" + header = AsrRequestHeader.default_header() + if is_last: + header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE) + seq = -seq + else: + header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE) + header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST) + + request = bytearray() + request.extend(header.to_bytes()) + request.extend(struct.pack('>i', seq)) + + compressed_segment = AudioUtils.gzip_compress(segment) + request.extend(struct.pack('>U', len(compressed_segment))) + request.extend(compressed_segment) + + return bytes(request) + +class AsrResponse: + """ASR响应""" + + def __init__(self): + self.code = 0 + self.event = 0 + self.is_last_package = False + self.payload_sequence = 0 + self.payload_size = 0 + self.payload_msg = None + + def to_dict(self) -> Dict[str, Any]: + return { + "code": self.code, + "event": self.event, + "is_last_package": self.is_last_package, + "payload_sequence": self.payload_sequence, + "payload_size": self.payload_size, + "payload_msg": self.payload_msg + } + +class ResponseParser: + """响应解析器""" + + @staticmethod + def parse_response(msg: bytes) -> AsrResponse: + """解析响应""" + response = AsrResponse() + + header_size = msg[0] & 0x0f + message_type = msg[1] >> 4 + message_type_specific_flags = msg[1] & 0x0f + serialization_method = msg[2] >> 4 + message_compression = msg[2] & 0x0f + + payload = msg[header_size*4:] + + # 解析message_type_specific_flags + if message_type_specific_flags & 0x01: + response.payload_sequence = struct.unpack('>i', payload[:4])[0] + payload = payload[4:] + if message_type_specific_flags & 0x02: + response.is_last_package = True + if message_type_specific_flags & 0x04: + response.event = struct.unpack('>i', payload[:4])[0] + payload = payload[4:] + + # 解析message_type + if message_type == MessageType.SERVER_FULL_RESPONSE: + response.payload_size = struct.unpack('>U', payload[:4])[0] + payload = payload[4:] + elif message_type == MessageType.SERVER_ERROR_RESPONSE: + response.code = struct.unpack('>i', payload[:4])[0] + response.payload_size = struct.unpack('>U', payload[4:8])[0] + payload = payload[8:] + + if not payload: + return response + + # 解压缩 + if message_compression == CompressionType.GZIP: + try: + payload = AudioUtils.gzip_decompress(payload) + except Exception as e: + logger.error(f"Failed to decompress payload: {e}") + return response + + # 解析payload + try: + if serialization_method == SerializationType.JSON: + response.payload_msg = json.loads(payload.decode('utf-8')) + except Exception as e: + logger.error(f"Failed to parse payload: {e}") + + return response + +class SpeechRecognizer: + """语音识别器""" + + def __init__(self, app_key: str = None, access_key: str = None, + url: str = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream"): + self.config = AsrConfig(app_key, access_key) + self.url = url + self.seq = 1 + + async def recognize_file(self, file_path: str) -> List[RecognitionResult]: + """识别音频文件""" + if not os.path.exists(file_path): + raise FileNotFoundError(f"Audio file not found: {file_path}") + + results = [] + + try: + async with aiohttp.ClientSession() as session: + # 读取音频文件 + with open(file_path, 'rb') as f: + content = f.read() + + if not AudioUtils.is_wav_file(content): + raise ValueError("Audio file must be in WAV format") + + # 获取音频信息 + try: + _, _, sample_rate, _, audio_data = AudioUtils.read_wav_info(content) + if sample_rate != DEFAULT_SAMPLE_RATE: + logger.warning(f"Sample rate {sample_rate} != {DEFAULT_SAMPLE_RATE}, may affect recognition accuracy") + except Exception as e: + logger.error(f"Failed to read audio info: {e}") + raise + + # 计算分段大小 (200ms per segment) + segment_size = 1 * 2 * DEFAULT_SAMPLE_RATE * 200 // 1000 # channel * bytes_per_sample * sample_rate * duration_ms / 1000 + + # 创建WebSocket连接 + headers = RequestBuilder.new_auth_headers(self.config) + async with session.ws_connect(self.url, headers=headers) as ws: + + # 发送完整客户端请求 + request = RequestBuilder.new_full_client_request(self.seq) + self.seq += 1 + await ws.send_bytes(request) + + # 接收初始响应 + msg = await ws.receive() + if msg.type == aiohttp.WSMsgType.BINARY: + response = ResponseParser.parse_response(msg.data) + logger.info(f"Initial response: {response.to_dict()}") + + # 分段发送音频数据 + audio_segments = self._split_audio(audio_data, segment_size) + total_segments = len(audio_segments) + + for i, segment in enumerate(audio_segments): + is_last = (i == total_segments - 1) + request = RequestBuilder.new_audio_only_request( + self.seq, + segment, + is_last=is_last + ) + await ws.send_bytes(request) + logger.info(f"Sent audio segment {i+1}/{total_segments}") + + if not is_last: + self.seq += 1 + + # 短暂延迟模拟实时流 + await asyncio.sleep(0.1) + + # 接收识别结果 + final_text = "" + while True: + msg = await ws.receive() + if msg.type == aiohttp.WSMsgType.BINARY: + response = ResponseParser.parse_response(msg.data) + + if response.payload_msg and 'text' in response.payload_msg: + text = response.payload_msg['text'] + if text: + final_text += text + + result = RecognitionResult( + text=text, + confidence=0.9, # 默认置信度 + is_final=response.is_last_package + ) + results.append(result) + + logger.info(f"Recognized: {text}") + + if response.is_last_package or response.code != 0: + break + + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error(f"WebSocket error: {msg.data}") + break + elif msg.type == aiohttp.WSMsgType.CLOSED: + logger.info("WebSocket connection closed") + break + + # 如果没有获得最终结果,创建一个包含所有文本的结果 + if final_text and not any(r.is_final for r in results): + final_result = RecognitionResult( + text=final_text, + confidence=0.9, + is_final=True + ) + results.append(final_result) + + return results + + except Exception as e: + logger.error(f"Speech recognition failed: {e}") + raise + + def _split_audio(self, data: bytes, segment_size: int) -> List[bytes]: + """分割音频数据""" + if segment_size <= 0: + return [] + + segments = [] + for i in range(0, len(data), segment_size): + end = i + segment_size + if end > len(data): + end = len(data) + segments.append(data[i:end]) + return segments + + async def recognize_latest_recording(self, directory: str = ".") -> Optional[RecognitionResult]: + """识别最新的录音文件""" + # 查找最新的录音文件 + recording_files = [f for f in os.listdir(directory) if f.startswith('recording_') and f.endswith('.wav')] + + if not recording_files: + logger.warning("No recording files found") + return None + + # 按文件名排序(包含时间戳) + recording_files.sort(reverse=True) + latest_file = recording_files[0] + latest_path = os.path.join(directory, latest_file) + + logger.info(f"Recognizing latest recording: {latest_file}") + + try: + results = await self.recognize_file(latest_path) + if results: + # 返回最终的识别结果 + final_results = [r for r in results if r.is_final] + if final_results: + return final_results[-1] + else: + # 如果没有标记为final的结果,返回最后一个 + return results[-1] + except Exception as e: + logger.error(f"Failed to recognize latest recording: {e}") + + return None + +async def main(): + """测试函数""" + import argparse + + parser = argparse.ArgumentParser(description="语音识别测试") + parser.add_argument("--file", type=str, help="音频文件路径") + parser.add_argument("--latest", action="store_true", help="识别最新的录音文件") + parser.add_argument("--app-key", type=str, help="SAUC App Key") + parser.add_argument("--access-key", type=str, help="SAUC Access Key") + + args = parser.parse_args() + + recognizer = SpeechRecognizer( + app_key=args.app_key, + access_key=args.access_key + ) + + try: + if args.latest: + result = await recognizer.recognize_latest_recording() + if result: + print(f"识别结果: {result.text}") + print(f"置信度: {result.confidence}") + print(f"最终结果: {result.is_final}") + else: + print("未能识别到语音内容") + elif args.file: + results = await recognizer.recognize_file(args.file) + for i, result in enumerate(results): + print(f"结果 {i+1}: {result.text}") + print(f"置信度: {result.confidence}") + print(f"最终结果: {result.is_final}") + print("-" * 40) + else: + print("请指定 --file 或 --latest 参数") + + except Exception as e: + print(f"识别失败: {e}") + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file