Local-Voice/doubao_debug.py
2025-09-20 14:35:54 +08:00

540 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块 - 调试版本
添加更多调试信息和错误处理
"""
import asyncio
import gzip
import json
import uuid
import wave
import struct
import time
import os
from typing import Dict, Any, Optional
import websockets
class DoubaoConfig:
"""豆包配置"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
class DoubaoProtocol:
"""豆包协议处理"""
# 协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
MSG_WITH_EVENT = 0b0100
JSON = 0b0001
NO_SERIALIZATION = 0b0000
GZIP = 0b0001
@classmethod
def generate_header(cls, message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON, compression_type=GZIP) -> bytes:
"""生成协议头"""
header = bytearray()
header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(0x00) # reserved
return bytes(header)
@classmethod
def parse_response(cls, res: bytes) -> Dict[str, Any]:
"""解析响应"""
if isinstance(res, str):
return {}
try:
protocol_version = res[0] >> 4
header_size = res[0] & 0x0f
message_type = res[1] >> 4
message_type_specific_flags = res[1] & 0x0f
serialization_method = res[2] >> 4
message_compression = res[2] & 0x0f
payload = res[header_size * 4:]
result = {}
if message_type == cls.SERVER_FULL_RESPONSE or message_type == cls.SERVER_ACK:
result['message_type'] = 'SERVER_FULL_RESPONSE'
if message_type == cls.SERVER_ACK:
result['message_type'] = 'SERVER_ACK'
start = 0
if message_type_specific_flags & cls.MSG_WITH_EVENT:
result['event'] = int.from_bytes(payload[:4], "big", signed=False)
start += 4
payload = payload[start:]
if len(payload) < 4:
result['error'] = 'Payload too short for session_id'
return result
session_id_size = int.from_bytes(payload[:4], "big", signed=True)
if session_id_size < 0 or session_id_size > len(payload) - 4:
result['error'] = f'Invalid session_id size: {session_id_size}'
return result
session_id = payload[4:session_id_size+4]
result['session_id'] = str(session_id)
payload = payload[4 + session_id_size:]
if len(payload) < 4:
result['error'] = 'Payload too short for payload_size'
return result
payload_size = int.from_bytes(payload[:4], "big", signed=False)
result['payload_size'] = payload_size
if len(payload) >= 4 + payload_size:
payload_msg = payload[4:4 + payload_size]
if payload_msg:
if message_compression == cls.GZIP:
try:
payload_msg = gzip.decompress(payload_msg)
except Exception as e:
result['decompress_error'] = str(e)
return result
if serialization_method == cls.JSON:
try:
payload_msg = json.loads(str(payload_msg, "utf-8"))
except Exception as e:
result['json_error'] = str(e)
payload_msg = str(payload_msg, "utf-8")
elif serialization_method != cls.NO_SERIALIZATION:
payload_msg = str(payload_msg, "utf-8")
result['payload_msg'] = payload_msg
elif message_type == cls.SERVER_ERROR_RESPONSE:
if len(payload) >= 8:
code = int.from_bytes(payload[:4], "big", signed=False)
result['code'] = code
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
result['payload_size'] = payload_size
if len(payload) >= 8 + payload_size:
payload_msg = payload[8:8 + payload_size]
if payload_msg and message_compression == cls.GZIP:
try:
payload_msg = gzip.decompress(payload_msg)
except:
pass
result['payload_msg'] = payload_msg
except Exception as e:
result['parse_error'] = str(e)
return result
class AudioProcessor:
"""音频处理器"""
@staticmethod
def read_wav_file(file_path: str) -> tuple:
"""读取WAV文件返回音频数据和参数"""
with wave.open(file_path, 'rb') as wf:
# 获取音频参数
channels = wf.getnchannels()
sampwidth = wf.getsampwidth()
framerate = wf.getframerate()
nframes = wf.getnframes()
# 读取音频数据
audio_data = wf.readframes(nframes)
return audio_data, {
'channels': channels,
'sampwidth': sampwidth,
'framerate': framerate,
'nframes': nframes
}
@staticmethod
def create_wav_file(audio_data: bytes, output_path: str,
sample_rate: int = 24000, channels: int = 1,
sampwidth: int = 2) -> None:
"""创建WAV文件适配树莓派播放"""
with wave.open(output_path, 'wb') as wf:
wf.setnchannels(channels)
wf.setsampwidth(sampwidth)
wf.setframerate(sample_rate)
wf.writeframes(audio_data)
class DoubaoClient:
"""豆包客户端"""
def __init__(self, config: DoubaoConfig):
self.config = config
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"连接豆包服务器: {self.config.base_url}")
try:
self.ws = await websockets.connect(
self.config.base_url,
additional_headers=self.config.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
except Exception as e:
print(f"连接失败: {e}")
raise
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
print("发送StartConnection请求...")
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
parsed_response = DoubaoProtocol.parse_response(response)
print(f"StartConnection响应: {parsed_response}")
# 检查是否有错误
if 'error' in parsed_response:
raise Exception(f"StartConnection解析错误: {parsed_response['error']}")
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
print("发送StartSession请求...")
session_config = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 30, # 增加超时时间
"input_mod": "audio_file", # 使用音频文件模式
},
},
}
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
parsed_response = DoubaoProtocol.parse_response(response)
print(f"StartSession响应: {parsed_response}")
# 检查是否有错误
if 'error' in parsed_response:
raise Exception(f"StartSession解析错误: {parsed_response['error']}")
# 等待一会确保会话完全建立
await asyncio.sleep(1.0)
async def send_audio_file(self, file_path: str) -> bytes:
"""发送音频文件并返回响应音频"""
print(f"处理音频文件: {file_path}")
# 读取音频文件
audio_data, audio_info = AudioProcessor.read_wav_file(file_path)
print(f"音频参数: {audio_info}")
# 计算分块大小减小到50ms避免数据块过大
chunk_size = int(audio_info['framerate'] * audio_info['channels'] *
audio_info['sampwidth'] * 0.05) # 50ms
# 分块发送音频数据
total_chunks = (len(audio_data) + chunk_size - 1) // chunk_size
print(f"开始分块发送音频,共 {total_chunks}")
received_audio = b""
error_count = 0
session_active = True
for i in range(0, len(audio_data), chunk_size):
if not session_active:
print("会话已结束,停止发送")
break
chunk = audio_data[i:i + chunk_size]
is_last = (i + chunk_size >= len(audio_data))
# 发送音频块
await self._send_audio_chunk(chunk, is_last)
# 接收响应
try:
response = await asyncio.wait_for(self.ws.recv(), timeout=3.0)
parsed_response = DoubaoProtocol.parse_response(response)
print(f"响应 {i//chunk_size + 1}/{total_chunks}: {parsed_response}")
# 检查是否是错误响应
if 'code' in parsed_response and parsed_response['code'] != 0:
print(f"服务器返回错误: {parsed_response}")
error_count += 1
if error_count > 3:
raise Exception(f"服务器连续返回错误: {parsed_response}")
continue
# 处理音频响应
if (parsed_response.get('message_type') == 'SERVER_ACK' and
isinstance(parsed_response.get('payload_msg'), bytes)):
audio_chunk = parsed_response['payload_msg']
received_audio += audio_chunk
print(f"接收到音频数据块,大小: {len(audio_chunk)} 字节")
# 检查会话状态
event = parsed_response.get('event')
if event in [359, 152, 153]: # 这些事件表示会话相关状态
print(f"会话事件: {event}")
if event in [152, 153]: # 会话结束
print("检测到会话结束事件")
session_active = False
break
except asyncio.TimeoutError:
print("等待响应超时,继续发送")
# 模拟实时发送的延迟
await asyncio.sleep(0.1)
print("音频文件发送完成")
return received_audio
async def _send_audio_chunk(self, audio_data: bytes, is_last: bool = False) -> None:
"""发送音频块"""
request = bytearray(
DoubaoProtocol.generate_header(
message_type=DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST,
message_type_specific_flags=DoubaoProtocol.NO_SEQUENCE,
serial_method=DoubaoProtocol.NO_SERIALIZATION, # 音频数据不需要序列化
compression_type=DoubaoProtocol.GZIP
)
)
request.extend(int(200).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
# 压缩音频数据
compressed_audio = gzip.compress(audio_data)
payload_size = len(compressed_audio)
request.extend(payload_size.to_bytes(4, 'big')) # payload size(4 bytes)
request.extend(compressed_audio)
print(f"发送音频块 - 原始大小: {len(audio_data)}, 压缩后大小: {payload_size}, 总请求数据大小: {len(request)}")
await self.ws.send(request)
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
# 发送FinishSession
await self._send_finish_session()
# 发送FinishConnection
await self._send_finish_connection()
except Exception as e:
print(f"关闭会话时出错: {e}")
finally:
# 确保WebSocket连接关闭
try:
await self.ws.close()
except:
pass
print("连接已关闭")
async def _send_finish_session(self) -> None:
"""发送FinishSession请求"""
print("发送FinishSession请求...")
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(102).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
async def _send_finish_connection(self) -> None:
"""发送FinishConnection请求"""
print("发送FinishConnection请求...")
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(2).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
try:
response = await asyncio.wait_for(self.ws.recv(), timeout=5.0)
parsed_response = DoubaoProtocol.parse_response(response)
print(f"FinishConnection响应: {parsed_response}")
except asyncio.TimeoutError:
print("FinishConnection响应超时")
class DoubaoProcessor:
"""豆包音频处理器"""
def __init__(self):
self.config = DoubaoConfig()
self.client = DoubaoClient(self.config)
async def process_audio_file(self, input_file: str, output_file: str = None) -> str:
"""处理音频文件
Args:
input_file: 输入音频文件路径
output_file: 输出音频文件路径如果为None则自动生成
Returns:
输出音频文件路径
"""
if not os.path.exists(input_file):
raise FileNotFoundError(f"音频文件不存在: {input_file}")
# 生成输出文件名
if output_file is None:
timestamp = time.strftime("%Y%m%d_%H%M%S")
output_file = f"doubao_response_{timestamp}.wav"
try:
# 连接豆包服务器
await self.client.connect()
# 发送音频文件并获取响应
received_audio = await self.client.send_audio_file(input_file)
if received_audio:
print(f"总共接收到音频数据: {len(received_audio)} 字节")
# 转换为WAV格式保存适配树莓派播放
AudioProcessor.create_wav_file(
received_audio,
output_file,
sample_rate=24000, # 豆包返回的音频采样率
channels=1,
sampwidth=2 # 16-bit
)
print(f"响应音频已保存到: {output_file}")
# 显示文件信息
file_size = os.path.getsize(output_file)
print(f"输出文件大小: {file_size} 字节")
else:
print("警告: 未接收到音频响应")
return output_file
except Exception as e:
print(f"处理音频文件时出错: {e}")
import traceback
traceback.print_exc()
raise
finally:
await self.client.close()
async def main():
"""测试函数"""
import argparse
parser = argparse.ArgumentParser(description="豆包音频处理测试")
parser.add_argument("--input", type=str, required=True, help="输入音频文件路径")
parser.add_argument("--output", type=str, help="输出音频文件路径")
args = parser.parse_args()
processor = DoubaoProcessor()
try:
output_file = await processor.process_audio_file(args.input, args.output)
print(f"处理完成,输出文件: {output_file}")
except Exception as e:
print(f"处理失败: {e}")
if __name__ == "__main__":
asyncio.run(main())