diff --git a/voice_chat.py b/voice_chat.py new file mode 100644 index 0000000..c80f514 --- /dev/null +++ b/voice_chat.py @@ -0,0 +1,840 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +语音交互聊天系统 - 集成豆包AI +基于能量检测的录音 + 豆包语音识别 + TTS回复 +""" + +import sys +import os +import time +import threading +import asyncio +import subprocess +import wave +import struct +import json +import gzip +import uuid +from typing import Dict, Any, Optional +import pyaudio +import numpy as np +import websockets + +# 豆包协议常量 +PROTOCOL_VERSION = 0b0001 +CLIENT_FULL_REQUEST = 0b0001 +CLIENT_AUDIO_ONLY_REQUEST = 0b0010 +SERVER_FULL_RESPONSE = 0b1001 +SERVER_ACK = 0b1011 +SERVER_ERROR_RESPONSE = 0b1111 +NO_SEQUENCE = 0b0000 +MSG_WITH_EVENT = 0b0100 +NO_SERIALIZATION = 0b0000 +JSON = 0b0001 +GZIP = 0b0001 + +class DoubaoClient: + """豆包音频处理客户端""" + + def __init__(self): + self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" + self.app_id = "8718217928" + self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc" + self.app_key = "PlgvMymc7f3tQnJ6" + self.resource_id = "volc.speech.dialog" + self.session_id = str(uuid.uuid4()) + self.ws = None + self.log_id = "" + + def get_headers(self) -> Dict[str, str]: + """获取请求头""" + return { + "X-Api-App-ID": self.app_id, + "X-Api-Access-Key": self.access_key, + "X-Api-Resource-Id": self.resource_id, + "X-Api-App-Key": self.app_key, + "X-Api-Connect-Id": str(uuid.uuid4()), + } + + def generate_header(self, message_type=CLIENT_FULL_REQUEST, + message_type_specific_flags=MSG_WITH_EVENT, + serial_method=JSON, compression_type=GZIP) -> bytes: + """生成协议头""" + header = bytearray() + header.append((PROTOCOL_VERSION << 4) | 1) # version + header_size + header.append((message_type << 4) | message_type_specific_flags) + header.append((serial_method << 4) | compression_type) + header.append(0x00) # reserved + return bytes(header) + + async def connect(self) -> None: + """建立WebSocket连接""" + print(f"🔗 连接豆包服务器...") + try: + self.ws = await websockets.connect( + self.base_url, + additional_headers=self.get_headers(), + ping_interval=None + ) + + # 获取log_id + if hasattr(self.ws, 'response_headers'): + self.log_id = self.ws.response_headers.get("X-Tt-Logid") + elif hasattr(self.ws, 'headers'): + self.log_id = self.ws.headers.get("X-Tt-Logid") + + print(f"✅ 连接成功, log_id: {self.log_id}") + + # 发送StartConnection请求 + await self._send_start_connection() + + # 发送StartSession请求 + await self._send_start_session() + + except Exception as e: + print(f"❌ 连接失败: {e}") + raise + + def parse_response(self, response): + """解析响应""" + if len(response) < 4: + return None + + protocol_version = response[0] >> 4 + header_size = response[0] & 0x0f + message_type = response[1] >> 4 + flags = response[1] & 0x0f + + payload_start = header_size * 4 + payload = response[payload_start:] + + result = { + 'protocol_version': protocol_version, + 'header_size': header_size, + 'message_type': message_type, + 'flags': flags, + 'payload': payload, + 'payload_size': len(payload) + } + + # 解析payload + if len(payload) >= 4: + result['event'] = int.from_bytes(payload[:4], 'big') + + if len(payload) >= 8: + session_id_len = int.from_bytes(payload[4:8], 'big') + if len(payload) >= 8 + session_id_len: + result['session_id'] = payload[8:8+session_id_len].decode() + + if len(payload) >= 12 + session_id_len: + data_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big') + result['data_size'] = data_size + result['data'] = payload[12+session_id_len:12+session_id_len+data_size] + + # 尝试解析JSON数据 + try: + result['json_data'] = json.loads(result['data'].decode('utf-8')) + except: + pass + + return result + + async def _send_start_connection(self) -> None: + """发送StartConnection请求""" + request = bytearray(self.generate_header()) + request.extend(int(1).to_bytes(4, 'big')) + + payload_bytes = b"{}" + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + response = await self.ws.recv() + + async def _send_start_session(self) -> None: + """发送StartSession请求""" + session_config = { + "asr": {"extra": {"end_smooth_window_ms": 1500}}, + "tts": { + "speaker": "zh_female_vv_jupiter_bigtts", + "audio_config": {"channel": 1, "format": "pcm", "sample_rate": 24000} + }, + "dialog": { + "bot_name": "豆包", + "system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。", + "speaking_style": "你的说话风格简洁明了,语速适中,语调自然。", + "location": {"city": "北京"}, + "extra": { + "strict_audit": False, + "audit_response": "支持客户自定义安全审核回复话术。", + "recv_timeout": 30, + "input_mod": "audio", + }, + }, + } + + request = bytearray(self.generate_header()) + request.extend(int(100).to_bytes(4, 'big')) + request.extend(len(self.session_id).to_bytes(4, 'big')) + request.extend(self.session_id.encode()) + + payload_bytes = json.dumps(session_config).encode() + payload_bytes = gzip.compress(payload_bytes) + request.extend(len(payload_bytes).to_bytes(4, 'big')) + request.extend(payload_bytes) + + await self.ws.send(request) + response = await self.ws.recv() + await asyncio.sleep(1.0) + + async def process_audio(self, audio_data: bytes) -> tuple[str, bytes]: + """处理音频并返回(识别文本, TTS音频)""" + try: + # 发送音频数据 - 使用与doubao_simple.py相同的格式 + task_request = bytearray( + self.generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST, + serial_method=NO_SERIALIZATION)) + task_request.extend(int(200).to_bytes(4, 'big')) + task_request.extend(len(self.session_id).to_bytes(4, 'big')) + task_request.extend(self.session_id.encode()) + payload_bytes = gzip.compress(audio_data) + task_request.extend(len(payload_bytes).to_bytes(4, 'big')) + task_request.extend(payload_bytes) + await self.ws.send(task_request) + print("📤 音频数据已发送") + + recognized_text = "" + tts_audio = b"" + response_count = 0 + + # 接收响应 - 使用与doubao_simple.py相同的解析逻辑 + audio_chunks = [] + max_responses = 30 + + while response_count < max_responses: + try: + response = await asyncio.wait_for(self.ws.recv(), timeout=30.0) + response_count += 1 + + parsed = self.parse_response(response) + if not parsed: + continue + + print(f"📥 响应 {response_count}: message_type={parsed['message_type']}, event={parsed.get('event', 'N/A')}, size={parsed['payload_size']}") + + # 处理不同类型的响应 + if parsed['message_type'] == 11: # SERVER_ACK - 可能包含音频 + if 'data' in parsed and parsed['data_size'] > 0: + audio_chunks.append(parsed['data']) + print(f"收集到音频块: {parsed['data_size']} 字节") + + elif parsed['message_type'] == 9: # SERVER_FULL_RESPONSE + event = parsed.get('event', 0) + + if event == 450: # ASR开始 + print("🎤 ASR处理开始") + elif event == 451: # ASR结果 + if 'json_data' in parsed and 'results' in parsed['json_data']: + text = parsed['json_data']['results'][0].get('text', '') + recognized_text = text + print(f"🧠 识别结果: {text}") + elif event == 459: # ASR结束 + print("✅ ASR处理结束") + elif event == 350: # TTS开始 + print("🎵 TTS生成开始") + elif event == 359: # TTS结束 + print("✅ TTS生成结束") + break + elif event == 550: # TTS音频数据 + if 'data' in parsed and parsed['data_size'] > 0: + # 检查是否是JSON(音频元数据)还是实际音频数据 + try: + json.loads(parsed['data'].decode('utf-8')) + print("收到TTS音频元数据") + except: + # 不是JSON,可能是音频数据 + audio_chunks.append(parsed['data']) + print(f"收集到TTS音频块: {parsed['data_size']} 字节") + + except asyncio.TimeoutError: + print(f"⏰ 等待响应 {response_count + 1} 超时") + break + except websockets.exceptions.ConnectionClosed: + print("🔌 连接已关闭") + break + + print(f"共收到 {response_count} 个响应,收集到 {len(audio_chunks)} 个音频块") + + # 合并音频数据 + if audio_chunks: + tts_audio = b''.join(audio_chunks) + print(f"合并后的音频数据: {len(tts_audio)} 字节") + + # 转换TTS音频格式(32位浮点 -> 16位整数) + if tts_audio: + # 检查是否是GZIP压缩数据 + try: + decompressed = gzip.decompress(tts_audio) + print(f"解压缩后音频数据: {len(decompressed)} 字节") + audio_to_write = decompressed + except: + print("音频数据不是GZIP压缩格式,直接使用原始数据") + audio_to_write = tts_audio + + # 检查音频数据长度是否是4的倍数(32位浮点) + if len(audio_to_write) % 4 != 0: + print(f"警告:音频数据长度 {len(audio_to_write)} 不是4的倍数,截断到最近的倍数") + audio_to_write = audio_to_write[:len(audio_to_write) // 4 * 4] + + # 将32位浮点转换为16位整数 + float_count = len(audio_to_write) // 4 + int16_data = bytearray(float_count * 2) + + for i in range(float_count): + # 读取32位浮点数(小端序) + float_value = struct.unpack(' None: + """发送静音数据保持连接活跃""" + try: + # 生成静音音频数据 + samples = int(16000 * duration_ms / 1000) # 16kHz采样率 + silence_data = bytes(samples * 2) # 16位PCM + + # 发送静音数据 + task_request = bytearray( + self.generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST, + serial_method=NO_SERIALIZATION)) + task_request.extend(int(200).to_bytes(4, 'big')) + task_request.extend(len(self.session_id).to_bytes(4, 'big')) + task_request.extend(self.session_id.encode()) + payload_bytes = gzip.compress(silence_data) + task_request.extend(len(payload_bytes).to_bytes(4, 'big')) + task_request.extend(payload_bytes) + await self.ws.send(task_request) + print("💓 发送心跳数据保持连接") + + # 简单处理响应(不等待完整响应) + try: + response = await asyncio.wait_for(self.ws.recv(), timeout=5.0) + # 只确认收到响应,不处理内容 + except asyncio.TimeoutError: + print("⚠️ 心跳响应超时") + except websockets.exceptions.ConnectionClosed: + print("❌ 心跳时连接已关闭") + raise + + except Exception as e: + print(f"❌ 发送心跳数据失败: {e}") + + async def close(self) -> None: + """关闭连接""" + if self.ws: + try: + await self.ws.close() + except: + pass + print("🔌 连接已关闭") + +class VoiceChatRecorder: + """语音聊天录音系统""" + + def __init__(self, enable_ai_chat=True): + # 音频参数 + self.FORMAT = pyaudio.paInt16 + self.CHANNELS = 1 + self.RATE = 16000 + self.CHUNK_SIZE = 1024 + + # 能量检测参数 + self.energy_threshold = 500 + self.silence_threshold = 2.0 + self.min_recording_time = 1.0 + self.max_recording_time = 20.0 + + # 状态变量 + self.audio = None + self.stream = None + self.running = False + self.recording = False + self.recorded_frames = [] + self.recording_start_time = None + self.last_sound_time = None + self.energy_history = [] + self.zcr_history = [] + + # AI聊天功能 + self.enable_ai_chat = enable_ai_chat + self.doubao_client = None + self.is_processing_ai = False + self.heartbeat_thread = None + self.last_heartbeat_time = time.time() + self.heartbeat_interval = 10.0 # 每10秒发送一次心跳 + + # 预录音缓冲区 + self.pre_record_buffer = [] + self.pre_record_max_frames = int(2.0 * self.RATE / self.CHUNK_SIZE) + + # 播放状态 + self.is_playing = False + + # ZCR检测参数 + self.consecutive_low_zcr_count = 0 + self.low_zcr_threshold_count = 15 + self.voice_activity_history = [] + + self._setup_audio() + + def _setup_audio(self): + """设置音频设备""" + try: + self.audio = pyaudio.PyAudio() + self.stream = self.audio.open( + format=self.FORMAT, + channels=self.CHANNELS, + rate=self.RATE, + input=True, + frames_per_buffer=self.CHUNK_SIZE + ) + print("✅ 音频设备初始化成功") + except Exception as e: + print(f"❌ 音频设备初始化失败: {e}") + + def generate_silence_audio(self, duration_ms=100): + """生成静音音频数据""" + # 生成指定时长的静音音频(16位PCM,值为0) + samples = int(self.RATE * duration_ms / 1000) + silence_data = bytes(samples * 2) # 16位 = 2字节每样本 + return silence_data + + def calculate_energy(self, audio_data): + """计算音频能量""" + if len(audio_data) == 0: + return 0 + + audio_array = np.frombuffer(audio_data, dtype=np.int16) + rms = np.sqrt(np.mean(audio_array ** 2)) + + if not self.recording: + self.energy_history.append(rms) + if len(self.energy_history) > 50: + self.energy_history.pop(0) + + return rms + + def calculate_zero_crossing_rate(self, audio_data): + """计算零交叉率""" + if len(audio_data) == 0: + return 0 + + audio_array = np.frombuffer(audio_data, dtype=np.int16) + zero_crossings = np.sum(np.diff(np.sign(audio_array)) != 0) + zcr = zero_crossings / len(audio_array) * self.RATE + + self.zcr_history.append(zcr) + if len(self.zcr_history) > 30: + self.zcr_history.pop(0) + + return zcr + + def is_voice_active(self, energy, zcr): + """使用ZCR进行语音活动检测""" + # 16000Hz采样率下的语音ZCR范围 + zcr_condition = 2400 < zcr < 12000 + return zcr_condition + + def save_recording(self, audio_data, filename=None): + """保存录音""" + if filename is None: + timestamp = time.strftime("%Y%m%d_%H%M%S") + filename = f"recording_{timestamp}.wav" + + try: + with wave.open(filename, 'wb') as wf: + wf.setnchannels(self.CHANNELS) + wf.setsampwidth(self.audio.get_sample_size(self.FORMAT)) + wf.setframerate(self.RATE) + wf.writeframes(audio_data) + + print(f"✅ 录音已保存: {filename}") + return True, filename + except Exception as e: + print(f"❌ 保存录音失败: {e}") + return False, None + + def play_audio(self, filename): + """播放音频文件""" + try: + # 停止当前录音 + if self.recording: + self.recording = False + self.recorded_frames = [] + + # 关闭输入流 + if self.stream: + self.stream.stop_stream() + self.stream.close() + self.stream = None + + self.is_playing = True + time.sleep(0.2) + + # 使用系统播放器 + print(f"🔊 播放: {filename}") + subprocess.run(['aplay', filename], check=True) + print("✅ 播放完成") + + except Exception as e: + print(f"❌ 播放失败: {e}") + finally: + self.is_playing = False + time.sleep(0.2) + self._setup_audio() + + def update_pre_record_buffer(self, audio_data): + """更新预录音缓冲区""" + self.pre_record_buffer.append(audio_data) + if len(self.pre_record_buffer) > self.pre_record_max_frames: + self.pre_record_buffer.pop(0) + + def start_recording(self): + """开始录音""" + print("🎙️ 检测到声音,开始录音...") + self.recording = True + self.recorded_frames = [] + self.recorded_frames.extend(self.pre_record_buffer) + self.pre_record_buffer = [] + self.recording_start_time = time.time() + self.last_sound_time = time.time() + self.consecutive_low_zcr_count = 0 + + def stop_recording(self): + """停止录音""" + if len(self.recorded_frames) > 0: + audio_data = b''.join(self.recorded_frames) + duration = len(audio_data) / (self.RATE * 2) + + print(f"📝 录音完成,时长: {duration:.2f}秒") + + if self.enable_ai_chat: + # AI聊天模式 + self.process_with_ai(audio_data) + else: + # 普通录音模式 + success, filename = self.save_recording(audio_data) + if success and filename: + print("=" * 50) + print("🔊 播放刚才录制的音频...") + self.play_audio(filename) + print("=" * 50) + + self.recording = False + self.recorded_frames = [] + self.recording_start_time = None + self.last_sound_time = None + + def process_with_ai(self, audio_data): + """使用AI处理录音""" + if self.is_processing_ai: + print("⏳ AI正在处理中,请稍候...") + return + + self.is_processing_ai = True + + # 在新线程中处理AI + ai_thread = threading.Thread(target=self._ai_processing_thread, args=(audio_data,)) + ai_thread.daemon = True + ai_thread.start() + + def _heartbeat_thread(self): + """心跳线程 - 定期发送静音数据保持连接活跃""" + while self.running and self.doubao_client and self.doubao_client.ws: + current_time = time.time() + if current_time - self.last_heartbeat_time >= self.heartbeat_interval: + try: + # 异步发送心跳数据 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(self.doubao_client.send_silence_data()) + self.last_heartbeat_time = current_time + except Exception as e: + print(f"❌ 心跳失败: {e}") + # 如果心跳失败,可能需要重新连接 + break + finally: + loop.close() + except Exception as e: + print(f"❌ 心跳线程异常: {e}") + break + + # 睡眠一段时间 + time.sleep(1.0) + + print("📡 心跳线程结束") + + def _ai_processing_thread(self, audio_data): + """AI处理线程""" + try: + print("🤖 开始AI处理...") + print("🧠 正在进行语音识别...") + + # 异步处理 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # 连接豆包 + self.doubao_client = DoubaoClient() + loop.run_until_complete(self.doubao_client.connect()) + + # 启动心跳线程 + self.last_heartbeat_time = time.time() + self.heartbeat_thread = threading.Thread(target=self._heartbeat_thread) + self.heartbeat_thread.daemon = True + self.heartbeat_thread.start() + print("💓 心跳线程已启动") + + # 语音识别和TTS回复 + recognized_text, tts_audio = loop.run_until_complete( + self.doubao_client.process_audio(audio_data) + ) + + if recognized_text: + print(f"🗣️ 你说: {recognized_text}") + + if tts_audio: + # 保存TTS音频 + tts_filename = "ai_response.wav" + with wave.open(tts_filename, 'wb') as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) + wav_file.setframerate(24000) + wav_file.writeframes(tts_audio) + + print("🎵 AI回复生成完成") + print("=" * 50) + print("🔊 播放AI回复...") + self.play_audio(tts_filename) + print("=" * 50) + else: + print("❌ 未收到AI回复") + + # 等待一段时间再关闭连接,以便心跳继续工作 + print("⏳ 等待5秒后关闭连接...") + time.sleep(5) + + except Exception as e: + print(f"❌ AI处理失败: {e}") + finally: + # 停止心跳线程 + if self.heartbeat_thread and self.heartbeat_thread.is_alive(): + print("🛑 停止心跳线程") + self.heartbeat_thread = None + + # 关闭连接 + if self.doubao_client: + loop.run_until_complete(self.doubao_client.close()) + + loop.close() + + except Exception as e: + print(f"❌ AI处理线程失败: {e}") + finally: + self.is_processing_ai = False + + def run(self): + """运行语音聊天系统""" + if not self.stream: + print("❌ 音频设备未初始化") + return + + self.running = True + + if self.enable_ai_chat: + print("🤖 语音聊天AI助手") + print("=" * 50) + print("🎯 功能特点:") + print("- 🎙️ 智能语音检测") + print("- 🧠 豆包AI语音识别") + print("- 🗣️ AI智能回复") + print("- 🔊 TTS语音播放") + print("- 🔄 实时对话") + print("=" * 50) + print("📖 使用说明:") + print("- 说话自动录音") + print("- 静音2秒结束录音") + print("- AI自动识别并回复") + print("- 按 Ctrl+C 退出") + print("=" * 50) + else: + print("🎙️ 智能录音系统") + print("=" * 50) + print("📖 使用说明:") + print("- 说话自动录音") + print("- 静音2秒结束录音") + print("- 录音完成后自动播放") + print("- 按 Ctrl+C 退出") + print("=" * 50) + + try: + while self.running: + # 如果正在播放AI回复,跳过音频处理 + if self.is_playing or self.is_processing_ai: + status = "🤖 AI处理中..." + print(f"\r{status}", end='', flush=True) + time.sleep(0.1) + continue + + # 读取音频数据 + data = self.stream.read(self.CHUNK_SIZE, exception_on_overflow=False) + + if len(data) == 0: + continue + + # 计算能量和ZCR + energy = self.calculate_energy(data) + zcr = self.calculate_zero_crossing_rate(data) + + if self.recording: + # 录音模式 + self.recorded_frames.append(data) + recording_duration = time.time() - self.recording_start_time + + # 检测语音活动 + if self.is_voice_active(energy, zcr): + self.last_sound_time = time.time() + self.consecutive_low_zcr_count = 0 + else: + self.consecutive_low_zcr_count += 1 + + # 检查是否应该结束录音 + should_stop = False + + # ZCR静音检测 + if self.consecutive_low_zcr_count >= self.low_zcr_threshold_count: + should_stop = True + + # 时间静音检测 + if not should_stop and time.time() - self.last_sound_time > self.silence_threshold: + should_stop = True + + # 执行停止录音 + if should_stop and recording_duration >= self.min_recording_time: + print(f"\n🔇 检测到静音,结束录音") + self.stop_recording() + + # 检查最大录音时间 + if recording_duration > self.max_recording_time: + print(f"\n⏰ 达到最大录音时间") + self.stop_recording() + + # 显示录音状态 + is_voice = self.is_voice_active(energy, zcr) + zcr_count = f"{self.consecutive_low_zcr_count}/{self.low_zcr_threshold_count}" + status = f"录音中... {recording_duration:.1f}s | ZCR: {zcr:.0f} | 语音: {is_voice} | 静音计数: {zcr_count}" + print(f"\r{status}", end='', flush=True) + + else: + # 监听模式 + self.update_pre_record_buffer(data) + + if self.is_voice_active(energy, zcr): + # 检测到声音,开始录音 + self.start_recording() + else: + # 显示监听状态 + is_voice = self.is_voice_active(energy, zcr) + buffer_usage = len(self.pre_record_buffer) / self.pre_record_max_frames * 100 + status = f"监听中... ZCR: {zcr:.0f} | 语音: {is_voice} | 缓冲: {buffer_usage:.0f}%" + print(f"\r{status}", end='', flush=True) + + time.sleep(0.01) + + except KeyboardInterrupt: + print("\n👋 退出") + except Exception as e: + print(f"❌ 错误: {e}") + finally: + self.stop() + + def stop(self): + """停止系统""" + self.running = False + + # 停止心跳线程 + if self.heartbeat_thread and self.heartbeat_thread.is_alive(): + print("🛑 停止心跳线程") + self.heartbeat_thread = None + + if self.recording: + self.stop_recording() + + if self.stream: + self.stream.stop_stream() + self.stream.close() + + if self.audio: + self.audio.terminate() + + # 关闭AI连接 + if self.doubao_client and self.doubao_client.ws: + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self.doubao_client.close()) + loop.close() + except: + pass + +def main(): + """主函数""" + import argparse + + parser = argparse.ArgumentParser(description='语音聊天AI助手') + parser.add_argument('--no-ai', action='store_true', help='禁用AI功能,仅录音') + args = parser.parse_args() + + enable_ai = not args.no_ai + + if enable_ai: + print("🚀 语音聊天AI助手") + else: + print("🚀 智能录音系统") + + print("=" * 50) + + # 创建语音聊天系统 + recorder = VoiceChatRecorder(enable_ai_chat=enable_ai) + + print("✅ 系统初始化成功") + print("=" * 50) + + # 开始运行 + recorder.run() + +if __name__ == "__main__": + main() \ No newline at end of file