Local-Voice/doubao_test.py
2025-09-20 14:35:54 +08:00

303 lines
11 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块 - 协议测试版本
专门测试协议格式问题
"""
import asyncio
import gzip
import json
import uuid
import wave
import struct
import time
import os
from typing import Dict, Any, Optional
import websockets
class DoubaoConfig:
"""豆包配置"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
class DoubaoProtocol:
"""豆包协议处理"""
# 协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
MSG_WITH_EVENT = 0b0100
JSON = 0b0001
NO_SERIALIZATION = 0b0000
GZIP = 0b0001
@classmethod
def generate_header(cls, message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON, compression_type=GZIP) -> bytes:
"""生成协议头"""
header = bytearray()
header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(0x00) # reserved
return bytes(header)
class DoubaoClient:
"""豆包客户端"""
def __init__(self, config: DoubaoConfig):
self.config = config
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"连接豆包服务器: {self.config.base_url}")
try:
self.ws = await websockets.connect(
self.config.base_url,
additional_headers=self.config.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
except Exception as e:
print(f"连接失败: {e}")
raise
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
print("发送StartConnection请求...")
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartConnection响应长度: {len(response)}")
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
print("发送StartSession请求...")
session_config = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 30,
"input_mod": "audio",
},
},
}
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartSession响应长度: {len(response)}")
# 等待一会确保会话完全建立
await asyncio.sleep(1.0)
async def test_audio_request(self) -> None:
"""测试音频请求格式"""
print("测试音频请求格式...")
# 创建音频数据(静音)- 使用原始豆包代码的chunk大小
small_audio = b'\x00' * 3200 # 原始豆包代码中的chunk大小
# 完全按照原始豆包代码的格式构建请求,不进行任何填充
header = bytearray()
header.append((DoubaoProtocol.PROTOCOL_VERSION << 4) | 1) # version + header_size
header.append((DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST << 4) | DoubaoProtocol.NO_SEQUENCE)
header.append((DoubaoProtocol.NO_SERIALIZATION << 4) | DoubaoProtocol.GZIP)
header.append(0x00) # reserved
request = bytearray(header)
# 添加消息类型 (200 = task request)
request.extend(int(200).to_bytes(4, 'big'))
# 添加session_id
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
# 压缩音频数据
compressed_audio = gzip.compress(small_audio)
# 添加payload size
request.extend(len(compressed_audio).to_bytes(4, 'big'))
# 添加压缩后的音频数据
request.extend(compressed_audio)
print(f"测试请求详细信息:")
print(f" - 音频原始大小: {len(small_audio)}")
print(f" - 音频压缩后大小: {len(compressed_audio)}")
print(f" - Session ID: {self.session_id} (长度: {len(self.session_id)})")
print(f" - 总请求大小: {len(request)}")
print(f" - 头部字节: {request[:4].hex()}")
print(f" - 消息类型: {int.from_bytes(request[4:8], 'big')}")
print(f" - Session ID长度: {int.from_bytes(request[8:12], 'big')}")
print(f" - Payload size: {int.from_bytes(request[12+len(self.session_id):16+len(self.session_id)], 'big')}")
try:
await self.ws.send(request)
print("请求发送成功")
# 等待响应
response = await asyncio.wait_for(self.ws.recv(), timeout=3.0)
print(f"收到响应,长度: {len(response)}")
# 尝试解析响应
try:
protocol_version = response[0] >> 4
header_size = response[0] & 0x0f
message_type = response[1] >> 4
message_type_specific_flags = response[1] & 0x0f
serialization_method = response[2] >> 4
message_compression = response[2] & 0x0f
print(f"响应协议信息:")
print(f" - version={protocol_version}")
print(f" - header_size={header_size}")
print(f" - message_type={message_type} (15=SERVER_ERROR_RESPONSE)")
print(f" - message_type_specific_flags={message_type_specific_flags}")
print(f" - serialization_method={serialization_method}")
print(f" - message_compression={message_compression}")
# 解析payload
payload = response[header_size * 4:]
if message_type == 15: # SERVER_ERROR_RESPONSE
if len(payload) >= 8:
code = int.from_bytes(payload[:4], "big", signed=False)
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
print(f" - 错误代码: {code}")
print(f" - payload大小: {payload_size}")
if len(payload) >= 8 + payload_size:
payload_msg = payload[8:8 + payload_size]
print(f" - payload长度: {len(payload_msg)}")
if message_compression == 1: # GZIP
try:
payload_msg = gzip.decompress(payload_msg)
print(f" - 解压缩后长度: {len(payload_msg)}")
except:
pass
try:
error_msg = json.loads(payload_msg.decode('utf-8'))
print(f" - 错误信息: {error_msg}")
except:
print(f" - 原始payload: {payload_msg}")
except Exception as e:
print(f"解析响应失败: {e}")
import traceback
traceback.print_exc()
except Exception as e:
print(f"发送测试请求失败: {e}")
raise
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
await self.ws.close()
except:
pass
print("连接已关闭")
async def main():
"""测试函数"""
config = DoubaoConfig()
client = DoubaoClient(config)
try:
await client.connect()
await client.test_audio_request()
except Exception as e:
print(f"测试失败: {e}")
import traceback
traceback.print_exc()
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(main())