config
This commit is contained in:
parent
9108fd4582
commit
43879961a2
840
voice_chat.py
Normal file
840
voice_chat.py
Normal file
@ -0,0 +1,840 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
语音交互聊天系统 - 集成豆包AI
|
||||
基于能量检测的录音 + 豆包语音识别 + TTS回复
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import asyncio
|
||||
import subprocess
|
||||
import wave
|
||||
import struct
|
||||
import json
|
||||
import gzip
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional
|
||||
import pyaudio
|
||||
import numpy as np
|
||||
import websockets
|
||||
|
||||
# 豆包协议常量
|
||||
PROTOCOL_VERSION = 0b0001
|
||||
CLIENT_FULL_REQUEST = 0b0001
|
||||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||
SERVER_FULL_RESPONSE = 0b1001
|
||||
SERVER_ACK = 0b1011
|
||||
SERVER_ERROR_RESPONSE = 0b1111
|
||||
NO_SEQUENCE = 0b0000
|
||||
MSG_WITH_EVENT = 0b0100
|
||||
NO_SERIALIZATION = 0b0000
|
||||
JSON = 0b0001
|
||||
GZIP = 0b0001
|
||||
|
||||
class DoubaoClient:
|
||||
"""豆包音频处理客户端"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
|
||||
self.app_id = "8718217928"
|
||||
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
|
||||
self.app_key = "PlgvMymc7f3tQnJ6"
|
||||
self.resource_id = "volc.speech.dialog"
|
||||
self.session_id = str(uuid.uuid4())
|
||||
self.ws = None
|
||||
self.log_id = ""
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
"""获取请求头"""
|
||||
return {
|
||||
"X-Api-App-ID": self.app_id,
|
||||
"X-Api-Access-Key": self.access_key,
|
||||
"X-Api-Resource-Id": self.resource_id,
|
||||
"X-Api-App-Key": self.app_key,
|
||||
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
def generate_header(self, message_type=CLIENT_FULL_REQUEST,
|
||||
message_type_specific_flags=MSG_WITH_EVENT,
|
||||
serial_method=JSON, compression_type=GZIP) -> bytes:
|
||||
"""生成协议头"""
|
||||
header = bytearray()
|
||||
header.append((PROTOCOL_VERSION << 4) | 1) # version + header_size
|
||||
header.append((message_type << 4) | message_type_specific_flags)
|
||||
header.append((serial_method << 4) | compression_type)
|
||||
header.append(0x00) # reserved
|
||||
return bytes(header)
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""建立WebSocket连接"""
|
||||
print(f"🔗 连接豆包服务器...")
|
||||
try:
|
||||
self.ws = await websockets.connect(
|
||||
self.base_url,
|
||||
additional_headers=self.get_headers(),
|
||||
ping_interval=None
|
||||
)
|
||||
|
||||
# 获取log_id
|
||||
if hasattr(self.ws, 'response_headers'):
|
||||
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
|
||||
elif hasattr(self.ws, 'headers'):
|
||||
self.log_id = self.ws.headers.get("X-Tt-Logid")
|
||||
|
||||
print(f"✅ 连接成功, log_id: {self.log_id}")
|
||||
|
||||
# 发送StartConnection请求
|
||||
await self._send_start_connection()
|
||||
|
||||
# 发送StartSession请求
|
||||
await self._send_start_session()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 连接失败: {e}")
|
||||
raise
|
||||
|
||||
def parse_response(self, response):
|
||||
"""解析响应"""
|
||||
if len(response) < 4:
|
||||
return None
|
||||
|
||||
protocol_version = response[0] >> 4
|
||||
header_size = response[0] & 0x0f
|
||||
message_type = response[1] >> 4
|
||||
flags = response[1] & 0x0f
|
||||
|
||||
payload_start = header_size * 4
|
||||
payload = response[payload_start:]
|
||||
|
||||
result = {
|
||||
'protocol_version': protocol_version,
|
||||
'header_size': header_size,
|
||||
'message_type': message_type,
|
||||
'flags': flags,
|
||||
'payload': payload,
|
||||
'payload_size': len(payload)
|
||||
}
|
||||
|
||||
# 解析payload
|
||||
if len(payload) >= 4:
|
||||
result['event'] = int.from_bytes(payload[:4], 'big')
|
||||
|
||||
if len(payload) >= 8:
|
||||
session_id_len = int.from_bytes(payload[4:8], 'big')
|
||||
if len(payload) >= 8 + session_id_len:
|
||||
result['session_id'] = payload[8:8+session_id_len].decode()
|
||||
|
||||
if len(payload) >= 12 + session_id_len:
|
||||
data_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big')
|
||||
result['data_size'] = data_size
|
||||
result['data'] = payload[12+session_id_len:12+session_id_len+data_size]
|
||||
|
||||
# 尝试解析JSON数据
|
||||
try:
|
||||
result['json_data'] = json.loads(result['data'].decode('utf-8'))
|
||||
except:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
async def _send_start_connection(self) -> None:
|
||||
"""发送StartConnection请求"""
|
||||
request = bytearray(self.generate_header())
|
||||
request.extend(int(1).to_bytes(4, 'big'))
|
||||
|
||||
payload_bytes = b"{}"
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
request.extend(payload_bytes)
|
||||
|
||||
await self.ws.send(request)
|
||||
response = await self.ws.recv()
|
||||
|
||||
async def _send_start_session(self) -> None:
|
||||
"""发送StartSession请求"""
|
||||
session_config = {
|
||||
"asr": {"extra": {"end_smooth_window_ms": 1500}},
|
||||
"tts": {
|
||||
"speaker": "zh_female_vv_jupiter_bigtts",
|
||||
"audio_config": {"channel": 1, "format": "pcm", "sample_rate": 24000}
|
||||
},
|
||||
"dialog": {
|
||||
"bot_name": "豆包",
|
||||
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
|
||||
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
|
||||
"location": {"city": "北京"},
|
||||
"extra": {
|
||||
"strict_audit": False,
|
||||
"audit_response": "支持客户自定义安全审核回复话术。",
|
||||
"recv_timeout": 30,
|
||||
"input_mod": "audio",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
request = bytearray(self.generate_header())
|
||||
request.extend(int(100).to_bytes(4, 'big'))
|
||||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||||
request.extend(self.session_id.encode())
|
||||
|
||||
payload_bytes = json.dumps(session_config).encode()
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
request.extend(payload_bytes)
|
||||
|
||||
await self.ws.send(request)
|
||||
response = await self.ws.recv()
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
async def process_audio(self, audio_data: bytes) -> tuple[str, bytes]:
|
||||
"""处理音频并返回(识别文本, TTS音频)"""
|
||||
try:
|
||||
# 发送音频数据 - 使用与doubao_simple.py相同的格式
|
||||
task_request = bytearray(
|
||||
self.generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
|
||||
serial_method=NO_SERIALIZATION))
|
||||
task_request.extend(int(200).to_bytes(4, 'big'))
|
||||
task_request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||||
task_request.extend(self.session_id.encode())
|
||||
payload_bytes = gzip.compress(audio_data)
|
||||
task_request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
task_request.extend(payload_bytes)
|
||||
await self.ws.send(task_request)
|
||||
print("📤 音频数据已发送")
|
||||
|
||||
recognized_text = ""
|
||||
tts_audio = b""
|
||||
response_count = 0
|
||||
|
||||
# 接收响应 - 使用与doubao_simple.py相同的解析逻辑
|
||||
audio_chunks = []
|
||||
max_responses = 30
|
||||
|
||||
while response_count < max_responses:
|
||||
try:
|
||||
response = await asyncio.wait_for(self.ws.recv(), timeout=30.0)
|
||||
response_count += 1
|
||||
|
||||
parsed = self.parse_response(response)
|
||||
if not parsed:
|
||||
continue
|
||||
|
||||
print(f"📥 响应 {response_count}: message_type={parsed['message_type']}, event={parsed.get('event', 'N/A')}, size={parsed['payload_size']}")
|
||||
|
||||
# 处理不同类型的响应
|
||||
if parsed['message_type'] == 11: # SERVER_ACK - 可能包含音频
|
||||
if 'data' in parsed and parsed['data_size'] > 0:
|
||||
audio_chunks.append(parsed['data'])
|
||||
print(f"收集到音频块: {parsed['data_size']} 字节")
|
||||
|
||||
elif parsed['message_type'] == 9: # SERVER_FULL_RESPONSE
|
||||
event = parsed.get('event', 0)
|
||||
|
||||
if event == 450: # ASR开始
|
||||
print("🎤 ASR处理开始")
|
||||
elif event == 451: # ASR结果
|
||||
if 'json_data' in parsed and 'results' in parsed['json_data']:
|
||||
text = parsed['json_data']['results'][0].get('text', '')
|
||||
recognized_text = text
|
||||
print(f"🧠 识别结果: {text}")
|
||||
elif event == 459: # ASR结束
|
||||
print("✅ ASR处理结束")
|
||||
elif event == 350: # TTS开始
|
||||
print("🎵 TTS生成开始")
|
||||
elif event == 359: # TTS结束
|
||||
print("✅ TTS生成结束")
|
||||
break
|
||||
elif event == 550: # TTS音频数据
|
||||
if 'data' in parsed and parsed['data_size'] > 0:
|
||||
# 检查是否是JSON(音频元数据)还是实际音频数据
|
||||
try:
|
||||
json.loads(parsed['data'].decode('utf-8'))
|
||||
print("收到TTS音频元数据")
|
||||
except:
|
||||
# 不是JSON,可能是音频数据
|
||||
audio_chunks.append(parsed['data'])
|
||||
print(f"收集到TTS音频块: {parsed['data_size']} 字节")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print(f"⏰ 等待响应 {response_count + 1} 超时")
|
||||
break
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
print("🔌 连接已关闭")
|
||||
break
|
||||
|
||||
print(f"共收到 {response_count} 个响应,收集到 {len(audio_chunks)} 个音频块")
|
||||
|
||||
# 合并音频数据
|
||||
if audio_chunks:
|
||||
tts_audio = b''.join(audio_chunks)
|
||||
print(f"合并后的音频数据: {len(tts_audio)} 字节")
|
||||
|
||||
# 转换TTS音频格式(32位浮点 -> 16位整数)
|
||||
if tts_audio:
|
||||
# 检查是否是GZIP压缩数据
|
||||
try:
|
||||
decompressed = gzip.decompress(tts_audio)
|
||||
print(f"解压缩后音频数据: {len(decompressed)} 字节")
|
||||
audio_to_write = decompressed
|
||||
except:
|
||||
print("音频数据不是GZIP压缩格式,直接使用原始数据")
|
||||
audio_to_write = tts_audio
|
||||
|
||||
# 检查音频数据长度是否是4的倍数(32位浮点)
|
||||
if len(audio_to_write) % 4 != 0:
|
||||
print(f"警告:音频数据长度 {len(audio_to_write)} 不是4的倍数,截断到最近的倍数")
|
||||
audio_to_write = audio_to_write[:len(audio_to_write) // 4 * 4]
|
||||
|
||||
# 将32位浮点转换为16位整数
|
||||
float_count = len(audio_to_write) // 4
|
||||
int16_data = bytearray(float_count * 2)
|
||||
|
||||
for i in range(float_count):
|
||||
# 读取32位浮点数(小端序)
|
||||
float_value = struct.unpack('<f', audio_to_write[i*4:i*4+4])[0]
|
||||
|
||||
# 将浮点数限制在[-1.0, 1.0]范围内
|
||||
float_value = max(-1.0, min(1.0, float_value))
|
||||
|
||||
# 转换为16位整数
|
||||
int16_value = int(float_value * 32767)
|
||||
|
||||
# 写入16位整数(小端序)
|
||||
int16_data[i*2:i*2+2] = struct.pack('<h', int16_value)
|
||||
|
||||
tts_audio = bytes(int16_data)
|
||||
print(f"✅ 音频转换完成: {len(tts_audio)} 字节")
|
||||
|
||||
return recognized_text, tts_audio
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return "", b""
|
||||
|
||||
async def send_silence_data(self, duration_ms=100) -> None:
|
||||
"""发送静音数据保持连接活跃"""
|
||||
try:
|
||||
# 生成静音音频数据
|
||||
samples = int(16000 * duration_ms / 1000) # 16kHz采样率
|
||||
silence_data = bytes(samples * 2) # 16位PCM
|
||||
|
||||
# 发送静音数据
|
||||
task_request = bytearray(
|
||||
self.generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
|
||||
serial_method=NO_SERIALIZATION))
|
||||
task_request.extend(int(200).to_bytes(4, 'big'))
|
||||
task_request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||||
task_request.extend(self.session_id.encode())
|
||||
payload_bytes = gzip.compress(silence_data)
|
||||
task_request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
task_request.extend(payload_bytes)
|
||||
await self.ws.send(task_request)
|
||||
print("💓 发送心跳数据保持连接")
|
||||
|
||||
# 简单处理响应(不等待完整响应)
|
||||
try:
|
||||
response = await asyncio.wait_for(self.ws.recv(), timeout=5.0)
|
||||
# 只确认收到响应,不处理内容
|
||||
except asyncio.TimeoutError:
|
||||
print("⚠️ 心跳响应超时")
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
print("❌ 心跳时连接已关闭")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 发送心跳数据失败: {e}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭连接"""
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.close()
|
||||
except:
|
||||
pass
|
||||
print("🔌 连接已关闭")
|
||||
|
||||
class VoiceChatRecorder:
|
||||
"""语音聊天录音系统"""
|
||||
|
||||
def __init__(self, enable_ai_chat=True):
|
||||
# 音频参数
|
||||
self.FORMAT = pyaudio.paInt16
|
||||
self.CHANNELS = 1
|
||||
self.RATE = 16000
|
||||
self.CHUNK_SIZE = 1024
|
||||
|
||||
# 能量检测参数
|
||||
self.energy_threshold = 500
|
||||
self.silence_threshold = 2.0
|
||||
self.min_recording_time = 1.0
|
||||
self.max_recording_time = 20.0
|
||||
|
||||
# 状态变量
|
||||
self.audio = None
|
||||
self.stream = None
|
||||
self.running = False
|
||||
self.recording = False
|
||||
self.recorded_frames = []
|
||||
self.recording_start_time = None
|
||||
self.last_sound_time = None
|
||||
self.energy_history = []
|
||||
self.zcr_history = []
|
||||
|
||||
# AI聊天功能
|
||||
self.enable_ai_chat = enable_ai_chat
|
||||
self.doubao_client = None
|
||||
self.is_processing_ai = False
|
||||
self.heartbeat_thread = None
|
||||
self.last_heartbeat_time = time.time()
|
||||
self.heartbeat_interval = 10.0 # 每10秒发送一次心跳
|
||||
|
||||
# 预录音缓冲区
|
||||
self.pre_record_buffer = []
|
||||
self.pre_record_max_frames = int(2.0 * self.RATE / self.CHUNK_SIZE)
|
||||
|
||||
# 播放状态
|
||||
self.is_playing = False
|
||||
|
||||
# ZCR检测参数
|
||||
self.consecutive_low_zcr_count = 0
|
||||
self.low_zcr_threshold_count = 15
|
||||
self.voice_activity_history = []
|
||||
|
||||
self._setup_audio()
|
||||
|
||||
def _setup_audio(self):
|
||||
"""设置音频设备"""
|
||||
try:
|
||||
self.audio = pyaudio.PyAudio()
|
||||
self.stream = self.audio.open(
|
||||
format=self.FORMAT,
|
||||
channels=self.CHANNELS,
|
||||
rate=self.RATE,
|
||||
input=True,
|
||||
frames_per_buffer=self.CHUNK_SIZE
|
||||
)
|
||||
print("✅ 音频设备初始化成功")
|
||||
except Exception as e:
|
||||
print(f"❌ 音频设备初始化失败: {e}")
|
||||
|
||||
def generate_silence_audio(self, duration_ms=100):
|
||||
"""生成静音音频数据"""
|
||||
# 生成指定时长的静音音频(16位PCM,值为0)
|
||||
samples = int(self.RATE * duration_ms / 1000)
|
||||
silence_data = bytes(samples * 2) # 16位 = 2字节每样本
|
||||
return silence_data
|
||||
|
||||
def calculate_energy(self, audio_data):
|
||||
"""计算音频能量"""
|
||||
if len(audio_data) == 0:
|
||||
return 0
|
||||
|
||||
audio_array = np.frombuffer(audio_data, dtype=np.int16)
|
||||
rms = np.sqrt(np.mean(audio_array ** 2))
|
||||
|
||||
if not self.recording:
|
||||
self.energy_history.append(rms)
|
||||
if len(self.energy_history) > 50:
|
||||
self.energy_history.pop(0)
|
||||
|
||||
return rms
|
||||
|
||||
def calculate_zero_crossing_rate(self, audio_data):
|
||||
"""计算零交叉率"""
|
||||
if len(audio_data) == 0:
|
||||
return 0
|
||||
|
||||
audio_array = np.frombuffer(audio_data, dtype=np.int16)
|
||||
zero_crossings = np.sum(np.diff(np.sign(audio_array)) != 0)
|
||||
zcr = zero_crossings / len(audio_array) * self.RATE
|
||||
|
||||
self.zcr_history.append(zcr)
|
||||
if len(self.zcr_history) > 30:
|
||||
self.zcr_history.pop(0)
|
||||
|
||||
return zcr
|
||||
|
||||
def is_voice_active(self, energy, zcr):
|
||||
"""使用ZCR进行语音活动检测"""
|
||||
# 16000Hz采样率下的语音ZCR范围
|
||||
zcr_condition = 2400 < zcr < 12000
|
||||
return zcr_condition
|
||||
|
||||
def save_recording(self, audio_data, filename=None):
|
||||
"""保存录音"""
|
||||
if filename is None:
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"recording_{timestamp}.wav"
|
||||
|
||||
try:
|
||||
with wave.open(filename, 'wb') as wf:
|
||||
wf.setnchannels(self.CHANNELS)
|
||||
wf.setsampwidth(self.audio.get_sample_size(self.FORMAT))
|
||||
wf.setframerate(self.RATE)
|
||||
wf.writeframes(audio_data)
|
||||
|
||||
print(f"✅ 录音已保存: {filename}")
|
||||
return True, filename
|
||||
except Exception as e:
|
||||
print(f"❌ 保存录音失败: {e}")
|
||||
return False, None
|
||||
|
||||
def play_audio(self, filename):
|
||||
"""播放音频文件"""
|
||||
try:
|
||||
# 停止当前录音
|
||||
if self.recording:
|
||||
self.recording = False
|
||||
self.recorded_frames = []
|
||||
|
||||
# 关闭输入流
|
||||
if self.stream:
|
||||
self.stream.stop_stream()
|
||||
self.stream.close()
|
||||
self.stream = None
|
||||
|
||||
self.is_playing = True
|
||||
time.sleep(0.2)
|
||||
|
||||
# 使用系统播放器
|
||||
print(f"🔊 播放: {filename}")
|
||||
subprocess.run(['aplay', filename], check=True)
|
||||
print("✅ 播放完成")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 播放失败: {e}")
|
||||
finally:
|
||||
self.is_playing = False
|
||||
time.sleep(0.2)
|
||||
self._setup_audio()
|
||||
|
||||
def update_pre_record_buffer(self, audio_data):
|
||||
"""更新预录音缓冲区"""
|
||||
self.pre_record_buffer.append(audio_data)
|
||||
if len(self.pre_record_buffer) > self.pre_record_max_frames:
|
||||
self.pre_record_buffer.pop(0)
|
||||
|
||||
def start_recording(self):
|
||||
"""开始录音"""
|
||||
print("🎙️ 检测到声音,开始录音...")
|
||||
self.recording = True
|
||||
self.recorded_frames = []
|
||||
self.recorded_frames.extend(self.pre_record_buffer)
|
||||
self.pre_record_buffer = []
|
||||
self.recording_start_time = time.time()
|
||||
self.last_sound_time = time.time()
|
||||
self.consecutive_low_zcr_count = 0
|
||||
|
||||
def stop_recording(self):
|
||||
"""停止录音"""
|
||||
if len(self.recorded_frames) > 0:
|
||||
audio_data = b''.join(self.recorded_frames)
|
||||
duration = len(audio_data) / (self.RATE * 2)
|
||||
|
||||
print(f"📝 录音完成,时长: {duration:.2f}秒")
|
||||
|
||||
if self.enable_ai_chat:
|
||||
# AI聊天模式
|
||||
self.process_with_ai(audio_data)
|
||||
else:
|
||||
# 普通录音模式
|
||||
success, filename = self.save_recording(audio_data)
|
||||
if success and filename:
|
||||
print("=" * 50)
|
||||
print("🔊 播放刚才录制的音频...")
|
||||
self.play_audio(filename)
|
||||
print("=" * 50)
|
||||
|
||||
self.recording = False
|
||||
self.recorded_frames = []
|
||||
self.recording_start_time = None
|
||||
self.last_sound_time = None
|
||||
|
||||
def process_with_ai(self, audio_data):
|
||||
"""使用AI处理录音"""
|
||||
if self.is_processing_ai:
|
||||
print("⏳ AI正在处理中,请稍候...")
|
||||
return
|
||||
|
||||
self.is_processing_ai = True
|
||||
|
||||
# 在新线程中处理AI
|
||||
ai_thread = threading.Thread(target=self._ai_processing_thread, args=(audio_data,))
|
||||
ai_thread.daemon = True
|
||||
ai_thread.start()
|
||||
|
||||
def _heartbeat_thread(self):
|
||||
"""心跳线程 - 定期发送静音数据保持连接活跃"""
|
||||
while self.running and self.doubao_client and self.doubao_client.ws:
|
||||
current_time = time.time()
|
||||
if current_time - self.last_heartbeat_time >= self.heartbeat_interval:
|
||||
try:
|
||||
# 异步发送心跳数据
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(self.doubao_client.send_silence_data())
|
||||
self.last_heartbeat_time = current_time
|
||||
except Exception as e:
|
||||
print(f"❌ 心跳失败: {e}")
|
||||
# 如果心跳失败,可能需要重新连接
|
||||
break
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
print(f"❌ 心跳线程异常: {e}")
|
||||
break
|
||||
|
||||
# 睡眠一段时间
|
||||
time.sleep(1.0)
|
||||
|
||||
print("📡 心跳线程结束")
|
||||
|
||||
def _ai_processing_thread(self, audio_data):
|
||||
"""AI处理线程"""
|
||||
try:
|
||||
print("🤖 开始AI处理...")
|
||||
print("🧠 正在进行语音识别...")
|
||||
|
||||
# 异步处理
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 连接豆包
|
||||
self.doubao_client = DoubaoClient()
|
||||
loop.run_until_complete(self.doubao_client.connect())
|
||||
|
||||
# 启动心跳线程
|
||||
self.last_heartbeat_time = time.time()
|
||||
self.heartbeat_thread = threading.Thread(target=self._heartbeat_thread)
|
||||
self.heartbeat_thread.daemon = True
|
||||
self.heartbeat_thread.start()
|
||||
print("💓 心跳线程已启动")
|
||||
|
||||
# 语音识别和TTS回复
|
||||
recognized_text, tts_audio = loop.run_until_complete(
|
||||
self.doubao_client.process_audio(audio_data)
|
||||
)
|
||||
|
||||
if recognized_text:
|
||||
print(f"🗣️ 你说: {recognized_text}")
|
||||
|
||||
if tts_audio:
|
||||
# 保存TTS音频
|
||||
tts_filename = "ai_response.wav"
|
||||
with wave.open(tts_filename, 'wb') as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(24000)
|
||||
wav_file.writeframes(tts_audio)
|
||||
|
||||
print("🎵 AI回复生成完成")
|
||||
print("=" * 50)
|
||||
print("🔊 播放AI回复...")
|
||||
self.play_audio(tts_filename)
|
||||
print("=" * 50)
|
||||
else:
|
||||
print("❌ 未收到AI回复")
|
||||
|
||||
# 等待一段时间再关闭连接,以便心跳继续工作
|
||||
print("⏳ 等待5秒后关闭连接...")
|
||||
time.sleep(5)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ AI处理失败: {e}")
|
||||
finally:
|
||||
# 停止心跳线程
|
||||
if self.heartbeat_thread and self.heartbeat_thread.is_alive():
|
||||
print("🛑 停止心跳线程")
|
||||
self.heartbeat_thread = None
|
||||
|
||||
# 关闭连接
|
||||
if self.doubao_client:
|
||||
loop.run_until_complete(self.doubao_client.close())
|
||||
|
||||
loop.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ AI处理线程失败: {e}")
|
||||
finally:
|
||||
self.is_processing_ai = False
|
||||
|
||||
def run(self):
|
||||
"""运行语音聊天系统"""
|
||||
if not self.stream:
|
||||
print("❌ 音频设备未初始化")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
|
||||
if self.enable_ai_chat:
|
||||
print("🤖 语音聊天AI助手")
|
||||
print("=" * 50)
|
||||
print("🎯 功能特点:")
|
||||
print("- 🎙️ 智能语音检测")
|
||||
print("- 🧠 豆包AI语音识别")
|
||||
print("- 🗣️ AI智能回复")
|
||||
print("- 🔊 TTS语音播放")
|
||||
print("- 🔄 实时对话")
|
||||
print("=" * 50)
|
||||
print("📖 使用说明:")
|
||||
print("- 说话自动录音")
|
||||
print("- 静音2秒结束录音")
|
||||
print("- AI自动识别并回复")
|
||||
print("- 按 Ctrl+C 退出")
|
||||
print("=" * 50)
|
||||
else:
|
||||
print("🎙️ 智能录音系统")
|
||||
print("=" * 50)
|
||||
print("📖 使用说明:")
|
||||
print("- 说话自动录音")
|
||||
print("- 静音2秒结束录音")
|
||||
print("- 录音完成后自动播放")
|
||||
print("- 按 Ctrl+C 退出")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
while self.running:
|
||||
# 如果正在播放AI回复,跳过音频处理
|
||||
if self.is_playing or self.is_processing_ai:
|
||||
status = "🤖 AI处理中..."
|
||||
print(f"\r{status}", end='', flush=True)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
# 读取音频数据
|
||||
data = self.stream.read(self.CHUNK_SIZE, exception_on_overflow=False)
|
||||
|
||||
if len(data) == 0:
|
||||
continue
|
||||
|
||||
# 计算能量和ZCR
|
||||
energy = self.calculate_energy(data)
|
||||
zcr = self.calculate_zero_crossing_rate(data)
|
||||
|
||||
if self.recording:
|
||||
# 录音模式
|
||||
self.recorded_frames.append(data)
|
||||
recording_duration = time.time() - self.recording_start_time
|
||||
|
||||
# 检测语音活动
|
||||
if self.is_voice_active(energy, zcr):
|
||||
self.last_sound_time = time.time()
|
||||
self.consecutive_low_zcr_count = 0
|
||||
else:
|
||||
self.consecutive_low_zcr_count += 1
|
||||
|
||||
# 检查是否应该结束录音
|
||||
should_stop = False
|
||||
|
||||
# ZCR静音检测
|
||||
if self.consecutive_low_zcr_count >= self.low_zcr_threshold_count:
|
||||
should_stop = True
|
||||
|
||||
# 时间静音检测
|
||||
if not should_stop and time.time() - self.last_sound_time > self.silence_threshold:
|
||||
should_stop = True
|
||||
|
||||
# 执行停止录音
|
||||
if should_stop and recording_duration >= self.min_recording_time:
|
||||
print(f"\n🔇 检测到静音,结束录音")
|
||||
self.stop_recording()
|
||||
|
||||
# 检查最大录音时间
|
||||
if recording_duration > self.max_recording_time:
|
||||
print(f"\n⏰ 达到最大录音时间")
|
||||
self.stop_recording()
|
||||
|
||||
# 显示录音状态
|
||||
is_voice = self.is_voice_active(energy, zcr)
|
||||
zcr_count = f"{self.consecutive_low_zcr_count}/{self.low_zcr_threshold_count}"
|
||||
status = f"录音中... {recording_duration:.1f}s | ZCR: {zcr:.0f} | 语音: {is_voice} | 静音计数: {zcr_count}"
|
||||
print(f"\r{status}", end='', flush=True)
|
||||
|
||||
else:
|
||||
# 监听模式
|
||||
self.update_pre_record_buffer(data)
|
||||
|
||||
if self.is_voice_active(energy, zcr):
|
||||
# 检测到声音,开始录音
|
||||
self.start_recording()
|
||||
else:
|
||||
# 显示监听状态
|
||||
is_voice = self.is_voice_active(energy, zcr)
|
||||
buffer_usage = len(self.pre_record_buffer) / self.pre_record_max_frames * 100
|
||||
status = f"监听中... ZCR: {zcr:.0f} | 语音: {is_voice} | 缓冲: {buffer_usage:.0f}%"
|
||||
print(f"\r{status}", end='', flush=True)
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 退出")
|
||||
except Exception as e:
|
||||
print(f"❌ 错误: {e}")
|
||||
finally:
|
||||
self.stop()
|
||||
|
||||
def stop(self):
|
||||
"""停止系统"""
|
||||
self.running = False
|
||||
|
||||
# 停止心跳线程
|
||||
if self.heartbeat_thread and self.heartbeat_thread.is_alive():
|
||||
print("🛑 停止心跳线程")
|
||||
self.heartbeat_thread = None
|
||||
|
||||
if self.recording:
|
||||
self.stop_recording()
|
||||
|
||||
if self.stream:
|
||||
self.stream.stop_stream()
|
||||
self.stream.close()
|
||||
|
||||
if self.audio:
|
||||
self.audio.terminate()
|
||||
|
||||
# 关闭AI连接
|
||||
if self.doubao_client and self.doubao_client.ws:
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self.doubao_client.close())
|
||||
loop.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='语音聊天AI助手')
|
||||
parser.add_argument('--no-ai', action='store_true', help='禁用AI功能,仅录音')
|
||||
args = parser.parse_args()
|
||||
|
||||
enable_ai = not args.no_ai
|
||||
|
||||
if enable_ai:
|
||||
print("🚀 语音聊天AI助手")
|
||||
else:
|
||||
print("🚀 智能录音系统")
|
||||
|
||||
print("=" * 50)
|
||||
|
||||
# 创建语音聊天系统
|
||||
recorder = VoiceChatRecorder(enable_ai_chat=enable_ai)
|
||||
|
||||
print("✅ 系统初始化成功")
|
||||
print("=" * 50)
|
||||
|
||||
# 开始运行
|
||||
recorder.run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user