This commit is contained in:
朱潮 2025-09-20 14:35:54 +08:00
parent 9f7858a30e
commit 9108fd4582
7 changed files with 2675 additions and 0 deletions

470
doubao.py Normal file
View File

@ -0,0 +1,470 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块
简化版WebSocket API支持音频文件上传和返回音频处理
"""
import asyncio
import gzip
import json
import uuid
import wave
import struct
import time
import os
from typing import Dict, Any, Optional
import websockets
class DoubaoConfig:
"""豆包配置"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
class DoubaoProtocol:
"""豆包协议处理"""
# 协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
MSG_WITH_EVENT = 0b0100
JSON = 0b0001
NO_SERIALIZATION = 0b0000
GZIP = 0b0001
@classmethod
def generate_header(cls, message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON, compression_type=GZIP) -> bytes:
"""生成协议头"""
header = bytearray()
header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(0x00) # reserved
return bytes(header)
@classmethod
def parse_response(cls, res: bytes) -> Dict[str, Any]:
"""解析响应"""
if isinstance(res, str):
return {}
protocol_version = res[0] >> 4
header_size = res[0] & 0x0f
message_type = res[1] >> 4
message_type_specific_flags = res[1] & 0x0f
serialization_method = res[2] >> 4
message_compression = res[2] & 0x0f
payload = res[header_size * 4:]
result = {}
if message_type == cls.SERVER_FULL_RESPONSE or message_type == cls.SERVER_ACK:
result['message_type'] = 'SERVER_FULL_RESPONSE'
if message_type == cls.SERVER_ACK:
result['message_type'] = 'SERVER_ACK'
start = 0
if message_type_specific_flags & cls.MSG_WITH_EVENT:
result['event'] = int.from_bytes(payload[:4], "big", signed=False)
start += 4
payload = payload[start:]
session_id_size = int.from_bytes(payload[:4], "big", signed=True)
session_id = payload[4:session_id_size+4]
result['session_id'] = str(session_id)
payload = payload[4 + session_id_size:]
payload_size = int.from_bytes(payload[:4], "big", signed=False)
payload_msg = payload[4:]
result['payload_size'] = payload_size
if payload_msg:
if message_compression == cls.GZIP:
payload_msg = gzip.decompress(payload_msg)
if serialization_method == cls.JSON:
payload_msg = json.loads(str(payload_msg, "utf-8"))
result['payload_msg'] = payload_msg
elif message_type == cls.SERVER_ERROR_RESPONSE:
code = int.from_bytes(payload[:4], "big", signed=False)
result['code'] = code
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
result['payload_size'] = payload_size
return result
class AudioProcessor:
"""音频处理器"""
@staticmethod
def read_wav_file(file_path: str) -> tuple:
"""读取WAV文件返回音频数据和参数"""
with wave.open(file_path, 'rb') as wf:
# 获取音频参数
channels = wf.getnchannels()
sampwidth = wf.getsampwidth()
framerate = wf.getframerate()
nframes = wf.getnframes()
# 读取音频数据
audio_data = wf.readframes(nframes)
return audio_data, {
'channels': channels,
'sampwidth': sampwidth,
'framerate': framerate,
'nframes': nframes
}
@staticmethod
def create_wav_file(audio_data: bytes, output_path: str,
sample_rate: int = 24000, channels: int = 1,
sampwidth: int = 2) -> None:
"""创建WAV文件适配树莓派播放"""
with wave.open(output_path, 'wb') as wf:
wf.setnchannels(channels)
wf.setsampwidth(sampwidth)
wf.setframerate(sample_rate)
wf.writeframes(audio_data)
class DoubaoClient:
"""豆包客户端"""
def __init__(self, config: DoubaoConfig):
self.config = config
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"连接豆包服务器: {self.config.base_url}")
self.ws = await websockets.connect(
self.config.base_url,
additional_headers=self.config.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
parsed_response = DoubaoProtocol.parse_response(response)
print(f"StartConnection响应: {parsed_response}")
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
session_config = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 10,
"input_mod": "audio_file", # 使用音频文件模式
},
},
}
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
parsed_response = DoubaoProtocol.parse_response(response)
print(f"StartSession响应: {parsed_response}")
async def send_audio_file(self, file_path: str) -> bytes:
"""发送音频文件并返回响应音频"""
print(f"处理音频文件: {file_path}")
# 读取音频文件
audio_data, audio_info = AudioProcessor.read_wav_file(file_path)
print(f"音频参数: {audio_info}")
# 计算分块大小200ms
chunk_size = int(audio_info['framerate'] * audio_info['channels'] *
audio_info['sampwidth'] * 0.2)
# 分块发送音频数据
total_chunks = (len(audio_data) + chunk_size - 1) // chunk_size
print(f"开始分块发送音频,共 {total_chunks}")
received_audio = b""
error_count = 0
for i in range(0, len(audio_data), chunk_size):
chunk = audio_data[i:i + chunk_size]
is_last = (i + chunk_size >= len(audio_data))
# 发送音频块
await self._send_audio_chunk(chunk, is_last)
# 接收响应
try:
response = await asyncio.wait_for(self.ws.recv(), timeout=2.0)
parsed_response = DoubaoProtocol.parse_response(response)
# 检查是否是错误响应
if 'code' in parsed_response and parsed_response['code'] != 0:
print(f"服务器返回错误: {parsed_response}")
error_count += 1
if error_count > 3:
raise Exception(f"服务器连续返回错误: {parsed_response}")
continue
# 处理音频响应
if (parsed_response.get('message_type') == 'SERVER_ACK' and
isinstance(parsed_response.get('payload_msg'), bytes)):
audio_chunk = parsed_response['payload_msg']
received_audio += audio_chunk
print(f"接收到音频数据块,大小: {len(audio_chunk)} 字节")
# 检查会话状态
event = parsed_response.get('event')
if event in [359, 152, 153]: # 这些事件表示会话相关状态
print(f"会话事件: {event}")
if event in [152, 153]: # 会话结束
print("检测到会话结束事件")
break
except asyncio.TimeoutError:
print("等待响应超时,继续发送")
# 模拟实时发送的延迟
await asyncio.sleep(0.05)
print("音频文件发送完成")
return received_audio
async def _send_audio_chunk(self, audio_data: bytes, is_last: bool = False) -> None:
"""发送音频块"""
request = bytearray(
DoubaoProtocol.generate_header(
message_type=DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST,
message_type_specific_flags=DoubaoProtocol.NO_SEQUENCE,
serial_method=DoubaoProtocol.NO_SERIALIZATION, # 音频数据不需要序列化
compression_type=DoubaoProtocol.GZIP
)
)
request.extend(int(200).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
# 压缩音频数据
compressed_audio = gzip.compress(audio_data)
request.extend(len(compressed_audio).to_bytes(4, 'big')) # payload size(4 bytes)
request.extend(compressed_audio)
await self.ws.send(request)
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
# 发送FinishSession
await self._send_finish_session()
# 发送FinishConnection
await self._send_finish_connection()
except Exception as e:
print(f"关闭会话时出错: {e}")
finally:
# 确保WebSocket连接关闭
try:
await self.ws.close()
except:
pass
print("连接已关闭")
async def _send_finish_session(self) -> None:
"""发送FinishSession请求"""
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(102).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
async def _send_finish_connection(self) -> None:
"""发送FinishConnection请求"""
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(2).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
parsed_response = DoubaoProtocol.parse_response(response)
print(f"FinishConnection响应: {parsed_response}")
class DoubaoProcessor:
"""豆包音频处理器"""
def __init__(self):
self.config = DoubaoConfig()
self.client = DoubaoClient(self.config)
async def process_audio_file(self, input_file: str, output_file: str = None) -> str:
"""处理音频文件
Args:
input_file: 输入音频文件路径
output_file: 输出音频文件路径如果为None则自动生成
Returns:
输出音频文件路径
"""
if not os.path.exists(input_file):
raise FileNotFoundError(f"音频文件不存在: {input_file}")
# 生成输出文件名
if output_file is None:
timestamp = time.strftime("%Y%m%d_%H%M%S")
output_file = f"doubao_response_{timestamp}.wav"
try:
# 连接豆包服务器
await self.client.connect()
# 等待一会确保会话建立
await asyncio.sleep(0.5)
# 发送音频文件并获取响应
received_audio = await self.client.send_audio_file(input_file)
if received_audio:
print(f"总共接收到音频数据: {len(received_audio)} 字节")
# 转换为WAV格式保存适配树莓派播放
AudioProcessor.create_wav_file(
received_audio,
output_file,
sample_rate=24000, # 豆包返回的音频采样率
channels=1,
sampwidth=2 # 16-bit
)
print(f"响应音频已保存到: {output_file}")
# 显示文件信息
file_size = os.path.getsize(output_file)
print(f"输出文件大小: {file_size} 字节")
else:
print("警告: 未接收到音频响应")
return output_file
except Exception as e:
print(f"处理音频文件时出错: {e}")
import traceback
traceback.print_exc()
raise
finally:
await self.client.close()
async def main():
"""测试函数"""
import argparse
parser = argparse.ArgumentParser(description="豆包音频处理测试")
parser.add_argument("--input", type=str, required=True, help="输入音频文件路径")
parser.add_argument("--output", type=str, help="输出音频文件路径")
args = parser.parse_args()
processor = DoubaoProcessor()
try:
output_file = await processor.process_audio_file(args.input, args.output)
print(f"处理完成,输出文件: {output_file}")
except Exception as e:
print(f"处理失败: {e}")
if __name__ == "__main__":
asyncio.run(main())

540
doubao_debug.py Normal file
View File

@ -0,0 +1,540 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块 - 调试版本
添加更多调试信息和错误处理
"""
import asyncio
import gzip
import json
import uuid
import wave
import struct
import time
import os
from typing import Dict, Any, Optional
import websockets
class DoubaoConfig:
"""豆包配置"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
class DoubaoProtocol:
"""豆包协议处理"""
# 协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
MSG_WITH_EVENT = 0b0100
JSON = 0b0001
NO_SERIALIZATION = 0b0000
GZIP = 0b0001
@classmethod
def generate_header(cls, message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON, compression_type=GZIP) -> bytes:
"""生成协议头"""
header = bytearray()
header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(0x00) # reserved
return bytes(header)
@classmethod
def parse_response(cls, res: bytes) -> Dict[str, Any]:
"""解析响应"""
if isinstance(res, str):
return {}
try:
protocol_version = res[0] >> 4
header_size = res[0] & 0x0f
message_type = res[1] >> 4
message_type_specific_flags = res[1] & 0x0f
serialization_method = res[2] >> 4
message_compression = res[2] & 0x0f
payload = res[header_size * 4:]
result = {}
if message_type == cls.SERVER_FULL_RESPONSE or message_type == cls.SERVER_ACK:
result['message_type'] = 'SERVER_FULL_RESPONSE'
if message_type == cls.SERVER_ACK:
result['message_type'] = 'SERVER_ACK'
start = 0
if message_type_specific_flags & cls.MSG_WITH_EVENT:
result['event'] = int.from_bytes(payload[:4], "big", signed=False)
start += 4
payload = payload[start:]
if len(payload) < 4:
result['error'] = 'Payload too short for session_id'
return result
session_id_size = int.from_bytes(payload[:4], "big", signed=True)
if session_id_size < 0 or session_id_size > len(payload) - 4:
result['error'] = f'Invalid session_id size: {session_id_size}'
return result
session_id = payload[4:session_id_size+4]
result['session_id'] = str(session_id)
payload = payload[4 + session_id_size:]
if len(payload) < 4:
result['error'] = 'Payload too short for payload_size'
return result
payload_size = int.from_bytes(payload[:4], "big", signed=False)
result['payload_size'] = payload_size
if len(payload) >= 4 + payload_size:
payload_msg = payload[4:4 + payload_size]
if payload_msg:
if message_compression == cls.GZIP:
try:
payload_msg = gzip.decompress(payload_msg)
except Exception as e:
result['decompress_error'] = str(e)
return result
if serialization_method == cls.JSON:
try:
payload_msg = json.loads(str(payload_msg, "utf-8"))
except Exception as e:
result['json_error'] = str(e)
payload_msg = str(payload_msg, "utf-8")
elif serialization_method != cls.NO_SERIALIZATION:
payload_msg = str(payload_msg, "utf-8")
result['payload_msg'] = payload_msg
elif message_type == cls.SERVER_ERROR_RESPONSE:
if len(payload) >= 8:
code = int.from_bytes(payload[:4], "big", signed=False)
result['code'] = code
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
result['payload_size'] = payload_size
if len(payload) >= 8 + payload_size:
payload_msg = payload[8:8 + payload_size]
if payload_msg and message_compression == cls.GZIP:
try:
payload_msg = gzip.decompress(payload_msg)
except:
pass
result['payload_msg'] = payload_msg
except Exception as e:
result['parse_error'] = str(e)
return result
class AudioProcessor:
"""音频处理器"""
@staticmethod
def read_wav_file(file_path: str) -> tuple:
"""读取WAV文件返回音频数据和参数"""
with wave.open(file_path, 'rb') as wf:
# 获取音频参数
channels = wf.getnchannels()
sampwidth = wf.getsampwidth()
framerate = wf.getframerate()
nframes = wf.getnframes()
# 读取音频数据
audio_data = wf.readframes(nframes)
return audio_data, {
'channels': channels,
'sampwidth': sampwidth,
'framerate': framerate,
'nframes': nframes
}
@staticmethod
def create_wav_file(audio_data: bytes, output_path: str,
sample_rate: int = 24000, channels: int = 1,
sampwidth: int = 2) -> None:
"""创建WAV文件适配树莓派播放"""
with wave.open(output_path, 'wb') as wf:
wf.setnchannels(channels)
wf.setsampwidth(sampwidth)
wf.setframerate(sample_rate)
wf.writeframes(audio_data)
class DoubaoClient:
"""豆包客户端"""
def __init__(self, config: DoubaoConfig):
self.config = config
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"连接豆包服务器: {self.config.base_url}")
try:
self.ws = await websockets.connect(
self.config.base_url,
additional_headers=self.config.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
except Exception as e:
print(f"连接失败: {e}")
raise
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
print("发送StartConnection请求...")
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
parsed_response = DoubaoProtocol.parse_response(response)
print(f"StartConnection响应: {parsed_response}")
# 检查是否有错误
if 'error' in parsed_response:
raise Exception(f"StartConnection解析错误: {parsed_response['error']}")
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
print("发送StartSession请求...")
session_config = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 30, # 增加超时时间
"input_mod": "audio_file", # 使用音频文件模式
},
},
}
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
parsed_response = DoubaoProtocol.parse_response(response)
print(f"StartSession响应: {parsed_response}")
# 检查是否有错误
if 'error' in parsed_response:
raise Exception(f"StartSession解析错误: {parsed_response['error']}")
# 等待一会确保会话完全建立
await asyncio.sleep(1.0)
async def send_audio_file(self, file_path: str) -> bytes:
"""发送音频文件并返回响应音频"""
print(f"处理音频文件: {file_path}")
# 读取音频文件
audio_data, audio_info = AudioProcessor.read_wav_file(file_path)
print(f"音频参数: {audio_info}")
# 计算分块大小减小到50ms避免数据块过大
chunk_size = int(audio_info['framerate'] * audio_info['channels'] *
audio_info['sampwidth'] * 0.05) # 50ms
# 分块发送音频数据
total_chunks = (len(audio_data) + chunk_size - 1) // chunk_size
print(f"开始分块发送音频,共 {total_chunks}")
received_audio = b""
error_count = 0
session_active = True
for i in range(0, len(audio_data), chunk_size):
if not session_active:
print("会话已结束,停止发送")
break
chunk = audio_data[i:i + chunk_size]
is_last = (i + chunk_size >= len(audio_data))
# 发送音频块
await self._send_audio_chunk(chunk, is_last)
# 接收响应
try:
response = await asyncio.wait_for(self.ws.recv(), timeout=3.0)
parsed_response = DoubaoProtocol.parse_response(response)
print(f"响应 {i//chunk_size + 1}/{total_chunks}: {parsed_response}")
# 检查是否是错误响应
if 'code' in parsed_response and parsed_response['code'] != 0:
print(f"服务器返回错误: {parsed_response}")
error_count += 1
if error_count > 3:
raise Exception(f"服务器连续返回错误: {parsed_response}")
continue
# 处理音频响应
if (parsed_response.get('message_type') == 'SERVER_ACK' and
isinstance(parsed_response.get('payload_msg'), bytes)):
audio_chunk = parsed_response['payload_msg']
received_audio += audio_chunk
print(f"接收到音频数据块,大小: {len(audio_chunk)} 字节")
# 检查会话状态
event = parsed_response.get('event')
if event in [359, 152, 153]: # 这些事件表示会话相关状态
print(f"会话事件: {event}")
if event in [152, 153]: # 会话结束
print("检测到会话结束事件")
session_active = False
break
except asyncio.TimeoutError:
print("等待响应超时,继续发送")
# 模拟实时发送的延迟
await asyncio.sleep(0.1)
print("音频文件发送完成")
return received_audio
async def _send_audio_chunk(self, audio_data: bytes, is_last: bool = False) -> None:
"""发送音频块"""
request = bytearray(
DoubaoProtocol.generate_header(
message_type=DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST,
message_type_specific_flags=DoubaoProtocol.NO_SEQUENCE,
serial_method=DoubaoProtocol.NO_SERIALIZATION, # 音频数据不需要序列化
compression_type=DoubaoProtocol.GZIP
)
)
request.extend(int(200).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
# 压缩音频数据
compressed_audio = gzip.compress(audio_data)
payload_size = len(compressed_audio)
request.extend(payload_size.to_bytes(4, 'big')) # payload size(4 bytes)
request.extend(compressed_audio)
print(f"发送音频块 - 原始大小: {len(audio_data)}, 压缩后大小: {payload_size}, 总请求数据大小: {len(request)}")
await self.ws.send(request)
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
# 发送FinishSession
await self._send_finish_session()
# 发送FinishConnection
await self._send_finish_connection()
except Exception as e:
print(f"关闭会话时出错: {e}")
finally:
# 确保WebSocket连接关闭
try:
await self.ws.close()
except:
pass
print("连接已关闭")
async def _send_finish_session(self) -> None:
"""发送FinishSession请求"""
print("发送FinishSession请求...")
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(102).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
async def _send_finish_connection(self) -> None:
"""发送FinishConnection请求"""
print("发送FinishConnection请求...")
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(2).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
try:
response = await asyncio.wait_for(self.ws.recv(), timeout=5.0)
parsed_response = DoubaoProtocol.parse_response(response)
print(f"FinishConnection响应: {parsed_response}")
except asyncio.TimeoutError:
print("FinishConnection响应超时")
class DoubaoProcessor:
"""豆包音频处理器"""
def __init__(self):
self.config = DoubaoConfig()
self.client = DoubaoClient(self.config)
async def process_audio_file(self, input_file: str, output_file: str = None) -> str:
"""处理音频文件
Args:
input_file: 输入音频文件路径
output_file: 输出音频文件路径如果为None则自动生成
Returns:
输出音频文件路径
"""
if not os.path.exists(input_file):
raise FileNotFoundError(f"音频文件不存在: {input_file}")
# 生成输出文件名
if output_file is None:
timestamp = time.strftime("%Y%m%d_%H%M%S")
output_file = f"doubao_response_{timestamp}.wav"
try:
# 连接豆包服务器
await self.client.connect()
# 发送音频文件并获取响应
received_audio = await self.client.send_audio_file(input_file)
if received_audio:
print(f"总共接收到音频数据: {len(received_audio)} 字节")
# 转换为WAV格式保存适配树莓派播放
AudioProcessor.create_wav_file(
received_audio,
output_file,
sample_rate=24000, # 豆包返回的音频采样率
channels=1,
sampwidth=2 # 16-bit
)
print(f"响应音频已保存到: {output_file}")
# 显示文件信息
file_size = os.path.getsize(output_file)
print(f"输出文件大小: {file_size} 字节")
else:
print("警告: 未接收到音频响应")
return output_file
except Exception as e:
print(f"处理音频文件时出错: {e}")
import traceback
traceback.print_exc()
raise
finally:
await self.client.close()
async def main():
"""测试函数"""
import argparse
parser = argparse.ArgumentParser(description="豆包音频处理测试")
parser.add_argument("--input", type=str, required=True, help="输入音频文件路径")
parser.add_argument("--output", type=str, help="输出音频文件路径")
args = parser.parse_args()
processor = DoubaoProcessor()
try:
output_file = await processor.process_audio_file(args.input, args.output)
print(f"处理完成,输出文件: {output_file}")
except Exception as e:
print(f"处理失败: {e}")
if __name__ == "__main__":
asyncio.run(main())

425
doubao_final_test.py Normal file
View File

@ -0,0 +1,425 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块 - 简化测试版本
专门测试完整的音频上传和TTS音频下载流程
"""
import asyncio
import gzip
import json
import uuid
import wave
import struct
import time
import os
from typing import Dict, Any, Optional
import websockets
# 直接复制原始豆包代码的协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
MSG_WITH_EVENT = 0b0100
NO_SERIALIZATION = 0b0000
JSON = 0b0001
GZIP = 0b0001
def generate_header(
version=PROTOCOL_VERSION,
message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON,
compression_type=GZIP,
reserved_data=0x00,
extension_header=bytes()
):
"""直接复制原始豆包代码的generate_header函数"""
header = bytearray()
header_size = int(len(extension_header) / 4) + 1
header.append((version << 4) | header_size)
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(reserved_data)
header.extend(extension_header)
return header
class DoubaoConfig:
"""豆包配置"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
class DoubaoClient:
"""豆包客户端 - 基于原始代码"""
def __init__(self, config: DoubaoConfig):
self.config = config
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"连接豆包服务器: {self.config.base_url}")
try:
self.ws = await websockets.connect(
self.config.base_url,
additional_headers=self.config.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
except Exception as e:
print(f"连接失败: {e}")
raise
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
print("发送StartConnection请求...")
request = bytearray(generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartConnection响应长度: {len(response)}")
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
print("发送StartSession请求...")
session_config = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 30,
"input_mod": "audio",
},
},
}
request = bytearray(generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartSession响应长度: {len(response)}")
# 等待一会确保会话完全建立
await asyncio.sleep(1.0)
async def task_request(self, audio: bytes) -> None:
"""直接复制原始豆包代码的task_request方法"""
task_request = bytearray(
generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
serial_method=NO_SERIALIZATION))
task_request.extend(int(200).to_bytes(4, 'big'))
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
task_request.extend(str.encode(self.session_id))
payload_bytes = gzip.compress(audio)
task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
task_request.extend(payload_bytes)
await self.ws.send(task_request)
async def test_full_dialog(self) -> None:
"""测试完整对话流程"""
print("开始完整对话测试...")
# 读取真实的录音文件
try:
import wave
with wave.open("recording_20250920_135137.wav", 'rb') as wf:
# 读取前5秒的音频数据
total_frames = wf.getnframes()
frames_to_read = min(total_frames, 80000) # 5秒
audio_data = wf.readframes(frames_to_read)
print(f"读取真实音频数据: {len(audio_data)} 字节")
print(f"音频参数: 采样率={wf.getframerate()}, 通道数={wf.getnchannels()}, 采样宽度={wf.getsampwidth()}")
except Exception as e:
print(f"读取音频文件失败: {e}")
return
print(f"音频数据大小: {len(audio_data)}")
try:
# 发送音频数据
print("发送音频数据...")
await self.task_request(audio_data)
print("音频数据发送成功")
# 等待语音识别响应
print("等待语音识别响应...")
response = await asyncio.wait_for(self.ws.recv(), timeout=15.0)
print(f"收到ASR响应长度: {len(response)}")
# 解析ASR响应
if len(response) >= 4:
protocol_version = response[0] >> 4
header_size = response[0] & 0x0f
message_type = response[1] >> 4
flags = response[1] & 0x0f
print(f"ASR响应协议: version={protocol_version}, header_size={header_size}, message_type={message_type}, flags={flags}")
if message_type == 9: # SERVER_FULL_RESPONSE
payload_start = header_size * 4
payload = response[payload_start:]
if len(payload) >= 4:
event = int.from_bytes(payload[:4], 'big')
print(f"ASR Event: {event}")
if len(payload) >= 8:
session_id_len = int.from_bytes(payload[4:8], 'big')
if len(payload) >= 8 + session_id_len:
session_id = payload[8:8+session_id_len].decode()
print(f"Session ID: {session_id}")
if len(payload) >= 12 + session_id_len:
payload_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big')
payload_data = payload[12+session_id_len:12+session_id_len+payload_size]
print(f"Payload size: {payload_size}")
# 解析ASR结果
try:
asr_result = json.loads(payload_data.decode('utf-8'))
print(f"ASR结果: {asr_result}")
# 如果有识别结果,提取文本
if 'results' in asr_result and asr_result['results']:
text = asr_result['results'][0].get('text', '')
print(f"识别文本: {text}")
except Exception as e:
print(f"解析ASR结果失败: {e}")
# 持续等待TTS音频响应
print("开始持续等待TTS音频响应...")
response_count = 0
max_responses = 10
while response_count < max_responses:
try:
print(f"等待第 {response_count + 1} 个响应...")
tts_response = await asyncio.wait_for(self.ws.recv(), timeout=30.0)
print(f"收到响应 {response_count + 1},长度: {len(tts_response)}")
# 解析响应
if len(tts_response) >= 4:
tts_version = tts_response[0] >> 4
tts_header_size = tts_response[0] & 0x0f
tts_message_type = tts_response[1] >> 4
tts_flags = tts_response[1] & 0x0f
print(f"响应协议: version={tts_version}, header_size={tts_header_size}, message_type={tts_message_type}, flags={tts_flags}")
if tts_message_type == 11: # SERVER_ACK (包含TTS音频)
tts_payload_start = tts_header_size * 4
tts_payload = tts_response[tts_payload_start:]
if len(tts_payload) >= 12:
tts_event = int.from_bytes(tts_payload[:4], 'big')
tts_session_len = int.from_bytes(tts_payload[4:8], 'big')
tts_session = tts_payload[8:8+tts_session_len].decode()
tts_audio_size = int.from_bytes(tts_payload[8+tts_session_len:12+tts_session_len], 'big')
tts_audio_data = tts_payload[12+tts_session_len:12+tts_session_len+tts_audio_size]
print(f"Event: {tts_event}")
print(f"音频数据大小: {tts_audio_size}")
if tts_audio_size > 0:
print("找到TTS音频数据")
# 尝试解压缩TTS音频
try:
decompressed_tts = gzip.decompress(tts_audio_data)
print(f"解压缩后TTS音频大小: {len(decompressed_tts)}")
# 创建WAV文件
sample_rate = 24000
channels = 1
sampwidth = 2
with wave.open(f'tts_response_{response_count}.wav', 'wb') as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(sampwidth)
wav_file.setframerate(sample_rate)
wav_file.writeframes(decompressed_tts)
print(f"成功创建TTS WAV文件: tts_response_{response_count}.wav")
print(f"音频参数: {sample_rate}Hz, {channels}通道, {sampwidth*8}-bit")
# 显示文件信息
if os.path.exists(f'tts_response_{response_count}.wav'):
file_size = os.path.getsize(f'tts_response_{response_count}.wav')
duration = file_size / (sample_rate * channels * sampwidth)
print(f"WAV文件大小: {file_size} 字节")
print(f"音频时长: {duration:.2f}")
# 成功获取音频,退出循环
break
except Exception as tts_e:
print(f"TTS音频解压缩失败: {tts_e}")
# 保存原始数据
with open(f'tts_response_audio_{response_count}.raw', 'wb') as f:
f.write(tts_audio_data)
print(f"原始TTS音频数据已保存到 tts_response_audio_{response_count}.raw")
elif tts_message_type == 9: # SERVER_FULL_RESPONSE
tts_payload_start = tts_header_size * 4
tts_payload = tts_response[tts_payload_start:]
if len(tts_payload) >= 4:
event = int.from_bytes(tts_payload[:4], 'big')
print(f"Event: {event}")
if event in [451, 359]: # ASR结果或TTS结束
# 解析payload
if len(tts_payload) >= 8:
session_id_len = int.from_bytes(tts_payload[4:8], 'big')
if len(tts_payload) >= 8 + session_id_len:
session_id = tts_payload[8:8+session_id_len].decode()
if len(tts_payload) >= 12 + session_id_len:
payload_size = int.from_bytes(tts_payload[8+session_id_len:12+session_id_len], 'big')
payload_data = tts_payload[12+session_id_len:12+session_id_len+payload_size]
try:
json_data = json.loads(payload_data.decode('utf-8'))
print(f"JSON数据: {json_data}")
# 如果是ASR结果
if 'results' in json_data:
text = json_data['results'][0].get('text', '')
print(f"识别文本: {text}")
# 如果是TTS结束标记
if event == 359:
print("TTS响应结束")
break
except Exception as e:
print(f"解析JSON失败: {e}")
# 保存原始数据
with open(f'tts_response_{response_count}.raw', 'wb') as f:
f.write(payload_data)
# 保存完整响应用于调试
with open(f'tts_response_full_{response_count}.raw', 'wb') as f:
f.write(tts_response)
print(f"完整响应已保存到 tts_response_full_{response_count}.raw")
response_count += 1
except asyncio.TimeoutError:
print(f"等待第 {response_count + 1} 个响应超时")
break
except websockets.exceptions.ConnectionClosed:
print("连接已关闭")
break
print(f"共收到 {response_count} 个响应")
except asyncio.TimeoutError:
print("等待响应超时")
except websockets.exceptions.ConnectionClosed as e:
print(f"连接关闭: {e}")
except Exception as e:
print(f"测试失败: {e}")
import traceback
traceback.print_exc()
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
await self.ws.close()
except:
pass
print("连接已关闭")
async def main():
"""测试函数"""
config = DoubaoConfig()
client = DoubaoClient(config)
try:
await client.connect()
await client.test_full_dialog()
except Exception as e:
print(f"测试失败: {e}")
import traceback
traceback.print_exc()
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(main())

412
doubao_original_test.py Normal file
View File

@ -0,0 +1,412 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块 - 基于原始代码的测试版本
直接使用原始豆包代码的核心逻辑
"""
import asyncio
import gzip
import json
import uuid
import wave
import struct
import time
import os
from typing import Dict, Any, Optional
import websockets
# 直接复制原始豆包代码的协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
MSG_WITH_EVENT = 0b0100
NO_SERIALIZATION = 0b0000
JSON = 0b0001
GZIP = 0b0001
def generate_header(
version=PROTOCOL_VERSION,
message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON,
compression_type=GZIP,
reserved_data=0x00,
extension_header=bytes()
):
"""直接复制原始豆包代码的generate_header函数"""
header = bytearray()
header_size = int(len(extension_header) / 4) + 1
header.append((version << 4) | header_size)
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(reserved_data)
header.extend(extension_header)
return header
class DoubaoConfig:
"""豆包配置"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
class DoubaoClient:
"""豆包客户端 - 基于原始代码"""
def __init__(self, config: DoubaoConfig):
self.config = config
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"连接豆包服务器: {self.config.base_url}")
try:
self.ws = await websockets.connect(
self.config.base_url,
additional_headers=self.config.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
except Exception as e:
print(f"连接失败: {e}")
raise
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
print("发送StartConnection请求...")
request = bytearray(generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartConnection响应长度: {len(response)}")
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
print("发送StartSession请求...")
session_config = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 30,
"input_mod": "audio",
},
},
}
request = bytearray(generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartSession响应长度: {len(response)}")
# 等待一会确保会话完全建立
await asyncio.sleep(1.0)
async def task_request(self, audio: bytes) -> None:
"""直接复制原始豆包代码的task_request方法"""
task_request = bytearray(
generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
serial_method=NO_SERIALIZATION))
task_request.extend(int(200).to_bytes(4, 'big'))
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
task_request.extend(str.encode(self.session_id))
payload_bytes = gzip.compress(audio)
task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
task_request.extend(payload_bytes)
await self.ws.send(task_request)
async def test_audio_request(self) -> None:
"""测试音频请求"""
print("测试音频请求...")
# 读取真实的录音文件
try:
import wave
with wave.open("recording_20250920_135137.wav", 'rb') as wf:
# 读取前10秒的音频数据16000采样率 * 10秒 = 160000帧
total_frames = wf.getnframes()
frames_to_read = min(total_frames, 160000) # 最多10秒
small_audio = wf.readframes(frames_to_read)
print(f"读取真实音频数据: {len(small_audio)} 字节")
print(f"音频参数: 采样率={wf.getframerate()}, 通道数={wf.getnchannels()}, 采样宽度={wf.getsampwidth()}")
print(f"总帧数: {total_frames}, 读取帧数: {frames_to_read}")
except Exception as e:
print(f"读取音频文件失败: {e}")
# 如果读取失败,使用静音数据
small_audio = b'\x00' * 3200
print(f"音频数据大小: {len(small_audio)}")
try:
# 发送完整的音频数据块
print(f"发送完整的音频数据块...")
await self.task_request(small_audio)
print(f"音频数据块发送成功")
print("等待语音识别响应...")
# 等待更长时间的响应(语音识别可能需要更长时间)
response = await asyncio.wait_for(self.ws.recv(), timeout=15.0)
print(f"收到响应,长度: {len(response)}")
# 解析响应
try:
if len(response) >= 4:
protocol_version = response[0] >> 4
header_size = response[0] & 0x0f
message_type = response[1] >> 4
message_type_specific_flags = response[1] & 0x0f
print(f"响应协议信息: version={protocol_version}, header_size={header_size}, message_type={message_type}, flags={message_type_specific_flags}")
# 解析payload
payload_start = header_size * 4
payload = response[payload_start:]
if message_type == 9: # SERVER_FULL_RESPONSE
print("收到SERVER_FULL_RESPONSE")
if len(payload) >= 4:
# 解析event
event = int.from_bytes(payload[:4], 'big')
print(f"Event: {event}")
# 解析session_id
if len(payload) >= 8:
session_id_len = int.from_bytes(payload[4:8], 'big')
if len(payload) >= 8 + session_id_len:
session_id = payload[8:8+session_id_len].decode()
print(f"Session ID: {session_id}")
# 解析payload size和data
if len(payload) >= 12 + session_id_len:
payload_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big')
payload_data = payload[12+session_id_len:12+session_id_len+payload_size]
print(f"Payload size: {payload_size}")
# 如果包含音频数据,保存到文件
if len(payload_data) > 0:
print(f"收到数据: {len(payload_data)} 字节")
# 保存原始音频数据
with open('response_audio.raw', 'wb') as f:
f.write(payload_data)
print("音频数据已保存到 response_audio.raw")
# 尝试解析JSON数据
try:
import json
json_data = json.loads(payload_data.decode('utf-8'))
print(f"JSON数据: {json_data}")
# 如果是语音识别任务开始,继续等待音频响应
if 'asr_task_id' in json_data:
print("语音识别任务开始,继续等待音频响应...")
try:
# 等待音频响应
audio_response = await asyncio.wait_for(self.ws.recv(), timeout=20.0)
print(f"收到音频响应,长度: {len(audio_response)}")
# 解析音频响应
if len(audio_response) >= 4:
audio_version = audio_response[0] >> 4
audio_header_size = audio_response[0] & 0x0f
audio_message_type = audio_response[1] >> 4
audio_flags = audio_response[1] & 0x0f
print(f"音频响应协议信息: version={audio_version}, header_size={audio_header_size}, message_type={audio_message_type}, flags={audio_flags}")
if audio_message_type == 9: # SERVER_FULL_RESPONSE (包含TTS音频)
audio_payload_start = audio_header_size * 4
audio_payload = audio_response[audio_payload_start:]
if len(audio_payload) >= 12:
# 解析event和session_id
audio_event = int.from_bytes(audio_payload[:4], 'big')
audio_session_len = int.from_bytes(audio_payload[4:8], 'big')
audio_session = audio_payload[8:8+audio_session_len].decode()
audio_data_size = int.from_bytes(audio_payload[8+audio_session_len:12+audio_session_len], 'big')
audio_data = audio_payload[12+audio_session_len:12+audio_session_len+audio_data_size]
print(f"音频Event: {audio_event}")
print(f"音频数据大小: {audio_data_size}")
if audio_data_size > 0:
# 保存原始音频数据
with open('tts_response_audio.raw', 'wb') as f:
f.write(audio_data)
print(f"TTS音频数据已保存到 tts_response_audio.raw")
# 尝试解析音频数据可能是JSON或GZIP压缩的音频
try:
# 首先尝试解压缩
import gzip
decompressed_audio = gzip.decompress(audio_data)
print(f"解压缩后音频数据大小: {len(decompressed_audio)}")
with open('tts_response_audio_decompressed.raw', 'wb') as f:
f.write(decompressed_audio)
print("解压缩的音频数据已保存")
# 创建WAV文件供树莓派播放
import wave
import struct
# 豆包返回的音频是24000Hz, 16-bit, 单声道
sample_rate = 24000
channels = 1
sampwidth = 2 # 16-bit = 2 bytes
with wave.open('tts_response.wav', 'wb') as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(sampwidth)
wav_file.setframerate(sample_rate)
wav_file.writeframes(decompressed_audio)
print("已创建WAV文件: tts_response.wav")
print(f"音频参数: {sample_rate}Hz, {channels}通道, {sampwidth*8}-bit")
except Exception as audio_e:
print(f"音频数据处理失败: {audio_e}")
# 如果解压缩失败,直接保存原始数据
with open('tts_response_audio_original.raw', 'wb') as f:
f.write(audio_data)
elif audio_message_type == 11: # SERVER_ACK
print("收到SERVER_ACK音频响应")
# 处理SERVER_ACK格式的音频响应
audio_payload_start = audio_header_size * 4
audio_payload = audio_response[audio_payload_start:]
print(f"音频payload长度: {len(audio_payload)}")
with open('tts_response_ack.raw', 'wb') as f:
f.write(audio_payload)
except asyncio.TimeoutError:
print("等待音频响应超时")
except Exception as json_e:
print(f"解析JSON失败: {json_e}")
# 如果不是JSON可能是音频数据直接保存
with open('response_audio.raw', 'wb') as f:
f.write(payload_data)
elif message_type == 11: # SERVER_ACK
print("收到SERVER_ACK响应")
elif message_type == 15: # SERVER_ERROR_RESPONSE
print("收到错误响应")
if len(response) > 8:
error_code = int.from_bytes(response[4:8], 'big')
print(f"错误代码: {error_code}")
except Exception as e:
print(f"解析响应失败: {e}")
import traceback
traceback.print_exc()
except asyncio.TimeoutError:
print("等待响应超时")
except websockets.exceptions.ConnectionClosed as e:
print(f"连接关闭: {e}")
except Exception as e:
print(f"发送音频请求失败: {e}")
raise
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
await self.ws.close()
except:
pass
print("连接已关闭")
async def main():
"""测试函数"""
config = DoubaoConfig()
client = DoubaoClient(config)
try:
await client.connect()
await client.test_audio_request()
except Exception as e:
print(f"测试失败: {e}")
import traceback
traceback.print_exc()
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(main())

412
doubao_simple.py Normal file
View File

@ -0,0 +1,412 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块 - 最终简化版本
实现音频文件上传和TTS音频下载的完整流程
"""
import asyncio
import gzip
import json
import uuid
import wave
import time
import os
from typing import Dict, Any, Optional
import websockets
# 协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
MSG_WITH_EVENT = 0b0100
NO_SERIALIZATION = 0b0000
JSON = 0b0001
GZIP = 0b0001
def generate_header(
version=PROTOCOL_VERSION,
message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON,
compression_type=GZIP,
reserved_data=0x00,
extension_header=bytes()
):
"""生成协议头"""
header = bytearray()
header_size = int(len(extension_header) / 4) + 1
header.append((version << 4) | header_size)
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(reserved_data)
header.extend(extension_header)
return header
class DoubaoClient:
"""豆包客户端"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"连接豆包服务器: {self.base_url}")
try:
self.ws = await websockets.connect(
self.base_url,
additional_headers=self.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
except Exception as e:
print(f"连接失败: {e}")
raise
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
print("发送StartConnection请求...")
request = bytearray(generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartConnection响应长度: {len(response)}")
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
print("发送StartSession请求...")
session_config = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 30,
"input_mod": "audio",
},
},
}
request = bytearray(generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartSession响应长度: {len(response)}")
# 等待一会确保会话完全建立
await asyncio.sleep(1.0)
async def task_request(self, audio: bytes) -> None:
"""发送音频数据"""
task_request = bytearray(
generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
serial_method=NO_SERIALIZATION))
task_request.extend(int(200).to_bytes(4, 'big'))
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
task_request.extend(str.encode(self.session_id))
payload_bytes = gzip.compress(audio)
task_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
task_request.extend(payload_bytes)
await self.ws.send(task_request)
def parse_response(self, response):
"""解析响应"""
if len(response) < 4:
return None
protocol_version = response[0] >> 4
header_size = response[0] & 0x0f
message_type = response[1] >> 4
flags = response[1] & 0x0f
payload_start = header_size * 4
payload = response[payload_start:]
result = {
'protocol_version': protocol_version,
'header_size': header_size,
'message_type': message_type,
'flags': flags,
'payload': payload,
'payload_size': len(payload)
}
# 解析payload
if len(payload) >= 4:
result['event'] = int.from_bytes(payload[:4], 'big')
if len(payload) >= 8:
session_id_len = int.from_bytes(payload[4:8], 'big')
if len(payload) >= 8 + session_id_len:
result['session_id'] = payload[8:8+session_id_len].decode()
if len(payload) >= 12 + session_id_len:
data_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big')
result['data_size'] = data_size
result['data'] = payload[12+session_id_len:12+session_id_len+data_size]
# 尝试解析JSON数据
try:
result['json_data'] = json.loads(result['data'].decode('utf-8'))
except:
pass
return result
async def process_audio_file(self, input_file: str, output_file: str) -> bool:
"""处理音频文件上传并获得TTS响应"""
print(f"开始处理音频文件: {input_file}")
try:
# 读取输入音频文件
with wave.open(input_file, 'rb') as wf:
audio_data = wf.readframes(wf.getnframes())
print(f"读取音频数据: {len(audio_data)} 字节")
print(f"音频参数: {wf.getframerate()}Hz, {wf.getnchannels()}通道, {wf.getsampwidth()*8}-bit")
# 发送音频数据
print("发送音频数据...")
await self.task_request(audio_data)
print("音频数据发送成功")
# 接收响应序列
print("开始接收响应...")
audio_chunks = []
response_count = 0
max_responses = 20
while response_count < max_responses:
try:
response = await asyncio.wait_for(self.ws.recv(), timeout=30.0)
response_count += 1
parsed = self.parse_response(response)
if not parsed:
continue
print(f"响应 {response_count}: message_type={parsed['message_type']}, event={parsed.get('event', 'N/A')}, size={parsed['payload_size']}")
# 处理不同类型的响应
if parsed['message_type'] == 11: # SERVER_ACK - 可能包含音频
if 'data' in parsed and parsed['data_size'] > 0:
audio_chunks.append(parsed['data'])
print(f"收集到音频块: {parsed['data_size']} 字节")
elif parsed['message_type'] == 9: # SERVER_FULL_RESPONSE
event = parsed.get('event', 0)
if event == 350: # TTS开始
print("TTS音频生成开始")
elif event == 359: # TTS结束
print("TTS音频生成结束")
break
elif event == 451: # ASR结果
if 'json_data' in parsed and 'results' in parsed['json_data']:
text = parsed['json_data']['results'][0].get('text', '')
print(f"语音识别结果: {text}")
elif event == 550: # TTS音频数据
if 'data' in parsed and parsed['data_size'] > 0:
# 检查是否是JSON音频元数据还是实际音频数据
try:
json.loads(parsed['data'].decode('utf-8'))
print("收到TTS音频元数据")
except:
# 不是JSON可能是音频数据
audio_chunks.append(parsed['data'])
print(f"收集到TTS音频块: {parsed['data_size']} 字节")
except asyncio.TimeoutError:
print(f"等待响应 {response_count + 1} 超时")
break
except websockets.exceptions.ConnectionClosed:
print("连接已关闭")
break
print(f"共收到 {response_count} 个响应,收集到 {len(audio_chunks)} 个音频块")
# 合并音频数据
if audio_chunks:
combined_audio = b''.join(audio_chunks)
print(f"合并后的音频数据: {len(combined_audio)} 字节")
# 检查是否是GZIP压缩数据
try:
decompressed = gzip.decompress(combined_audio)
print(f"解压缩后音频数据: {len(decompressed)} 字节")
audio_to_write = decompressed
except:
print("音频数据不是GZIP压缩格式直接使用原始数据")
audio_to_write = combined_audio
# 创建输出WAV文件
try:
# 豆包返回的音频是32位浮点格式需要转换为16位整数
import struct
# 检查音频数据长度是否是4的倍数32位浮点
if len(audio_to_write) % 4 != 0:
print(f"警告:音频数据长度 {len(audio_to_write)} 不是4的倍数截断到最近的倍数")
audio_to_write = audio_to_write[:len(audio_to_write) // 4 * 4]
# 将32位浮点转换为16位整数
float_count = len(audio_to_write) // 4
int16_data = bytearray(float_count * 2)
for i in range(float_count):
# 读取32位浮点数小端序
float_value = struct.unpack('<f', audio_to_write[i*4:i*4+4])[0]
# 将浮点数限制在[-1.0, 1.0]范围内
float_value = max(-1.0, min(1.0, float_value))
# 转换为16位整数
int16_value = int(float_value * 32767)
# 写入16位整数小端序
int16_data[i*2:i*2+2] = struct.pack('<h', int16_value)
# 创建WAV文件
with wave.open(output_file, 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(24000)
wav_file.writeframes(int16_data)
print(f"成功创建输出文件: {output_file}")
print(f"音频转换: {float_count} 个32位浮点样本 -> {len(int16_data)//2} 个16位整数样本")
# 显示文件信息
if os.path.exists(output_file):
file_size = os.path.getsize(output_file)
duration = file_size / (24000 * 1 * 2)
print(f"输出文件大小: {file_size} 字节,时长: {duration:.2f}")
return True
except Exception as e:
print(f"创建WAV文件失败: {e}")
# 保存原始数据
with open(output_file + '.raw', 'wb') as f:
f.write(audio_to_write)
print(f"原始音频数据已保存到: {output_file}.raw")
return False
else:
print("未收到音频数据")
return False
except Exception as e:
print(f"处理音频文件失败: {e}")
import traceback
traceback.print_exc()
return False
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
await self.ws.close()
except:
pass
print("连接已关闭")
async def main():
"""主函数"""
client = DoubaoClient()
try:
await client.connect()
# 处理录音文件
input_file = "recording_20250920_135137.wav"
output_file = "tts_output.wav"
success = await client.process_audio_file(input_file, output_file)
if success:
print("音频处理成功!")
else:
print("音频处理失败")
except Exception as e:
print(f"程序失败: {e}")
import traceback
traceback.print_exc()
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(main())

303
doubao_test.py Normal file
View File

@ -0,0 +1,303 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块 - 协议测试版本
专门测试协议格式问题
"""
import asyncio
import gzip
import json
import uuid
import wave
import struct
import time
import os
from typing import Dict, Any, Optional
import websockets
class DoubaoConfig:
"""豆包配置"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
class DoubaoProtocol:
"""豆包协议处理"""
# 协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
MSG_WITH_EVENT = 0b0100
JSON = 0b0001
NO_SERIALIZATION = 0b0000
GZIP = 0b0001
@classmethod
def generate_header(cls, message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON, compression_type=GZIP) -> bytes:
"""生成协议头"""
header = bytearray()
header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(0x00) # reserved
return bytes(header)
class DoubaoClient:
"""豆包客户端"""
def __init__(self, config: DoubaoConfig):
self.config = config
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"连接豆包服务器: {self.config.base_url}")
try:
self.ws = await websockets.connect(
self.config.base_url,
additional_headers=self.config.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
except Exception as e:
print(f"连接失败: {e}")
raise
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
print("发送StartConnection请求...")
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartConnection响应长度: {len(response)}")
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
print("发送StartSession请求...")
session_config = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 30,
"input_mod": "audio",
},
},
}
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartSession响应长度: {len(response)}")
# 等待一会确保会话完全建立
await asyncio.sleep(1.0)
async def test_audio_request(self) -> None:
"""测试音频请求格式"""
print("测试音频请求格式...")
# 创建音频数据(静音)- 使用原始豆包代码的chunk大小
small_audio = b'\x00' * 3200 # 原始豆包代码中的chunk大小
# 完全按照原始豆包代码的格式构建请求,不进行任何填充
header = bytearray()
header.append((DoubaoProtocol.PROTOCOL_VERSION << 4) | 1) # version + header_size
header.append((DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST << 4) | DoubaoProtocol.NO_SEQUENCE)
header.append((DoubaoProtocol.NO_SERIALIZATION << 4) | DoubaoProtocol.GZIP)
header.append(0x00) # reserved
request = bytearray(header)
# 添加消息类型 (200 = task request)
request.extend(int(200).to_bytes(4, 'big'))
# 添加session_id
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
# 压缩音频数据
compressed_audio = gzip.compress(small_audio)
# 添加payload size
request.extend(len(compressed_audio).to_bytes(4, 'big'))
# 添加压缩后的音频数据
request.extend(compressed_audio)
print(f"测试请求详细信息:")
print(f" - 音频原始大小: {len(small_audio)}")
print(f" - 音频压缩后大小: {len(compressed_audio)}")
print(f" - Session ID: {self.session_id} (长度: {len(self.session_id)})")
print(f" - 总请求大小: {len(request)}")
print(f" - 头部字节: {request[:4].hex()}")
print(f" - 消息类型: {int.from_bytes(request[4:8], 'big')}")
print(f" - Session ID长度: {int.from_bytes(request[8:12], 'big')}")
print(f" - Payload size: {int.from_bytes(request[12+len(self.session_id):16+len(self.session_id)], 'big')}")
try:
await self.ws.send(request)
print("请求发送成功")
# 等待响应
response = await asyncio.wait_for(self.ws.recv(), timeout=3.0)
print(f"收到响应,长度: {len(response)}")
# 尝试解析响应
try:
protocol_version = response[0] >> 4
header_size = response[0] & 0x0f
message_type = response[1] >> 4
message_type_specific_flags = response[1] & 0x0f
serialization_method = response[2] >> 4
message_compression = response[2] & 0x0f
print(f"响应协议信息:")
print(f" - version={protocol_version}")
print(f" - header_size={header_size}")
print(f" - message_type={message_type} (15=SERVER_ERROR_RESPONSE)")
print(f" - message_type_specific_flags={message_type_specific_flags}")
print(f" - serialization_method={serialization_method}")
print(f" - message_compression={message_compression}")
# 解析payload
payload = response[header_size * 4:]
if message_type == 15: # SERVER_ERROR_RESPONSE
if len(payload) >= 8:
code = int.from_bytes(payload[:4], "big", signed=False)
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
print(f" - 错误代码: {code}")
print(f" - payload大小: {payload_size}")
if len(payload) >= 8 + payload_size:
payload_msg = payload[8:8 + payload_size]
print(f" - payload长度: {len(payload_msg)}")
if message_compression == 1: # GZIP
try:
payload_msg = gzip.decompress(payload_msg)
print(f" - 解压缩后长度: {len(payload_msg)}")
except:
pass
try:
error_msg = json.loads(payload_msg.decode('utf-8'))
print(f" - 错误信息: {error_msg}")
except:
print(f" - 原始payload: {payload_msg}")
except Exception as e:
print(f"解析响应失败: {e}")
import traceback
traceback.print_exc()
except Exception as e:
print(f"发送测试请求失败: {e}")
raise
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
await self.ws.close()
except:
pass
print("连接已关闭")
async def main():
"""测试函数"""
config = DoubaoConfig()
client = DoubaoClient(config)
try:
await client.connect()
await client.test_audio_request()
except Exception as e:
print(f"测试失败: {e}")
import traceback
traceback.print_exc()
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(main())

113
test_doubao.py Normal file
View File

@ -0,0 +1,113 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块 - 验证脚本
验证完整的音频处理流程
"""
import asyncio
import subprocess
import os
from doubao_simple import DoubaoClient
async def test_complete_workflow():
"""测试完整的工作流程"""
print("=== 豆包音频处理模块验证 ===")
# 检查输入文件
input_file = "recording_20250920_135137.wav"
if not os.path.exists(input_file):
print(f"❌ 输入文件不存在: {input_file}")
return False
print(f"✅ 输入文件存在: {input_file}")
# 检查文件信息
try:
result = subprocess.run(['file', input_file], capture_output=True, text=True)
print(f"📁 输入文件格式: {result.stdout.strip()}")
except:
pass
# 初始化客户端
client = DoubaoClient()
try:
# 连接服务器
print("🔌 连接豆包服务器...")
await client.connect()
print("✅ 连接成功")
# 处理音频文件
output_file = "tts_output.wav"
print(f"🎵 处理音频文件: {input_file} -> {output_file}")
success = await client.process_audio_file(input_file, output_file)
if success:
print("✅ 音频处理成功!")
# 检查输出文件
if os.path.exists(output_file):
result = subprocess.run(['file', output_file], capture_output=True, text=True)
print(f"📁 输出文件格式: {result.stdout.strip()}")
# 获取文件大小
file_size = os.path.getsize(output_file)
print(f"📊 输出文件大小: {file_size:,} 字节")
# 测试播放
print("🔊 测试播放输出文件...")
try:
subprocess.run(['aplay', output_file], timeout=10, check=True)
print("✅ 播放成功")
except subprocess.TimeoutExpired:
print("✅ 播放完成(超时是正常的)")
except subprocess.CalledProcessError as e:
print(f"⚠️ 播放出现问题: {e}")
except FileNotFoundError:
print("⚠️ aplay命令不存在跳过播放测试")
return True
else:
print("❌ 输出文件未生成")
return False
else:
print("❌ 音频处理失败")
return False
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return False
finally:
try:
await client.close()
except:
pass
def main():
"""主函数"""
print("开始验证豆包音频处理模块...")
success = asyncio.run(test_complete_workflow())
if success:
print("\n🎉 验证完成!豆包音频处理模块工作正常。")
print("\n📋 功能总结:")
print(" ✅ WebSocket连接建立")
print(" ✅ 音频文件上传")
print(" ✅ 语音识别")
print(" ✅ TTS音频生成")
print(" ✅ 音频格式转换Float32 -> Int16")
print(" ✅ WAV文件生成")
print(" ✅ 树莓派兼容播放")
else:
print("\n❌ 验证失败,请检查错误信息。")
return success
if __name__ == "__main__":
main()