This commit is contained in:
朱潮 2025-09-20 14:58:49 +08:00
parent 9108fd4582
commit 43879961a2

840
voice_chat.py Normal file
View File

@ -0,0 +1,840 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
语音交互聊天系统 - 集成豆包AI
基于能量检测的录音 + 豆包语音识别 + TTS回复
"""
import sys
import os
import time
import threading
import asyncio
import subprocess
import wave
import struct
import json
import gzip
import uuid
from typing import Dict, Any, Optional
import pyaudio
import numpy as np
import websockets
# 豆包协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
MSG_WITH_EVENT = 0b0100
NO_SERIALIZATION = 0b0000
JSON = 0b0001
GZIP = 0b0001
class DoubaoClient:
"""豆包音频处理客户端"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
def generate_header(self, message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON, compression_type=GZIP) -> bytes:
"""生成协议头"""
header = bytearray()
header.append((PROTOCOL_VERSION << 4) | 1) # version + header_size
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(0x00) # reserved
return bytes(header)
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"🔗 连接豆包服务器...")
try:
self.ws = await websockets.connect(
self.base_url,
additional_headers=self.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"✅ 连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
except Exception as e:
print(f"❌ 连接失败: {e}")
raise
def parse_response(self, response):
"""解析响应"""
if len(response) < 4:
return None
protocol_version = response[0] >> 4
header_size = response[0] & 0x0f
message_type = response[1] >> 4
flags = response[1] & 0x0f
payload_start = header_size * 4
payload = response[payload_start:]
result = {
'protocol_version': protocol_version,
'header_size': header_size,
'message_type': message_type,
'flags': flags,
'payload': payload,
'payload_size': len(payload)
}
# 解析payload
if len(payload) >= 4:
result['event'] = int.from_bytes(payload[:4], 'big')
if len(payload) >= 8:
session_id_len = int.from_bytes(payload[4:8], 'big')
if len(payload) >= 8 + session_id_len:
result['session_id'] = payload[8:8+session_id_len].decode()
if len(payload) >= 12 + session_id_len:
data_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big')
result['data_size'] = data_size
result['data'] = payload[12+session_id_len:12+session_id_len+data_size]
# 尝试解析JSON数据
try:
result['json_data'] = json.loads(result['data'].decode('utf-8'))
except:
pass
return result
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
request = bytearray(self.generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
session_config = {
"asr": {"extra": {"end_smooth_window_ms": 1500}},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {"channel": 1, "format": "pcm", "sample_rate": 24000}
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 30,
"input_mod": "audio",
},
},
}
request = bytearray(self.generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
await asyncio.sleep(1.0)
async def process_audio(self, audio_data: bytes) -> tuple[str, bytes]:
"""处理音频并返回(识别文本, TTS音频)"""
try:
# 发送音频数据 - 使用与doubao_simple.py相同的格式
task_request = bytearray(
self.generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
serial_method=NO_SERIALIZATION))
task_request.extend(int(200).to_bytes(4, 'big'))
task_request.extend(len(self.session_id).to_bytes(4, 'big'))
task_request.extend(self.session_id.encode())
payload_bytes = gzip.compress(audio_data)
task_request.extend(len(payload_bytes).to_bytes(4, 'big'))
task_request.extend(payload_bytes)
await self.ws.send(task_request)
print("📤 音频数据已发送")
recognized_text = ""
tts_audio = b""
response_count = 0
# 接收响应 - 使用与doubao_simple.py相同的解析逻辑
audio_chunks = []
max_responses = 30
while response_count < max_responses:
try:
response = await asyncio.wait_for(self.ws.recv(), timeout=30.0)
response_count += 1
parsed = self.parse_response(response)
if not parsed:
continue
print(f"📥 响应 {response_count}: message_type={parsed['message_type']}, event={parsed.get('event', 'N/A')}, size={parsed['payload_size']}")
# 处理不同类型的响应
if parsed['message_type'] == 11: # SERVER_ACK - 可能包含音频
if 'data' in parsed and parsed['data_size'] > 0:
audio_chunks.append(parsed['data'])
print(f"收集到音频块: {parsed['data_size']} 字节")
elif parsed['message_type'] == 9: # SERVER_FULL_RESPONSE
event = parsed.get('event', 0)
if event == 450: # ASR开始
print("🎤 ASR处理开始")
elif event == 451: # ASR结果
if 'json_data' in parsed and 'results' in parsed['json_data']:
text = parsed['json_data']['results'][0].get('text', '')
recognized_text = text
print(f"🧠 识别结果: {text}")
elif event == 459: # ASR结束
print("✅ ASR处理结束")
elif event == 350: # TTS开始
print("🎵 TTS生成开始")
elif event == 359: # TTS结束
print("✅ TTS生成结束")
break
elif event == 550: # TTS音频数据
if 'data' in parsed and parsed['data_size'] > 0:
# 检查是否是JSON音频元数据还是实际音频数据
try:
json.loads(parsed['data'].decode('utf-8'))
print("收到TTS音频元数据")
except:
# 不是JSON可能是音频数据
audio_chunks.append(parsed['data'])
print(f"收集到TTS音频块: {parsed['data_size']} 字节")
except asyncio.TimeoutError:
print(f"⏰ 等待响应 {response_count + 1} 超时")
break
except websockets.exceptions.ConnectionClosed:
print("🔌 连接已关闭")
break
print(f"共收到 {response_count} 个响应,收集到 {len(audio_chunks)} 个音频块")
# 合并音频数据
if audio_chunks:
tts_audio = b''.join(audio_chunks)
print(f"合并后的音频数据: {len(tts_audio)} 字节")
# 转换TTS音频格式32位浮点 -> 16位整数
if tts_audio:
# 检查是否是GZIP压缩数据
try:
decompressed = gzip.decompress(tts_audio)
print(f"解压缩后音频数据: {len(decompressed)} 字节")
audio_to_write = decompressed
except:
print("音频数据不是GZIP压缩格式直接使用原始数据")
audio_to_write = tts_audio
# 检查音频数据长度是否是4的倍数32位浮点
if len(audio_to_write) % 4 != 0:
print(f"警告:音频数据长度 {len(audio_to_write)} 不是4的倍数截断到最近的倍数")
audio_to_write = audio_to_write[:len(audio_to_write) // 4 * 4]
# 将32位浮点转换为16位整数
float_count = len(audio_to_write) // 4
int16_data = bytearray(float_count * 2)
for i in range(float_count):
# 读取32位浮点数小端序
float_value = struct.unpack('<f', audio_to_write[i*4:i*4+4])[0]
# 将浮点数限制在[-1.0, 1.0]范围内
float_value = max(-1.0, min(1.0, float_value))
# 转换为16位整数
int16_value = int(float_value * 32767)
# 写入16位整数小端序
int16_data[i*2:i*2+2] = struct.pack('<h', int16_value)
tts_audio = bytes(int16_data)
print(f"✅ 音频转换完成: {len(tts_audio)} 字节")
return recognized_text, tts_audio
except Exception as e:
print(f"❌ 处理失败: {e}")
import traceback
traceback.print_exc()
return "", b""
async def send_silence_data(self, duration_ms=100) -> None:
"""发送静音数据保持连接活跃"""
try:
# 生成静音音频数据
samples = int(16000 * duration_ms / 1000) # 16kHz采样率
silence_data = bytes(samples * 2) # 16位PCM
# 发送静音数据
task_request = bytearray(
self.generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
serial_method=NO_SERIALIZATION))
task_request.extend(int(200).to_bytes(4, 'big'))
task_request.extend(len(self.session_id).to_bytes(4, 'big'))
task_request.extend(self.session_id.encode())
payload_bytes = gzip.compress(silence_data)
task_request.extend(len(payload_bytes).to_bytes(4, 'big'))
task_request.extend(payload_bytes)
await self.ws.send(task_request)
print("💓 发送心跳数据保持连接")
# 简单处理响应(不等待完整响应)
try:
response = await asyncio.wait_for(self.ws.recv(), timeout=5.0)
# 只确认收到响应,不处理内容
except asyncio.TimeoutError:
print("⚠️ 心跳响应超时")
except websockets.exceptions.ConnectionClosed:
print("❌ 心跳时连接已关闭")
raise
except Exception as e:
print(f"❌ 发送心跳数据失败: {e}")
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
await self.ws.close()
except:
pass
print("🔌 连接已关闭")
class VoiceChatRecorder:
"""语音聊天录音系统"""
def __init__(self, enable_ai_chat=True):
# 音频参数
self.FORMAT = pyaudio.paInt16
self.CHANNELS = 1
self.RATE = 16000
self.CHUNK_SIZE = 1024
# 能量检测参数
self.energy_threshold = 500
self.silence_threshold = 2.0
self.min_recording_time = 1.0
self.max_recording_time = 20.0
# 状态变量
self.audio = None
self.stream = None
self.running = False
self.recording = False
self.recorded_frames = []
self.recording_start_time = None
self.last_sound_time = None
self.energy_history = []
self.zcr_history = []
# AI聊天功能
self.enable_ai_chat = enable_ai_chat
self.doubao_client = None
self.is_processing_ai = False
self.heartbeat_thread = None
self.last_heartbeat_time = time.time()
self.heartbeat_interval = 10.0 # 每10秒发送一次心跳
# 预录音缓冲区
self.pre_record_buffer = []
self.pre_record_max_frames = int(2.0 * self.RATE / self.CHUNK_SIZE)
# 播放状态
self.is_playing = False
# ZCR检测参数
self.consecutive_low_zcr_count = 0
self.low_zcr_threshold_count = 15
self.voice_activity_history = []
self._setup_audio()
def _setup_audio(self):
"""设置音频设备"""
try:
self.audio = pyaudio.PyAudio()
self.stream = self.audio.open(
format=self.FORMAT,
channels=self.CHANNELS,
rate=self.RATE,
input=True,
frames_per_buffer=self.CHUNK_SIZE
)
print("✅ 音频设备初始化成功")
except Exception as e:
print(f"❌ 音频设备初始化失败: {e}")
def generate_silence_audio(self, duration_ms=100):
"""生成静音音频数据"""
# 生成指定时长的静音音频16位PCM值为0
samples = int(self.RATE * duration_ms / 1000)
silence_data = bytes(samples * 2) # 16位 = 2字节每样本
return silence_data
def calculate_energy(self, audio_data):
"""计算音频能量"""
if len(audio_data) == 0:
return 0
audio_array = np.frombuffer(audio_data, dtype=np.int16)
rms = np.sqrt(np.mean(audio_array ** 2))
if not self.recording:
self.energy_history.append(rms)
if len(self.energy_history) > 50:
self.energy_history.pop(0)
return rms
def calculate_zero_crossing_rate(self, audio_data):
"""计算零交叉率"""
if len(audio_data) == 0:
return 0
audio_array = np.frombuffer(audio_data, dtype=np.int16)
zero_crossings = np.sum(np.diff(np.sign(audio_array)) != 0)
zcr = zero_crossings / len(audio_array) * self.RATE
self.zcr_history.append(zcr)
if len(self.zcr_history) > 30:
self.zcr_history.pop(0)
return zcr
def is_voice_active(self, energy, zcr):
"""使用ZCR进行语音活动检测"""
# 16000Hz采样率下的语音ZCR范围
zcr_condition = 2400 < zcr < 12000
return zcr_condition
def save_recording(self, audio_data, filename=None):
"""保存录音"""
if filename is None:
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"recording_{timestamp}.wav"
try:
with wave.open(filename, 'wb') as wf:
wf.setnchannels(self.CHANNELS)
wf.setsampwidth(self.audio.get_sample_size(self.FORMAT))
wf.setframerate(self.RATE)
wf.writeframes(audio_data)
print(f"✅ 录音已保存: {filename}")
return True, filename
except Exception as e:
print(f"❌ 保存录音失败: {e}")
return False, None
def play_audio(self, filename):
"""播放音频文件"""
try:
# 停止当前录音
if self.recording:
self.recording = False
self.recorded_frames = []
# 关闭输入流
if self.stream:
self.stream.stop_stream()
self.stream.close()
self.stream = None
self.is_playing = True
time.sleep(0.2)
# 使用系统播放器
print(f"🔊 播放: {filename}")
subprocess.run(['aplay', filename], check=True)
print("✅ 播放完成")
except Exception as e:
print(f"❌ 播放失败: {e}")
finally:
self.is_playing = False
time.sleep(0.2)
self._setup_audio()
def update_pre_record_buffer(self, audio_data):
"""更新预录音缓冲区"""
self.pre_record_buffer.append(audio_data)
if len(self.pre_record_buffer) > self.pre_record_max_frames:
self.pre_record_buffer.pop(0)
def start_recording(self):
"""开始录音"""
print("🎙️ 检测到声音,开始录音...")
self.recording = True
self.recorded_frames = []
self.recorded_frames.extend(self.pre_record_buffer)
self.pre_record_buffer = []
self.recording_start_time = time.time()
self.last_sound_time = time.time()
self.consecutive_low_zcr_count = 0
def stop_recording(self):
"""停止录音"""
if len(self.recorded_frames) > 0:
audio_data = b''.join(self.recorded_frames)
duration = len(audio_data) / (self.RATE * 2)
print(f"📝 录音完成,时长: {duration:.2f}")
if self.enable_ai_chat:
# AI聊天模式
self.process_with_ai(audio_data)
else:
# 普通录音模式
success, filename = self.save_recording(audio_data)
if success and filename:
print("=" * 50)
print("🔊 播放刚才录制的音频...")
self.play_audio(filename)
print("=" * 50)
self.recording = False
self.recorded_frames = []
self.recording_start_time = None
self.last_sound_time = None
def process_with_ai(self, audio_data):
"""使用AI处理录音"""
if self.is_processing_ai:
print("⏳ AI正在处理中请稍候...")
return
self.is_processing_ai = True
# 在新线程中处理AI
ai_thread = threading.Thread(target=self._ai_processing_thread, args=(audio_data,))
ai_thread.daemon = True
ai_thread.start()
def _heartbeat_thread(self):
"""心跳线程 - 定期发送静音数据保持连接活跃"""
while self.running and self.doubao_client and self.doubao_client.ws:
current_time = time.time()
if current_time - self.last_heartbeat_time >= self.heartbeat_interval:
try:
# 异步发送心跳数据
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.doubao_client.send_silence_data())
self.last_heartbeat_time = current_time
except Exception as e:
print(f"❌ 心跳失败: {e}")
# 如果心跳失败,可能需要重新连接
break
finally:
loop.close()
except Exception as e:
print(f"❌ 心跳线程异常: {e}")
break
# 睡眠一段时间
time.sleep(1.0)
print("📡 心跳线程结束")
def _ai_processing_thread(self, audio_data):
"""AI处理线程"""
try:
print("🤖 开始AI处理...")
print("🧠 正在进行语音识别...")
# 异步处理
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 连接豆包
self.doubao_client = DoubaoClient()
loop.run_until_complete(self.doubao_client.connect())
# 启动心跳线程
self.last_heartbeat_time = time.time()
self.heartbeat_thread = threading.Thread(target=self._heartbeat_thread)
self.heartbeat_thread.daemon = True
self.heartbeat_thread.start()
print("💓 心跳线程已启动")
# 语音识别和TTS回复
recognized_text, tts_audio = loop.run_until_complete(
self.doubao_client.process_audio(audio_data)
)
if recognized_text:
print(f"🗣️ 你说: {recognized_text}")
if tts_audio:
# 保存TTS音频
tts_filename = "ai_response.wav"
with wave.open(tts_filename, 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(24000)
wav_file.writeframes(tts_audio)
print("🎵 AI回复生成完成")
print("=" * 50)
print("🔊 播放AI回复...")
self.play_audio(tts_filename)
print("=" * 50)
else:
print("❌ 未收到AI回复")
# 等待一段时间再关闭连接,以便心跳继续工作
print("⏳ 等待5秒后关闭连接...")
time.sleep(5)
except Exception as e:
print(f"❌ AI处理失败: {e}")
finally:
# 停止心跳线程
if self.heartbeat_thread and self.heartbeat_thread.is_alive():
print("🛑 停止心跳线程")
self.heartbeat_thread = None
# 关闭连接
if self.doubao_client:
loop.run_until_complete(self.doubao_client.close())
loop.close()
except Exception as e:
print(f"❌ AI处理线程失败: {e}")
finally:
self.is_processing_ai = False
def run(self):
"""运行语音聊天系统"""
if not self.stream:
print("❌ 音频设备未初始化")
return
self.running = True
if self.enable_ai_chat:
print("🤖 语音聊天AI助手")
print("=" * 50)
print("🎯 功能特点:")
print("- 🎙️ 智能语音检测")
print("- 🧠 豆包AI语音识别")
print("- 🗣️ AI智能回复")
print("- 🔊 TTS语音播放")
print("- 🔄 实时对话")
print("=" * 50)
print("📖 使用说明:")
print("- 说话自动录音")
print("- 静音2秒结束录音")
print("- AI自动识别并回复")
print("- 按 Ctrl+C 退出")
print("=" * 50)
else:
print("🎙️ 智能录音系统")
print("=" * 50)
print("📖 使用说明:")
print("- 说话自动录音")
print("- 静音2秒结束录音")
print("- 录音完成后自动播放")
print("- 按 Ctrl+C 退出")
print("=" * 50)
try:
while self.running:
# 如果正在播放AI回复跳过音频处理
if self.is_playing or self.is_processing_ai:
status = "🤖 AI处理中..."
print(f"\r{status}", end='', flush=True)
time.sleep(0.1)
continue
# 读取音频数据
data = self.stream.read(self.CHUNK_SIZE, exception_on_overflow=False)
if len(data) == 0:
continue
# 计算能量和ZCR
energy = self.calculate_energy(data)
zcr = self.calculate_zero_crossing_rate(data)
if self.recording:
# 录音模式
self.recorded_frames.append(data)
recording_duration = time.time() - self.recording_start_time
# 检测语音活动
if self.is_voice_active(energy, zcr):
self.last_sound_time = time.time()
self.consecutive_low_zcr_count = 0
else:
self.consecutive_low_zcr_count += 1
# 检查是否应该结束录音
should_stop = False
# ZCR静音检测
if self.consecutive_low_zcr_count >= self.low_zcr_threshold_count:
should_stop = True
# 时间静音检测
if not should_stop and time.time() - self.last_sound_time > self.silence_threshold:
should_stop = True
# 执行停止录音
if should_stop and recording_duration >= self.min_recording_time:
print(f"\n🔇 检测到静音,结束录音")
self.stop_recording()
# 检查最大录音时间
if recording_duration > self.max_recording_time:
print(f"\n⏰ 达到最大录音时间")
self.stop_recording()
# 显示录音状态
is_voice = self.is_voice_active(energy, zcr)
zcr_count = f"{self.consecutive_low_zcr_count}/{self.low_zcr_threshold_count}"
status = f"录音中... {recording_duration:.1f}s | ZCR: {zcr:.0f} | 语音: {is_voice} | 静音计数: {zcr_count}"
print(f"\r{status}", end='', flush=True)
else:
# 监听模式
self.update_pre_record_buffer(data)
if self.is_voice_active(energy, zcr):
# 检测到声音,开始录音
self.start_recording()
else:
# 显示监听状态
is_voice = self.is_voice_active(energy, zcr)
buffer_usage = len(self.pre_record_buffer) / self.pre_record_max_frames * 100
status = f"监听中... ZCR: {zcr:.0f} | 语音: {is_voice} | 缓冲: {buffer_usage:.0f}%"
print(f"\r{status}", end='', flush=True)
time.sleep(0.01)
except KeyboardInterrupt:
print("\n👋 退出")
except Exception as e:
print(f"❌ 错误: {e}")
finally:
self.stop()
def stop(self):
"""停止系统"""
self.running = False
# 停止心跳线程
if self.heartbeat_thread and self.heartbeat_thread.is_alive():
print("🛑 停止心跳线程")
self.heartbeat_thread = None
if self.recording:
self.stop_recording()
if self.stream:
self.stream.stop_stream()
self.stream.close()
if self.audio:
self.audio.terminate()
# 关闭AI连接
if self.doubao_client and self.doubao_client.ws:
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self.doubao_client.close())
loop.close()
except:
pass
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='语音聊天AI助手')
parser.add_argument('--no-ai', action='store_true', help='禁用AI功能仅录音')
args = parser.parse_args()
enable_ai = not args.no_ai
if enable_ai:
print("🚀 语音聊天AI助手")
else:
print("🚀 智能录音系统")
print("=" * 50)
# 创建语音聊天系统
recorder = VoiceChatRecorder(enable_ai_chat=enable_ai)
print("✅ 系统初始化成功")
print("=" * 50)
# 开始运行
recorder.run()
if __name__ == "__main__":
main()