Local-Voice/doubao_original_test.py
2025-09-20 14:35:54 +08:00

412 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块 - 基于原始代码的测试版本
直接使用原始豆包代码的核心逻辑
"""
import asyncio
import gzip
import json
import uuid
import wave
import struct
import time
import os
from typing import Dict, Any, Optional
import websockets
# 直接复制原始豆包代码的协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
MSG_WITH_EVENT = 0b0100
NO_SERIALIZATION = 0b0000
JSON = 0b0001
GZIP = 0b0001
def generate_header(
version=PROTOCOL_VERSION,
message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON,
compression_type=GZIP,
reserved_data=0x00,
extension_header=bytes()
):
"""直接复制原始豆包代码的generate_header函数"""
header = bytearray()
header_size = int(len(extension_header) / 4) + 1
header.append((version << 4) | header_size)
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(reserved_data)
header.extend(extension_header)
return header
class DoubaoConfig:
"""豆包配置"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
class DoubaoClient:
"""豆包客户端 - 基于原始代码"""
def __init__(self, config: DoubaoConfig):
self.config = config
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"连接豆包服务器: {self.config.base_url}")
try:
self.ws = await websockets.connect(
self.config.base_url,
additional_headers=self.config.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
except Exception as e:
print(f"连接失败: {e}")
raise
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
print("发送StartConnection请求...")
request = bytearray(generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartConnection响应长度: {len(response)}")
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
print("发送StartSession请求...")
session_config = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 30,
"input_mod": "audio",
},
},
}
request = bytearray(generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartSession响应长度: {len(response)}")
# 等待一会确保会话完全建立
await asyncio.sleep(1.0)
async def task_request(self, audio: bytes) -> None:
"""直接复制原始豆包代码的task_request方法"""
task_request = bytearray(
generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
serial_method=NO_SERIALIZATION))
task_request.extend(int(200).to_bytes(4, 'big'))
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
task_request.extend(str.encode(self.session_id))
payload_bytes = gzip.compress(audio)
task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
task_request.extend(payload_bytes)
await self.ws.send(task_request)
async def test_audio_request(self) -> None:
"""测试音频请求"""
print("测试音频请求...")
# 读取真实的录音文件
try:
import wave
with wave.open("recording_20250920_135137.wav", 'rb') as wf:
# 读取前10秒的音频数据16000采样率 * 10秒 = 160000帧
total_frames = wf.getnframes()
frames_to_read = min(total_frames, 160000) # 最多10秒
small_audio = wf.readframes(frames_to_read)
print(f"读取真实音频数据: {len(small_audio)} 字节")
print(f"音频参数: 采样率={wf.getframerate()}, 通道数={wf.getnchannels()}, 采样宽度={wf.getsampwidth()}")
print(f"总帧数: {total_frames}, 读取帧数: {frames_to_read}")
except Exception as e:
print(f"读取音频文件失败: {e}")
# 如果读取失败,使用静音数据
small_audio = b'\x00' * 3200
print(f"音频数据大小: {len(small_audio)}")
try:
# 发送完整的音频数据块
print(f"发送完整的音频数据块...")
await self.task_request(small_audio)
print(f"音频数据块发送成功")
print("等待语音识别响应...")
# 等待更长时间的响应(语音识别可能需要更长时间)
response = await asyncio.wait_for(self.ws.recv(), timeout=15.0)
print(f"收到响应,长度: {len(response)}")
# 解析响应
try:
if len(response) >= 4:
protocol_version = response[0] >> 4
header_size = response[0] & 0x0f
message_type = response[1] >> 4
message_type_specific_flags = response[1] & 0x0f
print(f"响应协议信息: version={protocol_version}, header_size={header_size}, message_type={message_type}, flags={message_type_specific_flags}")
# 解析payload
payload_start = header_size * 4
payload = response[payload_start:]
if message_type == 9: # SERVER_FULL_RESPONSE
print("收到SERVER_FULL_RESPONSE")
if len(payload) >= 4:
# 解析event
event = int.from_bytes(payload[:4], 'big')
print(f"Event: {event}")
# 解析session_id
if len(payload) >= 8:
session_id_len = int.from_bytes(payload[4:8], 'big')
if len(payload) >= 8 + session_id_len:
session_id = payload[8:8+session_id_len].decode()
print(f"Session ID: {session_id}")
# 解析payload size和data
if len(payload) >= 12 + session_id_len:
payload_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big')
payload_data = payload[12+session_id_len:12+session_id_len+payload_size]
print(f"Payload size: {payload_size}")
# 如果包含音频数据,保存到文件
if len(payload_data) > 0:
print(f"收到数据: {len(payload_data)} 字节")
# 保存原始音频数据
with open('response_audio.raw', 'wb') as f:
f.write(payload_data)
print("音频数据已保存到 response_audio.raw")
# 尝试解析JSON数据
try:
import json
json_data = json.loads(payload_data.decode('utf-8'))
print(f"JSON数据: {json_data}")
# 如果是语音识别任务开始,继续等待音频响应
if 'asr_task_id' in json_data:
print("语音识别任务开始,继续等待音频响应...")
try:
# 等待音频响应
audio_response = await asyncio.wait_for(self.ws.recv(), timeout=20.0)
print(f"收到音频响应,长度: {len(audio_response)}")
# 解析音频响应
if len(audio_response) >= 4:
audio_version = audio_response[0] >> 4
audio_header_size = audio_response[0] & 0x0f
audio_message_type = audio_response[1] >> 4
audio_flags = audio_response[1] & 0x0f
print(f"音频响应协议信息: version={audio_version}, header_size={audio_header_size}, message_type={audio_message_type}, flags={audio_flags}")
if audio_message_type == 9: # SERVER_FULL_RESPONSE (包含TTS音频)
audio_payload_start = audio_header_size * 4
audio_payload = audio_response[audio_payload_start:]
if len(audio_payload) >= 12:
# 解析event和session_id
audio_event = int.from_bytes(audio_payload[:4], 'big')
audio_session_len = int.from_bytes(audio_payload[4:8], 'big')
audio_session = audio_payload[8:8+audio_session_len].decode()
audio_data_size = int.from_bytes(audio_payload[8+audio_session_len:12+audio_session_len], 'big')
audio_data = audio_payload[12+audio_session_len:12+audio_session_len+audio_data_size]
print(f"音频Event: {audio_event}")
print(f"音频数据大小: {audio_data_size}")
if audio_data_size > 0:
# 保存原始音频数据
with open('tts_response_audio.raw', 'wb') as f:
f.write(audio_data)
print(f"TTS音频数据已保存到 tts_response_audio.raw")
# 尝试解析音频数据可能是JSON或GZIP压缩的音频
try:
# 首先尝试解压缩
import gzip
decompressed_audio = gzip.decompress(audio_data)
print(f"解压缩后音频数据大小: {len(decompressed_audio)}")
with open('tts_response_audio_decompressed.raw', 'wb') as f:
f.write(decompressed_audio)
print("解压缩的音频数据已保存")
# 创建WAV文件供树莓派播放
import wave
import struct
# 豆包返回的音频是24000Hz, 16-bit, 单声道
sample_rate = 24000
channels = 1
sampwidth = 2 # 16-bit = 2 bytes
with wave.open('tts_response.wav', 'wb') as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(sampwidth)
wav_file.setframerate(sample_rate)
wav_file.writeframes(decompressed_audio)
print("已创建WAV文件: tts_response.wav")
print(f"音频参数: {sample_rate}Hz, {channels}通道, {sampwidth*8}-bit")
except Exception as audio_e:
print(f"音频数据处理失败: {audio_e}")
# 如果解压缩失败,直接保存原始数据
with open('tts_response_audio_original.raw', 'wb') as f:
f.write(audio_data)
elif audio_message_type == 11: # SERVER_ACK
print("收到SERVER_ACK音频响应")
# 处理SERVER_ACK格式的音频响应
audio_payload_start = audio_header_size * 4
audio_payload = audio_response[audio_payload_start:]
print(f"音频payload长度: {len(audio_payload)}")
with open('tts_response_ack.raw', 'wb') as f:
f.write(audio_payload)
except asyncio.TimeoutError:
print("等待音频响应超时")
except Exception as json_e:
print(f"解析JSON失败: {json_e}")
# 如果不是JSON可能是音频数据直接保存
with open('response_audio.raw', 'wb') as f:
f.write(payload_data)
elif message_type == 11: # SERVER_ACK
print("收到SERVER_ACK响应")
elif message_type == 15: # SERVER_ERROR_RESPONSE
print("收到错误响应")
if len(response) > 8:
error_code = int.from_bytes(response[4:8], 'big')
print(f"错误代码: {error_code}")
except Exception as e:
print(f"解析响应失败: {e}")
import traceback
traceback.print_exc()
except asyncio.TimeoutError:
print("等待响应超时")
except websockets.exceptions.ConnectionClosed as e:
print(f"连接关闭: {e}")
except Exception as e:
print(f"发送音频请求失败: {e}")
raise
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
await self.ws.close()
except:
pass
print("连接已关闭")
async def main():
"""测试函数"""
config = DoubaoConfig()
client = DoubaoClient(config)
try:
await client.connect()
await client.test_audio_request()
except Exception as e:
print(f"测试失败: {e}")
import traceback
traceback.print_exc()
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(main())