412 lines
21 KiB
Python
412 lines
21 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
豆包音频处理模块 - 基于原始代码的测试版本
|
||
直接使用原始豆包代码的核心逻辑
|
||
"""
|
||
|
||
import asyncio
|
||
import gzip
|
||
import json
|
||
import uuid
|
||
import wave
|
||
import struct
|
||
import time
|
||
import os
|
||
from typing import Dict, Any, Optional
|
||
import websockets
|
||
|
||
|
||
# 直接复制原始豆包代码的协议常量
|
||
PROTOCOL_VERSION = 0b0001
|
||
CLIENT_FULL_REQUEST = 0b0001
|
||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||
SERVER_FULL_RESPONSE = 0b1001
|
||
SERVER_ACK = 0b1011
|
||
SERVER_ERROR_RESPONSE = 0b1111
|
||
|
||
NO_SEQUENCE = 0b0000
|
||
POS_SEQUENCE = 0b0001
|
||
MSG_WITH_EVENT = 0b0100
|
||
|
||
NO_SERIALIZATION = 0b0000
|
||
JSON = 0b0001
|
||
GZIP = 0b0001
|
||
|
||
|
||
def generate_header(
|
||
version=PROTOCOL_VERSION,
|
||
message_type=CLIENT_FULL_REQUEST,
|
||
message_type_specific_flags=MSG_WITH_EVENT,
|
||
serial_method=JSON,
|
||
compression_type=GZIP,
|
||
reserved_data=0x00,
|
||
extension_header=bytes()
|
||
):
|
||
"""直接复制原始豆包代码的generate_header函数"""
|
||
header = bytearray()
|
||
header_size = int(len(extension_header) / 4) + 1
|
||
header.append((version << 4) | header_size)
|
||
header.append((message_type << 4) | message_type_specific_flags)
|
||
header.append((serial_method << 4) | compression_type)
|
||
header.append(reserved_data)
|
||
header.extend(extension_header)
|
||
return header
|
||
|
||
|
||
class DoubaoConfig:
|
||
"""豆包配置"""
|
||
|
||
def __init__(self):
|
||
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
|
||
self.app_id = "8718217928"
|
||
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
|
||
self.app_key = "PlgvMymc7f3tQnJ6"
|
||
self.resource_id = "volc.speech.dialog"
|
||
|
||
def get_headers(self) -> Dict[str, str]:
|
||
"""获取请求头"""
|
||
return {
|
||
"X-Api-App-ID": self.app_id,
|
||
"X-Api-Access-Key": self.access_key,
|
||
"X-Api-Resource-Id": self.resource_id,
|
||
"X-Api-App-Key": self.app_key,
|
||
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||
}
|
||
|
||
|
||
class DoubaoClient:
|
||
"""豆包客户端 - 基于原始代码"""
|
||
|
||
def __init__(self, config: DoubaoConfig):
|
||
self.config = config
|
||
self.session_id = str(uuid.uuid4())
|
||
self.ws = None
|
||
self.log_id = ""
|
||
|
||
async def connect(self) -> None:
|
||
"""建立WebSocket连接"""
|
||
print(f"连接豆包服务器: {self.config.base_url}")
|
||
|
||
try:
|
||
self.ws = await websockets.connect(
|
||
self.config.base_url,
|
||
additional_headers=self.config.get_headers(),
|
||
ping_interval=None
|
||
)
|
||
|
||
# 获取log_id
|
||
if hasattr(self.ws, 'response_headers'):
|
||
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
|
||
elif hasattr(self.ws, 'headers'):
|
||
self.log_id = self.ws.headers.get("X-Tt-Logid")
|
||
|
||
print(f"连接成功, log_id: {self.log_id}")
|
||
|
||
# 发送StartConnection请求
|
||
await self._send_start_connection()
|
||
|
||
# 发送StartSession请求
|
||
await self._send_start_session()
|
||
|
||
except Exception as e:
|
||
print(f"连接失败: {e}")
|
||
raise
|
||
|
||
async def _send_start_connection(self) -> None:
|
||
"""发送StartConnection请求"""
|
||
print("发送StartConnection请求...")
|
||
request = bytearray(generate_header())
|
||
request.extend(int(1).to_bytes(4, 'big'))
|
||
|
||
payload_bytes = b"{}"
|
||
payload_bytes = gzip.compress(payload_bytes)
|
||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||
request.extend(payload_bytes)
|
||
|
||
await self.ws.send(request)
|
||
response = await self.ws.recv()
|
||
print(f"StartConnection响应长度: {len(response)}")
|
||
|
||
async def _send_start_session(self) -> None:
|
||
"""发送StartSession请求"""
|
||
print("发送StartSession请求...")
|
||
session_config = {
|
||
"asr": {
|
||
"extra": {
|
||
"end_smooth_window_ms": 1500,
|
||
},
|
||
},
|
||
"tts": {
|
||
"speaker": "zh_female_vv_jupiter_bigtts",
|
||
"audio_config": {
|
||
"channel": 1,
|
||
"format": "pcm",
|
||
"sample_rate": 24000
|
||
},
|
||
},
|
||
"dialog": {
|
||
"bot_name": "豆包",
|
||
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
|
||
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
|
||
"location": {"city": "北京"},
|
||
"extra": {
|
||
"strict_audit": False,
|
||
"audit_response": "支持客户自定义安全审核回复话术。",
|
||
"recv_timeout": 30,
|
||
"input_mod": "audio",
|
||
},
|
||
},
|
||
}
|
||
|
||
request = bytearray(generate_header())
|
||
request.extend(int(100).to_bytes(4, 'big'))
|
||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||
request.extend(self.session_id.encode())
|
||
|
||
payload_bytes = json.dumps(session_config).encode()
|
||
payload_bytes = gzip.compress(payload_bytes)
|
||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||
request.extend(payload_bytes)
|
||
|
||
await self.ws.send(request)
|
||
response = await self.ws.recv()
|
||
print(f"StartSession响应长度: {len(response)}")
|
||
|
||
# 等待一会确保会话完全建立
|
||
await asyncio.sleep(1.0)
|
||
|
||
async def task_request(self, audio: bytes) -> None:
|
||
"""直接复制原始豆包代码的task_request方法"""
|
||
task_request = bytearray(
|
||
generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
|
||
serial_method=NO_SERIALIZATION))
|
||
task_request.extend(int(200).to_bytes(4, 'big'))
|
||
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
|
||
task_request.extend(str.encode(self.session_id))
|
||
payload_bytes = gzip.compress(audio)
|
||
task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
|
||
task_request.extend(payload_bytes)
|
||
await self.ws.send(task_request)
|
||
|
||
async def test_audio_request(self) -> None:
|
||
"""测试音频请求"""
|
||
print("测试音频请求...")
|
||
|
||
# 读取真实的录音文件
|
||
try:
|
||
import wave
|
||
with wave.open("recording_20250920_135137.wav", 'rb') as wf:
|
||
# 读取前10秒的音频数据(16000采样率 * 10秒 = 160000帧)
|
||
total_frames = wf.getnframes()
|
||
frames_to_read = min(total_frames, 160000) # 最多10秒
|
||
small_audio = wf.readframes(frames_to_read)
|
||
print(f"读取真实音频数据: {len(small_audio)} 字节")
|
||
print(f"音频参数: 采样率={wf.getframerate()}, 通道数={wf.getnchannels()}, 采样宽度={wf.getsampwidth()}")
|
||
print(f"总帧数: {total_frames}, 读取帧数: {frames_to_read}")
|
||
except Exception as e:
|
||
print(f"读取音频文件失败: {e}")
|
||
# 如果读取失败,使用静音数据
|
||
small_audio = b'\x00' * 3200
|
||
|
||
print(f"音频数据大小: {len(small_audio)}")
|
||
|
||
try:
|
||
# 发送完整的音频数据块
|
||
print(f"发送完整的音频数据块...")
|
||
await self.task_request(small_audio)
|
||
print(f"音频数据块发送成功")
|
||
|
||
print("等待语音识别响应...")
|
||
|
||
# 等待更长时间的响应(语音识别可能需要更长时间)
|
||
response = await asyncio.wait_for(self.ws.recv(), timeout=15.0)
|
||
print(f"收到响应,长度: {len(response)}")
|
||
|
||
# 解析响应
|
||
try:
|
||
if len(response) >= 4:
|
||
protocol_version = response[0] >> 4
|
||
header_size = response[0] & 0x0f
|
||
message_type = response[1] >> 4
|
||
message_type_specific_flags = response[1] & 0x0f
|
||
print(f"响应协议信息: version={protocol_version}, header_size={header_size}, message_type={message_type}, flags={message_type_specific_flags}")
|
||
|
||
# 解析payload
|
||
payload_start = header_size * 4
|
||
payload = response[payload_start:]
|
||
|
||
if message_type == 9: # SERVER_FULL_RESPONSE
|
||
print("收到SERVER_FULL_RESPONSE!")
|
||
if len(payload) >= 4:
|
||
# 解析event
|
||
event = int.from_bytes(payload[:4], 'big')
|
||
print(f"Event: {event}")
|
||
|
||
# 解析session_id
|
||
if len(payload) >= 8:
|
||
session_id_len = int.from_bytes(payload[4:8], 'big')
|
||
if len(payload) >= 8 + session_id_len:
|
||
session_id = payload[8:8+session_id_len].decode()
|
||
print(f"Session ID: {session_id}")
|
||
|
||
# 解析payload size和data
|
||
if len(payload) >= 12 + session_id_len:
|
||
payload_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big')
|
||
payload_data = payload[12+session_id_len:12+session_id_len+payload_size]
|
||
print(f"Payload size: {payload_size}")
|
||
|
||
# 如果包含音频数据,保存到文件
|
||
if len(payload_data) > 0:
|
||
print(f"收到数据: {len(payload_data)} 字节")
|
||
# 保存原始音频数据
|
||
with open('response_audio.raw', 'wb') as f:
|
||
f.write(payload_data)
|
||
print("音频数据已保存到 response_audio.raw")
|
||
|
||
# 尝试解析JSON数据
|
||
try:
|
||
import json
|
||
json_data = json.loads(payload_data.decode('utf-8'))
|
||
print(f"JSON数据: {json_data}")
|
||
|
||
# 如果是语音识别任务开始,继续等待音频响应
|
||
if 'asr_task_id' in json_data:
|
||
print("语音识别任务开始,继续等待音频响应...")
|
||
try:
|
||
# 等待音频响应
|
||
audio_response = await asyncio.wait_for(self.ws.recv(), timeout=20.0)
|
||
print(f"收到音频响应,长度: {len(audio_response)}")
|
||
|
||
# 解析音频响应
|
||
if len(audio_response) >= 4:
|
||
audio_version = audio_response[0] >> 4
|
||
audio_header_size = audio_response[0] & 0x0f
|
||
audio_message_type = audio_response[1] >> 4
|
||
audio_flags = audio_response[1] & 0x0f
|
||
print(f"音频响应协议信息: version={audio_version}, header_size={audio_header_size}, message_type={audio_message_type}, flags={audio_flags}")
|
||
|
||
if audio_message_type == 9: # SERVER_FULL_RESPONSE (包含TTS音频)
|
||
audio_payload_start = audio_header_size * 4
|
||
audio_payload = audio_response[audio_payload_start:]
|
||
|
||
if len(audio_payload) >= 12:
|
||
# 解析event和session_id
|
||
audio_event = int.from_bytes(audio_payload[:4], 'big')
|
||
audio_session_len = int.from_bytes(audio_payload[4:8], 'big')
|
||
audio_session = audio_payload[8:8+audio_session_len].decode()
|
||
audio_data_size = int.from_bytes(audio_payload[8+audio_session_len:12+audio_session_len], 'big')
|
||
audio_data = audio_payload[12+audio_session_len:12+audio_session_len+audio_data_size]
|
||
|
||
print(f"音频Event: {audio_event}")
|
||
print(f"音频数据大小: {audio_data_size}")
|
||
|
||
if audio_data_size > 0:
|
||
# 保存原始音频数据
|
||
with open('tts_response_audio.raw', 'wb') as f:
|
||
f.write(audio_data)
|
||
print(f"TTS音频数据已保存到 tts_response_audio.raw")
|
||
|
||
# 尝试解析音频数据(可能是JSON或GZIP压缩的音频)
|
||
try:
|
||
# 首先尝试解压缩
|
||
import gzip
|
||
decompressed_audio = gzip.decompress(audio_data)
|
||
print(f"解压缩后音频数据大小: {len(decompressed_audio)}")
|
||
with open('tts_response_audio_decompressed.raw', 'wb') as f:
|
||
f.write(decompressed_audio)
|
||
print("解压缩的音频数据已保存")
|
||
|
||
# 创建WAV文件供树莓派播放
|
||
import wave
|
||
import struct
|
||
|
||
# 豆包返回的音频是24000Hz, 16-bit, 单声道
|
||
sample_rate = 24000
|
||
channels = 1
|
||
sampwidth = 2 # 16-bit = 2 bytes
|
||
|
||
with wave.open('tts_response.wav', 'wb') as wav_file:
|
||
wav_file.setnchannels(channels)
|
||
wav_file.setsampwidth(sampwidth)
|
||
wav_file.setframerate(sample_rate)
|
||
wav_file.writeframes(decompressed_audio)
|
||
|
||
print("已创建WAV文件: tts_response.wav")
|
||
print(f"音频参数: {sample_rate}Hz, {channels}通道, {sampwidth*8}-bit")
|
||
|
||
except Exception as audio_e:
|
||
print(f"音频数据处理失败: {audio_e}")
|
||
# 如果解压缩失败,直接保存原始数据
|
||
with open('tts_response_audio_original.raw', 'wb') as f:
|
||
f.write(audio_data)
|
||
|
||
elif audio_message_type == 11: # SERVER_ACK
|
||
print("收到SERVER_ACK音频响应")
|
||
# 处理SERVER_ACK格式的音频响应
|
||
audio_payload_start = audio_header_size * 4
|
||
audio_payload = audio_response[audio_payload_start:]
|
||
print(f"音频payload长度: {len(audio_payload)}")
|
||
with open('tts_response_ack.raw', 'wb') as f:
|
||
f.write(audio_payload)
|
||
|
||
except asyncio.TimeoutError:
|
||
print("等待音频响应超时")
|
||
|
||
except Exception as json_e:
|
||
print(f"解析JSON失败: {json_e}")
|
||
# 如果不是JSON,可能是音频数据,直接保存
|
||
with open('response_audio.raw', 'wb') as f:
|
||
f.write(payload_data)
|
||
|
||
elif message_type == 11: # SERVER_ACK
|
||
print("收到SERVER_ACK响应!")
|
||
|
||
elif message_type == 15: # SERVER_ERROR_RESPONSE
|
||
print("收到错误响应")
|
||
if len(response) > 8:
|
||
error_code = int.from_bytes(response[4:8], 'big')
|
||
print(f"错误代码: {error_code}")
|
||
|
||
except Exception as e:
|
||
print(f"解析响应失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
except asyncio.TimeoutError:
|
||
print("等待响应超时")
|
||
except websockets.exceptions.ConnectionClosed as e:
|
||
print(f"连接关闭: {e}")
|
||
except Exception as e:
|
||
print(f"发送音频请求失败: {e}")
|
||
raise
|
||
|
||
async def close(self) -> None:
|
||
"""关闭连接"""
|
||
if self.ws:
|
||
try:
|
||
await self.ws.close()
|
||
except:
|
||
pass
|
||
print("连接已关闭")
|
||
|
||
|
||
async def main():
|
||
"""测试函数"""
|
||
config = DoubaoConfig()
|
||
client = DoubaoClient(config)
|
||
|
||
try:
|
||
await client.connect()
|
||
await client.test_audio_request()
|
||
except Exception as e:
|
||
print(f"测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
finally:
|
||
await client.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main()) |