doubao
This commit is contained in:
parent
e6aa7f7be8
commit
53d53e4555
BIN
doubao/__pycache__/audio_manager.cpython-312.pyc
Normal file
BIN
doubao/__pycache__/audio_manager.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
doubao/__pycache__/protocol.cpython-312.pyc
Normal file
BIN
doubao/__pycache__/protocol.cpython-312.pyc
Normal file
Binary file not shown.
BIN
doubao/__pycache__/realtime_dialog_client.cpython-312.pyc
Normal file
BIN
doubao/__pycache__/realtime_dialog_client.cpython-312.pyc
Normal file
Binary file not shown.
@ -8,11 +8,10 @@ import time
|
||||
import uuid
|
||||
import wave
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import pyaudio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import config
|
||||
import pyaudio
|
||||
from realtime_dialog_client import RealtimeDialogClient
|
||||
|
||||
|
||||
@ -96,6 +95,14 @@ class DialogSession:
|
||||
self.is_user_querying = False
|
||||
self.is_sending_chat_tts_text = False
|
||||
self.audio_buffer = b''
|
||||
self.is_playing_audio = False # 是否正在播放音频
|
||||
self.audio_queue_lock = threading.Lock() # 音频队列锁
|
||||
self.is_recording_paused = False # 录音是否被暂停
|
||||
self.should_send_silence = False # 是否需要发送静音数据
|
||||
self.silence_send_count = 0 # 需要发送的静音数据数量
|
||||
self.pre_pause_time = 0 # 预暂停时间
|
||||
self.last_recording_state = False # 上次录音状态
|
||||
self.say_hello_completed = False # say hello 是否已完成
|
||||
|
||||
signal.signal(signal.SIGINT, self._keyboard_signal)
|
||||
self.audio_queue = queue.Queue()
|
||||
@ -105,7 +112,9 @@ class DialogSession:
|
||||
AudioConfig(**config.output_audio_config)
|
||||
)
|
||||
# 初始化音频队列和输出流
|
||||
print(f"输出音频配置: {config.output_audio_config}")
|
||||
self.output_stream = self.audio_device.open_output_stream()
|
||||
print("音频输出流已打开")
|
||||
# 启动播放线程
|
||||
self.is_recording = True
|
||||
self.is_playing = True
|
||||
@ -115,36 +124,123 @@ class DialogSession:
|
||||
|
||||
def _audio_player_thread(self):
|
||||
"""音频播放线程"""
|
||||
audio_playing_timeout = 1.0 # 1秒没有音频数据认为播放结束
|
||||
queue_check_interval = 0.1 # 每100ms检查一次队列状态
|
||||
|
||||
while self.is_playing:
|
||||
try:
|
||||
# 从队列获取音频数据
|
||||
audio_data = self.audio_queue.get(timeout=1.0)
|
||||
audio_data = self.audio_queue.get(timeout=queue_check_interval)
|
||||
if audio_data is not None:
|
||||
with self.audio_queue_lock:
|
||||
# 第三重保险:播放开始时最终确认暂停状态
|
||||
if not hasattr(self, 'last_audio_time') or not self.is_playing_audio:
|
||||
# 从非播放状态进入播放状态
|
||||
self.is_playing_audio = True
|
||||
# 确保录音已暂停
|
||||
if not self.is_recording_paused:
|
||||
self.is_recording_paused = True
|
||||
print("播放开始,最终确认暂停录音")
|
||||
|
||||
# 更新最后音频时间
|
||||
self.last_audio_time = time.time()
|
||||
|
||||
# 播放音频数据
|
||||
self.output_stream.write(audio_data)
|
||||
|
||||
except queue.Empty:
|
||||
# 队列为空时等待一小段时间
|
||||
time.sleep(0.1)
|
||||
# 队列为空,检查是否超时
|
||||
current_time = time.time()
|
||||
with self.audio_queue_lock:
|
||||
if self.is_playing_audio:
|
||||
if hasattr(self, 'last_audio_time') and current_time - self.last_audio_time > audio_playing_timeout:
|
||||
# 超过1秒没有新音频,认为播放结束
|
||||
self.is_playing_audio = False
|
||||
self.is_recording_paused = False
|
||||
# 标记 say hello 完成
|
||||
if hasattr(self, 'say_hello_completed') and not self.say_hello_completed:
|
||||
self.say_hello_completed = True
|
||||
print("say hello 音频播放完成")
|
||||
print("音频播放超时,恢复录音")
|
||||
# 直接发送静音数据,而不是在协程中发送
|
||||
try:
|
||||
silence_data = b'\x00' * config.input_audio_config["chunk"]
|
||||
# 使用同步方式发送静音数据
|
||||
# 这里我们设置一个标志,让主循环处理
|
||||
self.silence_send_count = 2 # 播放超时时发送2组静音数据
|
||||
self.should_send_silence = True
|
||||
except Exception as e:
|
||||
print(f"准备静音数据失败: {e}")
|
||||
elif self.audio_queue.empty():
|
||||
# 队列为空,但还没超时,继续等待
|
||||
pass
|
||||
time.sleep(0.01)
|
||||
except Exception as e:
|
||||
print(f"音频播放错误: {e}")
|
||||
with self.audio_queue_lock:
|
||||
self.is_playing_audio = False
|
||||
self.is_recording_paused = False
|
||||
time.sleep(0.1)
|
||||
|
||||
# 移除了静音检测函数,避免干扰正常的音频处理
|
||||
|
||||
async def _send_silence_on_playback_end(self):
|
||||
"""播放结束时发送静音数据"""
|
||||
try:
|
||||
silence_data = b'\x00' * config.input_audio_config["chunk"]
|
||||
await self.client.task_request(silence_data)
|
||||
print("播放结束,已发送静音数据")
|
||||
except Exception as e:
|
||||
print(f"发送静音数据失败: {e}")
|
||||
|
||||
def _check_and_restore_recording(self):
|
||||
"""检查并恢复录音状态"""
|
||||
with self.audio_queue_lock:
|
||||
if self.is_recording_paused and self.audio_queue.empty():
|
||||
# 如果队列为空且录音被暂停,恢复录音
|
||||
self.is_recording_paused = False
|
||||
self.is_playing_audio = False
|
||||
print("音频队列为空,自动恢复录音")
|
||||
return True
|
||||
return False
|
||||
|
||||
def handle_server_response(self, response: Dict[str, Any]) -> None:
|
||||
if response == {}:
|
||||
if not response or response == {}:
|
||||
return
|
||||
"""处理服务器响应"""
|
||||
if response['message_type'] == 'SERVER_ACK' and isinstance(response.get('payload_msg'), bytes):
|
||||
# print(f"\n接收到音频数据: {len(response['payload_msg'])} 字节")
|
||||
message_type = response.get('message_type')
|
||||
if message_type == 'SERVER_ACK' and isinstance(response.get('payload_msg'), bytes):
|
||||
if self.is_sending_chat_tts_text:
|
||||
return
|
||||
audio_data = response['payload_msg']
|
||||
|
||||
# 第二重保险:接收到音频数据时确认暂停状态
|
||||
with self.audio_queue_lock:
|
||||
was_not_playing = not self.is_playing_audio
|
||||
if was_not_playing:
|
||||
# 第一批音频数据到达,确保录音已暂停
|
||||
self.is_playing_audio = True
|
||||
if not self.is_recording_paused:
|
||||
self.is_recording_paused = True
|
||||
print("接收到首批音频数据,立即暂停录音")
|
||||
else:
|
||||
print("接收到音频数据,录音已暂停")
|
||||
|
||||
# 立即发送静音数据,确保管道清理
|
||||
self.silence_send_count = 3 # 音频数据到达时发送3组静音数据
|
||||
self.should_send_silence = True
|
||||
print("服务器收到音频数据,立即清理录音管道")
|
||||
|
||||
if not self.is_audio_file_input:
|
||||
self.audio_queue.put(audio_data)
|
||||
self.audio_buffer += audio_data
|
||||
elif response['message_type'] == 'SERVER_FULL_RESPONSE':
|
||||
elif message_type == 'SERVER_FULL_RESPONSE':
|
||||
print(f"服务器响应: {response}")
|
||||
event = response.get('event')
|
||||
payload_msg = response.get('payload_msg', {})
|
||||
|
||||
# 第一重保险:服务器开始响应时立即预暂停录音
|
||||
if event in [450, 359, 152, 153]: # 这些事件表示服务器开始或结束响应
|
||||
if event == 450:
|
||||
print(f"清空缓存音频: {response['session_id']}")
|
||||
while not self.audio_queue.empty():
|
||||
@ -153,6 +249,24 @@ class DialogSession:
|
||||
except queue.Empty:
|
||||
continue
|
||||
self.is_user_querying = True
|
||||
print("服务器准备接收用户输入")
|
||||
|
||||
# 预暂停录音,防止即将到来的音频回声
|
||||
with self.audio_queue_lock:
|
||||
if not self.is_recording_paused:
|
||||
self.is_recording_paused = True
|
||||
self.is_playing_audio = True # 同时设置播放状态,双重保险
|
||||
self.pre_pause_time = time.time()
|
||||
print("服务器开始响应,预暂停录音防止回声")
|
||||
|
||||
# 立即发送静音数据清理管道,防止前1-2秒回声
|
||||
print("预暂停期间立即发送静音数据清理管道")
|
||||
# 设置批量静音发送,确保管道完全清理
|
||||
self.silence_send_count = 8 # 增加到8组,确保彻底清理
|
||||
self.should_send_silence = True
|
||||
|
||||
# 强制重置录音状态
|
||||
self.last_recording_state = True # 标记为已暂停
|
||||
|
||||
if event == 350 and self.is_sending_chat_tts_text and payload_msg.get("tts_type") in ["chat_tts_text", "external_rag"]:
|
||||
while not self.audio_queue.empty():
|
||||
@ -164,11 +278,22 @@ class DialogSession:
|
||||
|
||||
if event == 459:
|
||||
self.is_user_querying = False
|
||||
if random.randint(0, 100000)%1 == 0:
|
||||
self.is_sending_chat_tts_text = True
|
||||
asyncio.create_task(self.trigger_chat_tts_text())
|
||||
asyncio.create_task(self.trigger_chat_rag_text())
|
||||
elif response['message_type'] == 'SERVER_ERROR':
|
||||
# 服务器完成响应,立即恢复录音
|
||||
with self.audio_queue_lock:
|
||||
was_paused = self.is_recording_paused
|
||||
self.is_recording_paused = False
|
||||
self.is_playing_audio = False
|
||||
if was_paused:
|
||||
print("服务器响应完成,立即恢复录音")
|
||||
# 设置标志发送静音数据
|
||||
self.silence_send_count = 2 # 响应完成时发送2组静音数据
|
||||
self.should_send_silence = True
|
||||
print("服务器完成响应,等待用户输入")
|
||||
#if random.randint(0, 100000)%1 == 0:
|
||||
# self.is_sending_chat_tts_text = True
|
||||
#asyncio.create_task(self.trigger_chat_tts_text())
|
||||
#asyncio.create_task(self.trigger_chat_rag_text())
|
||||
elif message_type == 'SERVER_ERROR':
|
||||
print(f"服务器错误: {response['payload_msg']}")
|
||||
raise Exception("服务器错误")
|
||||
|
||||
@ -220,7 +345,21 @@ class DialogSession:
|
||||
if not self.say_hello_over_event.is_set():
|
||||
print(f"receive tts sayhello ended event")
|
||||
self.say_hello_over_event.set()
|
||||
|
||||
# 对于音频模式,say hello 音频播放即将开始
|
||||
# 确保录音保持暂停状态
|
||||
if self.mod == "audio":
|
||||
with self.audio_queue_lock:
|
||||
self.is_recording_paused = True
|
||||
self.is_playing_audio = True
|
||||
print("say hello 音频即将开始,确保录音暂停")
|
||||
|
||||
if self.mod == "text":
|
||||
# 文本模式下 say hello 完成,恢复录音状态
|
||||
with self.audio_queue_lock:
|
||||
if self.is_recording_paused:
|
||||
self.is_recording_paused = False
|
||||
print("文本模式:say hello 完成,恢复录音")
|
||||
print("请输入内容:")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
@ -235,6 +374,28 @@ class DialogSession:
|
||||
await self.process_audio_file_input(self.audio_file_path)
|
||||
|
||||
async def process_text_input(self) -> None:
|
||||
# 程序启动后先静音2秒,确保系统稳定
|
||||
print("文本模式:程序启动,先静音2秒确保系统稳定...")
|
||||
with self.audio_queue_lock:
|
||||
self.is_recording_paused = True
|
||||
self.is_playing_audio = True # 标记正在播放
|
||||
|
||||
# 发送2秒静音数据,确保管道清理
|
||||
silence_data = b'\x00' * config.input_audio_config["chunk"]
|
||||
for i in range(20): # 2秒 = 20 * 100ms
|
||||
await self.client.task_request(silence_data)
|
||||
await asyncio.sleep(0.1)
|
||||
if i % 10 == 0: # 每秒打印一次进度
|
||||
print(f"文本模式:静音中... {i//10 + 1}/2秒")
|
||||
|
||||
print("文本模式:静音完成,准备 say hello")
|
||||
|
||||
# say hello 前确保录音仍处于暂停状态
|
||||
with self.audio_queue_lock:
|
||||
self.is_recording_paused = True
|
||||
self.is_playing_audio = True # 标记正在播放
|
||||
print("文本模式:准备 say hello,确保录音暂停")
|
||||
|
||||
await self.client.say_hello()
|
||||
await self.say_hello_over_event.wait()
|
||||
|
||||
@ -310,20 +471,131 @@ class DialogSession:
|
||||
await self.client.task_request(silence_data)
|
||||
|
||||
async def process_microphone_input(self) -> None:
|
||||
await self.client.say_hello()
|
||||
await self.say_hello_over_event.wait()
|
||||
await self.client.chat_text_query("你好,我也叫豆包")
|
||||
|
||||
"""处理麦克风输入"""
|
||||
stream = self.audio_device.open_input_stream()
|
||||
print("已打开麦克风,请讲话...")
|
||||
print("音频处理已启动,播放时将发送静音数据避免回声")
|
||||
|
||||
# 程序启动后先静音2秒,确保系统稳定
|
||||
print("程序启动,先静音2秒确保系统稳定...")
|
||||
with self.audio_queue_lock:
|
||||
self.is_recording_paused = True
|
||||
self.is_playing_audio = True # 标记正在播放
|
||||
|
||||
# 发送2秒静音数据,确保管道清理
|
||||
silence_data = b'\x00' * config.input_audio_config["chunk"]
|
||||
for i in range(20): # 2秒 = 20 * 100ms
|
||||
await self.client.task_request(silence_data)
|
||||
await asyncio.sleep(0.1)
|
||||
if i % 10 == 0: # 每秒打印一次进度
|
||||
print(f"静音中... {i//10 + 1}/2秒")
|
||||
|
||||
print("静音完成,准备 say hello")
|
||||
|
||||
# say hello 前确保录音仍处于暂停状态
|
||||
with self.audio_queue_lock:
|
||||
self.is_recording_paused = True
|
||||
self.is_playing_audio = True # 标记正在播放
|
||||
print("准备 say hello,确保录音暂停")
|
||||
|
||||
await self.client.say_hello()
|
||||
await self.say_hello_over_event.wait()
|
||||
|
||||
# 注意:不立即恢复录音状态,等待音频实际播放完成
|
||||
# 录音状态将由音频播放线程在播放超时后自动恢复
|
||||
print("say hello 请求完成,等待音频播放结束...")
|
||||
|
||||
# 创建静音数据
|
||||
silence_data = b'\x00' * config.input_audio_config["chunk"]
|
||||
last_silence_time = time.time()
|
||||
|
||||
# say hello 期间的特殊处理:确保完全静音
|
||||
say_hello_silence_sent = False
|
||||
|
||||
while self.is_recording:
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
# say hello 期间强制静音处理
|
||||
with self.audio_queue_lock:
|
||||
is_currently_playing = self.is_playing_audio
|
||||
|
||||
if is_currently_playing or not self.say_hello_completed:
|
||||
# 如果正在播放或者 say hello 未完成,发送静音数据
|
||||
if current_time - last_silence_time > 0.05: # 每50ms发送一次
|
||||
await self.client.task_request(silence_data)
|
||||
last_silence_time = current_time
|
||||
if not self.say_hello_completed and not is_currently_playing:
|
||||
print("say hello 期间发送静音数据")
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
# 检查是否需要发送静音数据(由播放线程触发)- 最高优先级
|
||||
if self.should_send_silence:
|
||||
with self.audio_queue_lock:
|
||||
self.should_send_silence = False
|
||||
# 获取需要发送的静音数据数量
|
||||
count = self.silence_send_count
|
||||
self.silence_send_count = 0
|
||||
|
||||
# 批量发送静音数据
|
||||
if count > 1:
|
||||
print(f"立即清理录音管道,批量发送{count}组静音数据")
|
||||
for i in range(count):
|
||||
await self.client.task_request(silence_data)
|
||||
await asyncio.sleep(0.005) # 短暂间隔确保发送成功
|
||||
else:
|
||||
await self.client.task_request(silence_data)
|
||||
print("立即清理录音管道,发送静音数据")
|
||||
|
||||
last_silence_time = current_time
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
# 检查录音是否被暂停
|
||||
with self.audio_queue_lock:
|
||||
should_pause_recording = self.is_recording_paused
|
||||
# 检查是否刚刚进入暂停状态
|
||||
just_paused = should_pause_recording and hasattr(self, 'last_recording_state') and self.last_recording_state != should_pause_recording
|
||||
self.last_recording_state = should_pause_recording
|
||||
|
||||
if should_pause_recording:
|
||||
# 播放期间:完全停止录音,只发送静音数据
|
||||
if just_paused or current_time - last_silence_time > 0.1: # 刚暂停或每100ms发送一次静音数据
|
||||
await self.client.task_request(silence_data)
|
||||
last_silence_time = current_time
|
||||
if just_paused:
|
||||
print("刚进入暂停状态,立即发送静音数据清理管道")
|
||||
# 每5秒打印一次状态,避免过多日志
|
||||
elif not hasattr(self, 'last_silence_log_time') or current_time - self.last_silence_log_time > 5:
|
||||
print("正在播放音频,发送静音数据中...")
|
||||
self.last_silence_log_time = current_time
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
# 非播放期间:正常录音
|
||||
last_silence_time = current_time
|
||||
|
||||
# 添加exception_on_overflow=False参数来忽略溢出错误
|
||||
audio_data = stream.read(config.input_audio_config["chunk"], exception_on_overflow=False)
|
||||
|
||||
# 在发送前再次检查是否应该发送静音数据(最后一道防线)
|
||||
with self.audio_queue_lock:
|
||||
if self.is_recording_paused or self.is_playing_audio:
|
||||
# 如果处于暂停状态,丢弃这个音频数据并发送静音
|
||||
save_input_pcm_to_wav(silence_data, "input.pcm") # 保存静音数据用于调试
|
||||
await self.client.task_request(silence_data)
|
||||
# 每50次打印一次日志,避免过多输出
|
||||
if not hasattr(self, 'pause_discard_count') or self.pause_discard_count % 50 == 0:
|
||||
print(f"暂停期间丢弃音频数据,发送静音数据 (次数: {getattr(self, 'pause_discard_count', 0) + 1})")
|
||||
self.pause_discard_count = getattr(self, 'pause_discard_count', 0) + 1
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
# 直接发送所有音频数据,不进行静音检测
|
||||
save_input_pcm_to_wav(audio_data, "input.pcm")
|
||||
await self.client.task_request(audio_data)
|
||||
|
||||
await asyncio.sleep(0.01) # 避免CPU过度使用
|
||||
except Exception as e:
|
||||
print(f"读取麦克风数据出错: {e}")
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
import uuid
|
||||
|
||||
import pyaudio
|
||||
|
||||
# 配置信息
|
||||
ws_connect_config = {
|
||||
"base_url": "wss://openspeech.bytedance.com/api/v3/realtime/dialogue",
|
||||
"headers": {
|
||||
"X-Api-App-ID": "",
|
||||
"X-Api-Access-Key": "",
|
||||
"X-Api-App-ID": "8718217928",
|
||||
"X-Api-Access-Key": "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc",
|
||||
"X-Api-Resource-Id": "volc.speech.dialog", # 固定值
|
||||
"X-Api-App-Key": "PlgvMymc7f3tQnJ6", # 固定值
|
||||
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||||
|
||||
BIN
doubao/input.pcm
Normal file
BIN
doubao/input.pcm
Normal file
Binary file not shown.
BIN
doubao/output.pcm
Normal file
BIN
doubao/output.pcm
Normal file
Binary file not shown.
@ -22,12 +22,19 @@ class RealtimeDialogClient:
|
||||
async def connect(self) -> None:
|
||||
"""建立WebSocket连接"""
|
||||
print(f"url: {self.config['base_url']}, headers: {self.config['headers']}")
|
||||
# For older websockets versions, use additional_headers instead of extra_headers
|
||||
self.ws = await websockets.connect(
|
||||
self.config['base_url'],
|
||||
extra_headers=self.config['headers'],
|
||||
additional_headers=self.config['headers'],
|
||||
ping_interval=None
|
||||
)
|
||||
# In older websockets versions, response headers are accessed differently
|
||||
if hasattr(self.ws, 'response_headers'):
|
||||
self.logid = self.ws.response_headers.get("X-Tt-Logid")
|
||||
elif hasattr(self.ws, 'headers'):
|
||||
self.logid = self.ws.headers.get("X-Tt-Logid")
|
||||
else:
|
||||
self.logid = "unknown"
|
||||
print(f"dialog server response logid: {self.logid}")
|
||||
|
||||
# StartConnection request
|
||||
|
||||
Loading…
Reference in New Issue
Block a user