Local-Voice/doubao/audio_manager.py
2025-09-18 23:34:55 +08:00

658 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import queue
import random
import signal
import sys
import threading
import time
import uuid
import wave
from dataclasses import dataclass
from typing import Any, Dict, Optional
import config
import pyaudio
from realtime_dialog_client import RealtimeDialogClient
@dataclass
class AudioConfig:
"""音频配置数据类"""
format: str
bit_size: int
channels: int
sample_rate: int
chunk: int
class AudioDeviceManager:
"""音频设备管理类,处理音频输入输出"""
def __init__(self, input_config: AudioConfig, output_config: AudioConfig):
self.input_config = input_config
self.output_config = output_config
self.pyaudio = pyaudio.PyAudio()
self.input_stream: Optional[pyaudio.Stream] = None
self.output_stream: Optional[pyaudio.Stream] = None
def open_input_stream(self) -> pyaudio.Stream:
"""打开音频输入流"""
# p = pyaudio.PyAudio()
self.input_stream = self.pyaudio.open(
format=self.input_config.bit_size,
channels=self.input_config.channels,
rate=self.input_config.sample_rate,
input=True,
frames_per_buffer=self.input_config.chunk
)
return self.input_stream
def open_output_stream(self) -> pyaudio.Stream:
"""打开音频输出流"""
self.output_stream = self.pyaudio.open(
format=self.output_config.bit_size,
channels=self.output_config.channels,
rate=self.output_config.sample_rate,
output=True,
frames_per_buffer=self.output_config.chunk
)
return self.output_stream
def cleanup(self) -> None:
"""清理音频设备资源"""
for stream in [self.input_stream, self.output_stream]:
if stream:
stream.stop_stream()
stream.close()
self.pyaudio.terminate()
class DialogSession:
"""对话会话管理类"""
is_audio_file_input: bool
mod: str
def __init__(self, ws_config: Dict[str, Any], output_audio_format: str = "pcm", audio_file_path: str = "",
mod: str = "audio", recv_timeout: int = 10):
self.audio_file_path = audio_file_path
self.recv_timeout = recv_timeout
self.is_audio_file_input = self.audio_file_path != ""
if self.is_audio_file_input:
mod = 'audio_file'
else:
self.say_hello_over_event = asyncio.Event()
self.mod = mod
self.session_id = str(uuid.uuid4())
self.client = RealtimeDialogClient(config=ws_config, session_id=self.session_id,
output_audio_format=output_audio_format, mod=mod, recv_timeout=recv_timeout)
if output_audio_format == "pcm_s16le":
config.output_audio_config["format"] = "pcm_s16le"
config.output_audio_config["bit_size"] = pyaudio.paInt16
self.is_running = True
self.is_session_finished = False
self.is_user_querying = False
self.is_sending_chat_tts_text = False
self.audio_buffer = b''
self.is_playing_audio = False # 是否正在播放音频
self.audio_queue_lock = threading.Lock() # 音频队列锁
self.is_recording_paused = False # 录音是否被暂停
self.should_send_silence = False # 是否需要发送静音数据
self.silence_send_count = 0 # 需要发送的静音数据数量
self.pre_pause_time = 0 # 预暂停时间
self.last_recording_state = False # 上次录音状态
self.say_hello_completed = False # say hello 是否已完成
signal.signal(signal.SIGINT, self._keyboard_signal)
self.audio_queue = queue.Queue()
if not self.is_audio_file_input:
self.audio_device = AudioDeviceManager(
AudioConfig(**config.input_audio_config),
AudioConfig(**config.output_audio_config)
)
# 初始化音频队列和输出流
print(f"输出音频配置: {config.output_audio_config}")
self.output_stream = self.audio_device.open_output_stream()
print("音频输出流已打开")
# 启动播放线程
self.is_recording = True
self.is_playing = True
self.player_thread = threading.Thread(target=self._audio_player_thread)
self.player_thread.daemon = True
self.player_thread.start()
def _audio_player_thread(self):
"""音频播放线程"""
audio_playing_timeout = 1.0 # 1秒没有音频数据认为播放结束
queue_check_interval = 0.1 # 每100ms检查一次队列状态
while self.is_playing:
try:
# 从队列获取音频数据
audio_data = self.audio_queue.get(timeout=queue_check_interval)
if audio_data is not None:
with self.audio_queue_lock:
# 第三重保险:播放开始时最终确认暂停状态
if not hasattr(self, 'last_audio_time') or not self.is_playing_audio:
# 从非播放状态进入播放状态
self.is_playing_audio = True
# 确保录音已暂停
if not self.is_recording_paused:
self.is_recording_paused = True
print("播放开始,最终确认暂停录音")
# 更新最后音频时间
self.last_audio_time = time.time()
# 播放音频数据
self.output_stream.write(audio_data)
except queue.Empty:
# 队列为空,检查是否超时
current_time = time.time()
with self.audio_queue_lock:
if self.is_playing_audio:
if hasattr(self, 'last_audio_time') and current_time - self.last_audio_time > audio_playing_timeout:
# 超过1秒没有新音频认为播放结束
self.is_playing_audio = False
self.is_recording_paused = False
# 标记 say hello 完成
if hasattr(self, 'say_hello_completed') and not self.say_hello_completed:
self.say_hello_completed = True
print("say hello 音频播放完成")
print("音频播放超时,恢复录音")
# 直接发送静音数据,而不是在协程中发送
try:
silence_data = b'\x00' * config.input_audio_config["chunk"]
# 使用同步方式发送静音数据
# 这里我们设置一个标志,让主循环处理
self.silence_send_count = 2 # 播放超时时发送2组静音数据
self.should_send_silence = True
except Exception as e:
print(f"准备静音数据失败: {e}")
elif self.audio_queue.empty():
# 队列为空,但还没超时,继续等待
pass
time.sleep(0.01)
except Exception as e:
print(f"音频播放错误: {e}")
with self.audio_queue_lock:
self.is_playing_audio = False
self.is_recording_paused = False
time.sleep(0.1)
# 移除了静音检测函数,避免干扰正常的音频处理
async def _send_silence_on_playback_end(self):
"""播放结束时发送静音数据"""
try:
silence_data = b'\x00' * config.input_audio_config["chunk"]
await self.client.task_request(silence_data)
print("播放结束,已发送静音数据")
except Exception as e:
print(f"发送静音数据失败: {e}")
def _check_and_restore_recording(self):
"""检查并恢复录音状态"""
with self.audio_queue_lock:
if self.is_recording_paused and self.audio_queue.empty():
# 如果队列为空且录音被暂停,恢复录音
self.is_recording_paused = False
self.is_playing_audio = False
print("音频队列为空,自动恢复录音")
return True
return False
def handle_server_response(self, response: Dict[str, Any]) -> None:
if not response or response == {}:
return
"""处理服务器响应"""
message_type = response.get('message_type')
if message_type == 'SERVER_ACK' and isinstance(response.get('payload_msg'), bytes):
if self.is_sending_chat_tts_text:
return
audio_data = response['payload_msg']
# 第二重保险:接收到音频数据时确认暂停状态
with self.audio_queue_lock:
was_not_playing = not self.is_playing_audio
if was_not_playing:
# 第一批音频数据到达,确保录音已暂停
self.is_playing_audio = True
if not self.is_recording_paused:
self.is_recording_paused = True
print("接收到首批音频数据,立即暂停录音")
else:
print("接收到音频数据,录音已暂停")
# 立即发送静音数据,确保管道清理
self.silence_send_count = 3 # 音频数据到达时发送3组静音数据
self.should_send_silence = True
print("服务器收到音频数据,立即清理录音管道")
if not self.is_audio_file_input:
self.audio_queue.put(audio_data)
self.audio_buffer += audio_data
elif message_type == 'SERVER_FULL_RESPONSE':
print(f"服务器响应: {response}")
event = response.get('event')
payload_msg = response.get('payload_msg', {})
# 第一重保险:服务器开始响应时立即预暂停录音
if event in [450, 359, 152, 153]: # 这些事件表示服务器开始或结束响应
if event == 450:
print(f"清空缓存音频: {response['session_id']}")
while not self.audio_queue.empty():
try:
self.audio_queue.get_nowait()
except queue.Empty:
continue
self.is_user_querying = True
print("服务器准备接收用户输入")
# 预暂停录音,防止即将到来的音频回声
with self.audio_queue_lock:
if not self.is_recording_paused:
self.is_recording_paused = True
self.is_playing_audio = True # 同时设置播放状态,双重保险
self.pre_pause_time = time.time()
print("服务器开始响应,预暂停录音防止回声")
# 立即发送静音数据清理管道防止前1-2秒回声
print("预暂停期间立即发送静音数据清理管道")
# 设置批量静音发送,确保管道完全清理
self.silence_send_count = 8 # 增加到8组确保彻底清理
self.should_send_silence = True
# 强制重置录音状态
self.last_recording_state = True # 标记为已暂停
if event == 350 and self.is_sending_chat_tts_text and payload_msg.get("tts_type") in ["chat_tts_text", "external_rag"]:
while not self.audio_queue.empty():
try:
self.audio_queue.get_nowait()
except queue.Empty:
continue
self.is_sending_chat_tts_text = False
if event == 459:
self.is_user_querying = False
# 服务器完成响应,立即恢复录音
with self.audio_queue_lock:
was_paused = self.is_recording_paused
self.is_recording_paused = False
self.is_playing_audio = False
if was_paused:
print("服务器响应完成,立即恢复录音")
# 设置标志发送静音数据
self.silence_send_count = 2 # 响应完成时发送2组静音数据
self.should_send_silence = True
print("服务器完成响应,等待用户输入")
#if random.randint(0, 100000)%1 == 0:
# self.is_sending_chat_tts_text = True
#asyncio.create_task(self.trigger_chat_tts_text())
#asyncio.create_task(self.trigger_chat_rag_text())
elif message_type == 'SERVER_ERROR':
print(f"服务器错误: {response['payload_msg']}")
raise Exception("服务器错误")
async def trigger_chat_tts_text(self):
"""概率触发发送ChatTTSText请求"""
print("hit ChatTTSText event, start sending...")
await self.client.chat_tts_text(
is_user_querying=self.is_user_querying,
start=True,
end=False,
content="这是查询到外部数据之前的安抚话术。",
)
await self.client.chat_tts_text(
is_user_querying=self.is_user_querying,
start=False,
end=True,
content="",
)
async def trigger_chat_rag_text(self):
await asyncio.sleep(5) # 模拟查询外部RAG的耗时这里为了不影响GTA安抚话术的播报直接sleep 5秒
print("hit ChatRAGText event, start sending...")
await self.client.chat_rag_text(self.is_user_querying, external_rag='[{"title":"北京天气","content":"今天北京整体以晴到多云为主,但西部和北部地带可能会出现分散性雷阵雨,特别是午后至傍晚时段需注意突发降雨。\n💨 风况与湿度\n风力较弱,一般为 23 级南风或西南风\n白天湿度较高,早晚略凉爽"}]')
def _keyboard_signal(self, sig, frame):
print(f"receive keyboard Ctrl+C")
self.stop()
def stop(self):
self.is_recording = False
self.is_playing = False
self.is_running = False
async def receive_loop(self):
try:
while True:
response = await self.client.receive_server_response()
self.handle_server_response(response)
if 'event' in response and (response['event'] == 152 or response['event'] == 153):
print(f"receive session finished event: {response['event']}")
self.is_session_finished = True
break
if 'event' in response and response['event'] == 359:
if self.is_audio_file_input:
print(f"receive tts ended event")
self.is_session_finished = True
break
else:
if not self.say_hello_over_event.is_set():
print(f"receive tts sayhello ended event")
self.say_hello_over_event.set()
# 对于音频模式say hello 音频播放即将开始
# 确保录音保持暂停状态
if self.mod == "audio":
with self.audio_queue_lock:
self.is_recording_paused = True
self.is_playing_audio = True
print("say hello 音频即将开始,确保录音暂停")
if self.mod == "text":
# 文本模式下 say hello 完成,恢复录音状态
with self.audio_queue_lock:
if self.is_recording_paused:
self.is_recording_paused = False
print("文本模式say hello 完成,恢复录音")
print("请输入内容:")
except asyncio.CancelledError:
print("接收任务已取消")
except Exception as e:
print(f"接收消息错误: {e}")
finally:
self.stop()
self.is_session_finished = True
async def process_audio_file(self) -> None:
await self.process_audio_file_input(self.audio_file_path)
async def process_text_input(self) -> None:
# 程序启动后先静音2秒确保系统稳定
print("文本模式程序启动先静音2秒确保系统稳定...")
with self.audio_queue_lock:
self.is_recording_paused = True
self.is_playing_audio = True # 标记正在播放
# 发送2秒静音数据确保管道清理
silence_data = b'\x00' * config.input_audio_config["chunk"]
for i in range(20): # 2秒 = 20 * 100ms
await self.client.task_request(silence_data)
await asyncio.sleep(0.1)
if i % 10 == 0: # 每秒打印一次进度
print(f"文本模式:静音中... {i//10 + 1}/2秒")
print("文本模式:静音完成,准备 say hello")
# say hello 前确保录音仍处于暂停状态
with self.audio_queue_lock:
self.is_recording_paused = True
self.is_playing_audio = True # 标记正在播放
print("文本模式:准备 say hello确保录音暂停")
await self.client.say_hello()
await self.say_hello_over_event.wait()
"""主逻辑处理文本输入和WebSocket通信"""
# 确保连接最终关闭
try:
# 启动输入监听线程
input_queue = queue.Queue()
input_thread = threading.Thread(target=self.input_listener, args=(input_queue,), daemon=True)
input_thread.start()
# 主循环:处理输入和上下文结束
while self.is_running:
try:
# 检查是否有输入(非阻塞)
input_str = input_queue.get_nowait()
if input_str is None:
# 输入流关闭
print("Input channel closed")
break
if input_str:
# 发送输入内容
await self.client.chat_text_query(input_str)
except queue.Empty:
# 无输入时短暂休眠
await asyncio.sleep(0.1)
except Exception as e:
print(f"Main loop error: {e}")
break
finally:
print("exit text input")
def input_listener(self, input_queue: queue.Queue) -> None:
"""在单独线程中监听标准输入"""
print("Start listening for input")
try:
while True:
# 读取标准输入(阻塞操作)
line = sys.stdin.readline()
if not line:
# 输入流关闭
input_queue.put(None)
break
input_str = line.strip()
input_queue.put(input_str)
except Exception as e:
print(f"Input listener error: {e}")
input_queue.put(None)
async def process_audio_file_input(self, audio_file_path: str) -> None:
# 读取WAV文件
with wave.open(audio_file_path, 'rb') as wf:
chunk_size = config.input_audio_config["chunk"]
framerate = wf.getframerate() # 采样率如16000Hz
# 时长 = chunkSize帧数 ÷ 采样率(帧/秒)
sleep_seconds = chunk_size / framerate
print(f"开始处理音频文件: {audio_file_path}")
# 分块读取并发送音频数据
while True:
audio_data = wf.readframes(chunk_size)
if not audio_data:
break # 文件读取完毕
await self.client.task_request(audio_data)
# sleep与chunk对应的音频时长一致模拟实时输入
await asyncio.sleep(sleep_seconds)
print(f"音频文件处理完成,等待服务器响应...")
async def process_silence_audio(self) -> None:
"""发送静音音频"""
silence_data = b'\x00' * 320
await self.client.task_request(silence_data)
async def process_microphone_input(self) -> None:
"""处理麦克风输入"""
stream = self.audio_device.open_input_stream()
print("已打开麦克风,请讲话...")
print("音频处理已启动,播放时将发送静音数据避免回声")
# 程序启动后先静音2秒确保系统稳定
print("程序启动先静音2秒确保系统稳定...")
with self.audio_queue_lock:
self.is_recording_paused = True
self.is_playing_audio = True # 标记正在播放
# 发送2秒静音数据确保管道清理
silence_data = b'\x00' * config.input_audio_config["chunk"]
for i in range(20): # 2秒 = 20 * 100ms
await self.client.task_request(silence_data)
await asyncio.sleep(0.1)
if i % 10 == 0: # 每秒打印一次进度
print(f"静音中... {i//10 + 1}/2秒")
print("静音完成,准备 say hello")
# say hello 前确保录音仍处于暂停状态
with self.audio_queue_lock:
self.is_recording_paused = True
self.is_playing_audio = True # 标记正在播放
print("准备 say hello确保录音暂停")
await self.client.say_hello()
await self.say_hello_over_event.wait()
# 注意:不立即恢复录音状态,等待音频实际播放完成
# 录音状态将由音频播放线程在播放超时后自动恢复
print("say hello 请求完成,等待音频播放结束...")
# 创建静音数据
silence_data = b'\x00' * config.input_audio_config["chunk"]
last_silence_time = time.time()
# say hello 期间的特殊处理:确保完全静音
say_hello_silence_sent = False
while self.is_recording:
try:
current_time = time.time()
# say hello 期间强制静音处理
with self.audio_queue_lock:
is_currently_playing = self.is_playing_audio
if is_currently_playing or not self.say_hello_completed:
# 如果正在播放或者 say hello 未完成,发送静音数据
if current_time - last_silence_time > 0.05: # 每50ms发送一次
await self.client.task_request(silence_data)
last_silence_time = current_time
if not self.say_hello_completed and not is_currently_playing:
print("say hello 期间发送静音数据")
await asyncio.sleep(0.01)
continue
# 检查是否需要发送静音数据(由播放线程触发)- 最高优先级
if self.should_send_silence:
with self.audio_queue_lock:
self.should_send_silence = False
# 获取需要发送的静音数据数量
count = self.silence_send_count
self.silence_send_count = 0
# 批量发送静音数据
if count > 1:
print(f"立即清理录音管道,批量发送{count}组静音数据")
for i in range(count):
await self.client.task_request(silence_data)
await asyncio.sleep(0.005) # 短暂间隔确保发送成功
else:
await self.client.task_request(silence_data)
print("立即清理录音管道,发送静音数据")
last_silence_time = current_time
await asyncio.sleep(0.01)
continue
# 检查录音是否被暂停
with self.audio_queue_lock:
should_pause_recording = self.is_recording_paused
# 检查是否刚刚进入暂停状态
just_paused = should_pause_recording and hasattr(self, 'last_recording_state') and self.last_recording_state != should_pause_recording
self.last_recording_state = should_pause_recording
if should_pause_recording:
# 播放期间:完全停止录音,只发送静音数据
if just_paused or current_time - last_silence_time > 0.1: # 刚暂停或每100ms发送一次静音数据
await self.client.task_request(silence_data)
last_silence_time = current_time
if just_paused:
print("刚进入暂停状态,立即发送静音数据清理管道")
# 每5秒打印一次状态避免过多日志
elif not hasattr(self, 'last_silence_log_time') or current_time - self.last_silence_log_time > 5:
print("正在播放音频,发送静音数据中...")
self.last_silence_log_time = current_time
await asyncio.sleep(0.01)
continue
# 非播放期间:正常录音
last_silence_time = current_time
# 添加exception_on_overflow=False参数来忽略溢出错误
audio_data = stream.read(config.input_audio_config["chunk"], exception_on_overflow=False)
# 在发送前再次检查是否应该发送静音数据(最后一道防线)
with self.audio_queue_lock:
if self.is_recording_paused or self.is_playing_audio:
# 如果处于暂停状态,丢弃这个音频数据并发送静音
save_input_pcm_to_wav(silence_data, "input.pcm") # 保存静音数据用于调试
await self.client.task_request(silence_data)
# 每50次打印一次日志避免过多输出
if not hasattr(self, 'pause_discard_count') or self.pause_discard_count % 50 == 0:
print(f"暂停期间丢弃音频数据,发送静音数据 (次数: {getattr(self, 'pause_discard_count', 0) + 1})")
self.pause_discard_count = getattr(self, 'pause_discard_count', 0) + 1
await asyncio.sleep(0.01)
continue
# 直接发送所有音频数据,不进行静音检测
save_input_pcm_to_wav(audio_data, "input.pcm")
await self.client.task_request(audio_data)
await asyncio.sleep(0.01) # 避免CPU过度使用
except Exception as e:
print(f"读取麦克风数据出错: {e}")
await asyncio.sleep(0.1) # 给系统一些恢复时间
async def start(self) -> None:
"""启动对话会话"""
try:
await self.client.connect()
if self.mod == "text":
asyncio.create_task(self.process_text_input())
asyncio.create_task(self.receive_loop())
while self.is_running:
await asyncio.sleep(0.1)
else:
if self.is_audio_file_input:
asyncio.create_task(self.process_audio_file())
await self.receive_loop()
else:
asyncio.create_task(self.process_microphone_input())
asyncio.create_task(self.receive_loop())
while self.is_running:
await asyncio.sleep(0.1)
await self.client.finish_session()
while not self.is_session_finished:
await asyncio.sleep(0.1)
await self.client.finish_connection()
await asyncio.sleep(0.1)
await self.client.close()
print(f"dialog request logid: {self.client.logid}, chat mod: {self.mod}")
save_output_to_file(self.audio_buffer, "output.pcm")
except Exception as e:
print(f"会话错误: {e}")
finally:
if not self.is_audio_file_input:
self.audio_device.cleanup()
def save_input_pcm_to_wav(pcm_data: bytes, filename: str) -> None:
"""保存PCM数据为WAV文件"""
with wave.open(filename, 'wb') as wf:
wf.setnchannels(config.input_audio_config["channels"])
wf.setsampwidth(2) # paInt16 = 2 bytes
wf.setframerate(config.input_audio_config["sample_rate"])
wf.writeframes(pcm_data)
def save_output_to_file(audio_data: bytes, filename: str) -> None:
"""保存原始PCM音频数据到文件"""
if not audio_data:
print("No audio data to save.")
return
try:
with open(filename, 'wb') as f:
f.write(audio_data)
except IOError as e:
print(f"Failed to save pcm file: {e}")