303 lines
11 KiB
Python
303 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
豆包音频处理模块 - 协议测试版本
|
|
专门测试协议格式问题
|
|
"""
|
|
|
|
import asyncio
|
|
import gzip
|
|
import json
|
|
import uuid
|
|
import wave
|
|
import struct
|
|
import time
|
|
import os
|
|
from typing import Dict, Any, Optional
|
|
import websockets
|
|
|
|
|
|
class DoubaoConfig:
|
|
"""豆包配置"""
|
|
|
|
def __init__(self):
|
|
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
|
|
self.app_id = "8718217928"
|
|
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
|
|
self.app_key = "PlgvMymc7f3tQnJ6"
|
|
self.resource_id = "volc.speech.dialog"
|
|
|
|
def get_headers(self) -> Dict[str, str]:
|
|
"""获取请求头"""
|
|
return {
|
|
"X-Api-App-ID": self.app_id,
|
|
"X-Api-Access-Key": self.access_key,
|
|
"X-Api-Resource-Id": self.resource_id,
|
|
"X-Api-App-Key": self.app_key,
|
|
"X-Api-Connect-Id": str(uuid.uuid4()),
|
|
}
|
|
|
|
|
|
class DoubaoProtocol:
|
|
"""豆包协议处理"""
|
|
|
|
# 协议常量
|
|
PROTOCOL_VERSION = 0b0001
|
|
CLIENT_FULL_REQUEST = 0b0001
|
|
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
|
SERVER_FULL_RESPONSE = 0b1001
|
|
SERVER_ACK = 0b1011
|
|
SERVER_ERROR_RESPONSE = 0b1111
|
|
|
|
NO_SEQUENCE = 0b0000
|
|
POS_SEQUENCE = 0b0001
|
|
MSG_WITH_EVENT = 0b0100
|
|
JSON = 0b0001
|
|
NO_SERIALIZATION = 0b0000
|
|
GZIP = 0b0001
|
|
|
|
@classmethod
|
|
def generate_header(cls, message_type=CLIENT_FULL_REQUEST,
|
|
message_type_specific_flags=MSG_WITH_EVENT,
|
|
serial_method=JSON, compression_type=GZIP) -> bytes:
|
|
"""生成协议头"""
|
|
header = bytearray()
|
|
header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size
|
|
header.append((message_type << 4) | message_type_specific_flags)
|
|
header.append((serial_method << 4) | compression_type)
|
|
header.append(0x00) # reserved
|
|
return bytes(header)
|
|
|
|
|
|
class DoubaoClient:
|
|
"""豆包客户端"""
|
|
|
|
def __init__(self, config: DoubaoConfig):
|
|
self.config = config
|
|
self.session_id = str(uuid.uuid4())
|
|
self.ws = None
|
|
self.log_id = ""
|
|
|
|
async def connect(self) -> None:
|
|
"""建立WebSocket连接"""
|
|
print(f"连接豆包服务器: {self.config.base_url}")
|
|
|
|
try:
|
|
self.ws = await websockets.connect(
|
|
self.config.base_url,
|
|
additional_headers=self.config.get_headers(),
|
|
ping_interval=None
|
|
)
|
|
|
|
# 获取log_id
|
|
if hasattr(self.ws, 'response_headers'):
|
|
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
|
|
elif hasattr(self.ws, 'headers'):
|
|
self.log_id = self.ws.headers.get("X-Tt-Logid")
|
|
|
|
print(f"连接成功, log_id: {self.log_id}")
|
|
|
|
# 发送StartConnection请求
|
|
await self._send_start_connection()
|
|
|
|
# 发送StartSession请求
|
|
await self._send_start_session()
|
|
|
|
except Exception as e:
|
|
print(f"连接失败: {e}")
|
|
raise
|
|
|
|
async def _send_start_connection(self) -> None:
|
|
"""发送StartConnection请求"""
|
|
print("发送StartConnection请求...")
|
|
request = bytearray(DoubaoProtocol.generate_header())
|
|
request.extend(int(1).to_bytes(4, 'big'))
|
|
|
|
payload_bytes = b"{}"
|
|
payload_bytes = gzip.compress(payload_bytes)
|
|
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
|
request.extend(payload_bytes)
|
|
|
|
await self.ws.send(request)
|
|
response = await self.ws.recv()
|
|
print(f"StartConnection响应长度: {len(response)}")
|
|
|
|
async def _send_start_session(self) -> None:
|
|
"""发送StartSession请求"""
|
|
print("发送StartSession请求...")
|
|
session_config = {
|
|
"asr": {
|
|
"extra": {
|
|
"end_smooth_window_ms": 1500,
|
|
},
|
|
},
|
|
"tts": {
|
|
"speaker": "zh_female_vv_jupiter_bigtts",
|
|
"audio_config": {
|
|
"channel": 1,
|
|
"format": "pcm",
|
|
"sample_rate": 24000
|
|
},
|
|
},
|
|
"dialog": {
|
|
"bot_name": "豆包",
|
|
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
|
|
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
|
|
"location": {"city": "北京"},
|
|
"extra": {
|
|
"strict_audit": False,
|
|
"audit_response": "支持客户自定义安全审核回复话术。",
|
|
"recv_timeout": 30,
|
|
"input_mod": "audio",
|
|
},
|
|
},
|
|
}
|
|
|
|
request = bytearray(DoubaoProtocol.generate_header())
|
|
request.extend(int(100).to_bytes(4, 'big'))
|
|
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
|
request.extend(self.session_id.encode())
|
|
|
|
payload_bytes = json.dumps(session_config).encode()
|
|
payload_bytes = gzip.compress(payload_bytes)
|
|
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
|
request.extend(payload_bytes)
|
|
|
|
await self.ws.send(request)
|
|
response = await self.ws.recv()
|
|
print(f"StartSession响应长度: {len(response)}")
|
|
|
|
# 等待一会确保会话完全建立
|
|
await asyncio.sleep(1.0)
|
|
|
|
async def test_audio_request(self) -> None:
|
|
"""测试音频请求格式"""
|
|
print("测试音频请求格式...")
|
|
|
|
# 创建音频数据(静音)- 使用原始豆包代码的chunk大小
|
|
small_audio = b'\x00' * 3200 # 原始豆包代码中的chunk大小
|
|
|
|
# 完全按照原始豆包代码的格式构建请求,不进行任何填充
|
|
header = bytearray()
|
|
header.append((DoubaoProtocol.PROTOCOL_VERSION << 4) | 1) # version + header_size
|
|
header.append((DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST << 4) | DoubaoProtocol.NO_SEQUENCE)
|
|
header.append((DoubaoProtocol.NO_SERIALIZATION << 4) | DoubaoProtocol.GZIP)
|
|
header.append(0x00) # reserved
|
|
|
|
request = bytearray(header)
|
|
|
|
# 添加消息类型 (200 = task request)
|
|
request.extend(int(200).to_bytes(4, 'big'))
|
|
|
|
# 添加session_id
|
|
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
|
request.extend(self.session_id.encode())
|
|
|
|
# 压缩音频数据
|
|
compressed_audio = gzip.compress(small_audio)
|
|
|
|
# 添加payload size
|
|
request.extend(len(compressed_audio).to_bytes(4, 'big'))
|
|
|
|
# 添加压缩后的音频数据
|
|
request.extend(compressed_audio)
|
|
|
|
print(f"测试请求详细信息:")
|
|
print(f" - 音频原始大小: {len(small_audio)}")
|
|
print(f" - 音频压缩后大小: {len(compressed_audio)}")
|
|
print(f" - Session ID: {self.session_id} (长度: {len(self.session_id)})")
|
|
print(f" - 总请求大小: {len(request)}")
|
|
print(f" - 头部字节: {request[:4].hex()}")
|
|
print(f" - 消息类型: {int.from_bytes(request[4:8], 'big')}")
|
|
print(f" - Session ID长度: {int.from_bytes(request[8:12], 'big')}")
|
|
print(f" - Payload size: {int.from_bytes(request[12+len(self.session_id):16+len(self.session_id)], 'big')}")
|
|
|
|
try:
|
|
await self.ws.send(request)
|
|
print("请求发送成功")
|
|
|
|
# 等待响应
|
|
response = await asyncio.wait_for(self.ws.recv(), timeout=3.0)
|
|
print(f"收到响应,长度: {len(response)}")
|
|
|
|
# 尝试解析响应
|
|
try:
|
|
protocol_version = response[0] >> 4
|
|
header_size = response[0] & 0x0f
|
|
message_type = response[1] >> 4
|
|
message_type_specific_flags = response[1] & 0x0f
|
|
serialization_method = response[2] >> 4
|
|
message_compression = response[2] & 0x0f
|
|
|
|
print(f"响应协议信息:")
|
|
print(f" - version={protocol_version}")
|
|
print(f" - header_size={header_size}")
|
|
print(f" - message_type={message_type} (15=SERVER_ERROR_RESPONSE)")
|
|
print(f" - message_type_specific_flags={message_type_specific_flags}")
|
|
print(f" - serialization_method={serialization_method}")
|
|
print(f" - message_compression={message_compression}")
|
|
|
|
# 解析payload
|
|
payload = response[header_size * 4:]
|
|
if message_type == 15: # SERVER_ERROR_RESPONSE
|
|
if len(payload) >= 8:
|
|
code = int.from_bytes(payload[:4], "big", signed=False)
|
|
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
|
print(f" - 错误代码: {code}")
|
|
print(f" - payload大小: {payload_size}")
|
|
|
|
if len(payload) >= 8 + payload_size:
|
|
payload_msg = payload[8:8 + payload_size]
|
|
print(f" - payload长度: {len(payload_msg)}")
|
|
|
|
if message_compression == 1: # GZIP
|
|
try:
|
|
payload_msg = gzip.decompress(payload_msg)
|
|
print(f" - 解压缩后长度: {len(payload_msg)}")
|
|
except:
|
|
pass
|
|
|
|
try:
|
|
error_msg = json.loads(payload_msg.decode('utf-8'))
|
|
print(f" - 错误信息: {error_msg}")
|
|
except:
|
|
print(f" - 原始payload: {payload_msg}")
|
|
|
|
except Exception as e:
|
|
print(f"解析响应失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
except Exception as e:
|
|
print(f"发送测试请求失败: {e}")
|
|
raise
|
|
|
|
async def close(self) -> None:
|
|
"""关闭连接"""
|
|
if self.ws:
|
|
try:
|
|
await self.ws.close()
|
|
except:
|
|
pass
|
|
print("连接已关闭")
|
|
|
|
|
|
async def main():
|
|
"""测试函数"""
|
|
config = DoubaoConfig()
|
|
client = DoubaoClient(config)
|
|
|
|
try:
|
|
await client.connect()
|
|
await client.test_audio_request()
|
|
except Exception as e:
|
|
print(f"测试失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
finally:
|
|
await client.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |