774 lines
36 KiB
Python
774 lines
36 KiB
Python
import asyncio
|
||
import queue
|
||
import random
|
||
import signal
|
||
import sys
|
||
import threading
|
||
import time
|
||
import uuid
|
||
import wave
|
||
from dataclasses import dataclass
|
||
from typing import Any, Dict, Optional
|
||
|
||
import config
|
||
import sounddevice as sd
|
||
import numpy as np
|
||
from realtime_dialog_client import RealtimeDialogClient
|
||
|
||
|
||
@dataclass
|
||
class AudioConfig:
|
||
"""音频配置数据类"""
|
||
format: str
|
||
bit_size: str # 改为字符串类型
|
||
channels: int
|
||
sample_rate: int
|
||
chunk: int
|
||
|
||
|
||
class AudioDeviceManager:
|
||
"""音频设备管理类,处理音频输入输出"""
|
||
|
||
def __init__(self, input_config: AudioConfig, output_config: AudioConfig):
|
||
self.input_config = input_config
|
||
self.output_config = output_config
|
||
self.input_stream = None
|
||
self.output_stream = None
|
||
self.audio_queue = None
|
||
self.recording = False
|
||
|
||
def open_input_stream(self):
|
||
"""打开音频输入流"""
|
||
try:
|
||
import queue
|
||
self.audio_queue = queue.Queue(maxsize=100) # 音频数据队列
|
||
|
||
def audio_callback(indata, frames, time_info, status):
|
||
"""音频数据回调"""
|
||
if status:
|
||
print(f"音频流状态: {status}")
|
||
if self.recording and self.audio_queue:
|
||
try:
|
||
# 将numpy数组转换为字节数据
|
||
audio_bytes = indata.tobytes()
|
||
self.audio_queue.put_nowait(audio_bytes)
|
||
except queue.Full:
|
||
print("警告: 音频队列已满,丢弃数据")
|
||
|
||
self.input_stream = sd.InputStream(
|
||
samplerate=self.input_config.sample_rate,
|
||
channels=self.input_config.channels,
|
||
dtype='int16', # 16-bit PCM
|
||
blocksize=self.input_config.chunk,
|
||
callback=audio_callback,
|
||
device=None # 使用默认设备
|
||
)
|
||
self.input_stream.start()
|
||
self.recording = True
|
||
return self.input_stream
|
||
except Exception as e:
|
||
print(f"打开输入流失败: {e}")
|
||
return None
|
||
|
||
def open_output_stream(self):
|
||
"""打开音频输出流"""
|
||
try:
|
||
self.output_stream = sd.OutputStream(
|
||
samplerate=self.output_config.sample_rate,
|
||
channels=self.output_config.channels,
|
||
dtype='int16', # 16-bit PCM
|
||
blocksize=self.output_config.chunk,
|
||
device=None # 使用默认设备
|
||
)
|
||
self.output_stream.start()
|
||
return self.output_stream
|
||
except Exception as e:
|
||
print(f"打开输出流失败: {e}")
|
||
return None
|
||
|
||
def play_audio(self, audio_data: bytes) -> None:
|
||
"""播放音频数据"""
|
||
try:
|
||
# 将字节数据转换为numpy数组
|
||
audio_array = np.frombuffer(audio_data, dtype=np.int16)
|
||
audio_array = audio_array.reshape(-1, self.output_config.channels)
|
||
|
||
# 使用sounddevice播放
|
||
sd.play(audio_array, samplerate=self.output_config.sample_rate)
|
||
sd.wait() # 等待播放完成
|
||
except Exception as e:
|
||
print(f"音频播放失败: {e}")
|
||
|
||
def read_audio_data(self, frames: int) -> bytes:
|
||
"""读取音频数据"""
|
||
try:
|
||
if not self.recording or self.audio_queue is None:
|
||
return b'\x00' * (frames * 2) # 返回静音数据
|
||
|
||
# 从队列获取音频数据
|
||
try:
|
||
audio_data = self.audio_queue.get(timeout=0.1) # 100ms超时
|
||
return audio_data
|
||
except queue.Empty:
|
||
# 队列为空,返回静音数据
|
||
return b'\x00' * (frames * 2)
|
||
|
||
except Exception as e:
|
||
print(f"读取音频数据失败: {e}")
|
||
return b'\x00' * (frames * 2) # 返回静音数据
|
||
|
||
def stop_recording(self):
|
||
"""停止录音"""
|
||
self.recording = False
|
||
|
||
def cleanup(self) -> None:
|
||
"""清理音频设备资源"""
|
||
try:
|
||
self.recording = False
|
||
if self.input_stream:
|
||
self.input_stream.stop()
|
||
self.input_stream.close()
|
||
if self.output_stream:
|
||
self.output_stream.stop()
|
||
self.output_stream.close()
|
||
sd.stop() # 停止所有音频播放
|
||
except Exception as e:
|
||
print(f"清理音频设备失败: {e}")
|
||
|
||
|
||
class DialogSession:
|
||
"""对话会话管理类"""
|
||
is_audio_file_input: bool
|
||
mod: str
|
||
|
||
def __init__(self, ws_config: Dict[str, Any], output_audio_format: str = "pcm", audio_file_path: str = "",
|
||
mod: str = "audio", recv_timeout: int = 10):
|
||
self.audio_file_path = audio_file_path
|
||
self.recv_timeout = recv_timeout
|
||
self.is_audio_file_input = self.audio_file_path != ""
|
||
if self.is_audio_file_input:
|
||
mod = 'audio_file'
|
||
else:
|
||
self.say_hello_over_event = asyncio.Event()
|
||
self.mod = mod
|
||
|
||
self.session_id = str(uuid.uuid4())
|
||
self.client = RealtimeDialogClient(config=ws_config, session_id=self.session_id,
|
||
output_audio_format=output_audio_format, mod=mod, recv_timeout=recv_timeout)
|
||
if output_audio_format == "pcm_s16le":
|
||
config.output_audio_config["format"] = "pcm_s16le"
|
||
config.output_audio_config["bit_size"] = "int16" # 使用字符串标识符
|
||
|
||
self.is_running = True
|
||
self.is_session_finished = False
|
||
self.is_user_querying = False
|
||
self.is_sending_chat_tts_text = False
|
||
self.audio_buffer = b''
|
||
self.is_playing_audio = False # 是否正在播放音频
|
||
self.audio_queue_lock = threading.Lock() # 音频队列锁
|
||
self.is_recording_paused = False # 录音是否被暂停
|
||
self.should_send_silence = False # 是否需要发送静音数据
|
||
self.silence_send_count = 0 # 需要发送的静音数据数量
|
||
self.pre_pause_time = 0 # 预暂停时间
|
||
self.last_recording_state = False # 上次录音状态
|
||
self.say_hello_completed = False # say hello 是否已完成
|
||
|
||
# 新增:音频输入流控制
|
||
self.input_stream_paused = False # 输入流是否被暂停
|
||
self.force_silence_mode = False # 强制静音模式
|
||
self.echo_suppression_start_time = 0 # 回声抑制开始时间
|
||
|
||
signal.signal(signal.SIGINT, self._keyboard_signal)
|
||
self.audio_queue = queue.Queue()
|
||
if not self.is_audio_file_input:
|
||
self.audio_device = AudioDeviceManager(
|
||
AudioConfig(**config.input_audio_config),
|
||
AudioConfig(**config.output_audio_config)
|
||
)
|
||
# 初始化音频队列和输出流
|
||
print(f"输出音频配置: {config.output_audio_config}")
|
||
output_stream = self.audio_device.open_output_stream()
|
||
if output_stream:
|
||
print("音频输出流已打开")
|
||
self.output_stream = output_stream
|
||
else:
|
||
print("警告:音频输出流打开失败,将使用直接播放模式")
|
||
# 启动播放线程
|
||
self.is_recording = True
|
||
self.is_playing = True
|
||
self.player_thread = threading.Thread(target=self._audio_player_thread)
|
||
self.player_thread.daemon = True
|
||
self.player_thread.start()
|
||
|
||
def _audio_player_thread(self):
|
||
"""音频播放线程"""
|
||
audio_playing_timeout = 1.0 # 1秒没有音频数据认为播放结束
|
||
queue_check_interval = 0.1 # 每100ms检查一次队列状态
|
||
|
||
while self.is_playing:
|
||
try:
|
||
# 从队列获取音频数据
|
||
audio_data = self.audio_queue.get(timeout=queue_check_interval)
|
||
if audio_data is not None:
|
||
with self.audio_queue_lock:
|
||
# 第三重保险:播放开始时最终确认暂停状态
|
||
was_not_playing = not self.is_playing_audio
|
||
if not hasattr(self, 'last_audio_time') or was_not_playing:
|
||
# 从非播放状态进入播放状态
|
||
self.is_playing_audio = True
|
||
# 确保录音已暂停
|
||
if not self.is_recording_paused:
|
||
self.is_recording_paused = True
|
||
print("播放开始,最终确认暂停录音")
|
||
|
||
# 更新最后音频时间
|
||
self.last_audio_time = time.time()
|
||
|
||
# 播放前额外发送静音数据清理管道
|
||
if was_not_playing:
|
||
print("播放开始前,额外发送静音数据清理管道")
|
||
for _ in range(3):
|
||
# 播放静音数据
|
||
self.audio_device.play_audio(b'\x00' * len(audio_data))
|
||
time.sleep(0.1)
|
||
|
||
# 播放音频数据
|
||
try:
|
||
self.audio_device.play_audio(audio_data)
|
||
except Exception as e:
|
||
print(f"音频播放错误: {e}")
|
||
|
||
except queue.Empty:
|
||
# 队列为空,检查是否超时
|
||
current_time = time.time()
|
||
with self.audio_queue_lock:
|
||
if self.is_playing_audio:
|
||
if hasattr(self, 'last_audio_time') and current_time - self.last_audio_time > audio_playing_timeout:
|
||
# 超过1秒没有新音频,认为播放结束
|
||
self.is_playing_audio = False
|
||
self.is_recording_paused = False
|
||
self.force_silence_mode = False # 关闭强制静音模式
|
||
self.input_stream_paused = False # 恢复输入流
|
||
# 标记 say hello 完成
|
||
if hasattr(self, 'say_hello_completed') and not self.say_hello_completed:
|
||
self.say_hello_completed = True
|
||
print("say hello 音频播放完成")
|
||
print("音频播放超时,恢复录音")
|
||
# 直接发送静音数据,而不是在协程中发送
|
||
try:
|
||
silence_data = b'\x00' * config.input_audio_config["chunk"]
|
||
# 使用同步方式发送静音数据
|
||
# 这里我们设置一个标志,让主循环处理
|
||
self.silence_send_count = 2 # 播放超时时发送2组静音数据
|
||
self.should_send_silence = True
|
||
except Exception as e:
|
||
print(f"准备静音数据失败: {e}")
|
||
elif self.audio_queue.empty():
|
||
# 队列为空,但还没超时,继续等待
|
||
pass
|
||
time.sleep(0.01)
|
||
except Exception as e:
|
||
print(f"音频播放错误: {e}")
|
||
with self.audio_queue_lock:
|
||
self.is_playing_audio = False
|
||
self.is_recording_paused = False
|
||
time.sleep(0.1)
|
||
|
||
# 移除了静音检测函数,避免干扰正常的音频处理
|
||
|
||
async def _send_silence_on_playback_end(self):
|
||
"""播放结束时发送静音数据"""
|
||
try:
|
||
silence_data = b'\x00' * config.input_audio_config["chunk"]
|
||
await self.client.task_request(silence_data)
|
||
print("播放结束,已发送静音数据")
|
||
except Exception as e:
|
||
print(f"发送静音数据失败: {e}")
|
||
|
||
def _check_and_restore_recording(self):
|
||
"""检查并恢复录音状态"""
|
||
with self.audio_queue_lock:
|
||
if self.is_recording_paused and self.audio_queue.empty():
|
||
# 如果队列为空且录音被暂停,恢复录音
|
||
self.is_recording_paused = False
|
||
self.is_playing_audio = False
|
||
print("音频队列为空,自动恢复录音")
|
||
return True
|
||
return False
|
||
|
||
def handle_server_response(self, response: Dict[str, Any]) -> None:
|
||
if not response or response == {}:
|
||
return
|
||
"""处理服务器响应"""
|
||
message_type = response.get('message_type')
|
||
if message_type == 'SERVER_ACK' and isinstance(response.get('payload_msg'), bytes):
|
||
if self.is_sending_chat_tts_text:
|
||
return
|
||
audio_data = response['payload_msg']
|
||
|
||
# 第二重保险:接收到音频数据时确认暂停状态
|
||
with self.audio_queue_lock:
|
||
was_not_playing = not self.is_playing_audio
|
||
if was_not_playing:
|
||
# 第一批音频数据到达,确保录音已暂停
|
||
self.is_playing_audio = True
|
||
if not self.is_recording_paused:
|
||
self.is_recording_paused = True
|
||
print("接收到首批音频数据,立即暂停录音")
|
||
else:
|
||
print("接收到音频数据,录音已暂停")
|
||
|
||
# 立即发送静音数据,确保管道清理
|
||
self.silence_send_count = 3 # 音频数据到达时发送3组静音数据
|
||
self.should_send_silence = True
|
||
print("服务器收到音频数据,立即清理录音管道")
|
||
|
||
if not self.is_audio_file_input:
|
||
self.audio_queue.put(audio_data)
|
||
self.audio_buffer += audio_data
|
||
elif message_type == 'SERVER_FULL_RESPONSE':
|
||
print(f"服务器响应: {response}")
|
||
event = response.get('event')
|
||
payload_msg = response.get('payload_msg', {})
|
||
|
||
# 第一重保险:服务器开始响应时立即预暂停录音
|
||
if event in [450, 359, 152, 153]: # 这些事件表示服务器开始或结束响应
|
||
if event == 450:
|
||
print(f"清空缓存音频: {response['session_id']}")
|
||
while not self.audio_queue.empty():
|
||
try:
|
||
self.audio_queue.get_nowait()
|
||
except queue.Empty:
|
||
continue
|
||
self.is_user_querying = True
|
||
print("服务器准备接收用户输入")
|
||
|
||
# 预暂停录音,防止即将到来的音频回声
|
||
with self.audio_queue_lock:
|
||
if not self.is_recording_paused:
|
||
self.is_recording_paused = True
|
||
self.is_playing_audio = True # 同时设置播放状态,双重保险
|
||
self.pre_pause_time = time.time() - 2.0 # 提前2秒预暂停
|
||
self.force_silence_mode = True # 启用强制静音模式
|
||
self.echo_suppression_start_time = time.time() # 记录回声抑制开始时间
|
||
print("服务器开始响应,预暂停录音防止回声")
|
||
|
||
# 立即发送静音数据清理管道,防止前1-2秒回声
|
||
print("预暂停期间立即发送静音数据清理管道")
|
||
# 设置批量静音发送,确保管道完全清理
|
||
self.silence_send_count = 20 # 增加到20组,确保彻底清理
|
||
self.should_send_silence = True
|
||
|
||
# 强制重置录音状态
|
||
self.last_recording_state = True # 标记为已暂停
|
||
self.input_stream_paused = True # 暂停输入流
|
||
|
||
if event == 350 and self.is_sending_chat_tts_text and payload_msg.get("tts_type") in ["chat_tts_text", "external_rag"]:
|
||
while not self.audio_queue.empty():
|
||
try:
|
||
self.audio_queue.get_nowait()
|
||
except queue.Empty:
|
||
continue
|
||
self.is_sending_chat_tts_text = False
|
||
|
||
if event == 459:
|
||
self.is_user_querying = False
|
||
# 服务器完成响应,立即恢复录音
|
||
with self.audio_queue_lock:
|
||
was_paused = self.is_recording_paused
|
||
self.is_recording_paused = False
|
||
self.is_playing_audio = False
|
||
self.force_silence_mode = False # 关闭强制静音模式
|
||
self.input_stream_paused = False # 恢复输入流
|
||
if was_paused:
|
||
print("服务器响应完成,立即恢复录音")
|
||
# 设置标志发送静音数据
|
||
self.silence_send_count = 2 # 响应完成时发送2组静音数据
|
||
self.should_send_silence = True
|
||
print("服务器完成响应,等待用户输入")
|
||
#if random.randint(0, 100000)%1 == 0:
|
||
# self.is_sending_chat_tts_text = True
|
||
#asyncio.create_task(self.trigger_chat_tts_text())
|
||
#asyncio.create_task(self.trigger_chat_rag_text())
|
||
elif message_type == 'SERVER_ERROR':
|
||
print(f"服务器错误: {response['payload_msg']}")
|
||
raise Exception("服务器错误")
|
||
|
||
async def trigger_chat_tts_text(self):
|
||
"""概率触发发送ChatTTSText请求"""
|
||
print("hit ChatTTSText event, start sending...")
|
||
await self.client.chat_tts_text(
|
||
is_user_querying=self.is_user_querying,
|
||
start=True,
|
||
end=False,
|
||
content="这是查询到外部数据之前的安抚话术。",
|
||
)
|
||
await self.client.chat_tts_text(
|
||
is_user_querying=self.is_user_querying,
|
||
start=False,
|
||
end=True,
|
||
content="",
|
||
)
|
||
|
||
async def trigger_chat_rag_text(self):
|
||
await asyncio.sleep(5) # 模拟查询外部RAG的耗时,这里为了不影响GTA安抚话术的播报,直接sleep 5秒
|
||
print("hit ChatRAGText event, start sending...")
|
||
await self.client.chat_rag_text(self.is_user_querying, external_rag='[{"title":"北京天气","content":"今天北京整体以晴到多云为主,但西部和北部地带可能会出现分散性雷阵雨,特别是午后至傍晚时段需注意突发降雨。\n💨 风况与湿度\n风力较弱,一般为 2–3 级南风或西南风\n白天湿度较高,早晚略凉爽"}]')
|
||
|
||
def _keyboard_signal(self, sig, frame):
|
||
print(f"receive keyboard Ctrl+C")
|
||
self.stop()
|
||
|
||
def stop(self):
|
||
self.is_recording = False
|
||
self.is_playing = False
|
||
self.is_running = False
|
||
|
||
async def receive_loop(self):
|
||
try:
|
||
while True:
|
||
response = await self.client.receive_server_response()
|
||
self.handle_server_response(response)
|
||
if 'event' in response and (response['event'] == 152 or response['event'] == 153):
|
||
print(f"receive session finished event: {response['event']}")
|
||
self.is_session_finished = True
|
||
break
|
||
if 'event' in response and response['event'] == 359:
|
||
if self.is_audio_file_input:
|
||
print(f"receive tts ended event")
|
||
self.is_session_finished = True
|
||
break
|
||
else:
|
||
if not self.say_hello_over_event.is_set():
|
||
print(f"receive tts sayhello ended event")
|
||
self.say_hello_over_event.set()
|
||
|
||
# 对于音频模式,say hello 音频播放即将开始
|
||
# 确保录音保持暂停状态
|
||
if self.mod == "audio":
|
||
with self.audio_queue_lock:
|
||
self.is_recording_paused = True
|
||
self.is_playing_audio = True
|
||
print("say hello 音频即将开始,确保录音暂停")
|
||
|
||
if self.mod == "text":
|
||
# 文本模式下 say hello 完成,恢复录音状态
|
||
with self.audio_queue_lock:
|
||
if self.is_recording_paused:
|
||
self.is_recording_paused = False
|
||
print("文本模式:say hello 完成,恢复录音")
|
||
print("请输入内容:")
|
||
|
||
except asyncio.CancelledError:
|
||
print("接收任务已取消")
|
||
except Exception as e:
|
||
print(f"接收消息错误: {e}")
|
||
finally:
|
||
self.stop()
|
||
self.is_session_finished = True
|
||
|
||
async def process_audio_file(self) -> None:
|
||
await self.process_audio_file_input(self.audio_file_path)
|
||
|
||
async def process_text_input(self) -> None:
|
||
# 程序启动后先静音2秒,确保系统稳定
|
||
print("文本模式:程序启动,先静音2秒确保系统稳定...")
|
||
with self.audio_queue_lock:
|
||
self.is_recording_paused = True
|
||
self.is_playing_audio = True # 标记正在播放
|
||
|
||
# 发送2秒静音数据,确保管道清理
|
||
silence_data = b'\x00' * config.input_audio_config["chunk"]
|
||
for i in range(20): # 2秒 = 20 * 100ms
|
||
await self.client.task_request(silence_data)
|
||
await asyncio.sleep(0.1)
|
||
if i % 10 == 0: # 每秒打印一次进度
|
||
print(f"文本模式:静音中... {i//10 + 1}/2秒")
|
||
|
||
print("文本模式:静音完成,准备 say hello")
|
||
|
||
# say hello 前确保录音仍处于暂停状态
|
||
with self.audio_queue_lock:
|
||
self.is_recording_paused = True
|
||
self.is_playing_audio = True # 标记正在播放
|
||
print("文本模式:准备 say hello,确保录音暂停")
|
||
|
||
await self.client.say_hello()
|
||
await self.say_hello_over_event.wait()
|
||
|
||
"""主逻辑:处理文本输入和WebSocket通信"""
|
||
# 确保连接最终关闭
|
||
try:
|
||
# 启动输入监听线程
|
||
input_queue = queue.Queue()
|
||
input_thread = threading.Thread(target=self.input_listener, args=(input_queue,), daemon=True)
|
||
input_thread.start()
|
||
# 主循环:处理输入和上下文结束
|
||
while self.is_running:
|
||
try:
|
||
# 检查是否有输入(非阻塞)
|
||
input_str = input_queue.get_nowait()
|
||
if input_str is None:
|
||
# 输入流关闭
|
||
print("Input channel closed")
|
||
break
|
||
if input_str:
|
||
# 发送输入内容
|
||
await self.client.chat_text_query(input_str)
|
||
except queue.Empty:
|
||
# 无输入时短暂休眠
|
||
await asyncio.sleep(0.1)
|
||
except Exception as e:
|
||
print(f"Main loop error: {e}")
|
||
break
|
||
finally:
|
||
print("exit text input")
|
||
|
||
def input_listener(self, input_queue: queue.Queue) -> None:
|
||
"""在单独线程中监听标准输入"""
|
||
print("Start listening for input")
|
||
try:
|
||
while True:
|
||
# 读取标准输入(阻塞操作)
|
||
line = sys.stdin.readline()
|
||
if not line:
|
||
# 输入流关闭
|
||
input_queue.put(None)
|
||
break
|
||
input_str = line.strip()
|
||
input_queue.put(input_str)
|
||
except Exception as e:
|
||
print(f"Input listener error: {e}")
|
||
input_queue.put(None)
|
||
|
||
async def process_audio_file_input(self, audio_file_path: str) -> None:
|
||
# 读取WAV文件
|
||
with wave.open(audio_file_path, 'rb') as wf:
|
||
chunk_size = config.input_audio_config["chunk"]
|
||
framerate = wf.getframerate() # 采样率(如16000Hz)
|
||
# 时长 = chunkSize(帧数) ÷ 采样率(帧/秒)
|
||
sleep_seconds = chunk_size / framerate
|
||
print(f"开始处理音频文件: {audio_file_path}")
|
||
|
||
# 分块读取并发送音频数据
|
||
while True:
|
||
audio_data = wf.readframes(chunk_size)
|
||
if not audio_data:
|
||
break # 文件读取完毕
|
||
|
||
await self.client.task_request(audio_data)
|
||
# sleep与chunk对应的音频时长一致,模拟实时输入
|
||
await asyncio.sleep(sleep_seconds)
|
||
|
||
print(f"音频文件处理完成,等待服务器响应...")
|
||
|
||
async def process_silence_audio(self) -> None:
|
||
"""发送静音音频"""
|
||
silence_data = b'\x00' * 320
|
||
await self.client.task_request(silence_data)
|
||
|
||
async def process_microphone_input(self) -> None:
|
||
"""处理麦克风输入"""
|
||
stream = self.audio_device.open_input_stream()
|
||
print("已打开麦克风,请讲话...")
|
||
print("音频处理已启动,播放时将发送静音数据避免回声")
|
||
|
||
# 程序启动后先静音2秒,确保系统稳定
|
||
print("程序启动,先静音2秒确保系统稳定...")
|
||
with self.audio_queue_lock:
|
||
self.is_recording_paused = True
|
||
self.is_playing_audio = True # 标记正在播放
|
||
|
||
# 发送2秒静音数据,确保管道清理
|
||
silence_data = b'\x00' * config.input_audio_config["chunk"]
|
||
for i in range(20): # 2秒 = 20 * 100ms
|
||
await self.client.task_request(silence_data)
|
||
await asyncio.sleep(0.1)
|
||
if i % 10 == 0: # 每秒打印一次进度
|
||
print(f"静音中... {i//10 + 1}/2秒")
|
||
|
||
print("静音完成,准备 say hello")
|
||
|
||
# say hello 前确保录音仍处于暂停状态
|
||
with self.audio_queue_lock:
|
||
self.is_recording_paused = True
|
||
self.is_playing_audio = True # 标记正在播放
|
||
print("准备 say hello,确保录音暂停")
|
||
|
||
await self.client.say_hello()
|
||
await self.say_hello_over_event.wait()
|
||
|
||
# 注意:不立即恢复录音状态,等待音频实际播放完成
|
||
# 录音状态将由音频播放线程在播放超时后自动恢复
|
||
print("say hello 请求完成,等待音频播放结束...")
|
||
|
||
# 创建静音数据
|
||
silence_data = b'\x00' * config.input_audio_config["chunk"]
|
||
last_silence_time = time.time()
|
||
|
||
# say hello 期间的特殊处理:确保完全静音
|
||
say_hello_silence_sent = False
|
||
|
||
while self.is_recording:
|
||
try:
|
||
current_time = time.time()
|
||
|
||
# 强制静音模式检查:包括回声抑制窗口期
|
||
with self.audio_queue_lock:
|
||
should_force_silence = (self.force_silence_mode or
|
||
(self.echo_suppression_start_time > 0 and
|
||
current_time - self.echo_suppression_start_time < 3.0) or # 3秒回声抑制窗口
|
||
self.is_playing_audio or
|
||
not self.say_hello_completed)
|
||
|
||
if should_force_silence:
|
||
# 强制静音模式:完全停止任何音频录制
|
||
if current_time - last_silence_time > 0.05: # 每50ms发送一次
|
||
await self.client.task_request(silence_data)
|
||
last_silence_time = current_time
|
||
|
||
# 调试信息
|
||
if not hasattr(self, 'last_silence_debug_time') or current_time - self.last_silence_debug_time > 2:
|
||
mode_desc = []
|
||
if self.force_silence_mode:
|
||
mode_desc.append("强制静音")
|
||
if self.is_playing_audio:
|
||
mode_desc.append("播放中")
|
||
if not self.say_hello_completed:
|
||
mode_desc.append("say_hello")
|
||
if self.echo_suppression_start_time > 0 and current_time - self.echo_suppression_start_time < 3.0:
|
||
mode_desc.append("回声抑制")
|
||
|
||
print(f"强制静音模式: {', '.join(mode_desc)}")
|
||
self.last_silence_debug_time = current_time
|
||
|
||
await asyncio.sleep(0.01)
|
||
continue
|
||
|
||
# 检查是否需要发送静音数据(由播放线程触发)- 最高优先级
|
||
if self.should_send_silence:
|
||
with self.audio_queue_lock:
|
||
self.should_send_silence = False
|
||
# 获取需要发送的静音数据数量
|
||
count = self.silence_send_count
|
||
self.silence_send_count = 0
|
||
|
||
# 批量发送静音数据
|
||
if count > 1:
|
||
print(f"立即清理录音管道,批量发送{count}组静音数据")
|
||
for i in range(count):
|
||
await self.client.task_request(silence_data)
|
||
await asyncio.sleep(0.005) # 短暂间隔确保发送成功
|
||
else:
|
||
await self.client.task_request(silence_data)
|
||
print("立即清理录音管道,发送静音数据")
|
||
|
||
last_silence_time = current_time
|
||
await asyncio.sleep(0.01)
|
||
continue
|
||
|
||
# 检查录音是否被暂停
|
||
with self.audio_queue_lock:
|
||
should_pause_recording = self.is_recording_paused
|
||
# 检查是否刚刚进入暂停状态
|
||
just_paused = should_pause_recording and hasattr(self, 'last_recording_state') and self.last_recording_state != should_pause_recording
|
||
self.last_recording_state = should_pause_recording
|
||
|
||
if should_pause_recording:
|
||
# 播放期间:完全停止录音,只发送静音数据
|
||
if just_paused or current_time - last_silence_time > 0.1: # 刚暂停或每100ms发送一次静音数据
|
||
await self.client.task_request(silence_data)
|
||
last_silence_time = current_time
|
||
if just_paused:
|
||
print("刚进入暂停状态,立即发送静音数据清理管道")
|
||
# 每5秒打印一次状态,避免过多日志
|
||
elif not hasattr(self, 'last_silence_log_time') or current_time - self.last_silence_log_time > 5:
|
||
print("正在播放音频,发送静音数据中...")
|
||
self.last_silence_log_time = current_time
|
||
await asyncio.sleep(0.01)
|
||
continue
|
||
|
||
# 非播放期间:正常录音
|
||
last_silence_time = current_time
|
||
|
||
# 使用AudioDeviceManager的专用读取方法
|
||
audio_data = self.audio_device.read_audio_data(config.input_audio_config["chunk"])
|
||
|
||
# 在发送前再次检查是否应该发送静音数据(最后一道防线)
|
||
with self.audio_queue_lock:
|
||
if self.is_recording_paused or self.is_playing_audio:
|
||
# 如果处于暂停状态,丢弃这个音频数据并发送静音
|
||
save_input_pcm_to_wav(silence_data, "input.pcm") # 保存静音数据用于调试
|
||
await self.client.task_request(silence_data)
|
||
# 每50次打印一次日志,避免过多输出
|
||
if not hasattr(self, 'pause_discard_count') or self.pause_discard_count % 50 == 0:
|
||
print(f"暂停期间丢弃音频数据,发送静音数据 (次数: {getattr(self, 'pause_discard_count', 0) + 1})")
|
||
self.pause_discard_count = getattr(self, 'pause_discard_count', 0) + 1
|
||
await asyncio.sleep(0.01)
|
||
continue
|
||
|
||
# 直接发送所有音频数据,不进行静音检测
|
||
save_input_pcm_to_wav(audio_data, "input.pcm")
|
||
await self.client.task_request(audio_data)
|
||
|
||
await asyncio.sleep(0.01) # 避免CPU过度使用
|
||
except Exception as e:
|
||
print(f"读取麦克风数据出错: {e}")
|
||
await asyncio.sleep(0.1) # 给系统一些恢复时间
|
||
|
||
async def start(self) -> None:
|
||
"""启动对话会话"""
|
||
try:
|
||
await self.client.connect()
|
||
|
||
if self.mod == "text":
|
||
asyncio.create_task(self.process_text_input())
|
||
asyncio.create_task(self.receive_loop())
|
||
while self.is_running:
|
||
await asyncio.sleep(0.1)
|
||
else:
|
||
if self.is_audio_file_input:
|
||
asyncio.create_task(self.process_audio_file())
|
||
await self.receive_loop()
|
||
else:
|
||
asyncio.create_task(self.process_microphone_input())
|
||
asyncio.create_task(self.receive_loop())
|
||
while self.is_running:
|
||
await asyncio.sleep(0.1)
|
||
|
||
await self.client.finish_session()
|
||
while not self.is_session_finished:
|
||
await asyncio.sleep(0.1)
|
||
await self.client.finish_connection()
|
||
await asyncio.sleep(0.1)
|
||
await self.client.close()
|
||
print(f"dialog request logid: {self.client.logid}, chat mod: {self.mod}")
|
||
save_output_to_file(self.audio_buffer, "output.pcm")
|
||
except Exception as e:
|
||
print(f"会话错误: {e}")
|
||
finally:
|
||
if not self.is_audio_file_input:
|
||
self.audio_device.stop_recording() # 先停止录音
|
||
self.audio_device.cleanup()
|
||
|
||
|
||
def save_input_pcm_to_wav(pcm_data: bytes, filename: str) -> None:
|
||
"""保存PCM数据为WAV文件"""
|
||
with wave.open(filename, 'wb') as wf:
|
||
wf.setnchannels(config.input_audio_config["channels"])
|
||
wf.setsampwidth(2) # paInt16 = 2 bytes
|
||
wf.setframerate(config.input_audio_config["sample_rate"])
|
||
wf.writeframes(pcm_data)
|
||
|
||
|
||
def save_output_to_file(audio_data: bytes, filename: str) -> None:
|
||
"""保存原始PCM音频数据到文件"""
|
||
if not audio_data:
|
||
print("No audio data to save.")
|
||
return
|
||
try:
|
||
with open(filename, 'wb') as f:
|
||
f.write(audio_data)
|
||
except IOError as e:
|
||
print(f"Failed to save pcm file: {e}")
|