config
This commit is contained in:
parent
dbdeeeefcb
commit
97aecf0c30
143
README.md
Normal file
143
README.md
Normal file
@ -0,0 +1,143 @@
|
||||
# 智能语音助手系统使用说明
|
||||
|
||||
## 功能概述
|
||||
这是一个完整的智能语音助手系统,集成了语音录制、语音识别、大语言模型和文本转语音功能,实现语音对话交互。
|
||||
|
||||
## 完整工作流程
|
||||
1. 🎙️ **语音录制** - 基于ZCR的智能语音检测
|
||||
2. 📝 **保存录音** - 自动保存为WAV文件
|
||||
3. 🤖 **语音识别** - 使用字节跳动ASR将语音转为文字
|
||||
4. 💬 **AI回复** - 使用豆包大模型生成智能回复
|
||||
5. 🔊 **语音回复** - 使用字节跳动TTS将AI回复转为语音
|
||||
|
||||
## 环境配置
|
||||
|
||||
### 1. 安装依赖
|
||||
```bash
|
||||
pip install websockets requests pyaudio numpy
|
||||
```
|
||||
|
||||
### 2. 安装音频播放器(树莓派/Linux系统)
|
||||
系统使用PCM格式音频,只需要安装基础的音频播放工具:
|
||||
|
||||
```bash
|
||||
# 安装 alsa-utils(包含aplay播放器)
|
||||
sudo apt-get update
|
||||
sudo apt-get install alsa-utils
|
||||
```
|
||||
|
||||
> **优势**: PCM格式无需额外解码器,兼容性更好,资源占用更少。
|
||||
> **注意**: macOS和Windows系统通常内置支持音频播放,无需额外安装。
|
||||
|
||||
### 3. 设置API密钥
|
||||
为了启用大语言模型功能,需要设置环境变量:
|
||||
|
||||
```bash
|
||||
# Linux/Mac
|
||||
export ARK_API_KEY='your_api_key_here'
|
||||
|
||||
# Windows
|
||||
set ARK_API_KEY=your_api_key_here
|
||||
```
|
||||
|
||||
> **注意**: 语音识别和文本转语音功能使用内置的API密钥,无需额外配置。
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 基本使用
|
||||
```bash
|
||||
python recorder.py
|
||||
```
|
||||
|
||||
### 功能说明
|
||||
- 🎯 **自动检测语音**:系统会自动检测声音并开始录音
|
||||
- ⏱️ **智能停止**:静音3秒后自动停止录音
|
||||
- 🔊 **自动播放**:录音完成后自动播放音频
|
||||
- 📝 **语音识别**:自动将语音转为文字
|
||||
- 🤖 **AI助手**:自动调用大语言模型生成回复
|
||||
|
||||
### 配置参数
|
||||
- `energy_threshold=200` - 能量阈值(调整灵敏度)
|
||||
- `silence_threshold=3.0` - 静音阈值(秒)
|
||||
- `min_recording_time=2.0` - 最小录音时间(秒)
|
||||
- `max_recording_time=30.0` - 最大录音时间(秒)
|
||||
- `enable_asr=True` - 启用语音识别
|
||||
- `enable_llm=True` - 启用大语言模型
|
||||
- `enable_tts=True` - 启用文本转语音
|
||||
|
||||
## 输出示例
|
||||
```
|
||||
🎤 开始监听...
|
||||
能量阈值: 200 (已弃用)
|
||||
静音阈值: 3.0秒
|
||||
📖 使用说明:
|
||||
- 检测到声音自动开始录音
|
||||
- 持续静音3秒自动结束录音
|
||||
- 最少录音2秒,最多30秒
|
||||
- 录音完成后自动进行语音识别和AI回复
|
||||
- 按 Ctrl+C 退出
|
||||
==================================================
|
||||
🎙️ 检测到声音,开始录音...
|
||||
📝 录音完成,时长: 3.45秒 (包含预录音 2.0秒)
|
||||
✅ 录音已保存: recording_20250920_163022.wav
|
||||
==================================================
|
||||
📡 音频输入已保持关闭状态
|
||||
🔄 开始处理音频...
|
||||
🤖 开始语音识别...
|
||||
📝 识别结果: 你好,今天天气怎么样?
|
||||
--------------------------------------------------
|
||||
🤖 调用大语言模型...
|
||||
💬 AI助手回复: 你好!我无法实时获取天气信息,建议你查看天气预报或打开天气应用来了解今天的天气情况。有什么其他我可以帮助你的吗?
|
||||
--------------------------------------------------
|
||||
🔊 开始文本转语音...
|
||||
TTS句子信息: {'code': 0, 'message': '', 'data': None, 'sentence': {'phonemes': [], 'text': '你好!我无法实时获取天气信息,建议你查看天气预报或打开天气应用来了解今天的天气情况。有什么其他我可以帮助你的吗?', 'words': [...]}}
|
||||
✅ TTS音频已保存: tts_response_20250920_163022.pcm
|
||||
📁 文件大小: 128.75 KB
|
||||
🔊 播放AI语音回复...
|
||||
✅ AI语音回复完成
|
||||
🔄 准备重新开启音频输入
|
||||
✅ 音频设备初始化成功
|
||||
📡 音频输入已重新开启
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
1. **网络连接**:需要网络连接来使用语音识别、大语言模型和文本转语音服务
|
||||
2. **API密钥**:需要有效的ARK_API_KEY才能使用大语言模型功能
|
||||
3. **音频设备**:确保麦克风和扬声器工作正常
|
||||
4. **权限**:确保程序有访问麦克风、网络和存储的权限
|
||||
5. **文件存储**:系统会保存录音文件和TTS生成的音频文件
|
||||
|
||||
## 故障排除
|
||||
- 如果语音识别失败,检查网络连接和API密钥
|
||||
- 如果大语言模型失败,检查ARK_API_KEY是否正确设置
|
||||
- 如果文本转语音失败,检查TTS服务状态
|
||||
- 如果录音失败,检查麦克风权限和设备
|
||||
- 如果播放失败,检查音频设备权限
|
||||
- 如果PCM文件无法播放,检查是否安装了alsa-utils:
|
||||
```bash
|
||||
# 树莓派/Ubuntu/Debian系统
|
||||
sudo apt-get install alsa-utils
|
||||
|
||||
# 或检查aplay是否安装
|
||||
which aplay
|
||||
```
|
||||
|
||||
## 技术特点
|
||||
- 🎯 基于ZCR的精确语音检测
|
||||
- 🚀 低延迟实时处理
|
||||
- 💾 环形缓冲区防止音频丢失
|
||||
- 🔧 自动调整能量阈值
|
||||
- 📊 实时性能监控
|
||||
- 🌐 完整的语音对话链路
|
||||
- 📁 自动文件管理和权限设置
|
||||
- 🔊 PCM格式音频,无需额外解码器
|
||||
|
||||
## 生成的文件
|
||||
- `recording_*.wav` - 录制的音频文件
|
||||
- `tts_response_*.pcm` - AI语音回复文件(PCM格式)
|
||||
|
||||
## PCM格式优势
|
||||
- **兼容性好**:aplay原生支持,树莓派开箱即用
|
||||
- **资源占用少**:无需解码过程,CPU占用更低
|
||||
- **延迟更低**:直接播放,无需格式转换
|
||||
- **稳定性高**:减少依赖组件,提高系统稳定性
|
||||
74
install.sh
Executable file
74
install.sh
Executable file
@ -0,0 +1,74 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 智能语音助手系统安装脚本
|
||||
# 适用于树莓派和Linux系统
|
||||
|
||||
echo "🚀 智能语音助手系统 - 安装脚本"
|
||||
echo "================================"
|
||||
|
||||
# 检查是否为root用户
|
||||
if [ "$EUID" -eq 0 ]; then
|
||||
echo "⚠️ 请不要以root身份运行此脚本"
|
||||
echo " 建议使用普通用户: sudo ./install.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 更新包管理器
|
||||
echo "📦 更新包管理器..."
|
||||
sudo apt-get update
|
||||
|
||||
# 安装系统依赖
|
||||
echo "🔧 安装系统依赖..."
|
||||
sudo apt-get install -y \
|
||||
python3 \
|
||||
python3-pip \
|
||||
portaudio19-dev \
|
||||
python3-dev \
|
||||
alsa-utils
|
||||
|
||||
# 安装Python依赖
|
||||
echo "🐍 安装Python依赖..."
|
||||
pip3 install --user \
|
||||
websockets \
|
||||
requests \
|
||||
pyaudio \
|
||||
numpy
|
||||
|
||||
# 检查音频播放器
|
||||
echo "🔊 检查音频播放器..."
|
||||
if command -v aplay >/dev/null 2>&1; then
|
||||
echo "✅ aplay 已安装(支持PCM/WAV播放)"
|
||||
else
|
||||
echo "❌ aplay 安装失败"
|
||||
fi
|
||||
|
||||
# 检查Python模块
|
||||
echo "🧪 检查Python模块..."
|
||||
python3 -c "import websockets, requests, pyaudio, numpy" 2>/dev/null
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✅ 所有Python依赖已安装"
|
||||
else
|
||||
echo "❌ 部分Python依赖安装失败"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "✅ 安装完成!"
|
||||
echo ""
|
||||
echo "📋 使用说明:"
|
||||
echo "1. 设置API密钥(如需使用大语言模型):"
|
||||
echo " export ARK_API_KEY='your_api_key_here'"
|
||||
echo ""
|
||||
echo "2. 运行程序:"
|
||||
echo " python3 recorder.py"
|
||||
echo ""
|
||||
echo "3. 故障排除:"
|
||||
echo " - 如果遇到权限问题,请确保用户在audio组中:"
|
||||
echo " sudo usermod -a -G audio \$USER"
|
||||
echo " - 然后重新登录或重启系统"
|
||||
echo ""
|
||||
echo "🎯 系统功能:"
|
||||
echo "- 🎙️ 智能语音录制"
|
||||
echo "- 🤖 在线语音识别"
|
||||
echo "- 💬 AI智能对话"
|
||||
echo "- 🔊 语音回复合成"
|
||||
echo "- 📁 自动文件管理"
|
||||
665
recorder.py
665
recorder.py
@ -6,24 +6,75 @@
|
||||
专门针对树莓派3B优化,完全移除Vosk识别依赖
|
||||
"""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import base64
|
||||
import gzip
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
import threading
|
||||
import pyaudio
|
||||
import numpy as np
|
||||
import time
|
||||
import uuid
|
||||
import wave
|
||||
from io import BytesIO
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import pyaudio
|
||||
import requests
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("⚠️ websockets 未安装,语音识别功能将不可用")
|
||||
websockets = None
|
||||
|
||||
class EnergyBasedRecorder:
|
||||
"""基于能量检测的录音系统"""
|
||||
|
||||
def __init__(self, energy_threshold=500, silence_threshold=1.5, min_recording_time=2.0, max_recording_time=30.0):
|
||||
def __init__(self, energy_threshold=500, silence_threshold=1.5, min_recording_time=2.0, max_recording_time=30.0, enable_asr=True, enable_llm=True, enable_tts=True):
|
||||
# 音频参数 - 极简优化
|
||||
self.FORMAT = pyaudio.paInt16
|
||||
self.CHANNELS = 1
|
||||
self.RATE = 16000 # 16kHz采样率
|
||||
self.CHUNK_SIZE = 1024 # 适中块大小
|
||||
|
||||
# 语音识别配置
|
||||
self.enable_asr = enable_asr
|
||||
self.asr_appid = "8718217928"
|
||||
self.asr_token = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
|
||||
self.asr_cluster = "volcengine_input_common"
|
||||
self.asr_ws_url = "wss://openspeech.bytedance.com/api/v2/asr"
|
||||
|
||||
# 大语言模型配置
|
||||
self.enable_llm = enable_llm
|
||||
self.llm_api_url = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
|
||||
self.llm_model = "doubao-seed-1-6-flash-250828"
|
||||
self.llm_api_key = os.environ.get("ARK_API_KEY", "")
|
||||
|
||||
# 检查API密钥
|
||||
if self.enable_llm and not self.llm_api_key:
|
||||
print("⚠️ 未设置 ARK_API_KEY 环境变量,大语言模型功能将被禁用")
|
||||
self.enable_llm = False
|
||||
|
||||
# 文本转语音配置
|
||||
self.enable_tts = enable_tts
|
||||
self.tts_url = "https://openspeech.bytedance.com/api/v3/tts/unidirectional"
|
||||
self.tts_app_id = "8718217928"
|
||||
self.tts_access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
|
||||
self.tts_resource_id = "volc.service_type.10029"
|
||||
self.tts_app_key = "aGjiRDfUWi"
|
||||
self.tts_speaker = "zh_female_wanqudashu_moon_bigtts"
|
||||
|
||||
# 检查音频播放能力
|
||||
if self.enable_tts:
|
||||
self.audio_player_available = self._check_audio_player()
|
||||
if not self.audio_player_available:
|
||||
print("⚠️ 未找到音频播放器,TTS音频播放功能可能不可用")
|
||||
print(" 建议安装: sudo apt-get install alsa-utils")
|
||||
# 不禁用TTS功能,因为仍然可以生成文件
|
||||
|
||||
# 能量检测参数
|
||||
self.energy_threshold = energy_threshold # 能量阈值,高于此值认为有声音
|
||||
self.silence_threshold = silence_threshold # 静音阈值,低于此值持续多久认为结束
|
||||
@ -90,16 +141,22 @@ class EnergyBasedRecorder:
|
||||
# 将字节数据转换为numpy数组
|
||||
audio_array = np.frombuffer(audio_data, dtype=np.int16)
|
||||
|
||||
# 计算RMS能量
|
||||
rms = np.sqrt(np.mean(audio_array ** 2))
|
||||
|
||||
# 更新能量历史(只在非录音状态下更新,避免语音影响背景噪音计算)
|
||||
if not self.recording:
|
||||
self.energy_history.append(rms)
|
||||
if len(self.energy_history) > self.max_energy_history:
|
||||
self.energy_history.pop(0)
|
||||
|
||||
return rms
|
||||
# 计算RMS能量,处理可能的无效值
|
||||
try:
|
||||
rms = np.sqrt(np.mean(audio_array ** 2))
|
||||
# 检查是否为有效值
|
||||
if np.isnan(rms) or np.isinf(rms):
|
||||
return 0
|
||||
|
||||
# 更新能量历史(只在非录音状态下更新,避免语音影响背景噪音计算)
|
||||
if not self.recording:
|
||||
self.energy_history.append(rms)
|
||||
if len(self.energy_history) > self.max_energy_history:
|
||||
self.energy_history.pop(0)
|
||||
|
||||
return rms
|
||||
except:
|
||||
return 0
|
||||
|
||||
def calculate_peak_energy(self, audio_data):
|
||||
"""计算峰值能量(辅助判断)"""
|
||||
@ -259,16 +316,147 @@ class EnergyBasedRecorder:
|
||||
"""使用系统播放器播放音频"""
|
||||
try:
|
||||
import subprocess
|
||||
cmd = ['aplay', filename] # Linux系统
|
||||
import platform
|
||||
|
||||
# 获取文件扩展名
|
||||
file_ext = filename.lower().split('.')[-1] if '.' in filename else ''
|
||||
|
||||
# 根据文件类型和平台选择播放命令
|
||||
if file_ext == 'mp3':
|
||||
# MP3文件播放
|
||||
system = platform.system().lower()
|
||||
|
||||
if system == 'linux':
|
||||
# Linux系统 - 尝试多个MP3播放器
|
||||
mp3_players = [
|
||||
['mpg123', filename], # 最常用的MP3播放器
|
||||
['mpg321', filename], # 另一个MP3播放器
|
||||
['mplayer', filename], # 通用媒体播放器
|
||||
['cvlc', '--play-and-exit', filename], # VLC命令行版本
|
||||
['ffplay', '-nodisp', '-autoexit', filename] # FFmpeg播放器
|
||||
]
|
||||
|
||||
cmd = None
|
||||
for player in mp3_players:
|
||||
try:
|
||||
subprocess.run(['which', player[0]], capture_output=True, check=True)
|
||||
cmd = player
|
||||
break
|
||||
except:
|
||||
continue
|
||||
|
||||
if cmd is None:
|
||||
raise Exception("未找到可用的MP3播放器,请安装 mpg123 或 mplayer")
|
||||
|
||||
elif system == 'darwin': # macOS
|
||||
cmd = ['afplay', filename]
|
||||
|
||||
elif system == 'windows':
|
||||
cmd = ['cmd', '/c', 'start', '/min', filename]
|
||||
|
||||
else:
|
||||
cmd = ['aplay', filename] # 默认,可能会失败
|
||||
|
||||
elif file_ext == 'pcm':
|
||||
# PCM文件播放 - 需要指定格式
|
||||
cmd = ['aplay', '-f', 'S16_LE', '-r', '16000', '-c', '1', filename]
|
||||
|
||||
else:
|
||||
# WAV文件或其他格式
|
||||
cmd = ['aplay', filename] # Linux系统
|
||||
|
||||
print(f"🔊 使用系统播放器: {' '.join(cmd)}")
|
||||
print("🚫 系统播放器播放中,音频输入保持关闭")
|
||||
subprocess.run(cmd, check=True)
|
||||
print("✅ 播放完成")
|
||||
print("📡 音频输入已保持关闭状态")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 系统播放器也失败: {e}")
|
||||
print(f"❌ 系统播放器失败: {e}")
|
||||
|
||||
# 尝试使用pygame作为备选方案
|
||||
try:
|
||||
self._play_with_pygame(filename)
|
||||
except Exception as pygame_error:
|
||||
print(f"❌ pygame播放也失败: {pygame_error}")
|
||||
raise e
|
||||
|
||||
def play_audio_safe(self, filename):
|
||||
def _check_audio_player(self):
|
||||
"""检查系统是否支持音频播放"""
|
||||
try:
|
||||
import subprocess
|
||||
import platform
|
||||
|
||||
system = platform.system().lower()
|
||||
|
||||
if system == 'linux':
|
||||
# 检查aplay(用于PCM和WAV播放)
|
||||
try:
|
||||
subprocess.run(['which', 'aplay'], capture_output=True, check=True)
|
||||
print("✅ 找到音频播放器: aplay")
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# 检查pygame作为备选方案
|
||||
try:
|
||||
import pygame
|
||||
print("✅ 找到pygame作为音频播放备选方案")
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
elif system == 'darwin': # macOS
|
||||
# 检查afplay
|
||||
try:
|
||||
subprocess.run(['which', 'afplay'], capture_output=True, check=True)
|
||||
print("✅ 找到音频播放器: afplay")
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
elif system == 'windows':
|
||||
# Windows通常支持音频播放
|
||||
return True
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 检查音频播放器时出错: {e}")
|
||||
return False
|
||||
|
||||
def _play_with_pygame(self, filename):
|
||||
"""使用pygame播放音频作为备选方案"""
|
||||
try:
|
||||
import pygame
|
||||
pygame.mixer.init()
|
||||
|
||||
print(f"🔊 尝试使用pygame播放: {filename}")
|
||||
|
||||
# 加载并播放音频
|
||||
sound = pygame.mixer.Sound(filename)
|
||||
sound.play()
|
||||
|
||||
# 等待播放完成
|
||||
while pygame.mixer.get_busy():
|
||||
pygame.time.Clock().tick(10)
|
||||
|
||||
print("✅ pygame播放完成")
|
||||
|
||||
except ImportError:
|
||||
raise Exception("pygame未安装")
|
||||
except Exception as e:
|
||||
raise Exception(f"pygame播放失败: {e}")
|
||||
finally:
|
||||
try:
|
||||
pygame.mixer.quit()
|
||||
except:
|
||||
pass
|
||||
|
||||
def play_audio_safe(self, filename, reopen_input=False):
|
||||
"""安全的播放方式 - 使用系统播放器"""
|
||||
try:
|
||||
print("🔇 准备播放,完全停止音频输入")
|
||||
@ -303,7 +491,8 @@ class EnergyBasedRecorder:
|
||||
# 使用系统播放器
|
||||
self.play_with_system_player(filename)
|
||||
|
||||
print("🔄 准备重新开启音频输入")
|
||||
if reopen_input:
|
||||
print("🔄 准备重新开启音频输入")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 播放失败: {e}")
|
||||
@ -314,13 +503,15 @@ class EnergyBasedRecorder:
|
||||
# 等待播放完全结束
|
||||
time.sleep(0.5)
|
||||
|
||||
# 重新开启输入流
|
||||
self._setup_audio()
|
||||
|
||||
# 重置所有状态
|
||||
self.energy_history = []
|
||||
self.zcr_history = []
|
||||
print("📡 音频输入已重新开启")
|
||||
# 只在需要时重新开启输入流
|
||||
if reopen_input:
|
||||
# 重新开启输入流
|
||||
self._setup_audio()
|
||||
|
||||
# 重置所有状态
|
||||
self.energy_history = []
|
||||
self.zcr_history = []
|
||||
print("📡 音频输入已重新开启")
|
||||
|
||||
def update_pre_record_buffer(self, audio_data):
|
||||
"""更新预录音缓冲区"""
|
||||
@ -367,13 +558,57 @@ class EnergyBasedRecorder:
|
||||
# 保存录音
|
||||
success, filename = self.save_recording(audio_data)
|
||||
|
||||
# 如果保存成功,播放录音
|
||||
# 如果保存成功,进行后续处理
|
||||
if success and filename:
|
||||
print("=" * 50)
|
||||
print("🔊 播放刚才录制的音频...")
|
||||
# 优先使用系统播放器避免回声
|
||||
self.play_audio_safe(filename)
|
||||
print("=" * 50)
|
||||
print("📡 音频输入已保持关闭状态")
|
||||
print("🔄 开始处理音频...")
|
||||
|
||||
# 语音识别和LLM调用
|
||||
if self.enable_asr and websockets is not None:
|
||||
print("🤖 开始语音识别...")
|
||||
asr_result = self.recognize_audio_sync(filename)
|
||||
if asr_result and 'payload_msg' in asr_result:
|
||||
result_text = asr_result['payload_msg'].get('result', [])
|
||||
if result_text:
|
||||
text = result_text[0].get('text', '识别失败')
|
||||
print(f"📝 识别结果: {text}")
|
||||
|
||||
# 调用大语言模型
|
||||
if self.enable_llm and text != '识别失败':
|
||||
print("-" * 50)
|
||||
llm_response = self.call_llm(text)
|
||||
if llm_response:
|
||||
print(f"💬 AI助手回复: {llm_response}")
|
||||
|
||||
# 调用文本转语音
|
||||
if self.enable_tts:
|
||||
print("-" * 50)
|
||||
tts_file = self.text_to_speech(llm_response)
|
||||
if tts_file:
|
||||
print("✅ AI语音回复完成")
|
||||
else:
|
||||
print("❌ 文本转语音失败")
|
||||
else:
|
||||
print("ℹ️ 文本转语音功能已禁用")
|
||||
else:
|
||||
print("❌ 大语言模型调用失败")
|
||||
else:
|
||||
if not self.enable_llm:
|
||||
print("ℹ️ 大语言模型功能已禁用")
|
||||
elif not self.llm_api_key:
|
||||
print("ℹ️ 请设置 ARK_API_KEY 环境变量以启用大语言模型功能")
|
||||
else:
|
||||
print("❌ 语音识别失败: 无结果")
|
||||
else:
|
||||
print("❌ 语音识别失败")
|
||||
else:
|
||||
if not self.enable_asr:
|
||||
print("ℹ️ 语音识别功能已禁用")
|
||||
elif websockets is None:
|
||||
print("ℹ️ 请安装 websockets 库以启用语音识别功能")
|
||||
|
||||
print("🔄 准备重新开启音频输入")
|
||||
|
||||
self.recording = False
|
||||
self.recorded_frames = []
|
||||
@ -382,6 +617,12 @@ class EnergyBasedRecorder:
|
||||
self.energy_history = []
|
||||
self.zcr_history = []
|
||||
|
||||
def get_average_energy(self):
|
||||
"""计算平均能量"""
|
||||
if len(self.energy_history) == 0:
|
||||
return 0
|
||||
return np.mean(self.energy_history)
|
||||
|
||||
def monitor_performance(self):
|
||||
"""性能监控"""
|
||||
self.frame_count += 1
|
||||
@ -435,8 +676,22 @@ class EnergyBasedRecorder:
|
||||
|
||||
try:
|
||||
while self.running:
|
||||
# 检查音频流是否可用
|
||||
if self.stream is None:
|
||||
print("\n❌ 音频流已断开,尝试重新连接...")
|
||||
self._setup_audio()
|
||||
if self.stream is None:
|
||||
print("❌ 音频流重连失败,等待...")
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
# 读取音频数据
|
||||
data = self.stream.read(self.CHUNK_SIZE, exception_on_overflow=False)
|
||||
try:
|
||||
data = self.stream.read(self.CHUNK_SIZE, exception_on_overflow=False)
|
||||
except Exception as e:
|
||||
print(f"\n❌ 读取音频数据失败: {e}")
|
||||
self.stream = None
|
||||
continue
|
||||
|
||||
if len(data) == 0:
|
||||
continue
|
||||
@ -555,10 +810,331 @@ class EnergyBasedRecorder:
|
||||
|
||||
if self.audio:
|
||||
self.audio.terminate()
|
||||
|
||||
def generate_asr_header(self, message_type=1, message_type_specific_flags=0):
|
||||
"""生成ASR请求头部"""
|
||||
PROTOCOL_VERSION = 0b0001
|
||||
DEFAULT_HEADER_SIZE = 0b0001
|
||||
JSON = 0b0001
|
||||
GZIP = 0b0001
|
||||
|
||||
header = bytearray()
|
||||
header.append((PROTOCOL_VERSION << 4) | DEFAULT_HEADER_SIZE)
|
||||
header.append((message_type << 4) | message_type_specific_flags)
|
||||
header.append((JSON << 4) | GZIP)
|
||||
header.append(0x00) # reserved
|
||||
return header
|
||||
|
||||
def parse_asr_response(self, res):
|
||||
"""解析ASR响应"""
|
||||
PROTOCOL_VERSION = res[0] >> 4
|
||||
header_size = res[0] & 0x0f
|
||||
message_type = res[1] >> 4
|
||||
message_type_specific_flags = res[1] & 0x0f
|
||||
serialization_method = res[2] >> 4
|
||||
message_compression = res[2] & 0x0f
|
||||
reserved = res[3]
|
||||
header_extensions = res[4:header_size * 4]
|
||||
payload = res[header_size * 4:]
|
||||
result = {}
|
||||
payload_msg = None
|
||||
payload_size = 0
|
||||
|
||||
if message_type == 0b1001: # SERVER_FULL_RESPONSE
|
||||
payload_size = int.from_bytes(payload[:4], "big", signed=True)
|
||||
payload_msg = payload[4:]
|
||||
elif message_type == 0b1011: # SERVER_ACK
|
||||
seq = int.from_bytes(payload[:4], "big", signed=True)
|
||||
result['seq'] = seq
|
||||
if len(payload) >= 8:
|
||||
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
payload_msg = payload[8:]
|
||||
elif message_type == 0b1111: # SERVER_ERROR_RESPONSE
|
||||
code = int.from_bytes(payload[:4], "big", signed=False)
|
||||
result['code'] = code
|
||||
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
payload_msg = payload[8:]
|
||||
|
||||
if payload_msg is None:
|
||||
return result
|
||||
|
||||
if message_compression == 0b0001: # GZIP
|
||||
payload_msg = gzip.decompress(payload_msg)
|
||||
|
||||
if serialization_method == 0b0001: # JSON
|
||||
payload_msg = json.loads(str(payload_msg, "utf-8"))
|
||||
|
||||
result['payload_msg'] = payload_msg
|
||||
result['payload_size'] = payload_size
|
||||
return result
|
||||
|
||||
async def recognize_audio(self, audio_path):
|
||||
"""识别音频文件"""
|
||||
if not self.enable_asr or websockets is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
print("🤖 开始语音识别...")
|
||||
|
||||
# 读取音频文件
|
||||
with open(audio_path, mode="rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
# 构建请求
|
||||
reqid = str(uuid.uuid4())
|
||||
request_params = {
|
||||
'app': {
|
||||
'appid': self.asr_appid,
|
||||
'cluster': self.asr_cluster,
|
||||
'token': self.asr_token,
|
||||
},
|
||||
'user': {
|
||||
'uid': 'recorder_asr'
|
||||
},
|
||||
'request': {
|
||||
'reqid': reqid,
|
||||
'nbest': 1,
|
||||
'workflow': 'audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate',
|
||||
'show_language': False,
|
||||
'show_utterances': False,
|
||||
'result_type': 'full',
|
||||
"sequence": 1
|
||||
},
|
||||
'audio': {
|
||||
'format': 'wav',
|
||||
'rate': self.RATE,
|
||||
'language': 'zh-CN',
|
||||
'bits': 16,
|
||||
'channel': self.CHANNELS,
|
||||
'codec': 'raw'
|
||||
}
|
||||
}
|
||||
|
||||
# 构建头部
|
||||
payload_bytes = str.encode(json.dumps(request_params))
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
full_client_request = bytearray(self.generate_asr_header())
|
||||
full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
||||
full_client_request.extend(payload_bytes)
|
||||
|
||||
# 设置认证头
|
||||
additional_headers = {'Authorization': 'Bearer; {}'.format(self.asr_token)}
|
||||
|
||||
# 连接WebSocket并发送请求
|
||||
async with websockets.connect(self.asr_ws_url, additional_headers=additional_headers, max_size=1000000000) as ws:
|
||||
# 发送完整请求
|
||||
await ws.send(full_client_request)
|
||||
res = await ws.recv()
|
||||
result = self.parse_asr_response(res)
|
||||
|
||||
if 'payload_msg' in result and result['payload_msg'].get('code') != 1000:
|
||||
print(f"❌ ASR请求失败: {result['payload_msg']}")
|
||||
return None
|
||||
|
||||
# 分块发送音频数据
|
||||
chunk_size = int(self.CHANNELS * 2 * self.RATE * 15000 / 1000) # 15ms chunks
|
||||
|
||||
for offset in range(0, len(audio_data), chunk_size):
|
||||
chunk = audio_data[offset:offset + chunk_size]
|
||||
last = (offset + chunk_size) >= len(audio_data)
|
||||
|
||||
payload_bytes = gzip.compress(chunk)
|
||||
audio_only_request = bytearray(self.generate_asr_header(message_type=0b0010, message_type_specific_flags=0b0010 if last else 0))
|
||||
audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
||||
audio_only_request.extend(payload_bytes)
|
||||
|
||||
await ws.send(audio_only_request)
|
||||
res = await ws.recv()
|
||||
result = self.parse_asr_response(res)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 语音识别失败: {e}")
|
||||
return None
|
||||
|
||||
def recognize_audio_sync(self, audio_path):
|
||||
"""同步版本的语音识别"""
|
||||
if not self.enable_asr or websockets is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return asyncio.run(self.recognize_audio(audio_path))
|
||||
except Exception as e:
|
||||
print(f"❌ 语音识别失败: {e}")
|
||||
return None
|
||||
|
||||
def call_llm(self, user_message):
|
||||
"""调用大语言模型API"""
|
||||
if not self.enable_llm:
|
||||
return None
|
||||
|
||||
try:
|
||||
print("🤖 调用大语言模型...")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.llm_api_key}"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": self.llm_model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是唐朝大诗人李白,用简短诗词和小朋友对话,每次回答不超过50字。"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_message
|
||||
}
|
||||
],
|
||||
"max_tokens": 50
|
||||
}
|
||||
|
||||
response = requests.post(self.llm_api_url, headers=headers, json=data, timeout=30)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
llm_response = result["choices"][0]["message"]["content"]
|
||||
return llm_response.strip()
|
||||
else:
|
||||
print("❌ LLM API响应格式错误")
|
||||
return None
|
||||
else:
|
||||
print(f"❌ LLM API调用失败: {response.status_code}")
|
||||
print(f"响应内容: {response.text}")
|
||||
return None
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"❌ 网络请求失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"❌ LLM调用失败: {e}")
|
||||
return None
|
||||
|
||||
def generate_tts_filename(self):
|
||||
"""生成TTS文件名"""
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
return f"tts_response_{timestamp}.pcm"
|
||||
|
||||
def text_to_speech(self, text):
|
||||
"""文本转语音"""
|
||||
if not self.enable_tts:
|
||||
return None
|
||||
|
||||
try:
|
||||
print("🔊 开始文本转语音...")
|
||||
|
||||
# 生成输出文件名
|
||||
output_file = self.generate_tts_filename()
|
||||
|
||||
# 构建请求头
|
||||
headers = {
|
||||
"X-Api-App-Id": self.tts_app_id,
|
||||
"X-Api-Access-Key": self.tts_access_key,
|
||||
"X-Api-Resource-Id": self.tts_resource_id,
|
||||
"X-Api-App-Key": self.tts_app_key,
|
||||
"Content-Type": "application/json",
|
||||
"Connection": "keep-alive"
|
||||
}
|
||||
|
||||
# 构建请求参数
|
||||
payload = {
|
||||
"user": {
|
||||
"uid": "recorder_tts"
|
||||
},
|
||||
"req_params": {
|
||||
"text": text,
|
||||
"speaker": self.tts_speaker,
|
||||
"audio_params": {
|
||||
"format": "pcm",
|
||||
"sample_rate": 16000,
|
||||
"enable_timestamp": True
|
||||
},
|
||||
"additions": "{\"explicit_language\":\"zh\",\"disable_markdown_filter\":true, \"enable_timestamp\":true}\"}"
|
||||
}
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
session = requests.Session()
|
||||
try:
|
||||
response = session.post(self.tts_url, headers=headers, json=payload, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"❌ TTS请求失败: {response.status_code}")
|
||||
print(f"响应内容: {response.text}")
|
||||
return None
|
||||
|
||||
# 处理流式响应
|
||||
audio_data = bytearray()
|
||||
total_audio_size = 0
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True):
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(chunk)
|
||||
|
||||
if data.get("code", 0) == 0 and "data" in data and data["data"]:
|
||||
chunk_audio = base64.b64decode(data["data"])
|
||||
audio_size = len(chunk_audio)
|
||||
total_audio_size += audio_size
|
||||
audio_data.extend(chunk_audio)
|
||||
continue
|
||||
|
||||
if data.get("code", 0) == 0 and "sentence" in data and data["sentence"]:
|
||||
print("TTS句子信息:", data["sentence"])
|
||||
continue
|
||||
|
||||
if data.get("code", 0) == 20000000:
|
||||
break
|
||||
|
||||
if data.get("code", 0) > 0:
|
||||
print(f"❌ TTS错误响应: {data}")
|
||||
break
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"❌ 解析TTS响应失败: {chunk}")
|
||||
continue
|
||||
|
||||
# 保存音频文件
|
||||
if audio_data:
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(audio_data)
|
||||
print(f"✅ TTS音频已保存: {output_file}")
|
||||
print(f"📁 文件大小: {len(audio_data) / 1024:.2f} KB")
|
||||
|
||||
# 确保文件有正确的访问权限
|
||||
os.chmod(output_file, 0o644)
|
||||
|
||||
# 播放生成的音频
|
||||
if hasattr(self, 'audio_player_available') and self.audio_player_available:
|
||||
print("🔊 播放AI语音回复...")
|
||||
self.play_audio_safe(output_file, reopen_input=False)
|
||||
else:
|
||||
print("ℹ️ 跳过播放TTS音频(无可用播放器)")
|
||||
print(f"📁 TTS音频已保存到: {output_file}")
|
||||
|
||||
return output_file
|
||||
else:
|
||||
print("❌ 未接收到TTS音频数据")
|
||||
return None
|
||||
|
||||
finally:
|
||||
response.close()
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ TTS转换失败: {e}")
|
||||
return None
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("🚀 基于能量检测的极简录音系统")
|
||||
print("🤖 集成语音识别功能")
|
||||
print("=" * 50)
|
||||
|
||||
# 创建录音系统
|
||||
@ -566,20 +1142,37 @@ def main():
|
||||
energy_threshold=200, # 能量阈值(降低以提高灵敏度)
|
||||
silence_threshold=3.0, # 静音阈值(秒)- 改为3秒
|
||||
min_recording_time=2.0, # 最小录音时间
|
||||
max_recording_time=30.0 # 最大录音时间
|
||||
max_recording_time=30.0, # 最大录音时间
|
||||
enable_asr=True, # 启用语音识别功能
|
||||
enable_llm=True, # 启用大语言模型功能
|
||||
enable_tts=True # 启用文本转语音功能
|
||||
)
|
||||
|
||||
print("✅ 系统初始化成功")
|
||||
print("🎯 优化特点:")
|
||||
print("🎯 功能特点:")
|
||||
print(" - 完全移除Vosk识别依赖")
|
||||
print(" - 基于能量检测,极低CPU占用")
|
||||
print(" - 基于ZCR语音检测,精确识别")
|
||||
print(" - 集成在线语音识别(字节跳动ASR)")
|
||||
print(" - 集成大语言模型(豆包大模型)")
|
||||
print(" - 集成文本转语音(字节跳动TTS)")
|
||||
print(" - 录音完成后自动语音识别")
|
||||
print(" - 语音识别后自动调用AI助手")
|
||||
print(" - AI回复后自动转换为语音")
|
||||
print(" - 预录音功能(包含声音开始前2秒)")
|
||||
print(" - 环形缓冲区防止丢失开头音频")
|
||||
print(" - 自动调整能量阈值")
|
||||
print(" - 实时性能监控")
|
||||
print(" - 预期延迟: <0.1秒")
|
||||
print("=" * 50)
|
||||
|
||||
# 显示API密钥状态
|
||||
if not recorder.enable_llm:
|
||||
print("🔑 提示: 如需启用大语言模型功能,请设置环境变量:")
|
||||
print(" export ARK_API_KEY='your_api_key_here'")
|
||||
print("=" * 50)
|
||||
|
||||
# 开始运行
|
||||
recorder.run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
523
sauc_websocket_demo.py
Normal file
523
sauc_websocket_demo.py
Normal file
@ -0,0 +1,523 @@
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import struct
|
||||
import gzip
|
||||
import uuid
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Optional, List, Dict, Any, Tuple, AsyncGenerator
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('run.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 常量定义
|
||||
DEFAULT_SAMPLE_RATE = 16000
|
||||
|
||||
class ProtocolVersion:
|
||||
V1 = 0b0001
|
||||
|
||||
class MessageType:
|
||||
CLIENT_FULL_REQUEST = 0b0001
|
||||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||
SERVER_FULL_RESPONSE = 0b1001
|
||||
SERVER_ERROR_RESPONSE = 0b1111
|
||||
|
||||
class MessageTypeSpecificFlags:
|
||||
NO_SEQUENCE = 0b0000
|
||||
POS_SEQUENCE = 0b0001
|
||||
NEG_SEQUENCE = 0b0010
|
||||
NEG_WITH_SEQUENCE = 0b0011
|
||||
|
||||
class SerializationType:
|
||||
NO_SERIALIZATION = 0b0000
|
||||
JSON = 0b0001
|
||||
|
||||
class CompressionType:
|
||||
GZIP = 0b0001
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
# 填入控制台获取的app id和access token
|
||||
self.auth = {
|
||||
"app_key": "8718217928",
|
||||
"access_key": "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
|
||||
}
|
||||
|
||||
@property
|
||||
def app_key(self) -> str:
|
||||
return self.auth["app_key"]
|
||||
|
||||
@property
|
||||
def access_key(self) -> str:
|
||||
return self.auth["access_key"]
|
||||
|
||||
config = Config()
|
||||
|
||||
class CommonUtils:
|
||||
@staticmethod
|
||||
def gzip_compress(data: bytes) -> bytes:
|
||||
return gzip.compress(data)
|
||||
|
||||
@staticmethod
|
||||
def gzip_decompress(data: bytes) -> bytes:
|
||||
return gzip.decompress(data)
|
||||
|
||||
@staticmethod
|
||||
def judge_wav(data: bytes) -> bool:
|
||||
if len(data) < 44:
|
||||
return False
|
||||
return data[:4] == b'RIFF' and data[8:12] == b'WAVE'
|
||||
|
||||
@staticmethod
|
||||
def convert_wav_with_path(audio_path: str, sample_rate: int = DEFAULT_SAMPLE_RATE) -> bytes:
|
||||
try:
|
||||
cmd = [
|
||||
"ffmpeg", "-v", "quiet", "-y", "-i", audio_path,
|
||||
"-acodec", "pcm_s16le", "-ac", "1", "-ar", str(sample_rate),
|
||||
"-f", "wav", "-"
|
||||
]
|
||||
result = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
# 尝试删除原始文件
|
||||
try:
|
||||
os.remove(audio_path)
|
||||
except OSError as e:
|
||||
logger.warning(f"Failed to remove original file: {e}")
|
||||
|
||||
return result.stdout
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}")
|
||||
raise RuntimeError(f"Audio conversion failed: {e.stderr.decode()}")
|
||||
|
||||
@staticmethod
|
||||
def read_wav_info(data: bytes) -> Tuple[int, int, int, int, bytes]:
|
||||
if len(data) < 44:
|
||||
raise ValueError("Invalid WAV file: too short")
|
||||
|
||||
# 解析WAV头
|
||||
chunk_id = data[:4]
|
||||
if chunk_id != b'RIFF':
|
||||
raise ValueError("Invalid WAV file: not RIFF format")
|
||||
|
||||
format_ = data[8:12]
|
||||
if format_ != b'WAVE':
|
||||
raise ValueError("Invalid WAV file: not WAVE format")
|
||||
|
||||
# 解析fmt子块
|
||||
audio_format = struct.unpack('<H', data[20:22])[0]
|
||||
num_channels = struct.unpack('<H', data[22:24])[0]
|
||||
sample_rate = struct.unpack('<I', data[24:28])[0]
|
||||
bits_per_sample = struct.unpack('<H', data[34:36])[0]
|
||||
|
||||
# 查找data子块
|
||||
pos = 36
|
||||
while pos < len(data) - 8:
|
||||
subchunk_id = data[pos:pos+4]
|
||||
subchunk_size = struct.unpack('<I', data[pos+4:pos+8])[0]
|
||||
if subchunk_id == b'data':
|
||||
wave_data = data[pos+8:pos+8+subchunk_size]
|
||||
return (
|
||||
num_channels,
|
||||
bits_per_sample // 8,
|
||||
sample_rate,
|
||||
subchunk_size // (num_channels * (bits_per_sample // 8)),
|
||||
wave_data
|
||||
)
|
||||
pos += 8 + subchunk_size
|
||||
|
||||
raise ValueError("Invalid WAV file: no data subchunk found")
|
||||
|
||||
class AsrRequestHeader:
|
||||
def __init__(self):
|
||||
self.message_type = MessageType.CLIENT_FULL_REQUEST
|
||||
self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE
|
||||
self.serialization_type = SerializationType.JSON
|
||||
self.compression_type = CompressionType.GZIP
|
||||
self.reserved_data = bytes([0x00])
|
||||
|
||||
def with_message_type(self, message_type: int) -> 'AsrRequestHeader':
|
||||
self.message_type = message_type
|
||||
return self
|
||||
|
||||
def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader':
|
||||
self.message_type_specific_flags = flags
|
||||
return self
|
||||
|
||||
def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader':
|
||||
self.serialization_type = serialization_type
|
||||
return self
|
||||
|
||||
def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader':
|
||||
self.compression_type = compression_type
|
||||
return self
|
||||
|
||||
def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader':
|
||||
self.reserved_data = reserved_data
|
||||
return self
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
header = bytearray()
|
||||
header.append((ProtocolVersion.V1 << 4) | 1)
|
||||
header.append((self.message_type << 4) | self.message_type_specific_flags)
|
||||
header.append((self.serialization_type << 4) | self.compression_type)
|
||||
header.extend(self.reserved_data)
|
||||
return bytes(header)
|
||||
|
||||
@staticmethod
|
||||
def default_header() -> 'AsrRequestHeader':
|
||||
return AsrRequestHeader()
|
||||
|
||||
class RequestBuilder:
|
||||
@staticmethod
|
||||
def new_auth_headers() -> Dict[str, str]:
|
||||
reqid = str(uuid.uuid4())
|
||||
return {
|
||||
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
|
||||
"X-Api-Request-Id": reqid,
|
||||
"X-Api-Access-Key": config.access_key,
|
||||
"X-Api-App-Key": config.app_key
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def new_full_client_request(seq: int) -> bytes: # 添加seq参数
|
||||
header = AsrRequestHeader.default_header() \
|
||||
.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
|
||||
|
||||
payload = {
|
||||
"user": {
|
||||
"uid": "demo_uid"
|
||||
},
|
||||
"audio": {
|
||||
"format": "wav",
|
||||
"codec": "raw",
|
||||
"rate": 16000,
|
||||
"bits": 16,
|
||||
"channel": 1
|
||||
},
|
||||
"request": {
|
||||
"model_name": "bigmodel",
|
||||
"enable_itn": True,
|
||||
"enable_punc": True,
|
||||
"enable_ddc": True,
|
||||
"show_utterances": True,
|
||||
"enable_nonstream": False
|
||||
}
|
||||
}
|
||||
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
compressed_payload = CommonUtils.gzip_compress(payload_bytes)
|
||||
payload_size = len(compressed_payload)
|
||||
|
||||
request = bytearray()
|
||||
request.extend(header.to_bytes())
|
||||
request.extend(struct.pack('>i', seq)) # 使用传入的seq
|
||||
request.extend(struct.pack('>I', payload_size))
|
||||
request.extend(compressed_payload)
|
||||
|
||||
return bytes(request)
|
||||
|
||||
@staticmethod
|
||||
def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes:
|
||||
header = AsrRequestHeader.default_header()
|
||||
if is_last: # 最后一个包特殊处理
|
||||
header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE)
|
||||
seq = -seq # 设为负值
|
||||
else:
|
||||
header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
|
||||
header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST)
|
||||
|
||||
request = bytearray()
|
||||
request.extend(header.to_bytes())
|
||||
request.extend(struct.pack('>i', seq))
|
||||
|
||||
compressed_segment = CommonUtils.gzip_compress(segment)
|
||||
request.extend(struct.pack('>I', len(compressed_segment)))
|
||||
request.extend(compressed_segment)
|
||||
|
||||
return bytes(request)
|
||||
|
||||
class AsrResponse:
|
||||
def __init__(self):
|
||||
self.code = 0
|
||||
self.event = 0
|
||||
self.is_last_package = False
|
||||
self.payload_sequence = 0
|
||||
self.payload_size = 0
|
||||
self.payload_msg = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"code": self.code,
|
||||
"event": self.event,
|
||||
"is_last_package": self.is_last_package,
|
||||
"payload_sequence": self.payload_sequence,
|
||||
"payload_size": self.payload_size,
|
||||
"payload_msg": self.payload_msg
|
||||
}
|
||||
|
||||
class ResponseParser:
|
||||
@staticmethod
|
||||
def parse_response(msg: bytes) -> AsrResponse:
|
||||
response = AsrResponse()
|
||||
|
||||
header_size = msg[0] & 0x0f
|
||||
message_type = msg[1] >> 4
|
||||
message_type_specific_flags = msg[1] & 0x0f
|
||||
serialization_method = msg[2] >> 4
|
||||
message_compression = msg[2] & 0x0f
|
||||
|
||||
payload = msg[header_size*4:]
|
||||
|
||||
# 解析message_type_specific_flags
|
||||
if message_type_specific_flags & 0x01:
|
||||
response.payload_sequence = struct.unpack('>i', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
if message_type_specific_flags & 0x02:
|
||||
response.is_last_package = True
|
||||
if message_type_specific_flags & 0x04:
|
||||
response.event = struct.unpack('>i', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
|
||||
# 解析message_type
|
||||
if message_type == MessageType.SERVER_FULL_RESPONSE:
|
||||
response.payload_size = struct.unpack('>I', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
elif message_type == MessageType.SERVER_ERROR_RESPONSE:
|
||||
response.code = struct.unpack('>i', payload[:4])[0]
|
||||
response.payload_size = struct.unpack('>I', payload[4:8])[0]
|
||||
payload = payload[8:]
|
||||
|
||||
if not payload:
|
||||
return response
|
||||
|
||||
# 解压缩
|
||||
if message_compression == CompressionType.GZIP:
|
||||
try:
|
||||
payload = CommonUtils.gzip_decompress(payload)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decompress payload: {e}")
|
||||
return response
|
||||
|
||||
# 解析payload
|
||||
try:
|
||||
if serialization_method == SerializationType.JSON:
|
||||
response.payload_msg = json.loads(payload.decode('utf-8'))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse payload: {e}")
|
||||
|
||||
return response
|
||||
|
||||
class AsrWsClient:
|
||||
def __init__(self, url: str, segment_duration: int = 200):
|
||||
self.seq = 1
|
||||
self.url = url
|
||||
self.segment_duration = segment_duration
|
||||
self.conn = None
|
||||
self.session = None # 添加session引用
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = aiohttp.ClientSession()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
if self.conn and not self.conn.closed:
|
||||
await self.conn.close()
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
|
||||
async def read_audio_data(self, file_path: str) -> bytes:
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
|
||||
if not CommonUtils.judge_wav(content):
|
||||
logger.info("Converting audio to WAV format...")
|
||||
content = CommonUtils.convert_wav_with_path(file_path, DEFAULT_SAMPLE_RATE)
|
||||
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read audio data: {e}")
|
||||
raise
|
||||
|
||||
def get_segment_size(self, content: bytes) -> int:
|
||||
try:
|
||||
channel_num, samp_width, frame_rate, _, _ = CommonUtils.read_wav_info(content)[:5]
|
||||
size_per_sec = channel_num * samp_width * frame_rate
|
||||
segment_size = size_per_sec * self.segment_duration // 1000
|
||||
return segment_size
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to calculate segment size: {e}")
|
||||
raise
|
||||
|
||||
async def create_connection(self) -> None:
|
||||
headers = RequestBuilder.new_auth_headers()
|
||||
try:
|
||||
self.conn = await self.session.ws_connect( # 使用self.session
|
||||
self.url,
|
||||
headers=headers
|
||||
)
|
||||
logger.info(f"Connected to {self.url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to WebSocket: {e}")
|
||||
raise
|
||||
|
||||
async def send_full_client_request(self) -> None:
|
||||
request = RequestBuilder.new_full_client_request(self.seq)
|
||||
self.seq += 1 # 发送后递增
|
||||
try:
|
||||
await self.conn.send_bytes(request)
|
||||
logger.info(f"Sent full client request with seq: {self.seq-1}")
|
||||
|
||||
msg = await self.conn.receive()
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
response = ResponseParser.parse_response(msg.data)
|
||||
logger.info(f"Received response: {response.to_dict()}")
|
||||
else:
|
||||
logger.error(f"Unexpected message type: {msg.type}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send full client request: {e}")
|
||||
raise
|
||||
|
||||
async def send_messages(self, segment_size: int, content: bytes) -> AsyncGenerator[None, None]:
|
||||
audio_segments = self.split_audio(content, segment_size)
|
||||
total_segments = len(audio_segments)
|
||||
|
||||
for i, segment in enumerate(audio_segments):
|
||||
is_last = (i == total_segments - 1)
|
||||
request = RequestBuilder.new_audio_only_request(
|
||||
self.seq,
|
||||
segment,
|
||||
is_last=is_last
|
||||
)
|
||||
await self.conn.send_bytes(request)
|
||||
logger.info(f"Sent audio segment with seq: {self.seq} (last: {is_last})")
|
||||
|
||||
if not is_last:
|
||||
self.seq += 1
|
||||
|
||||
await asyncio.sleep(self.segment_duration / 1000) # 逐个发送,间隔时间模拟实时流
|
||||
# 让出控制权,允许接受消息
|
||||
yield
|
||||
|
||||
async def recv_messages(self) -> AsyncGenerator[AsrResponse, None]:
|
||||
try:
|
||||
async for msg in self.conn:
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
response = ResponseParser.parse_response(msg.data)
|
||||
yield response
|
||||
|
||||
if response.is_last_package or response.code != 0:
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error: {msg.data}")
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
logger.info("WebSocket connection closed")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving messages: {e}")
|
||||
raise
|
||||
|
||||
async def start_audio_stream(self, segment_size: int, content: bytes) -> AsyncGenerator[AsrResponse, None]:
|
||||
async def sender():
|
||||
async for _ in self.send_messages(segment_size, content):
|
||||
pass
|
||||
|
||||
# 启动发送和接收任务
|
||||
sender_task = asyncio.create_task(sender())
|
||||
|
||||
try:
|
||||
async for response in self.recv_messages():
|
||||
yield response
|
||||
finally:
|
||||
sender_task.cancel()
|
||||
try:
|
||||
await sender_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def split_audio(data: bytes, segment_size: int) -> List[bytes]:
|
||||
if segment_size <= 0:
|
||||
return []
|
||||
|
||||
segments = []
|
||||
for i in range(0, len(data), segment_size):
|
||||
end = i + segment_size
|
||||
if end > len(data):
|
||||
end = len(data)
|
||||
segments.append(data[i:end])
|
||||
return segments
|
||||
|
||||
async def execute(self, file_path: str) -> AsyncGenerator[AsrResponse, None]:
|
||||
if not file_path:
|
||||
raise ValueError("File path is empty")
|
||||
|
||||
if not self.url:
|
||||
raise ValueError("URL is empty")
|
||||
|
||||
self.seq = 1
|
||||
|
||||
try:
|
||||
# 1. 读取音频文件
|
||||
content = await self.read_audio_data(file_path)
|
||||
|
||||
# 2. 计算分段大小
|
||||
segment_size = self.get_segment_size(content)
|
||||
|
||||
# 3. 创建WebSocket连接
|
||||
await self.create_connection()
|
||||
|
||||
# 4. 发送完整客户端请求
|
||||
await self.send_full_client_request()
|
||||
|
||||
# 5. 启动音频流处理
|
||||
async for response in self.start_audio_stream(segment_size, content):
|
||||
yield response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ASR execution: {e}")
|
||||
raise
|
||||
finally:
|
||||
if self.conn:
|
||||
await self.conn.close()
|
||||
|
||||
async def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="ASR WebSocket Client")
|
||||
parser.add_argument("--file", type=str, required=True, help="Audio file path")
|
||||
|
||||
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel
|
||||
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async
|
||||
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream
|
||||
parser.add_argument("--url", type=str, default="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream",
|
||||
help="WebSocket URL")
|
||||
parser.add_argument("--seg-duration", type=int, default=200,
|
||||
help="Audio duration(ms) per packet, default:200")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
async with AsrWsClient(args.url, args.seg_duration) as client: # 使用async with
|
||||
try:
|
||||
async for response in client.execute(args.file):
|
||||
logger.info(f"Received response: {json.dumps(response.to_dict(), indent=2, ensure_ascii=False)}")
|
||||
except Exception as e:
|
||||
logger.error(f"ASR processing failed: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
# 用法:
|
||||
# python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav
|
||||
363
streaming_asr_demo.py
Normal file
363
streaming_asr_demo.py
Normal file
@ -0,0 +1,363 @@
|
||||
#coding=utf-8
|
||||
|
||||
"""
|
||||
requires Python 3.6 or later
|
||||
|
||||
pip install asyncio
|
||||
pip install websockets
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import gzip
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import wave
|
||||
from enum import Enum
|
||||
from hashlib import sha256
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import websockets
|
||||
|
||||
appid = "8718217928" # 项目的 appid
|
||||
token = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc" # 项目的 token
|
||||
cluster = "volcengine_input_common" # 请求的集群
|
||||
audio_path = "recording_20250920_161438.wav" # 本地音频路径
|
||||
audio_format = "wav" # wav 或者 mp3,根据实际音频格式设置
|
||||
|
||||
PROTOCOL_VERSION = 0b0001
|
||||
DEFAULT_HEADER_SIZE = 0b0001
|
||||
|
||||
PROTOCOL_VERSION_BITS = 4
|
||||
HEADER_BITS = 4
|
||||
MESSAGE_TYPE_BITS = 4
|
||||
MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4
|
||||
MESSAGE_SERIALIZATION_BITS = 4
|
||||
MESSAGE_COMPRESSION_BITS = 4
|
||||
RESERVED_BITS = 8
|
||||
|
||||
# Message Type:
|
||||
CLIENT_FULL_REQUEST = 0b0001
|
||||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||
SERVER_FULL_RESPONSE = 0b1001
|
||||
SERVER_ACK = 0b1011
|
||||
SERVER_ERROR_RESPONSE = 0b1111
|
||||
|
||||
# Message Type Specific Flags
|
||||
NO_SEQUENCE = 0b0000 # no check sequence
|
||||
POS_SEQUENCE = 0b0001
|
||||
NEG_SEQUENCE = 0b0010
|
||||
NEG_SEQUENCE_1 = 0b0011
|
||||
|
||||
# Message Serialization
|
||||
NO_SERIALIZATION = 0b0000
|
||||
JSON = 0b0001
|
||||
THRIFT = 0b0011
|
||||
CUSTOM_TYPE = 0b1111
|
||||
|
||||
# Message Compression
|
||||
NO_COMPRESSION = 0b0000
|
||||
GZIP = 0b0001
|
||||
CUSTOM_COMPRESSION = 0b1111
|
||||
|
||||
|
||||
def generate_header(
|
||||
version=PROTOCOL_VERSION,
|
||||
message_type=CLIENT_FULL_REQUEST,
|
||||
message_type_specific_flags=NO_SEQUENCE,
|
||||
serial_method=JSON,
|
||||
compression_type=GZIP,
|
||||
reserved_data=0x00,
|
||||
extension_header=bytes()
|
||||
):
|
||||
"""
|
||||
protocol_version(4 bits), header_size(4 bits),
|
||||
message_type(4 bits), message_type_specific_flags(4 bits)
|
||||
serialization_method(4 bits) message_compression(4 bits)
|
||||
reserved (8bits) 保留字段
|
||||
header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
|
||||
"""
|
||||
header = bytearray()
|
||||
header_size = int(len(extension_header) / 4) + 1
|
||||
header.append((version << 4) | header_size)
|
||||
header.append((message_type << 4) | message_type_specific_flags)
|
||||
header.append((serial_method << 4) | compression_type)
|
||||
header.append(reserved_data)
|
||||
header.extend(extension_header)
|
||||
return header
|
||||
|
||||
|
||||
def generate_full_default_header():
|
||||
return generate_header()
|
||||
|
||||
|
||||
def generate_audio_default_header():
|
||||
return generate_header(
|
||||
message_type=CLIENT_AUDIO_ONLY_REQUEST
|
||||
)
|
||||
|
||||
|
||||
def generate_last_audio_default_header():
|
||||
return generate_header(
|
||||
message_type=CLIENT_AUDIO_ONLY_REQUEST,
|
||||
message_type_specific_flags=NEG_SEQUENCE
|
||||
)
|
||||
|
||||
def parse_response(res):
|
||||
"""
|
||||
protocol_version(4 bits), header_size(4 bits),
|
||||
message_type(4 bits), message_type_specific_flags(4 bits)
|
||||
serialization_method(4 bits) message_compression(4 bits)
|
||||
reserved (8bits) 保留字段
|
||||
header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
|
||||
payload 类似与http 请求体
|
||||
"""
|
||||
protocol_version = res[0] >> 4
|
||||
header_size = res[0] & 0x0f
|
||||
message_type = res[1] >> 4
|
||||
message_type_specific_flags = res[1] & 0x0f
|
||||
serialization_method = res[2] >> 4
|
||||
message_compression = res[2] & 0x0f
|
||||
reserved = res[3]
|
||||
header_extensions = res[4:header_size * 4]
|
||||
payload = res[header_size * 4:]
|
||||
result = {}
|
||||
payload_msg = None
|
||||
payload_size = 0
|
||||
if message_type == SERVER_FULL_RESPONSE:
|
||||
payload_size = int.from_bytes(payload[:4], "big", signed=True)
|
||||
payload_msg = payload[4:]
|
||||
elif message_type == SERVER_ACK:
|
||||
seq = int.from_bytes(payload[:4], "big", signed=True)
|
||||
result['seq'] = seq
|
||||
if len(payload) >= 8:
|
||||
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
payload_msg = payload[8:]
|
||||
elif message_type == SERVER_ERROR_RESPONSE:
|
||||
code = int.from_bytes(payload[:4], "big", signed=False)
|
||||
result['code'] = code
|
||||
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
payload_msg = payload[8:]
|
||||
if payload_msg is None:
|
||||
return result
|
||||
if message_compression == GZIP:
|
||||
payload_msg = gzip.decompress(payload_msg)
|
||||
if serialization_method == JSON:
|
||||
payload_msg = json.loads(str(payload_msg, "utf-8"))
|
||||
elif serialization_method != NO_SERIALIZATION:
|
||||
payload_msg = str(payload_msg, "utf-8")
|
||||
result['payload_msg'] = payload_msg
|
||||
result['payload_size'] = payload_size
|
||||
return result
|
||||
|
||||
|
||||
def read_wav_info(data: bytes = None) -> (int, int, int, int, int):
|
||||
with BytesIO(data) as _f:
|
||||
wave_fp = wave.open(_f, 'rb')
|
||||
nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4]
|
||||
wave_bytes = wave_fp.readframes(nframes)
|
||||
return nchannels, sampwidth, framerate, nframes, len(wave_bytes)
|
||||
|
||||
class AudioType(Enum):
|
||||
LOCAL = 1 # 使用本地音频文件
|
||||
|
||||
class AsrWsClient:
|
||||
def __init__(self, audio_path, cluster, **kwargs):
|
||||
"""
|
||||
:param config: config
|
||||
"""
|
||||
self.audio_path = audio_path
|
||||
self.cluster = cluster
|
||||
self.success_code = 1000 # success code, default is 1000
|
||||
self.seg_duration = int(kwargs.get("seg_duration", 15000))
|
||||
self.nbest = int(kwargs.get("nbest", 1))
|
||||
self.appid = kwargs.get("appid", "")
|
||||
self.token = kwargs.get("token", "")
|
||||
self.ws_url = kwargs.get("ws_url", "wss://openspeech.bytedance.com/api/v2/asr")
|
||||
self.uid = kwargs.get("uid", "streaming_asr_demo")
|
||||
self.workflow = kwargs.get("workflow", "audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate")
|
||||
self.show_language = kwargs.get("show_language", False)
|
||||
self.show_utterances = kwargs.get("show_utterances", False)
|
||||
self.result_type = kwargs.get("result_type", "full")
|
||||
self.format = kwargs.get("format", "wav")
|
||||
self.rate = kwargs.get("sample_rate", 16000)
|
||||
self.language = kwargs.get("language", "zh-CN")
|
||||
self.bits = kwargs.get("bits", 16)
|
||||
self.channel = kwargs.get("channel", 1)
|
||||
self.codec = kwargs.get("codec", "raw")
|
||||
self.audio_type = kwargs.get("audio_type", AudioType.LOCAL)
|
||||
self.secret = kwargs.get("secret", "access_secret")
|
||||
self.auth_method = kwargs.get("auth_method", "token")
|
||||
self.mp3_seg_size = int(kwargs.get("mp3_seg_size", 10000))
|
||||
|
||||
def construct_request(self, reqid):
|
||||
req = {
|
||||
'app': {
|
||||
'appid': self.appid,
|
||||
'cluster': self.cluster,
|
||||
'token': self.token,
|
||||
},
|
||||
'user': {
|
||||
'uid': self.uid
|
||||
},
|
||||
'request': {
|
||||
'reqid': reqid,
|
||||
'nbest': self.nbest,
|
||||
'workflow': self.workflow,
|
||||
'show_language': self.show_language,
|
||||
'show_utterances': self.show_utterances,
|
||||
'result_type': self.result_type,
|
||||
"sequence": 1
|
||||
},
|
||||
'audio': {
|
||||
'format': self.format,
|
||||
'rate': self.rate,
|
||||
'language': self.language,
|
||||
'bits': self.bits,
|
||||
'channel': self.channel,
|
||||
'codec': self.codec
|
||||
}
|
||||
}
|
||||
return req
|
||||
|
||||
@staticmethod
|
||||
def slice_data(data: bytes, chunk_size: int) -> (list, bool):
|
||||
"""
|
||||
slice data
|
||||
:param data: wav data
|
||||
:param chunk_size: the segment size in one request
|
||||
:return: segment data, last flag
|
||||
"""
|
||||
data_len = len(data)
|
||||
offset = 0
|
||||
while offset + chunk_size < data_len:
|
||||
yield data[offset: offset + chunk_size], False
|
||||
offset += chunk_size
|
||||
else:
|
||||
yield data[offset: data_len], True
|
||||
|
||||
def _real_processor(self, request_params: dict) -> dict:
|
||||
pass
|
||||
|
||||
def token_auth(self):
|
||||
return {'Authorization': 'Bearer; {}'.format(self.token)}
|
||||
|
||||
def signature_auth(self, data):
|
||||
header_dicts = {
|
||||
'Custom': 'auth_custom',
|
||||
}
|
||||
|
||||
url_parse = urlparse(self.ws_url)
|
||||
input_str = 'GET {} HTTP/1.1\n'.format(url_parse.path)
|
||||
auth_headers = 'Custom'
|
||||
for header in auth_headers.split(','):
|
||||
input_str += '{}\n'.format(header_dicts[header])
|
||||
input_data = bytearray(input_str, 'utf-8')
|
||||
input_data += data
|
||||
mac = base64.urlsafe_b64encode(
|
||||
hmac.new(self.secret.encode('utf-8'), input_data, digestmod=sha256).digest())
|
||||
header_dicts['Authorization'] = 'HMAC256; access_token="{}"; mac="{}"; h="{}"'.format(self.token,
|
||||
str(mac, 'utf-8'), auth_headers)
|
||||
return header_dicts
|
||||
|
||||
async def segment_data_processor(self, wav_data: bytes, segment_size: int):
|
||||
reqid = str(uuid.uuid4())
|
||||
# 构建 full client request,并序列化压缩
|
||||
request_params = self.construct_request(reqid)
|
||||
payload_bytes = str.encode(json.dumps(request_params))
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
full_client_request = bytearray(generate_full_default_header())
|
||||
full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
|
||||
full_client_request.extend(payload_bytes) # payload
|
||||
additional_headers = None
|
||||
if self.auth_method == "token":
|
||||
additional_headers = self.token_auth()
|
||||
elif self.auth_method == "signature":
|
||||
additional_headers = self.signature_auth(full_client_request)
|
||||
|
||||
connection_kwargs = {"max_size": 1000000000}
|
||||
if additional_headers:
|
||||
connection_kwargs["additional_headers"] = additional_headers
|
||||
|
||||
async with websockets.connect(self.ws_url, **connection_kwargs) as ws:
|
||||
# 发送 full client request
|
||||
await ws.send(full_client_request)
|
||||
res = await ws.recv()
|
||||
result = parse_response(res)
|
||||
if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code:
|
||||
return result
|
||||
for seq, (chunk, last) in enumerate(AsrWsClient.slice_data(wav_data, segment_size), 1):
|
||||
# if no compression, comment this line
|
||||
payload_bytes = gzip.compress(chunk)
|
||||
audio_only_request = bytearray(generate_audio_default_header())
|
||||
if last:
|
||||
audio_only_request = bytearray(generate_last_audio_default_header())
|
||||
audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
|
||||
audio_only_request.extend(payload_bytes) # payload
|
||||
# 发送 audio-only client request
|
||||
await ws.send(audio_only_request)
|
||||
res = await ws.recv()
|
||||
result = parse_response(res)
|
||||
if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code:
|
||||
return result
|
||||
return result
|
||||
|
||||
async def execute(self):
|
||||
with open(self.audio_path, mode="rb") as _f:
|
||||
data = _f.read()
|
||||
audio_data = bytes(data)
|
||||
if self.format == "mp3":
|
||||
segment_size = self.mp3_seg_size
|
||||
return await self.segment_data_processor(audio_data, segment_size)
|
||||
if self.format != "wav":
|
||||
raise Exception("format should in wav or mp3")
|
||||
nchannels, sampwidth, framerate, nframes, wav_len = read_wav_info(
|
||||
audio_data)
|
||||
size_per_sec = nchannels * sampwidth * framerate
|
||||
segment_size = int(size_per_sec * self.seg_duration / 1000)
|
||||
return await self.segment_data_processor(audio_data, segment_size)
|
||||
|
||||
|
||||
def execute_one(audio_item, cluster, **kwargs):
|
||||
"""
|
||||
|
||||
:param audio_item: {"id": xxx, "path": "xxx"}
|
||||
:param cluster:集群名称
|
||||
:return:
|
||||
"""
|
||||
assert 'id' in audio_item
|
||||
assert 'path' in audio_item
|
||||
audio_id = audio_item['id']
|
||||
audio_path = audio_item['path']
|
||||
audio_type = AudioType.LOCAL
|
||||
asr_http_client = AsrWsClient(
|
||||
audio_path=audio_path,
|
||||
cluster=cluster,
|
||||
audio_type=audio_type,
|
||||
**kwargs
|
||||
)
|
||||
result = asyncio.run(asr_http_client.execute())
|
||||
return {"id": audio_id, "path": audio_path, "result": result}
|
||||
|
||||
def test_one():
|
||||
result = execute_one(
|
||||
{
|
||||
'id': 1,
|
||||
'path': audio_path
|
||||
},
|
||||
cluster=cluster,
|
||||
appid=appid,
|
||||
token=token,
|
||||
format=audio_format,
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_one()
|
||||
96
test_llm.py
Normal file
96
test_llm.py
Normal file
@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
测试大语言模型API功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
|
||||
def test_llm_api():
|
||||
"""测试大语言模型API"""
|
||||
|
||||
# 检查API密钥
|
||||
api_key = os.environ.get("ARK_API_KEY")
|
||||
if not api_key:
|
||||
print("❌ 未设置 ARK_API_KEY 环境变量")
|
||||
return False
|
||||
|
||||
print(f"✅ API密钥已设置: {api_key[:20]}...")
|
||||
|
||||
# API配置
|
||||
api_url = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
|
||||
model = "doubao-1-5-pro-32k-250115"
|
||||
|
||||
# 测试消息
|
||||
test_message = "你好,请简单介绍一下自己"
|
||||
|
||||
try:
|
||||
print("🤖 测试大语言模型API...")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个智能助手,请根据用户的语音输入提供有帮助的回答。保持回答简洁明了。"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": test_message
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = requests.post(api_url, headers=headers, json=data, timeout=30)
|
||||
|
||||
print(f"📡 HTTP状态码: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print("✅ API调用成功")
|
||||
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
llm_response = result["choices"][0]["message"]["content"]
|
||||
print(f"💬 AI回复: {llm_response}")
|
||||
|
||||
# 显示完整响应结构
|
||||
print("\n📋 完整响应结构:")
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
return True
|
||||
else:
|
||||
print("❌ 响应格式错误")
|
||||
print(f"响应内容: {response.text}")
|
||||
return False
|
||||
else:
|
||||
print(f"❌ API调用失败: {response.status_code}")
|
||||
print(f"响应内容: {response.text}")
|
||||
return False
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"❌ 网络请求失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🧪 测试大语言模型API功能")
|
||||
print("=" * 50)
|
||||
|
||||
success = test_llm_api()
|
||||
|
||||
if success:
|
||||
print("\n✅ 大语言模型功能测试通过!")
|
||||
print("🚀 现在可以运行完整的语音助手系统了")
|
||||
else:
|
||||
print("\n❌ 大语言模型功能测试失败")
|
||||
print("🔧 请检查API密钥和网络连接")
|
||||
101
tts_http_demo.py
Normal file
101
tts_http_demo.py
Normal file
@ -0,0 +1,101 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Project : tob_service
|
||||
# @Company : ByteDance
|
||||
# @Time : 2025/7/10 19:01
|
||||
# @Author : SiNian
|
||||
# @FileName: TTSv3HttpDemo.py
|
||||
# @IDE: PyCharm
|
||||
# @Motto: I,with no mountain to rely on,am the mountain myself.
|
||||
import requests
|
||||
import json
|
||||
import base64
|
||||
import os
|
||||
|
||||
# python版本:==3.11
|
||||
|
||||
# -------------客户需要填写的参数----------------
|
||||
appID = "8718217928"
|
||||
accessKey = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
|
||||
resourceID = "volc.service_type.10029"
|
||||
text = "这是一段测试文本,用于测试字节大模型语音合成http单向流式接口效果。"
|
||||
# ---------------请求地址----------------------
|
||||
url = "https://openspeech.bytedance.com/api/v3/tts/unidirectional"
|
||||
|
||||
def tts_http_stream(url, headers, params, audio_save_path):
|
||||
session = requests.Session()
|
||||
try:
|
||||
print('请求的url:', url)
|
||||
print('请求的headers:', headers)
|
||||
print('请求的params:\n', params)
|
||||
response = session.post(url, headers=headers, json=params, stream=True)
|
||||
print(response)
|
||||
# 打印response headers
|
||||
print(f"code: {response.status_code} header: {response.headers}")
|
||||
logid = response.headers.get('X-Tt-Logid')
|
||||
print(f"X-Tt-Logid: {logid}")
|
||||
|
||||
# 用于存储音频数据
|
||||
audio_data = bytearray()
|
||||
total_audio_size = 0
|
||||
for chunk in response.iter_lines(decode_unicode=True):
|
||||
if not chunk:
|
||||
continue
|
||||
data = json.loads(chunk)
|
||||
|
||||
if data.get("code", 0) == 0 and "data" in data and data["data"]:
|
||||
chunk_audio = base64.b64decode(data["data"])
|
||||
audio_size = len(chunk_audio)
|
||||
total_audio_size += audio_size
|
||||
audio_data.extend(chunk_audio)
|
||||
continue
|
||||
if data.get("code", 0) == 0 and "sentence" in data and data["sentence"]:
|
||||
print("sentence_data:", data)
|
||||
continue
|
||||
if data.get("code", 0) == 20000000:
|
||||
break
|
||||
if data.get("code", 0) > 0:
|
||||
print(f"error response:{data}")
|
||||
break
|
||||
|
||||
# 保存音频文件
|
||||
if audio_data:
|
||||
with open(audio_save_path, "wb") as f:
|
||||
f.write(audio_data)
|
||||
print(f"文件保存在{audio_save_path},文件大小: {len(audio_data) / 1024:.2f} KB")
|
||||
# 确保生成的音频有正确的访问权限
|
||||
os.chmod(audio_save_path, 0o644)
|
||||
|
||||
except Exception as e:
|
||||
print(f"请求失败: {e}")
|
||||
finally:
|
||||
response.close()
|
||||
session.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ---------------请求地址----------------------
|
||||
headers = {
|
||||
"X-Api-App-Id": appID,
|
||||
"X-Api-Access-Key": accessKey,
|
||||
"X-Api-Resource-Id": resourceID,
|
||||
"X-Api-App-Key": "aGjiRDfUWi",
|
||||
"Content-Type": "application/json",
|
||||
"Connection": "keep-alive"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"user": {
|
||||
"uid": "123123"
|
||||
},
|
||||
"req_params":{
|
||||
"text": "其他人",
|
||||
"speaker": "zh_female_wanqudashu_moon_bigtts",
|
||||
"audio_params": {
|
||||
"format": "mp3",
|
||||
"sample_rate": 24000,
|
||||
"enable_timestamp": True
|
||||
},
|
||||
"additions": "{\"explicit_language\":\"zh\",\"disable_markdown_filter\":true, \"enable_timestamp\":true}\"}"
|
||||
}
|
||||
}
|
||||
|
||||
tts_http_stream(url=url, headers=headers, params=payload, audio_save_path="tts_test.mp3")
|
||||
BIN
tts_response_20250920_165657.mp3
Normal file
BIN
tts_response_20250920_165657.mp3
Normal file
Binary file not shown.
BIN
tts_test.mp3
Normal file
BIN
tts_test.mp3
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user