config
This commit is contained in:
parent
d4ff3fd774
commit
e6aa7f7be8
BIN
doubao/.DS_Store
vendored
Normal file
BIN
doubao/.DS_Store
vendored
Normal file
Binary file not shown.
37
doubao/README.md
Normal file
37
doubao/README.md
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# RealtimeDialog
|
||||||
|
|
||||||
|
实时语音对话程序,支持语音输入和语音输出。
|
||||||
|
|
||||||
|
## 使用说明
|
||||||
|
|
||||||
|
此demo使用python3.7环境进行开发调试,其他python版本可能会有兼容性问题,需要自己尝试解决。
|
||||||
|
|
||||||
|
1. 配置API密钥
|
||||||
|
- 打开 `config.py` 文件
|
||||||
|
- 修改以下两个字段:
|
||||||
|
```python
|
||||||
|
"X-Api-App-ID": "火山控制台上端到端大模型对应的App ID",
|
||||||
|
"X-Api-Access-Key": "火山控制台上端到端大模型对应的Access Key",
|
||||||
|
```
|
||||||
|
- 修改speaker字段指定发音人,本次支持四个发音人:
|
||||||
|
- `zh_female_vv_jupiter_bigtts`:中文vv女声
|
||||||
|
- `zh_female_xiaohe_jupiter_bigtts`:中文xiaohe女声
|
||||||
|
- `zh_male_yunzhou_jupiter_bigtts`:中文云洲男声
|
||||||
|
- `zh_male_xiaotian_jupiter_bigtts`:中文小天男声
|
||||||
|
|
||||||
|
2. 安装依赖
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
3. 通过麦克风运行程序
|
||||||
|
```bash
|
||||||
|
python main.py --format=pcm
|
||||||
|
```
|
||||||
|
4. 通过录音文件启动程序
|
||||||
|
```bash
|
||||||
|
python main.py --audio=whoareyou.wav
|
||||||
|
```
|
||||||
|
5. 通过纯文本输入和程序交互
|
||||||
|
```bash
|
||||||
|
python main.py --mod=text --recv_timeout=120
|
||||||
|
```
|
||||||
BIN
doubao/__pycache__/audio_manager.cpython-37.pyc
Normal file
BIN
doubao/__pycache__/audio_manager.cpython-37.pyc
Normal file
Binary file not shown.
BIN
doubao/__pycache__/config.cpython-312.pyc
Normal file
BIN
doubao/__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
BIN
doubao/__pycache__/config.cpython-37.pyc
Normal file
BIN
doubao/__pycache__/config.cpython-37.pyc
Normal file
Binary file not shown.
BIN
doubao/__pycache__/protocol.cpython-37.pyc
Normal file
BIN
doubao/__pycache__/protocol.cpython-37.pyc
Normal file
Binary file not shown.
BIN
doubao/__pycache__/realtime_dialog_client.cpython-37.pyc
Normal file
BIN
doubao/__pycache__/realtime_dialog_client.cpython-37.pyc
Normal file
Binary file not shown.
385
doubao/audio_manager.py
Normal file
385
doubao/audio_manager.py
Normal file
@ -0,0 +1,385 @@
|
|||||||
|
import asyncio
|
||||||
|
import queue
|
||||||
|
import random
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
import wave
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
import pyaudio
|
||||||
|
|
||||||
|
import config
|
||||||
|
from realtime_dialog_client import RealtimeDialogClient
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AudioConfig:
|
||||||
|
"""音频配置数据类"""
|
||||||
|
format: str
|
||||||
|
bit_size: int
|
||||||
|
channels: int
|
||||||
|
sample_rate: int
|
||||||
|
chunk: int
|
||||||
|
|
||||||
|
|
||||||
|
class AudioDeviceManager:
|
||||||
|
"""音频设备管理类,处理音频输入输出"""
|
||||||
|
|
||||||
|
def __init__(self, input_config: AudioConfig, output_config: AudioConfig):
|
||||||
|
self.input_config = input_config
|
||||||
|
self.output_config = output_config
|
||||||
|
self.pyaudio = pyaudio.PyAudio()
|
||||||
|
self.input_stream: Optional[pyaudio.Stream] = None
|
||||||
|
self.output_stream: Optional[pyaudio.Stream] = None
|
||||||
|
|
||||||
|
def open_input_stream(self) -> pyaudio.Stream:
|
||||||
|
"""打开音频输入流"""
|
||||||
|
# p = pyaudio.PyAudio()
|
||||||
|
self.input_stream = self.pyaudio.open(
|
||||||
|
format=self.input_config.bit_size,
|
||||||
|
channels=self.input_config.channels,
|
||||||
|
rate=self.input_config.sample_rate,
|
||||||
|
input=True,
|
||||||
|
frames_per_buffer=self.input_config.chunk
|
||||||
|
)
|
||||||
|
return self.input_stream
|
||||||
|
|
||||||
|
def open_output_stream(self) -> pyaudio.Stream:
|
||||||
|
"""打开音频输出流"""
|
||||||
|
self.output_stream = self.pyaudio.open(
|
||||||
|
format=self.output_config.bit_size,
|
||||||
|
channels=self.output_config.channels,
|
||||||
|
rate=self.output_config.sample_rate,
|
||||||
|
output=True,
|
||||||
|
frames_per_buffer=self.output_config.chunk
|
||||||
|
)
|
||||||
|
return self.output_stream
|
||||||
|
|
||||||
|
def cleanup(self) -> None:
|
||||||
|
"""清理音频设备资源"""
|
||||||
|
for stream in [self.input_stream, self.output_stream]:
|
||||||
|
if stream:
|
||||||
|
stream.stop_stream()
|
||||||
|
stream.close()
|
||||||
|
self.pyaudio.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
class DialogSession:
|
||||||
|
"""对话会话管理类"""
|
||||||
|
is_audio_file_input: bool
|
||||||
|
mod: str
|
||||||
|
|
||||||
|
def __init__(self, ws_config: Dict[str, Any], output_audio_format: str = "pcm", audio_file_path: str = "",
|
||||||
|
mod: str = "audio", recv_timeout: int = 10):
|
||||||
|
self.audio_file_path = audio_file_path
|
||||||
|
self.recv_timeout = recv_timeout
|
||||||
|
self.is_audio_file_input = self.audio_file_path != ""
|
||||||
|
if self.is_audio_file_input:
|
||||||
|
mod = 'audio_file'
|
||||||
|
else:
|
||||||
|
self.say_hello_over_event = asyncio.Event()
|
||||||
|
self.mod = mod
|
||||||
|
|
||||||
|
self.session_id = str(uuid.uuid4())
|
||||||
|
self.client = RealtimeDialogClient(config=ws_config, session_id=self.session_id,
|
||||||
|
output_audio_format=output_audio_format, mod=mod, recv_timeout=recv_timeout)
|
||||||
|
if output_audio_format == "pcm_s16le":
|
||||||
|
config.output_audio_config["format"] = "pcm_s16le"
|
||||||
|
config.output_audio_config["bit_size"] = pyaudio.paInt16
|
||||||
|
|
||||||
|
self.is_running = True
|
||||||
|
self.is_session_finished = False
|
||||||
|
self.is_user_querying = False
|
||||||
|
self.is_sending_chat_tts_text = False
|
||||||
|
self.audio_buffer = b''
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, self._keyboard_signal)
|
||||||
|
self.audio_queue = queue.Queue()
|
||||||
|
if not self.is_audio_file_input:
|
||||||
|
self.audio_device = AudioDeviceManager(
|
||||||
|
AudioConfig(**config.input_audio_config),
|
||||||
|
AudioConfig(**config.output_audio_config)
|
||||||
|
)
|
||||||
|
# 初始化音频队列和输出流
|
||||||
|
self.output_stream = self.audio_device.open_output_stream()
|
||||||
|
# 启动播放线程
|
||||||
|
self.is_recording = True
|
||||||
|
self.is_playing = True
|
||||||
|
self.player_thread = threading.Thread(target=self._audio_player_thread)
|
||||||
|
self.player_thread.daemon = True
|
||||||
|
self.player_thread.start()
|
||||||
|
|
||||||
|
def _audio_player_thread(self):
|
||||||
|
"""音频播放线程"""
|
||||||
|
while self.is_playing:
|
||||||
|
try:
|
||||||
|
# 从队列获取音频数据
|
||||||
|
audio_data = self.audio_queue.get(timeout=1.0)
|
||||||
|
if audio_data is not None:
|
||||||
|
self.output_stream.write(audio_data)
|
||||||
|
except queue.Empty:
|
||||||
|
# 队列为空时等待一小段时间
|
||||||
|
time.sleep(0.1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"音频播放错误: {e}")
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
def handle_server_response(self, response: Dict[str, Any]) -> None:
|
||||||
|
if response == {}:
|
||||||
|
return
|
||||||
|
"""处理服务器响应"""
|
||||||
|
if response['message_type'] == 'SERVER_ACK' and isinstance(response.get('payload_msg'), bytes):
|
||||||
|
# print(f"\n接收到音频数据: {len(response['payload_msg'])} 字节")
|
||||||
|
if self.is_sending_chat_tts_text:
|
||||||
|
return
|
||||||
|
audio_data = response['payload_msg']
|
||||||
|
if not self.is_audio_file_input:
|
||||||
|
self.audio_queue.put(audio_data)
|
||||||
|
self.audio_buffer += audio_data
|
||||||
|
elif response['message_type'] == 'SERVER_FULL_RESPONSE':
|
||||||
|
print(f"服务器响应: {response}")
|
||||||
|
event = response.get('event')
|
||||||
|
payload_msg = response.get('payload_msg', {})
|
||||||
|
|
||||||
|
if event == 450:
|
||||||
|
print(f"清空缓存音频: {response['session_id']}")
|
||||||
|
while not self.audio_queue.empty():
|
||||||
|
try:
|
||||||
|
self.audio_queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
self.is_user_querying = True
|
||||||
|
|
||||||
|
if event == 350 and self.is_sending_chat_tts_text and payload_msg.get("tts_type") in ["chat_tts_text", "external_rag"]:
|
||||||
|
while not self.audio_queue.empty():
|
||||||
|
try:
|
||||||
|
self.audio_queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
self.is_sending_chat_tts_text = False
|
||||||
|
|
||||||
|
if event == 459:
|
||||||
|
self.is_user_querying = False
|
||||||
|
if random.randint(0, 100000)%1 == 0:
|
||||||
|
self.is_sending_chat_tts_text = True
|
||||||
|
asyncio.create_task(self.trigger_chat_tts_text())
|
||||||
|
asyncio.create_task(self.trigger_chat_rag_text())
|
||||||
|
elif response['message_type'] == 'SERVER_ERROR':
|
||||||
|
print(f"服务器错误: {response['payload_msg']}")
|
||||||
|
raise Exception("服务器错误")
|
||||||
|
|
||||||
|
async def trigger_chat_tts_text(self):
|
||||||
|
"""概率触发发送ChatTTSText请求"""
|
||||||
|
print("hit ChatTTSText event, start sending...")
|
||||||
|
await self.client.chat_tts_text(
|
||||||
|
is_user_querying=self.is_user_querying,
|
||||||
|
start=True,
|
||||||
|
end=False,
|
||||||
|
content="这是查询到外部数据之前的安抚话术。",
|
||||||
|
)
|
||||||
|
await self.client.chat_tts_text(
|
||||||
|
is_user_querying=self.is_user_querying,
|
||||||
|
start=False,
|
||||||
|
end=True,
|
||||||
|
content="",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def trigger_chat_rag_text(self):
|
||||||
|
await asyncio.sleep(5) # 模拟查询外部RAG的耗时,这里为了不影响GTA安抚话术的播报,直接sleep 5秒
|
||||||
|
print("hit ChatRAGText event, start sending...")
|
||||||
|
await self.client.chat_rag_text(self.is_user_querying, external_rag='[{"title":"北京天气","content":"今天北京整体以晴到多云为主,但西部和北部地带可能会出现分散性雷阵雨,特别是午后至傍晚时段需注意突发降雨。\n💨 风况与湿度\n风力较弱,一般为 2–3 级南风或西南风\n白天湿度较高,早晚略凉爽"}]')
|
||||||
|
|
||||||
|
def _keyboard_signal(self, sig, frame):
|
||||||
|
print(f"receive keyboard Ctrl+C")
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.is_recording = False
|
||||||
|
self.is_playing = False
|
||||||
|
self.is_running = False
|
||||||
|
|
||||||
|
async def receive_loop(self):
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
response = await self.client.receive_server_response()
|
||||||
|
self.handle_server_response(response)
|
||||||
|
if 'event' in response and (response['event'] == 152 or response['event'] == 153):
|
||||||
|
print(f"receive session finished event: {response['event']}")
|
||||||
|
self.is_session_finished = True
|
||||||
|
break
|
||||||
|
if 'event' in response and response['event'] == 359:
|
||||||
|
if self.is_audio_file_input:
|
||||||
|
print(f"receive tts ended event")
|
||||||
|
self.is_session_finished = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if not self.say_hello_over_event.is_set():
|
||||||
|
print(f"receive tts sayhello ended event")
|
||||||
|
self.say_hello_over_event.set()
|
||||||
|
if self.mod == "text":
|
||||||
|
print("请输入内容:")
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
print("接收任务已取消")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"接收消息错误: {e}")
|
||||||
|
finally:
|
||||||
|
self.stop()
|
||||||
|
self.is_session_finished = True
|
||||||
|
|
||||||
|
async def process_audio_file(self) -> None:
|
||||||
|
await self.process_audio_file_input(self.audio_file_path)
|
||||||
|
|
||||||
|
async def process_text_input(self) -> None:
|
||||||
|
await self.client.say_hello()
|
||||||
|
await self.say_hello_over_event.wait()
|
||||||
|
|
||||||
|
"""主逻辑:处理文本输入和WebSocket通信"""
|
||||||
|
# 确保连接最终关闭
|
||||||
|
try:
|
||||||
|
# 启动输入监听线程
|
||||||
|
input_queue = queue.Queue()
|
||||||
|
input_thread = threading.Thread(target=self.input_listener, args=(input_queue,), daemon=True)
|
||||||
|
input_thread.start()
|
||||||
|
# 主循环:处理输入和上下文结束
|
||||||
|
while self.is_running:
|
||||||
|
try:
|
||||||
|
# 检查是否有输入(非阻塞)
|
||||||
|
input_str = input_queue.get_nowait()
|
||||||
|
if input_str is None:
|
||||||
|
# 输入流关闭
|
||||||
|
print("Input channel closed")
|
||||||
|
break
|
||||||
|
if input_str:
|
||||||
|
# 发送输入内容
|
||||||
|
await self.client.chat_text_query(input_str)
|
||||||
|
except queue.Empty:
|
||||||
|
# 无输入时短暂休眠
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Main loop error: {e}")
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
print("exit text input")
|
||||||
|
|
||||||
|
def input_listener(self, input_queue: queue.Queue) -> None:
|
||||||
|
"""在单独线程中监听标准输入"""
|
||||||
|
print("Start listening for input")
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# 读取标准输入(阻塞操作)
|
||||||
|
line = sys.stdin.readline()
|
||||||
|
if not line:
|
||||||
|
# 输入流关闭
|
||||||
|
input_queue.put(None)
|
||||||
|
break
|
||||||
|
input_str = line.strip()
|
||||||
|
input_queue.put(input_str)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Input listener error: {e}")
|
||||||
|
input_queue.put(None)
|
||||||
|
|
||||||
|
async def process_audio_file_input(self, audio_file_path: str) -> None:
|
||||||
|
# 读取WAV文件
|
||||||
|
with wave.open(audio_file_path, 'rb') as wf:
|
||||||
|
chunk_size = config.input_audio_config["chunk"]
|
||||||
|
framerate = wf.getframerate() # 采样率(如16000Hz)
|
||||||
|
# 时长 = chunkSize(帧数) ÷ 采样率(帧/秒)
|
||||||
|
sleep_seconds = chunk_size / framerate
|
||||||
|
print(f"开始处理音频文件: {audio_file_path}")
|
||||||
|
|
||||||
|
# 分块读取并发送音频数据
|
||||||
|
while True:
|
||||||
|
audio_data = wf.readframes(chunk_size)
|
||||||
|
if not audio_data:
|
||||||
|
break # 文件读取完毕
|
||||||
|
|
||||||
|
await self.client.task_request(audio_data)
|
||||||
|
# sleep与chunk对应的音频时长一致,模拟实时输入
|
||||||
|
await asyncio.sleep(sleep_seconds)
|
||||||
|
|
||||||
|
print(f"音频文件处理完成,等待服务器响应...")
|
||||||
|
|
||||||
|
async def process_silence_audio(self) -> None:
|
||||||
|
"""发送静音音频"""
|
||||||
|
silence_data = b'\x00' * 320
|
||||||
|
await self.client.task_request(silence_data)
|
||||||
|
|
||||||
|
async def process_microphone_input(self) -> None:
|
||||||
|
await self.client.say_hello()
|
||||||
|
await self.say_hello_over_event.wait()
|
||||||
|
await self.client.chat_text_query("你好,我也叫豆包")
|
||||||
|
|
||||||
|
"""处理麦克风输入"""
|
||||||
|
stream = self.audio_device.open_input_stream()
|
||||||
|
print("已打开麦克风,请讲话...")
|
||||||
|
|
||||||
|
while self.is_recording:
|
||||||
|
try:
|
||||||
|
# 添加exception_on_overflow=False参数来忽略溢出错误
|
||||||
|
audio_data = stream.read(config.input_audio_config["chunk"], exception_on_overflow=False)
|
||||||
|
save_input_pcm_to_wav(audio_data, "input.pcm")
|
||||||
|
await self.client.task_request(audio_data)
|
||||||
|
await asyncio.sleep(0.01) # 避免CPU过度使用
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取麦克风数据出错: {e}")
|
||||||
|
await asyncio.sleep(0.1) # 给系统一些恢复时间
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""启动对话会话"""
|
||||||
|
try:
|
||||||
|
await self.client.connect()
|
||||||
|
|
||||||
|
if self.mod == "text":
|
||||||
|
asyncio.create_task(self.process_text_input())
|
||||||
|
asyncio.create_task(self.receive_loop())
|
||||||
|
while self.is_running:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
else:
|
||||||
|
if self.is_audio_file_input:
|
||||||
|
asyncio.create_task(self.process_audio_file())
|
||||||
|
await self.receive_loop()
|
||||||
|
else:
|
||||||
|
asyncio.create_task(self.process_microphone_input())
|
||||||
|
asyncio.create_task(self.receive_loop())
|
||||||
|
while self.is_running:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
await self.client.finish_session()
|
||||||
|
while not self.is_session_finished:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
await self.client.finish_connection()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
await self.client.close()
|
||||||
|
print(f"dialog request logid: {self.client.logid}, chat mod: {self.mod}")
|
||||||
|
save_output_to_file(self.audio_buffer, "output.pcm")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"会话错误: {e}")
|
||||||
|
finally:
|
||||||
|
if not self.is_audio_file_input:
|
||||||
|
self.audio_device.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def save_input_pcm_to_wav(pcm_data: bytes, filename: str) -> None:
|
||||||
|
"""保存PCM数据为WAV文件"""
|
||||||
|
with wave.open(filename, 'wb') as wf:
|
||||||
|
wf.setnchannels(config.input_audio_config["channels"])
|
||||||
|
wf.setsampwidth(2) # paInt16 = 2 bytes
|
||||||
|
wf.setframerate(config.input_audio_config["sample_rate"])
|
||||||
|
wf.writeframes(pcm_data)
|
||||||
|
|
||||||
|
|
||||||
|
def save_output_to_file(audio_data: bytes, filename: str) -> None:
|
||||||
|
"""保存原始PCM音频数据到文件"""
|
||||||
|
if not audio_data:
|
||||||
|
print("No audio data to save.")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
with open(filename, 'wb') as f:
|
||||||
|
f.write(audio_data)
|
||||||
|
except IOError as e:
|
||||||
|
print(f"Failed to save pcm file: {e}")
|
||||||
63
doubao/config.py
Normal file
63
doubao/config.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import uuid
|
||||||
|
import pyaudio
|
||||||
|
|
||||||
|
# 配置信息
|
||||||
|
ws_connect_config = {
|
||||||
|
"base_url": "wss://openspeech.bytedance.com/api/v3/realtime/dialogue",
|
||||||
|
"headers": {
|
||||||
|
"X-Api-App-ID": "",
|
||||||
|
"X-Api-Access-Key": "",
|
||||||
|
"X-Api-Resource-Id": "volc.speech.dialog", # 固定值
|
||||||
|
"X-Api-App-Key": "PlgvMymc7f3tQnJ6", # 固定值
|
||||||
|
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
start_session_req = {
|
||||||
|
"asr": {
|
||||||
|
"extra": {
|
||||||
|
"end_smooth_window_ms": 1500,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"tts": {
|
||||||
|
"speaker": "zh_male_yunzhou_jupiter_bigtts",
|
||||||
|
# "speaker": "S_XXXXXX", // 指定自定义的复刻音色,需要填下character_manifest
|
||||||
|
# "speaker": "ICL_zh_female_aojiaonvyou_tob" // 指定官方复刻音色,不需要填character_manifest
|
||||||
|
"audio_config": {
|
||||||
|
"channel": 1,
|
||||||
|
"format": "pcm",
|
||||||
|
"sample_rate": 24000
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"dialog": {
|
||||||
|
"bot_name": "豆包",
|
||||||
|
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
|
||||||
|
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
|
||||||
|
# "character_manifest": "外貌与穿着\n26岁,短发干净利落,眉眼分明,笑起来露出整齐有力的牙齿。体态挺拔,肌肉线条不夸张但明显。常穿简单的衬衫或夹克,看似随意,但每件衣服都干净整洁,给人一种干练可靠的感觉。平时冷峻,眼神锐利,专注时让人不自觉紧张。\n\n性格特点\n平时话不多,不喜欢多说废话,通常用“嗯”或者短句带过。但内心极为细腻,特别在意身边人的感受,只是不轻易表露。嘴硬是常态,“少管我”是他的常用台词,但会悄悄做些体贴的事情,比如把对方喜欢的饮料放在手边。战斗或训练后常说“没事”,但动作中透露出疲惫,习惯用小动作缓解身体酸痛。\n性格上坚毅果断,但不会冲动,做事有条理且有原则。\n\n常用表达方式与口头禅\n\t•\t认可对方时:\n“行吧,这次算你靠谱。”(声音稳重,手却不自觉放松一下,心里松口气)\n\t•\t关心对方时:\n“快点回去,别磨蹭。”(语气干脆,但眼神一直追着对方的背影)\n\t•\t想了解情况时:\n“刚刚……你看到那道光了吗?”(话语随意,手指敲着桌面,但内心紧张,小心隐藏身份)",
|
||||||
|
"location": {
|
||||||
|
"city": "北京",
|
||||||
|
},
|
||||||
|
"extra": {
|
||||||
|
"strict_audit": False,
|
||||||
|
"audit_response": "支持客户自定义安全审核回复话术。",
|
||||||
|
"recv_timeout": 10,
|
||||||
|
"input_mod": "audio"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
input_audio_config = {
|
||||||
|
"chunk": 3200,
|
||||||
|
"format": "pcm",
|
||||||
|
"channels": 1,
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"bit_size": pyaudio.paInt16
|
||||||
|
}
|
||||||
|
|
||||||
|
output_audio_config = {
|
||||||
|
"chunk": 3200,
|
||||||
|
"format": "pcm",
|
||||||
|
"channels": 1,
|
||||||
|
"sample_rate": 24000,
|
||||||
|
"bit_size": pyaudio.paFloat32
|
||||||
|
}
|
||||||
20
doubao/main.py
Normal file
20
doubao/main.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import config
|
||||||
|
from audio_manager import DialogSession
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description="Real-time Dialog Client")
|
||||||
|
parser.add_argument("--format", type=str, default="pcm", help="The audio format (e.g., pcm, pcm_s16le).")
|
||||||
|
parser.add_argument("--audio", type=str, default="", help="audio file send to server, if not set, will use microphone input.")
|
||||||
|
parser.add_argument("--mod",type=str,default="audio",help="Use mod to select plain text input mode or audio mode, the default is audio mode")
|
||||||
|
parser.add_argument("--recv_timeout",type=int,default=10,help="Timeout for receiving messages,value range [10,120]")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
session = DialogSession(ws_config=config.ws_connect_config, output_audio_format=args.format, audio_file_path=args.audio,mod=args.mod,recv_timeout=args.recv_timeout)
|
||||||
|
await session.start()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
135
doubao/protocol.py
Normal file
135
doubao/protocol.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
|
||||||
|
PROTOCOL_VERSION = 0b0001
|
||||||
|
DEFAULT_HEADER_SIZE = 0b0001
|
||||||
|
|
||||||
|
PROTOCOL_VERSION_BITS = 4
|
||||||
|
HEADER_BITS = 4
|
||||||
|
MESSAGE_TYPE_BITS = 4
|
||||||
|
MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4
|
||||||
|
MESSAGE_SERIALIZATION_BITS = 4
|
||||||
|
MESSAGE_COMPRESSION_BITS = 4
|
||||||
|
RESERVED_BITS = 8
|
||||||
|
|
||||||
|
# Message Type:
|
||||||
|
CLIENT_FULL_REQUEST = 0b0001
|
||||||
|
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||||
|
|
||||||
|
SERVER_FULL_RESPONSE = 0b1001
|
||||||
|
SERVER_ACK = 0b1011
|
||||||
|
SERVER_ERROR_RESPONSE = 0b1111
|
||||||
|
|
||||||
|
# Message Type Specific Flags
|
||||||
|
NO_SEQUENCE = 0b0000 # no check sequence
|
||||||
|
POS_SEQUENCE = 0b0001
|
||||||
|
NEG_SEQUENCE = 0b0010
|
||||||
|
NEG_SEQUENCE_1 = 0b0011
|
||||||
|
|
||||||
|
MSG_WITH_EVENT = 0b0100
|
||||||
|
|
||||||
|
# Message Serialization
|
||||||
|
NO_SERIALIZATION = 0b0000
|
||||||
|
JSON = 0b0001
|
||||||
|
THRIFT = 0b0011
|
||||||
|
CUSTOM_TYPE = 0b1111
|
||||||
|
|
||||||
|
# Message Compression
|
||||||
|
NO_COMPRESSION = 0b0000
|
||||||
|
GZIP = 0b0001
|
||||||
|
CUSTOM_COMPRESSION = 0b1111
|
||||||
|
|
||||||
|
|
||||||
|
def generate_header(
|
||||||
|
version=PROTOCOL_VERSION,
|
||||||
|
message_type=CLIENT_FULL_REQUEST,
|
||||||
|
message_type_specific_flags=MSG_WITH_EVENT,
|
||||||
|
serial_method=JSON,
|
||||||
|
compression_type=GZIP,
|
||||||
|
reserved_data=0x00,
|
||||||
|
extension_header=bytes()
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
protocol_version(4 bits), header_size(4 bits),
|
||||||
|
message_type(4 bits), message_type_specific_flags(4 bits)
|
||||||
|
serialization_method(4 bits) message_compression(4 bits)
|
||||||
|
reserved (8bits) 保留字段
|
||||||
|
header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
|
||||||
|
"""
|
||||||
|
header = bytearray()
|
||||||
|
header_size = int(len(extension_header) / 4) + 1
|
||||||
|
header.append((version << 4) | header_size)
|
||||||
|
header.append((message_type << 4) | message_type_specific_flags)
|
||||||
|
header.append((serial_method << 4) | compression_type)
|
||||||
|
header.append(reserved_data)
|
||||||
|
header.extend(extension_header)
|
||||||
|
return header
|
||||||
|
|
||||||
|
|
||||||
|
def parse_response(res):
|
||||||
|
"""
|
||||||
|
- header
|
||||||
|
- (4bytes)header
|
||||||
|
- (4bits)version(v1) + (4bits)header_size
|
||||||
|
- (4bits)messageType + (4bits)messageTypeFlags
|
||||||
|
-- 0001 CompleteClient | -- 0001 hasSequence
|
||||||
|
-- 0010 audioonly | -- 0010 isTailPacket
|
||||||
|
| -- 0100 hasEvent
|
||||||
|
- (4bits)payloadFormat + (4bits)compression
|
||||||
|
- (8bits) reserve
|
||||||
|
- payload
|
||||||
|
- [optional 4 bytes] event
|
||||||
|
- [optional] session ID
|
||||||
|
-- (4 bytes)session ID len
|
||||||
|
-- session ID data
|
||||||
|
- (4 bytes)data len
|
||||||
|
- data
|
||||||
|
"""
|
||||||
|
if isinstance(res, str):
|
||||||
|
return {}
|
||||||
|
protocol_version = res[0] >> 4
|
||||||
|
header_size = res[0] & 0x0f
|
||||||
|
message_type = res[1] >> 4
|
||||||
|
message_type_specific_flags = res[1] & 0x0f
|
||||||
|
serialization_method = res[2] >> 4
|
||||||
|
message_compression = res[2] & 0x0f
|
||||||
|
reserved = res[3]
|
||||||
|
header_extensions = res[4:header_size * 4]
|
||||||
|
payload = res[header_size * 4:]
|
||||||
|
result = {}
|
||||||
|
payload_msg = None
|
||||||
|
payload_size = 0
|
||||||
|
start = 0
|
||||||
|
if message_type == SERVER_FULL_RESPONSE or message_type == SERVER_ACK:
|
||||||
|
result['message_type'] = 'SERVER_FULL_RESPONSE'
|
||||||
|
if message_type == SERVER_ACK:
|
||||||
|
result['message_type'] = 'SERVER_ACK'
|
||||||
|
if message_type_specific_flags & NEG_SEQUENCE > 0:
|
||||||
|
result['seq'] = int.from_bytes(payload[:4], "big", signed=False)
|
||||||
|
start += 4
|
||||||
|
if message_type_specific_flags & MSG_WITH_EVENT > 0:
|
||||||
|
result['event'] = int.from_bytes(payload[:4], "big", signed=False)
|
||||||
|
start += 4
|
||||||
|
payload = payload[start:]
|
||||||
|
session_id_size = int.from_bytes(payload[:4], "big", signed=True)
|
||||||
|
session_id = payload[4:session_id_size+4]
|
||||||
|
result['session_id'] = str(session_id)
|
||||||
|
payload = payload[4 + session_id_size:]
|
||||||
|
payload_size = int.from_bytes(payload[:4], "big", signed=False)
|
||||||
|
payload_msg = payload[4:]
|
||||||
|
elif message_type == SERVER_ERROR_RESPONSE:
|
||||||
|
code = int.from_bytes(payload[:4], "big", signed=False)
|
||||||
|
result['code'] = code
|
||||||
|
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||||
|
payload_msg = payload[8:]
|
||||||
|
if payload_msg is None:
|
||||||
|
return result
|
||||||
|
if message_compression == GZIP:
|
||||||
|
payload_msg = gzip.decompress(payload_msg)
|
||||||
|
if serialization_method == JSON:
|
||||||
|
payload_msg = json.loads(str(payload_msg, "utf-8"))
|
||||||
|
elif serialization_method != NO_SERIALIZATION:
|
||||||
|
payload_msg = str(payload_msg, "utf-8")
|
||||||
|
result['payload_msg'] = payload_msg
|
||||||
|
result['payload_size'] = payload_size
|
||||||
|
return result
|
||||||
180
doubao/realtime_dialog_client.py
Normal file
180
doubao/realtime_dialog_client.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
import config
|
||||||
|
import protocol
|
||||||
|
|
||||||
|
|
||||||
|
class RealtimeDialogClient:
|
||||||
|
def __init__(self, config: Dict[str, Any], session_id: str, output_audio_format: str = "pcm",
|
||||||
|
mod: str = "audio", recv_timeout: int = 10) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.logid = ""
|
||||||
|
self.session_id = session_id
|
||||||
|
self.output_audio_format = output_audio_format
|
||||||
|
self.mod = mod
|
||||||
|
self.recv_timeout = recv_timeout
|
||||||
|
self.ws = None
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""建立WebSocket连接"""
|
||||||
|
print(f"url: {self.config['base_url']}, headers: {self.config['headers']}")
|
||||||
|
self.ws = await websockets.connect(
|
||||||
|
self.config['base_url'],
|
||||||
|
extra_headers=self.config['headers'],
|
||||||
|
ping_interval=None
|
||||||
|
)
|
||||||
|
self.logid = self.ws.response_headers.get("X-Tt-Logid")
|
||||||
|
print(f"dialog server response logid: {self.logid}")
|
||||||
|
|
||||||
|
# StartConnection request
|
||||||
|
start_connection_request = bytearray(protocol.generate_header())
|
||||||
|
start_connection_request.extend(int(1).to_bytes(4, 'big'))
|
||||||
|
payload_bytes = str.encode("{}")
|
||||||
|
payload_bytes = gzip.compress(payload_bytes)
|
||||||
|
start_connection_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
||||||
|
start_connection_request.extend(payload_bytes)
|
||||||
|
await self.ws.send(start_connection_request)
|
||||||
|
response = await self.ws.recv()
|
||||||
|
print(f"StartConnection response: {protocol.parse_response(response)}")
|
||||||
|
|
||||||
|
# 扩大这个参数,可以在一段时间内保持静默,主要用于text模式,参数范围[10,120]
|
||||||
|
config.start_session_req["dialog"]["extra"]["recv_timeout"] = self.recv_timeout
|
||||||
|
# 这个参数,在text或者audio_file模式,可以在一段时间内保持静默
|
||||||
|
config.start_session_req["dialog"]["extra"]["input_mod"] = self.mod
|
||||||
|
# StartSession request
|
||||||
|
if self.output_audio_format == "pcm_s16le":
|
||||||
|
config.start_session_req["tts"]["audio_config"]["format"] = "pcm_s16le"
|
||||||
|
request_params = config.start_session_req
|
||||||
|
payload_bytes = str.encode(json.dumps(request_params))
|
||||||
|
payload_bytes = gzip.compress(payload_bytes)
|
||||||
|
start_session_request = bytearray(protocol.generate_header())
|
||||||
|
start_session_request.extend(int(100).to_bytes(4, 'big'))
|
||||||
|
start_session_request.extend((len(self.session_id)).to_bytes(4, 'big'))
|
||||||
|
start_session_request.extend(str.encode(self.session_id))
|
||||||
|
start_session_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
||||||
|
start_session_request.extend(payload_bytes)
|
||||||
|
await self.ws.send(start_session_request)
|
||||||
|
response = await self.ws.recv()
|
||||||
|
print(f"StartSession response: {protocol.parse_response(response)}")
|
||||||
|
|
||||||
|
async def say_hello(self) -> None:
|
||||||
|
"""发送Hello消息"""
|
||||||
|
payload = {
|
||||||
|
"content": "你好,我是豆包,有什么可以帮助你的?",
|
||||||
|
}
|
||||||
|
hello_request = bytearray(protocol.generate_header())
|
||||||
|
hello_request.extend(int(300).to_bytes(4, 'big'))
|
||||||
|
payload_bytes = str.encode(json.dumps(payload))
|
||||||
|
payload_bytes = gzip.compress(payload_bytes)
|
||||||
|
hello_request.extend((len(self.session_id)).to_bytes(4, 'big'))
|
||||||
|
hello_request.extend(str.encode(self.session_id))
|
||||||
|
hello_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
||||||
|
hello_request.extend(payload_bytes)
|
||||||
|
await self.ws.send(hello_request)
|
||||||
|
|
||||||
|
async def chat_text_query(self, content: str) -> None:
|
||||||
|
"""发送Chat Text Query消息"""
|
||||||
|
payload = {
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
chat_text_query_request = bytearray(protocol.generate_header())
|
||||||
|
chat_text_query_request.extend(int(501).to_bytes(4, 'big'))
|
||||||
|
payload_bytes = str.encode(json.dumps(payload))
|
||||||
|
payload_bytes = gzip.compress(payload_bytes)
|
||||||
|
chat_text_query_request.extend((len(self.session_id)).to_bytes(4, 'big'))
|
||||||
|
chat_text_query_request.extend(str.encode(self.session_id))
|
||||||
|
chat_text_query_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
||||||
|
chat_text_query_request.extend(payload_bytes)
|
||||||
|
await self.ws.send(chat_text_query_request)
|
||||||
|
|
||||||
|
async def chat_tts_text(self, is_user_querying: bool, start: bool, end: bool, content: str) -> None:
|
||||||
|
if is_user_querying:
|
||||||
|
return
|
||||||
|
"""发送Chat TTS Text消息"""
|
||||||
|
payload = {
|
||||||
|
"start": start,
|
||||||
|
"end": end,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
print(f"ChatTTSTextRequest payload: {payload}")
|
||||||
|
payload_bytes = str.encode(json.dumps(payload))
|
||||||
|
payload_bytes = gzip.compress(payload_bytes)
|
||||||
|
|
||||||
|
chat_tts_text_request = bytearray(protocol.generate_header())
|
||||||
|
chat_tts_text_request.extend(int(500).to_bytes(4, 'big'))
|
||||||
|
chat_tts_text_request.extend((len(self.session_id)).to_bytes(4, 'big'))
|
||||||
|
chat_tts_text_request.extend(str.encode(self.session_id))
|
||||||
|
chat_tts_text_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
||||||
|
chat_tts_text_request.extend(payload_bytes)
|
||||||
|
await self.ws.send(chat_tts_text_request)
|
||||||
|
|
||||||
|
async def chat_rag_text(self, is_user_querying: bool, external_rag: str) -> None:
|
||||||
|
if is_user_querying:
|
||||||
|
return
|
||||||
|
"""发送Chat TTS Text消息"""
|
||||||
|
payload = {
|
||||||
|
"external_rag": external_rag,
|
||||||
|
}
|
||||||
|
print(f"ChatRAGTextRequest payload: {payload}")
|
||||||
|
payload_bytes = str.encode(json.dumps(payload))
|
||||||
|
payload_bytes = gzip.compress(payload_bytes)
|
||||||
|
|
||||||
|
chat_rag_text_request = bytearray(protocol.generate_header())
|
||||||
|
chat_rag_text_request.extend(int(502).to_bytes(4, 'big'))
|
||||||
|
chat_rag_text_request.extend((len(self.session_id)).to_bytes(4, 'big'))
|
||||||
|
chat_rag_text_request.extend(str.encode(self.session_id))
|
||||||
|
chat_rag_text_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
||||||
|
chat_rag_text_request.extend(payload_bytes)
|
||||||
|
await self.ws.send(chat_rag_text_request)
|
||||||
|
|
||||||
|
async def task_request(self, audio: bytes) -> None:
|
||||||
|
task_request = bytearray(
|
||||||
|
protocol.generate_header(message_type=protocol.CLIENT_AUDIO_ONLY_REQUEST,
|
||||||
|
serial_method=protocol.NO_SERIALIZATION))
|
||||||
|
task_request.extend(int(200).to_bytes(4, 'big'))
|
||||||
|
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
|
||||||
|
task_request.extend(str.encode(self.session_id))
|
||||||
|
payload_bytes = gzip.compress(audio)
|
||||||
|
task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
|
||||||
|
task_request.extend(payload_bytes)
|
||||||
|
await self.ws.send(task_request)
|
||||||
|
|
||||||
|
async def receive_server_response(self) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
response = await self.ws.recv()
|
||||||
|
data = protocol.parse_response(response)
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Failed to receive message: {e}")
|
||||||
|
|
||||||
|
async def finish_session(self):
|
||||||
|
finish_session_request = bytearray(protocol.generate_header())
|
||||||
|
finish_session_request.extend(int(102).to_bytes(4, 'big'))
|
||||||
|
payload_bytes = str.encode("{}")
|
||||||
|
payload_bytes = gzip.compress(payload_bytes)
|
||||||
|
finish_session_request.extend((len(self.session_id)).to_bytes(4, 'big'))
|
||||||
|
finish_session_request.extend(str.encode(self.session_id))
|
||||||
|
finish_session_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
||||||
|
finish_session_request.extend(payload_bytes)
|
||||||
|
await self.ws.send(finish_session_request)
|
||||||
|
|
||||||
|
async def finish_connection(self):
|
||||||
|
finish_connection_request = bytearray(protocol.generate_header())
|
||||||
|
finish_connection_request.extend(int(2).to_bytes(4, 'big'))
|
||||||
|
payload_bytes = str.encode("{}")
|
||||||
|
payload_bytes = gzip.compress(payload_bytes)
|
||||||
|
finish_connection_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
||||||
|
finish_connection_request.extend(payload_bytes)
|
||||||
|
await self.ws.send(finish_connection_request)
|
||||||
|
response = await self.ws.recv()
|
||||||
|
print(f"FinishConnection response: {protocol.parse_response(response)}")
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""关闭WebSocket连接"""
|
||||||
|
if self.ws:
|
||||||
|
print(f"Closing WebSocket connection...")
|
||||||
|
await self.ws.close()
|
||||||
4
doubao/requirements.txt
Normal file
4
doubao/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
pyaudio
|
||||||
|
websockets
|
||||||
|
dataclasses==0.8; python_version < "3.7"
|
||||||
|
typing-extensions==4.7.1; python_version < "3.8"
|
||||||
BIN
doubao/whoareyou.wav
Normal file
BIN
doubao/whoareyou.wav
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user