412 lines
16 KiB
Python
412 lines
16 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
豆包音频处理模块 - 最终简化版本
|
||
实现音频文件上传和TTS音频下载的完整流程
|
||
"""
|
||
|
||
import asyncio
|
||
import gzip
|
||
import json
|
||
import uuid
|
||
import wave
|
||
import time
|
||
import os
|
||
from typing import Dict, Any, Optional
|
||
import websockets
|
||
|
||
|
||
# 协议常量
|
||
PROTOCOL_VERSION = 0b0001
|
||
CLIENT_FULL_REQUEST = 0b0001
|
||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||
SERVER_FULL_RESPONSE = 0b1001
|
||
SERVER_ACK = 0b1011
|
||
SERVER_ERROR_RESPONSE = 0b1111
|
||
|
||
NO_SEQUENCE = 0b0000
|
||
MSG_WITH_EVENT = 0b0100
|
||
|
||
NO_SERIALIZATION = 0b0000
|
||
JSON = 0b0001
|
||
GZIP = 0b0001
|
||
|
||
|
||
def generate_header(
|
||
version=PROTOCOL_VERSION,
|
||
message_type=CLIENT_FULL_REQUEST,
|
||
message_type_specific_flags=MSG_WITH_EVENT,
|
||
serial_method=JSON,
|
||
compression_type=GZIP,
|
||
reserved_data=0x00,
|
||
extension_header=bytes()
|
||
):
|
||
"""生成协议头"""
|
||
header = bytearray()
|
||
header_size = int(len(extension_header) / 4) + 1
|
||
header.append((version << 4) | header_size)
|
||
header.append((message_type << 4) | message_type_specific_flags)
|
||
header.append((serial_method << 4) | compression_type)
|
||
header.append(reserved_data)
|
||
header.extend(extension_header)
|
||
return header
|
||
|
||
|
||
class DoubaoClient:
|
||
"""豆包客户端"""
|
||
|
||
def __init__(self):
|
||
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
|
||
self.app_id = "8718217928"
|
||
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
|
||
self.app_key = "PlgvMymc7f3tQnJ6"
|
||
self.resource_id = "volc.speech.dialog"
|
||
self.session_id = str(uuid.uuid4())
|
||
self.ws = None
|
||
self.log_id = ""
|
||
|
||
def get_headers(self) -> Dict[str, str]:
|
||
"""获取请求头"""
|
||
return {
|
||
"X-Api-App-ID": self.app_id,
|
||
"X-Api-Access-Key": self.access_key,
|
||
"X-Api-Resource-Id": self.resource_id,
|
||
"X-Api-App-Key": self.app_key,
|
||
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||
}
|
||
|
||
async def connect(self) -> None:
|
||
"""建立WebSocket连接"""
|
||
print(f"连接豆包服务器: {self.base_url}")
|
||
|
||
try:
|
||
self.ws = await websockets.connect(
|
||
self.base_url,
|
||
additional_headers=self.get_headers(),
|
||
ping_interval=None
|
||
)
|
||
|
||
# 获取log_id
|
||
if hasattr(self.ws, 'response_headers'):
|
||
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
|
||
elif hasattr(self.ws, 'headers'):
|
||
self.log_id = self.ws.headers.get("X-Tt-Logid")
|
||
|
||
print(f"连接成功, log_id: {self.log_id}")
|
||
|
||
# 发送StartConnection请求
|
||
await self._send_start_connection()
|
||
|
||
# 发送StartSession请求
|
||
await self._send_start_session()
|
||
|
||
except Exception as e:
|
||
print(f"连接失败: {e}")
|
||
raise
|
||
|
||
async def _send_start_connection(self) -> None:
|
||
"""发送StartConnection请求"""
|
||
print("发送StartConnection请求...")
|
||
request = bytearray(generate_header())
|
||
request.extend(int(1).to_bytes(4, 'big'))
|
||
|
||
payload_bytes = b"{}"
|
||
payload_bytes = gzip.compress(payload_bytes)
|
||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||
request.extend(payload_bytes)
|
||
|
||
await self.ws.send(request)
|
||
response = await self.ws.recv()
|
||
print(f"StartConnection响应长度: {len(response)}")
|
||
|
||
async def _send_start_session(self) -> None:
|
||
"""发送StartSession请求"""
|
||
print("发送StartSession请求...")
|
||
session_config = {
|
||
"asr": {
|
||
"extra": {
|
||
"end_smooth_window_ms": 1500,
|
||
},
|
||
},
|
||
"tts": {
|
||
"speaker": "zh_female_vv_jupiter_bigtts",
|
||
"audio_config": {
|
||
"channel": 1,
|
||
"format": "pcm",
|
||
"sample_rate": 24000
|
||
},
|
||
},
|
||
"dialog": {
|
||
"bot_name": "豆包",
|
||
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
|
||
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
|
||
"location": {"city": "北京"},
|
||
"extra": {
|
||
"strict_audit": False,
|
||
"audit_response": "支持客户自定义安全审核回复话术。",
|
||
"recv_timeout": 30,
|
||
"input_mod": "audio",
|
||
},
|
||
},
|
||
}
|
||
|
||
request = bytearray(generate_header())
|
||
request.extend(int(100).to_bytes(4, 'big'))
|
||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||
request.extend(self.session_id.encode())
|
||
|
||
payload_bytes = json.dumps(session_config).encode()
|
||
payload_bytes = gzip.compress(payload_bytes)
|
||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||
request.extend(payload_bytes)
|
||
|
||
await self.ws.send(request)
|
||
response = await self.ws.recv()
|
||
print(f"StartSession响应长度: {len(response)}")
|
||
|
||
# 等待一会确保会话完全建立
|
||
await asyncio.sleep(1.0)
|
||
|
||
async def task_request(self, audio: bytes) -> None:
|
||
"""发送音频数据"""
|
||
task_request = bytearray(
|
||
generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
|
||
serial_method=NO_SERIALIZATION))
|
||
task_request.extend(int(200).to_bytes(4, 'big'))
|
||
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
|
||
task_request.extend(str.encode(self.session_id))
|
||
payload_bytes = gzip.compress(audio)
|
||
task_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
||
task_request.extend(payload_bytes)
|
||
await self.ws.send(task_request)
|
||
|
||
def parse_response(self, response):
|
||
"""解析响应"""
|
||
if len(response) < 4:
|
||
return None
|
||
|
||
protocol_version = response[0] >> 4
|
||
header_size = response[0] & 0x0f
|
||
message_type = response[1] >> 4
|
||
flags = response[1] & 0x0f
|
||
|
||
payload_start = header_size * 4
|
||
payload = response[payload_start:]
|
||
|
||
result = {
|
||
'protocol_version': protocol_version,
|
||
'header_size': header_size,
|
||
'message_type': message_type,
|
||
'flags': flags,
|
||
'payload': payload,
|
||
'payload_size': len(payload)
|
||
}
|
||
|
||
# 解析payload
|
||
if len(payload) >= 4:
|
||
result['event'] = int.from_bytes(payload[:4], 'big')
|
||
|
||
if len(payload) >= 8:
|
||
session_id_len = int.from_bytes(payload[4:8], 'big')
|
||
if len(payload) >= 8 + session_id_len:
|
||
result['session_id'] = payload[8:8+session_id_len].decode()
|
||
|
||
if len(payload) >= 12 + session_id_len:
|
||
data_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big')
|
||
result['data_size'] = data_size
|
||
result['data'] = payload[12+session_id_len:12+session_id_len+data_size]
|
||
|
||
# 尝试解析JSON数据
|
||
try:
|
||
result['json_data'] = json.loads(result['data'].decode('utf-8'))
|
||
except:
|
||
pass
|
||
|
||
return result
|
||
|
||
async def process_audio_file(self, input_file: str, output_file: str) -> bool:
|
||
"""处理音频文件:上传并获得TTS响应"""
|
||
print(f"开始处理音频文件: {input_file}")
|
||
|
||
try:
|
||
# 读取输入音频文件
|
||
with wave.open(input_file, 'rb') as wf:
|
||
audio_data = wf.readframes(wf.getnframes())
|
||
print(f"读取音频数据: {len(audio_data)} 字节")
|
||
print(f"音频参数: {wf.getframerate()}Hz, {wf.getnchannels()}通道, {wf.getsampwidth()*8}-bit")
|
||
|
||
# 发送音频数据
|
||
print("发送音频数据...")
|
||
await self.task_request(audio_data)
|
||
print("音频数据发送成功")
|
||
|
||
# 接收响应序列
|
||
print("开始接收响应...")
|
||
audio_chunks = []
|
||
response_count = 0
|
||
max_responses = 20
|
||
|
||
while response_count < max_responses:
|
||
try:
|
||
response = await asyncio.wait_for(self.ws.recv(), timeout=30.0)
|
||
response_count += 1
|
||
|
||
parsed = self.parse_response(response)
|
||
if not parsed:
|
||
continue
|
||
|
||
print(f"响应 {response_count}: message_type={parsed['message_type']}, event={parsed.get('event', 'N/A')}, size={parsed['payload_size']}")
|
||
|
||
# 处理不同类型的响应
|
||
if parsed['message_type'] == 11: # SERVER_ACK - 可能包含音频
|
||
if 'data' in parsed and parsed['data_size'] > 0:
|
||
audio_chunks.append(parsed['data'])
|
||
print(f"收集到音频块: {parsed['data_size']} 字节")
|
||
|
||
elif parsed['message_type'] == 9: # SERVER_FULL_RESPONSE
|
||
event = parsed.get('event', 0)
|
||
|
||
if event == 350: # TTS开始
|
||
print("TTS音频生成开始")
|
||
elif event == 359: # TTS结束
|
||
print("TTS音频生成结束")
|
||
break
|
||
elif event == 451: # ASR结果
|
||
if 'json_data' in parsed and 'results' in parsed['json_data']:
|
||
text = parsed['json_data']['results'][0].get('text', '')
|
||
print(f"语音识别结果: {text}")
|
||
elif event == 550: # TTS音频数据
|
||
if 'data' in parsed and parsed['data_size'] > 0:
|
||
# 检查是否是JSON(音频元数据)还是实际音频数据
|
||
try:
|
||
json.loads(parsed['data'].decode('utf-8'))
|
||
print("收到TTS音频元数据")
|
||
except:
|
||
# 不是JSON,可能是音频数据
|
||
audio_chunks.append(parsed['data'])
|
||
print(f"收集到TTS音频块: {parsed['data_size']} 字节")
|
||
|
||
except asyncio.TimeoutError:
|
||
print(f"等待响应 {response_count + 1} 超时")
|
||
break
|
||
except websockets.exceptions.ConnectionClosed:
|
||
print("连接已关闭")
|
||
break
|
||
|
||
print(f"共收到 {response_count} 个响应,收集到 {len(audio_chunks)} 个音频块")
|
||
|
||
# 合并音频数据
|
||
if audio_chunks:
|
||
combined_audio = b''.join(audio_chunks)
|
||
print(f"合并后的音频数据: {len(combined_audio)} 字节")
|
||
|
||
# 检查是否是GZIP压缩数据
|
||
try:
|
||
decompressed = gzip.decompress(combined_audio)
|
||
print(f"解压缩后音频数据: {len(decompressed)} 字节")
|
||
audio_to_write = decompressed
|
||
except:
|
||
print("音频数据不是GZIP压缩格式,直接使用原始数据")
|
||
audio_to_write = combined_audio
|
||
|
||
# 创建输出WAV文件
|
||
try:
|
||
# 豆包返回的音频是32位浮点格式,需要转换为16位整数
|
||
import struct
|
||
|
||
# 检查音频数据长度是否是4的倍数(32位浮点)
|
||
if len(audio_to_write) % 4 != 0:
|
||
print(f"警告:音频数据长度 {len(audio_to_write)} 不是4的倍数,截断到最近的倍数")
|
||
audio_to_write = audio_to_write[:len(audio_to_write) // 4 * 4]
|
||
|
||
# 将32位浮点转换为16位整数
|
||
float_count = len(audio_to_write) // 4
|
||
int16_data = bytearray(float_count * 2)
|
||
|
||
for i in range(float_count):
|
||
# 读取32位浮点数(小端序)
|
||
float_value = struct.unpack('<f', audio_to_write[i*4:i*4+4])[0]
|
||
|
||
# 将浮点数限制在[-1.0, 1.0]范围内
|
||
float_value = max(-1.0, min(1.0, float_value))
|
||
|
||
# 转换为16位整数
|
||
int16_value = int(float_value * 32767)
|
||
|
||
# 写入16位整数(小端序)
|
||
int16_data[i*2:i*2+2] = struct.pack('<h', int16_value)
|
||
|
||
# 创建WAV文件
|
||
with wave.open(output_file, 'wb') as wav_file:
|
||
wav_file.setnchannels(1)
|
||
wav_file.setsampwidth(2)
|
||
wav_file.setframerate(24000)
|
||
wav_file.writeframes(int16_data)
|
||
|
||
print(f"成功创建输出文件: {output_file}")
|
||
print(f"音频转换: {float_count} 个32位浮点样本 -> {len(int16_data)//2} 个16位整数样本")
|
||
|
||
# 显示文件信息
|
||
if os.path.exists(output_file):
|
||
file_size = os.path.getsize(output_file)
|
||
duration = file_size / (24000 * 1 * 2)
|
||
print(f"输出文件大小: {file_size} 字节,时长: {duration:.2f} 秒")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"创建WAV文件失败: {e}")
|
||
# 保存原始数据
|
||
with open(output_file + '.raw', 'wb') as f:
|
||
f.write(audio_to_write)
|
||
print(f"原始音频数据已保存到: {output_file}.raw")
|
||
return False
|
||
else:
|
||
print("未收到音频数据")
|
||
return False
|
||
|
||
except Exception as e:
|
||
print(f"处理音频文件失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
async def close(self) -> None:
|
||
"""关闭连接"""
|
||
if self.ws:
|
||
try:
|
||
await self.ws.close()
|
||
except:
|
||
pass
|
||
print("连接已关闭")
|
||
|
||
|
||
async def main():
|
||
"""主函数"""
|
||
client = DoubaoClient()
|
||
|
||
try:
|
||
await client.connect()
|
||
|
||
# 处理录音文件
|
||
input_file = "recording_20250920_135137.wav"
|
||
output_file = "tts_output.wav"
|
||
|
||
success = await client.process_audio_file(input_file, output_file)
|
||
|
||
if success:
|
||
print("音频处理成功!")
|
||
else:
|
||
print("音频处理失败")
|
||
|
||
except Exception as e:
|
||
print(f"程序失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
finally:
|
||
await client.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main()) |