425 lines
19 KiB
Python
425 lines
19 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
豆包音频处理模块 - 简化测试版本
|
||
专门测试完整的音频上传和TTS音频下载流程
|
||
"""
|
||
|
||
import asyncio
|
||
import gzip
|
||
import json
|
||
import uuid
|
||
import wave
|
||
import struct
|
||
import time
|
||
import os
|
||
from typing import Dict, Any, Optional
|
||
import websockets
|
||
|
||
|
||
# 直接复制原始豆包代码的协议常量
|
||
PROTOCOL_VERSION = 0b0001
|
||
CLIENT_FULL_REQUEST = 0b0001
|
||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||
SERVER_FULL_RESPONSE = 0b1001
|
||
SERVER_ACK = 0b1011
|
||
SERVER_ERROR_RESPONSE = 0b1111
|
||
|
||
NO_SEQUENCE = 0b0000
|
||
POS_SEQUENCE = 0b0001
|
||
MSG_WITH_EVENT = 0b0100
|
||
|
||
NO_SERIALIZATION = 0b0000
|
||
JSON = 0b0001
|
||
GZIP = 0b0001
|
||
|
||
|
||
def generate_header(
|
||
version=PROTOCOL_VERSION,
|
||
message_type=CLIENT_FULL_REQUEST,
|
||
message_type_specific_flags=MSG_WITH_EVENT,
|
||
serial_method=JSON,
|
||
compression_type=GZIP,
|
||
reserved_data=0x00,
|
||
extension_header=bytes()
|
||
):
|
||
"""直接复制原始豆包代码的generate_header函数"""
|
||
header = bytearray()
|
||
header_size = int(len(extension_header) / 4) + 1
|
||
header.append((version << 4) | header_size)
|
||
header.append((message_type << 4) | message_type_specific_flags)
|
||
header.append((serial_method << 4) | compression_type)
|
||
header.append(reserved_data)
|
||
header.extend(extension_header)
|
||
return header
|
||
|
||
|
||
class DoubaoConfig:
|
||
"""豆包配置"""
|
||
|
||
def __init__(self):
|
||
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
|
||
self.app_id = "8718217928"
|
||
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
|
||
self.app_key = "PlgvMymc7f3tQnJ6"
|
||
self.resource_id = "volc.speech.dialog"
|
||
|
||
def get_headers(self) -> Dict[str, str]:
|
||
"""获取请求头"""
|
||
return {
|
||
"X-Api-App-ID": self.app_id,
|
||
"X-Api-Access-Key": self.access_key,
|
||
"X-Api-Resource-Id": self.resource_id,
|
||
"X-Api-App-Key": self.app_key,
|
||
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||
}
|
||
|
||
|
||
class DoubaoClient:
|
||
"""豆包客户端 - 基于原始代码"""
|
||
|
||
def __init__(self, config: DoubaoConfig):
|
||
self.config = config
|
||
self.session_id = str(uuid.uuid4())
|
||
self.ws = None
|
||
self.log_id = ""
|
||
|
||
async def connect(self) -> None:
|
||
"""建立WebSocket连接"""
|
||
print(f"连接豆包服务器: {self.config.base_url}")
|
||
|
||
try:
|
||
self.ws = await websockets.connect(
|
||
self.config.base_url,
|
||
additional_headers=self.config.get_headers(),
|
||
ping_interval=None
|
||
)
|
||
|
||
# 获取log_id
|
||
if hasattr(self.ws, 'response_headers'):
|
||
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
|
||
elif hasattr(self.ws, 'headers'):
|
||
self.log_id = self.ws.headers.get("X-Tt-Logid")
|
||
|
||
print(f"连接成功, log_id: {self.log_id}")
|
||
|
||
# 发送StartConnection请求
|
||
await self._send_start_connection()
|
||
|
||
# 发送StartSession请求
|
||
await self._send_start_session()
|
||
|
||
except Exception as e:
|
||
print(f"连接失败: {e}")
|
||
raise
|
||
|
||
async def _send_start_connection(self) -> None:
|
||
"""发送StartConnection请求"""
|
||
print("发送StartConnection请求...")
|
||
request = bytearray(generate_header())
|
||
request.extend(int(1).to_bytes(4, 'big'))
|
||
|
||
payload_bytes = b"{}"
|
||
payload_bytes = gzip.compress(payload_bytes)
|
||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||
request.extend(payload_bytes)
|
||
|
||
await self.ws.send(request)
|
||
response = await self.ws.recv()
|
||
print(f"StartConnection响应长度: {len(response)}")
|
||
|
||
async def _send_start_session(self) -> None:
|
||
"""发送StartSession请求"""
|
||
print("发送StartSession请求...")
|
||
session_config = {
|
||
"asr": {
|
||
"extra": {
|
||
"end_smooth_window_ms": 1500,
|
||
},
|
||
},
|
||
"tts": {
|
||
"speaker": "zh_female_vv_jupiter_bigtts",
|
||
"audio_config": {
|
||
"channel": 1,
|
||
"format": "pcm",
|
||
"sample_rate": 24000
|
||
},
|
||
},
|
||
"dialog": {
|
||
"bot_name": "豆包",
|
||
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
|
||
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
|
||
"location": {"city": "北京"},
|
||
"extra": {
|
||
"strict_audit": False,
|
||
"audit_response": "支持客户自定义安全审核回复话术。",
|
||
"recv_timeout": 30,
|
||
"input_mod": "audio",
|
||
},
|
||
},
|
||
}
|
||
|
||
request = bytearray(generate_header())
|
||
request.extend(int(100).to_bytes(4, 'big'))
|
||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||
request.extend(self.session_id.encode())
|
||
|
||
payload_bytes = json.dumps(session_config).encode()
|
||
payload_bytes = gzip.compress(payload_bytes)
|
||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||
request.extend(payload_bytes)
|
||
|
||
await self.ws.send(request)
|
||
response = await self.ws.recv()
|
||
print(f"StartSession响应长度: {len(response)}")
|
||
|
||
# 等待一会确保会话完全建立
|
||
await asyncio.sleep(1.0)
|
||
|
||
async def task_request(self, audio: bytes) -> None:
|
||
"""直接复制原始豆包代码的task_request方法"""
|
||
task_request = bytearray(
|
||
generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
|
||
serial_method=NO_SERIALIZATION))
|
||
task_request.extend(int(200).to_bytes(4, 'big'))
|
||
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
|
||
task_request.extend(str.encode(self.session_id))
|
||
payload_bytes = gzip.compress(audio)
|
||
task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
|
||
task_request.extend(payload_bytes)
|
||
await self.ws.send(task_request)
|
||
|
||
async def test_full_dialog(self) -> None:
|
||
"""测试完整对话流程"""
|
||
print("开始完整对话测试...")
|
||
|
||
# 读取真实的录音文件
|
||
try:
|
||
import wave
|
||
with wave.open("recording_20250920_135137.wav", 'rb') as wf:
|
||
# 读取前5秒的音频数据
|
||
total_frames = wf.getnframes()
|
||
frames_to_read = min(total_frames, 80000) # 5秒
|
||
audio_data = wf.readframes(frames_to_read)
|
||
print(f"读取真实音频数据: {len(audio_data)} 字节")
|
||
print(f"音频参数: 采样率={wf.getframerate()}, 通道数={wf.getnchannels()}, 采样宽度={wf.getsampwidth()}")
|
||
except Exception as e:
|
||
print(f"读取音频文件失败: {e}")
|
||
return
|
||
|
||
print(f"音频数据大小: {len(audio_data)}")
|
||
|
||
try:
|
||
# 发送音频数据
|
||
print("发送音频数据...")
|
||
await self.task_request(audio_data)
|
||
print("音频数据发送成功")
|
||
|
||
# 等待语音识别响应
|
||
print("等待语音识别响应...")
|
||
response = await asyncio.wait_for(self.ws.recv(), timeout=15.0)
|
||
print(f"收到ASR响应,长度: {len(response)}")
|
||
|
||
# 解析ASR响应
|
||
if len(response) >= 4:
|
||
protocol_version = response[0] >> 4
|
||
header_size = response[0] & 0x0f
|
||
message_type = response[1] >> 4
|
||
flags = response[1] & 0x0f
|
||
print(f"ASR响应协议: version={protocol_version}, header_size={header_size}, message_type={message_type}, flags={flags}")
|
||
|
||
if message_type == 9: # SERVER_FULL_RESPONSE
|
||
payload_start = header_size * 4
|
||
payload = response[payload_start:]
|
||
|
||
if len(payload) >= 4:
|
||
event = int.from_bytes(payload[:4], 'big')
|
||
print(f"ASR Event: {event}")
|
||
|
||
if len(payload) >= 8:
|
||
session_id_len = int.from_bytes(payload[4:8], 'big')
|
||
if len(payload) >= 8 + session_id_len:
|
||
session_id = payload[8:8+session_id_len].decode()
|
||
print(f"Session ID: {session_id}")
|
||
|
||
if len(payload) >= 12 + session_id_len:
|
||
payload_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big')
|
||
payload_data = payload[12+session_id_len:12+session_id_len+payload_size]
|
||
print(f"Payload size: {payload_size}")
|
||
|
||
# 解析ASR结果
|
||
try:
|
||
asr_result = json.loads(payload_data.decode('utf-8'))
|
||
print(f"ASR结果: {asr_result}")
|
||
|
||
# 如果有识别结果,提取文本
|
||
if 'results' in asr_result and asr_result['results']:
|
||
text = asr_result['results'][0].get('text', '')
|
||
print(f"识别文本: {text}")
|
||
|
||
except Exception as e:
|
||
print(f"解析ASR结果失败: {e}")
|
||
|
||
# 持续等待TTS音频响应
|
||
print("开始持续等待TTS音频响应...")
|
||
response_count = 0
|
||
max_responses = 10
|
||
|
||
while response_count < max_responses:
|
||
try:
|
||
print(f"等待第 {response_count + 1} 个响应...")
|
||
tts_response = await asyncio.wait_for(self.ws.recv(), timeout=30.0)
|
||
print(f"收到响应 {response_count + 1},长度: {len(tts_response)}")
|
||
|
||
# 解析响应
|
||
if len(tts_response) >= 4:
|
||
tts_version = tts_response[0] >> 4
|
||
tts_header_size = tts_response[0] & 0x0f
|
||
tts_message_type = tts_response[1] >> 4
|
||
tts_flags = tts_response[1] & 0x0f
|
||
print(f"响应协议: version={tts_version}, header_size={tts_header_size}, message_type={tts_message_type}, flags={tts_flags}")
|
||
|
||
if tts_message_type == 11: # SERVER_ACK (包含TTS音频)
|
||
tts_payload_start = tts_header_size * 4
|
||
tts_payload = tts_response[tts_payload_start:]
|
||
|
||
if len(tts_payload) >= 12:
|
||
tts_event = int.from_bytes(tts_payload[:4], 'big')
|
||
tts_session_len = int.from_bytes(tts_payload[4:8], 'big')
|
||
tts_session = tts_payload[8:8+tts_session_len].decode()
|
||
tts_audio_size = int.from_bytes(tts_payload[8+tts_session_len:12+tts_session_len], 'big')
|
||
tts_audio_data = tts_payload[12+tts_session_len:12+tts_session_len+tts_audio_size]
|
||
|
||
print(f"Event: {tts_event}")
|
||
print(f"音频数据大小: {tts_audio_size}")
|
||
|
||
if tts_audio_size > 0:
|
||
print("找到TTS音频数据!")
|
||
# 尝试解压缩TTS音频
|
||
try:
|
||
decompressed_tts = gzip.decompress(tts_audio_data)
|
||
print(f"解压缩后TTS音频大小: {len(decompressed_tts)}")
|
||
|
||
# 创建WAV文件
|
||
sample_rate = 24000
|
||
channels = 1
|
||
sampwidth = 2
|
||
|
||
with wave.open(f'tts_response_{response_count}.wav', 'wb') as wav_file:
|
||
wav_file.setnchannels(channels)
|
||
wav_file.setsampwidth(sampwidth)
|
||
wav_file.setframerate(sample_rate)
|
||
wav_file.writeframes(decompressed_tts)
|
||
|
||
print(f"成功创建TTS WAV文件: tts_response_{response_count}.wav")
|
||
print(f"音频参数: {sample_rate}Hz, {channels}通道, {sampwidth*8}-bit")
|
||
|
||
# 显示文件信息
|
||
if os.path.exists(f'tts_response_{response_count}.wav'):
|
||
file_size = os.path.getsize(f'tts_response_{response_count}.wav')
|
||
duration = file_size / (sample_rate * channels * sampwidth)
|
||
print(f"WAV文件大小: {file_size} 字节")
|
||
print(f"音频时长: {duration:.2f} 秒")
|
||
|
||
# 成功获取音频,退出循环
|
||
break
|
||
|
||
except Exception as tts_e:
|
||
print(f"TTS音频解压缩失败: {tts_e}")
|
||
# 保存原始数据
|
||
with open(f'tts_response_audio_{response_count}.raw', 'wb') as f:
|
||
f.write(tts_audio_data)
|
||
print(f"原始TTS音频数据已保存到 tts_response_audio_{response_count}.raw")
|
||
|
||
elif tts_message_type == 9: # SERVER_FULL_RESPONSE
|
||
tts_payload_start = tts_header_size * 4
|
||
tts_payload = tts_response[tts_payload_start:]
|
||
|
||
if len(tts_payload) >= 4:
|
||
event = int.from_bytes(tts_payload[:4], 'big')
|
||
print(f"Event: {event}")
|
||
|
||
if event in [451, 359]: # ASR结果或TTS结束
|
||
# 解析payload
|
||
if len(tts_payload) >= 8:
|
||
session_id_len = int.from_bytes(tts_payload[4:8], 'big')
|
||
if len(tts_payload) >= 8 + session_id_len:
|
||
session_id = tts_payload[8:8+session_id_len].decode()
|
||
if len(tts_payload) >= 12 + session_id_len:
|
||
payload_size = int.from_bytes(tts_payload[8+session_id_len:12+session_id_len], 'big')
|
||
payload_data = tts_payload[12+session_id_len:12+session_id_len+payload_size]
|
||
|
||
try:
|
||
json_data = json.loads(payload_data.decode('utf-8'))
|
||
print(f"JSON数据: {json_data}")
|
||
|
||
# 如果是ASR结果
|
||
if 'results' in json_data:
|
||
text = json_data['results'][0].get('text', '')
|
||
print(f"识别文本: {text}")
|
||
|
||
# 如果是TTS结束标记
|
||
if event == 359:
|
||
print("TTS响应结束")
|
||
break
|
||
|
||
except Exception as e:
|
||
print(f"解析JSON失败: {e}")
|
||
# 保存原始数据
|
||
with open(f'tts_response_{response_count}.raw', 'wb') as f:
|
||
f.write(payload_data)
|
||
|
||
# 保存完整响应用于调试
|
||
with open(f'tts_response_full_{response_count}.raw', 'wb') as f:
|
||
f.write(tts_response)
|
||
print(f"完整响应已保存到 tts_response_full_{response_count}.raw")
|
||
|
||
response_count += 1
|
||
|
||
except asyncio.TimeoutError:
|
||
print(f"等待第 {response_count + 1} 个响应超时")
|
||
break
|
||
except websockets.exceptions.ConnectionClosed:
|
||
print("连接已关闭")
|
||
break
|
||
|
||
print(f"共收到 {response_count} 个响应")
|
||
|
||
except asyncio.TimeoutError:
|
||
print("等待响应超时")
|
||
except websockets.exceptions.ConnectionClosed as e:
|
||
print(f"连接关闭: {e}")
|
||
except Exception as e:
|
||
print(f"测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
async def close(self) -> None:
|
||
"""关闭连接"""
|
||
if self.ws:
|
||
try:
|
||
await self.ws.close()
|
||
except:
|
||
pass
|
||
print("连接已关闭")
|
||
|
||
|
||
async def main():
|
||
"""测试函数"""
|
||
config = DoubaoConfig()
|
||
client = DoubaoClient(config)
|
||
|
||
try:
|
||
await client.connect()
|
||
await client.test_full_dialog()
|
||
except Exception as e:
|
||
print(f"测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
finally:
|
||
await client.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main()) |