470 lines
17 KiB
Python
470 lines
17 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
豆包音频处理模块
|
||
简化版WebSocket API,支持音频文件上传和返回音频处理
|
||
"""
|
||
|
||
import asyncio
|
||
import gzip
|
||
import json
|
||
import uuid
|
||
import wave
|
||
import struct
|
||
import time
|
||
import os
|
||
from typing import Dict, Any, Optional
|
||
import websockets
|
||
|
||
|
||
class DoubaoConfig:
|
||
"""豆包配置"""
|
||
|
||
def __init__(self):
|
||
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
|
||
self.app_id = "8718217928"
|
||
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
|
||
self.app_key = "PlgvMymc7f3tQnJ6"
|
||
self.resource_id = "volc.speech.dialog"
|
||
|
||
def get_headers(self) -> Dict[str, str]:
|
||
"""获取请求头"""
|
||
return {
|
||
"X-Api-App-ID": self.app_id,
|
||
"X-Api-Access-Key": self.access_key,
|
||
"X-Api-Resource-Id": self.resource_id,
|
||
"X-Api-App-Key": self.app_key,
|
||
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||
}
|
||
|
||
|
||
class DoubaoProtocol:
|
||
"""豆包协议处理"""
|
||
|
||
# 协议常量
|
||
PROTOCOL_VERSION = 0b0001
|
||
CLIENT_FULL_REQUEST = 0b0001
|
||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||
SERVER_FULL_RESPONSE = 0b1001
|
||
SERVER_ACK = 0b1011
|
||
SERVER_ERROR_RESPONSE = 0b1111
|
||
|
||
NO_SEQUENCE = 0b0000
|
||
POS_SEQUENCE = 0b0001
|
||
MSG_WITH_EVENT = 0b0100
|
||
JSON = 0b0001
|
||
NO_SERIALIZATION = 0b0000
|
||
GZIP = 0b0001
|
||
|
||
@classmethod
|
||
def generate_header(cls, message_type=CLIENT_FULL_REQUEST,
|
||
message_type_specific_flags=MSG_WITH_EVENT,
|
||
serial_method=JSON, compression_type=GZIP) -> bytes:
|
||
"""生成协议头"""
|
||
header = bytearray()
|
||
header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size
|
||
header.append((message_type << 4) | message_type_specific_flags)
|
||
header.append((serial_method << 4) | compression_type)
|
||
header.append(0x00) # reserved
|
||
return bytes(header)
|
||
|
||
@classmethod
|
||
def parse_response(cls, res: bytes) -> Dict[str, Any]:
|
||
"""解析响应"""
|
||
if isinstance(res, str):
|
||
return {}
|
||
|
||
protocol_version = res[0] >> 4
|
||
header_size = res[0] & 0x0f
|
||
message_type = res[1] >> 4
|
||
message_type_specific_flags = res[1] & 0x0f
|
||
serialization_method = res[2] >> 4
|
||
message_compression = res[2] & 0x0f
|
||
|
||
payload = res[header_size * 4:]
|
||
result = {}
|
||
|
||
if message_type == cls.SERVER_FULL_RESPONSE or message_type == cls.SERVER_ACK:
|
||
result['message_type'] = 'SERVER_FULL_RESPONSE'
|
||
if message_type == cls.SERVER_ACK:
|
||
result['message_type'] = 'SERVER_ACK'
|
||
|
||
start = 0
|
||
if message_type_specific_flags & cls.MSG_WITH_EVENT:
|
||
result['event'] = int.from_bytes(payload[:4], "big", signed=False)
|
||
start += 4
|
||
|
||
payload = payload[start:]
|
||
session_id_size = int.from_bytes(payload[:4], "big", signed=True)
|
||
session_id = payload[4:session_id_size+4]
|
||
result['session_id'] = str(session_id)
|
||
payload = payload[4 + session_id_size:]
|
||
|
||
payload_size = int.from_bytes(payload[:4], "big", signed=False)
|
||
payload_msg = payload[4:]
|
||
result['payload_size'] = payload_size
|
||
|
||
if payload_msg:
|
||
if message_compression == cls.GZIP:
|
||
payload_msg = gzip.decompress(payload_msg)
|
||
if serialization_method == cls.JSON:
|
||
payload_msg = json.loads(str(payload_msg, "utf-8"))
|
||
result['payload_msg'] = payload_msg
|
||
|
||
elif message_type == cls.SERVER_ERROR_RESPONSE:
|
||
code = int.from_bytes(payload[:4], "big", signed=False)
|
||
result['code'] = code
|
||
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||
result['payload_size'] = payload_size
|
||
|
||
return result
|
||
|
||
|
||
class AudioProcessor:
|
||
"""音频处理器"""
|
||
|
||
@staticmethod
|
||
def read_wav_file(file_path: str) -> tuple:
|
||
"""读取WAV文件,返回音频数据和参数"""
|
||
with wave.open(file_path, 'rb') as wf:
|
||
# 获取音频参数
|
||
channels = wf.getnchannels()
|
||
sampwidth = wf.getsampwidth()
|
||
framerate = wf.getframerate()
|
||
nframes = wf.getnframes()
|
||
|
||
# 读取音频数据
|
||
audio_data = wf.readframes(nframes)
|
||
|
||
return audio_data, {
|
||
'channels': channels,
|
||
'sampwidth': sampwidth,
|
||
'framerate': framerate,
|
||
'nframes': nframes
|
||
}
|
||
|
||
@staticmethod
|
||
def create_wav_file(audio_data: bytes, output_path: str,
|
||
sample_rate: int = 24000, channels: int = 1,
|
||
sampwidth: int = 2) -> None:
|
||
"""创建WAV文件,适配树莓派播放"""
|
||
with wave.open(output_path, 'wb') as wf:
|
||
wf.setnchannels(channels)
|
||
wf.setsampwidth(sampwidth)
|
||
wf.setframerate(sample_rate)
|
||
wf.writeframes(audio_data)
|
||
|
||
|
||
class DoubaoClient:
|
||
"""豆包客户端"""
|
||
|
||
def __init__(self, config: DoubaoConfig):
|
||
self.config = config
|
||
self.session_id = str(uuid.uuid4())
|
||
self.ws = None
|
||
self.log_id = ""
|
||
|
||
async def connect(self) -> None:
|
||
"""建立WebSocket连接"""
|
||
print(f"连接豆包服务器: {self.config.base_url}")
|
||
|
||
self.ws = await websockets.connect(
|
||
self.config.base_url,
|
||
additional_headers=self.config.get_headers(),
|
||
ping_interval=None
|
||
)
|
||
|
||
# 获取log_id
|
||
if hasattr(self.ws, 'response_headers'):
|
||
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
|
||
elif hasattr(self.ws, 'headers'):
|
||
self.log_id = self.ws.headers.get("X-Tt-Logid")
|
||
|
||
print(f"连接成功, log_id: {self.log_id}")
|
||
|
||
# 发送StartConnection请求
|
||
await self._send_start_connection()
|
||
|
||
# 发送StartSession请求
|
||
await self._send_start_session()
|
||
|
||
async def _send_start_connection(self) -> None:
|
||
"""发送StartConnection请求"""
|
||
request = bytearray(DoubaoProtocol.generate_header())
|
||
request.extend(int(1).to_bytes(4, 'big'))
|
||
|
||
payload_bytes = b"{}"
|
||
payload_bytes = gzip.compress(payload_bytes)
|
||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||
request.extend(payload_bytes)
|
||
|
||
await self.ws.send(request)
|
||
response = await self.ws.recv()
|
||
parsed_response = DoubaoProtocol.parse_response(response)
|
||
print(f"StartConnection响应: {parsed_response}")
|
||
|
||
async def _send_start_session(self) -> None:
|
||
"""发送StartSession请求"""
|
||
session_config = {
|
||
"asr": {
|
||
"extra": {
|
||
"end_smooth_window_ms": 1500,
|
||
},
|
||
},
|
||
"tts": {
|
||
"speaker": "zh_female_vv_jupiter_bigtts",
|
||
"audio_config": {
|
||
"channel": 1,
|
||
"format": "pcm",
|
||
"sample_rate": 24000
|
||
},
|
||
},
|
||
"dialog": {
|
||
"bot_name": "豆包",
|
||
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
|
||
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
|
||
"location": {"city": "北京"},
|
||
"extra": {
|
||
"strict_audit": False,
|
||
"audit_response": "支持客户自定义安全审核回复话术。",
|
||
"recv_timeout": 10,
|
||
"input_mod": "audio_file", # 使用音频文件模式
|
||
},
|
||
},
|
||
}
|
||
|
||
request = bytearray(DoubaoProtocol.generate_header())
|
||
request.extend(int(100).to_bytes(4, 'big'))
|
||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||
request.extend(self.session_id.encode())
|
||
|
||
payload_bytes = json.dumps(session_config).encode()
|
||
payload_bytes = gzip.compress(payload_bytes)
|
||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||
request.extend(payload_bytes)
|
||
|
||
await self.ws.send(request)
|
||
response = await self.ws.recv()
|
||
parsed_response = DoubaoProtocol.parse_response(response)
|
||
print(f"StartSession响应: {parsed_response}")
|
||
|
||
async def send_audio_file(self, file_path: str) -> bytes:
|
||
"""发送音频文件并返回响应音频"""
|
||
print(f"处理音频文件: {file_path}")
|
||
|
||
# 读取音频文件
|
||
audio_data, audio_info = AudioProcessor.read_wav_file(file_path)
|
||
print(f"音频参数: {audio_info}")
|
||
|
||
# 计算分块大小(200ms)
|
||
chunk_size = int(audio_info['framerate'] * audio_info['channels'] *
|
||
audio_info['sampwidth'] * 0.2)
|
||
|
||
# 分块发送音频数据
|
||
total_chunks = (len(audio_data) + chunk_size - 1) // chunk_size
|
||
print(f"开始分块发送音频,共 {total_chunks} 块")
|
||
|
||
received_audio = b""
|
||
error_count = 0
|
||
|
||
for i in range(0, len(audio_data), chunk_size):
|
||
chunk = audio_data[i:i + chunk_size]
|
||
is_last = (i + chunk_size >= len(audio_data))
|
||
|
||
# 发送音频块
|
||
await self._send_audio_chunk(chunk, is_last)
|
||
|
||
# 接收响应
|
||
try:
|
||
response = await asyncio.wait_for(self.ws.recv(), timeout=2.0)
|
||
parsed_response = DoubaoProtocol.parse_response(response)
|
||
|
||
# 检查是否是错误响应
|
||
if 'code' in parsed_response and parsed_response['code'] != 0:
|
||
print(f"服务器返回错误: {parsed_response}")
|
||
error_count += 1
|
||
if error_count > 3:
|
||
raise Exception(f"服务器连续返回错误: {parsed_response}")
|
||
continue
|
||
|
||
# 处理音频响应
|
||
if (parsed_response.get('message_type') == 'SERVER_ACK' and
|
||
isinstance(parsed_response.get('payload_msg'), bytes)):
|
||
audio_chunk = parsed_response['payload_msg']
|
||
received_audio += audio_chunk
|
||
print(f"接收到音频数据块,大小: {len(audio_chunk)} 字节")
|
||
|
||
# 检查会话状态
|
||
event = parsed_response.get('event')
|
||
if event in [359, 152, 153]: # 这些事件表示会话相关状态
|
||
print(f"会话事件: {event}")
|
||
if event in [152, 153]: # 会话结束
|
||
print("检测到会话结束事件")
|
||
break
|
||
|
||
except asyncio.TimeoutError:
|
||
print("等待响应超时,继续发送")
|
||
|
||
# 模拟实时发送的延迟
|
||
await asyncio.sleep(0.05)
|
||
|
||
print("音频文件发送完成")
|
||
return received_audio
|
||
|
||
async def _send_audio_chunk(self, audio_data: bytes, is_last: bool = False) -> None:
|
||
"""发送音频块"""
|
||
request = bytearray(
|
||
DoubaoProtocol.generate_header(
|
||
message_type=DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST,
|
||
message_type_specific_flags=DoubaoProtocol.NO_SEQUENCE,
|
||
serial_method=DoubaoProtocol.NO_SERIALIZATION, # 音频数据不需要序列化
|
||
compression_type=DoubaoProtocol.GZIP
|
||
)
|
||
)
|
||
request.extend(int(200).to_bytes(4, 'big'))
|
||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||
request.extend(self.session_id.encode())
|
||
|
||
# 压缩音频数据
|
||
compressed_audio = gzip.compress(audio_data)
|
||
request.extend(len(compressed_audio).to_bytes(4, 'big')) # payload size(4 bytes)
|
||
request.extend(compressed_audio)
|
||
|
||
await self.ws.send(request)
|
||
|
||
async def close(self) -> None:
|
||
"""关闭连接"""
|
||
if self.ws:
|
||
try:
|
||
# 发送FinishSession
|
||
await self._send_finish_session()
|
||
|
||
# 发送FinishConnection
|
||
await self._send_finish_connection()
|
||
|
||
except Exception as e:
|
||
print(f"关闭会话时出错: {e}")
|
||
finally:
|
||
# 确保WebSocket连接关闭
|
||
try:
|
||
await self.ws.close()
|
||
except:
|
||
pass
|
||
print("连接已关闭")
|
||
|
||
async def _send_finish_session(self) -> None:
|
||
"""发送FinishSession请求"""
|
||
request = bytearray(DoubaoProtocol.generate_header())
|
||
request.extend(int(102).to_bytes(4, 'big'))
|
||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||
request.extend(self.session_id.encode())
|
||
|
||
payload_bytes = b"{}"
|
||
payload_bytes = gzip.compress(payload_bytes)
|
||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||
request.extend(payload_bytes)
|
||
|
||
await self.ws.send(request)
|
||
|
||
async def _send_finish_connection(self) -> None:
|
||
"""发送FinishConnection请求"""
|
||
request = bytearray(DoubaoProtocol.generate_header())
|
||
request.extend(int(2).to_bytes(4, 'big'))
|
||
|
||
payload_bytes = b"{}"
|
||
payload_bytes = gzip.compress(payload_bytes)
|
||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||
request.extend(payload_bytes)
|
||
|
||
await self.ws.send(request)
|
||
response = await self.ws.recv()
|
||
parsed_response = DoubaoProtocol.parse_response(response)
|
||
print(f"FinishConnection响应: {parsed_response}")
|
||
|
||
|
||
class DoubaoProcessor:
|
||
"""豆包音频处理器"""
|
||
|
||
def __init__(self):
|
||
self.config = DoubaoConfig()
|
||
self.client = DoubaoClient(self.config)
|
||
|
||
async def process_audio_file(self, input_file: str, output_file: str = None) -> str:
|
||
"""处理音频文件
|
||
|
||
Args:
|
||
input_file: 输入音频文件路径
|
||
output_file: 输出音频文件路径,如果为None则自动生成
|
||
|
||
Returns:
|
||
输出音频文件路径
|
||
"""
|
||
if not os.path.exists(input_file):
|
||
raise FileNotFoundError(f"音频文件不存在: {input_file}")
|
||
|
||
# 生成输出文件名
|
||
if output_file is None:
|
||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||
output_file = f"doubao_response_{timestamp}.wav"
|
||
|
||
try:
|
||
# 连接豆包服务器
|
||
await self.client.connect()
|
||
|
||
# 等待一会确保会话建立
|
||
await asyncio.sleep(0.5)
|
||
|
||
# 发送音频文件并获取响应
|
||
received_audio = await self.client.send_audio_file(input_file)
|
||
|
||
if received_audio:
|
||
print(f"总共接收到音频数据: {len(received_audio)} 字节")
|
||
# 转换为WAV格式保存(适配树莓派播放)
|
||
AudioProcessor.create_wav_file(
|
||
received_audio,
|
||
output_file,
|
||
sample_rate=24000, # 豆包返回的音频采样率
|
||
channels=1,
|
||
sampwidth=2 # 16-bit
|
||
)
|
||
print(f"响应音频已保存到: {output_file}")
|
||
|
||
# 显示文件信息
|
||
file_size = os.path.getsize(output_file)
|
||
print(f"输出文件大小: {file_size} 字节")
|
||
|
||
else:
|
||
print("警告: 未接收到音频响应")
|
||
|
||
return output_file
|
||
|
||
except Exception as e:
|
||
print(f"处理音频文件时出错: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
raise
|
||
finally:
|
||
await self.client.close()
|
||
|
||
|
||
async def main():
|
||
"""测试函数"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="豆包音频处理测试")
|
||
parser.add_argument("--input", type=str, required=True, help="输入音频文件路径")
|
||
parser.add_argument("--output", type=str, help="输出音频文件路径")
|
||
|
||
args = parser.parse_args()
|
||
|
||
processor = DoubaoProcessor()
|
||
try:
|
||
output_file = await processor.process_audio_file(args.input, args.output)
|
||
print(f"处理完成,输出文件: {output_file}")
|
||
except Exception as e:
|
||
print(f"处理失败: {e}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main()) |