Local-Voice/doubao_final_test.py
2025-09-20 14:35:54 +08:00

425 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块 - 简化测试版本
专门测试完整的音频上传和TTS音频下载流程
"""
import asyncio
import gzip
import json
import uuid
import wave
import struct
import time
import os
from typing import Dict, Any, Optional
import websockets
# 直接复制原始豆包代码的协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
MSG_WITH_EVENT = 0b0100
NO_SERIALIZATION = 0b0000
JSON = 0b0001
GZIP = 0b0001
def generate_header(
version=PROTOCOL_VERSION,
message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON,
compression_type=GZIP,
reserved_data=0x00,
extension_header=bytes()
):
"""直接复制原始豆包代码的generate_header函数"""
header = bytearray()
header_size = int(len(extension_header) / 4) + 1
header.append((version << 4) | header_size)
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(reserved_data)
header.extend(extension_header)
return header
class DoubaoConfig:
"""豆包配置"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
class DoubaoClient:
"""豆包客户端 - 基于原始代码"""
def __init__(self, config: DoubaoConfig):
self.config = config
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"连接豆包服务器: {self.config.base_url}")
try:
self.ws = await websockets.connect(
self.config.base_url,
additional_headers=self.config.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
except Exception as e:
print(f"连接失败: {e}")
raise
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
print("发送StartConnection请求...")
request = bytearray(generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartConnection响应长度: {len(response)}")
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
print("发送StartSession请求...")
session_config = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 30,
"input_mod": "audio",
},
},
}
request = bytearray(generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartSession响应长度: {len(response)}")
# 等待一会确保会话完全建立
await asyncio.sleep(1.0)
async def task_request(self, audio: bytes) -> None:
"""直接复制原始豆包代码的task_request方法"""
task_request = bytearray(
generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
serial_method=NO_SERIALIZATION))
task_request.extend(int(200).to_bytes(4, 'big'))
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
task_request.extend(str.encode(self.session_id))
payload_bytes = gzip.compress(audio)
task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
task_request.extend(payload_bytes)
await self.ws.send(task_request)
async def test_full_dialog(self) -> None:
"""测试完整对话流程"""
print("开始完整对话测试...")
# 读取真实的录音文件
try:
import wave
with wave.open("recording_20250920_135137.wav", 'rb') as wf:
# 读取前5秒的音频数据
total_frames = wf.getnframes()
frames_to_read = min(total_frames, 80000) # 5秒
audio_data = wf.readframes(frames_to_read)
print(f"读取真实音频数据: {len(audio_data)} 字节")
print(f"音频参数: 采样率={wf.getframerate()}, 通道数={wf.getnchannels()}, 采样宽度={wf.getsampwidth()}")
except Exception as e:
print(f"读取音频文件失败: {e}")
return
print(f"音频数据大小: {len(audio_data)}")
try:
# 发送音频数据
print("发送音频数据...")
await self.task_request(audio_data)
print("音频数据发送成功")
# 等待语音识别响应
print("等待语音识别响应...")
response = await asyncio.wait_for(self.ws.recv(), timeout=15.0)
print(f"收到ASR响应长度: {len(response)}")
# 解析ASR响应
if len(response) >= 4:
protocol_version = response[0] >> 4
header_size = response[0] & 0x0f
message_type = response[1] >> 4
flags = response[1] & 0x0f
print(f"ASR响应协议: version={protocol_version}, header_size={header_size}, message_type={message_type}, flags={flags}")
if message_type == 9: # SERVER_FULL_RESPONSE
payload_start = header_size * 4
payload = response[payload_start:]
if len(payload) >= 4:
event = int.from_bytes(payload[:4], 'big')
print(f"ASR Event: {event}")
if len(payload) >= 8:
session_id_len = int.from_bytes(payload[4:8], 'big')
if len(payload) >= 8 + session_id_len:
session_id = payload[8:8+session_id_len].decode()
print(f"Session ID: {session_id}")
if len(payload) >= 12 + session_id_len:
payload_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big')
payload_data = payload[12+session_id_len:12+session_id_len+payload_size]
print(f"Payload size: {payload_size}")
# 解析ASR结果
try:
asr_result = json.loads(payload_data.decode('utf-8'))
print(f"ASR结果: {asr_result}")
# 如果有识别结果,提取文本
if 'results' in asr_result and asr_result['results']:
text = asr_result['results'][0].get('text', '')
print(f"识别文本: {text}")
except Exception as e:
print(f"解析ASR结果失败: {e}")
# 持续等待TTS音频响应
print("开始持续等待TTS音频响应...")
response_count = 0
max_responses = 10
while response_count < max_responses:
try:
print(f"等待第 {response_count + 1} 个响应...")
tts_response = await asyncio.wait_for(self.ws.recv(), timeout=30.0)
print(f"收到响应 {response_count + 1},长度: {len(tts_response)}")
# 解析响应
if len(tts_response) >= 4:
tts_version = tts_response[0] >> 4
tts_header_size = tts_response[0] & 0x0f
tts_message_type = tts_response[1] >> 4
tts_flags = tts_response[1] & 0x0f
print(f"响应协议: version={tts_version}, header_size={tts_header_size}, message_type={tts_message_type}, flags={tts_flags}")
if tts_message_type == 11: # SERVER_ACK (包含TTS音频)
tts_payload_start = tts_header_size * 4
tts_payload = tts_response[tts_payload_start:]
if len(tts_payload) >= 12:
tts_event = int.from_bytes(tts_payload[:4], 'big')
tts_session_len = int.from_bytes(tts_payload[4:8], 'big')
tts_session = tts_payload[8:8+tts_session_len].decode()
tts_audio_size = int.from_bytes(tts_payload[8+tts_session_len:12+tts_session_len], 'big')
tts_audio_data = tts_payload[12+tts_session_len:12+tts_session_len+tts_audio_size]
print(f"Event: {tts_event}")
print(f"音频数据大小: {tts_audio_size}")
if tts_audio_size > 0:
print("找到TTS音频数据")
# 尝试解压缩TTS音频
try:
decompressed_tts = gzip.decompress(tts_audio_data)
print(f"解压缩后TTS音频大小: {len(decompressed_tts)}")
# 创建WAV文件
sample_rate = 24000
channels = 1
sampwidth = 2
with wave.open(f'tts_response_{response_count}.wav', 'wb') as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(sampwidth)
wav_file.setframerate(sample_rate)
wav_file.writeframes(decompressed_tts)
print(f"成功创建TTS WAV文件: tts_response_{response_count}.wav")
print(f"音频参数: {sample_rate}Hz, {channels}通道, {sampwidth*8}-bit")
# 显示文件信息
if os.path.exists(f'tts_response_{response_count}.wav'):
file_size = os.path.getsize(f'tts_response_{response_count}.wav')
duration = file_size / (sample_rate * channels * sampwidth)
print(f"WAV文件大小: {file_size} 字节")
print(f"音频时长: {duration:.2f}")
# 成功获取音频,退出循环
break
except Exception as tts_e:
print(f"TTS音频解压缩失败: {tts_e}")
# 保存原始数据
with open(f'tts_response_audio_{response_count}.raw', 'wb') as f:
f.write(tts_audio_data)
print(f"原始TTS音频数据已保存到 tts_response_audio_{response_count}.raw")
elif tts_message_type == 9: # SERVER_FULL_RESPONSE
tts_payload_start = tts_header_size * 4
tts_payload = tts_response[tts_payload_start:]
if len(tts_payload) >= 4:
event = int.from_bytes(tts_payload[:4], 'big')
print(f"Event: {event}")
if event in [451, 359]: # ASR结果或TTS结束
# 解析payload
if len(tts_payload) >= 8:
session_id_len = int.from_bytes(tts_payload[4:8], 'big')
if len(tts_payload) >= 8 + session_id_len:
session_id = tts_payload[8:8+session_id_len].decode()
if len(tts_payload) >= 12 + session_id_len:
payload_size = int.from_bytes(tts_payload[8+session_id_len:12+session_id_len], 'big')
payload_data = tts_payload[12+session_id_len:12+session_id_len+payload_size]
try:
json_data = json.loads(payload_data.decode('utf-8'))
print(f"JSON数据: {json_data}")
# 如果是ASR结果
if 'results' in json_data:
text = json_data['results'][0].get('text', '')
print(f"识别文本: {text}")
# 如果是TTS结束标记
if event == 359:
print("TTS响应结束")
break
except Exception as e:
print(f"解析JSON失败: {e}")
# 保存原始数据
with open(f'tts_response_{response_count}.raw', 'wb') as f:
f.write(payload_data)
# 保存完整响应用于调试
with open(f'tts_response_full_{response_count}.raw', 'wb') as f:
f.write(tts_response)
print(f"完整响应已保存到 tts_response_full_{response_count}.raw")
response_count += 1
except asyncio.TimeoutError:
print(f"等待第 {response_count + 1} 个响应超时")
break
except websockets.exceptions.ConnectionClosed:
print("连接已关闭")
break
print(f"共收到 {response_count} 个响应")
except asyncio.TimeoutError:
print("等待响应超时")
except websockets.exceptions.ConnectionClosed as e:
print(f"连接关闭: {e}")
except Exception as e:
print(f"测试失败: {e}")
import traceback
traceback.print_exc()
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
await self.ws.close()
except:
pass
print("连接已关闭")
async def main():
"""测试函数"""
config = DoubaoConfig()
client = DoubaoClient(config)
try:
await client.connect()
await client.test_full_dialog()
except Exception as e:
print(f"测试失败: {e}")
import traceback
traceback.print_exc()
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(main())