Local-Voice/doubao_simple.py
2025-09-20 14:35:54 +08:00

412 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块 - 最终简化版本
实现音频文件上传和TTS音频下载的完整流程
"""
import asyncio
import gzip
import json
import uuid
import wave
import time
import os
from typing import Dict, Any, Optional
import websockets
# 协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
MSG_WITH_EVENT = 0b0100
NO_SERIALIZATION = 0b0000
JSON = 0b0001
GZIP = 0b0001
def generate_header(
version=PROTOCOL_VERSION,
message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON,
compression_type=GZIP,
reserved_data=0x00,
extension_header=bytes()
):
"""生成协议头"""
header = bytearray()
header_size = int(len(extension_header) / 4) + 1
header.append((version << 4) | header_size)
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(reserved_data)
header.extend(extension_header)
return header
class DoubaoClient:
"""豆包客户端"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"连接豆包服务器: {self.base_url}")
try:
self.ws = await websockets.connect(
self.base_url,
additional_headers=self.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
except Exception as e:
print(f"连接失败: {e}")
raise
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
print("发送StartConnection请求...")
request = bytearray(generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartConnection响应长度: {len(response)}")
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
print("发送StartSession请求...")
session_config = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 30,
"input_mod": "audio",
},
},
}
request = bytearray(generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartSession响应长度: {len(response)}")
# 等待一会确保会话完全建立
await asyncio.sleep(1.0)
async def task_request(self, audio: bytes) -> None:
"""发送音频数据"""
task_request = bytearray(
generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
serial_method=NO_SERIALIZATION))
task_request.extend(int(200).to_bytes(4, 'big'))
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
task_request.extend(str.encode(self.session_id))
payload_bytes = gzip.compress(audio)
task_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
task_request.extend(payload_bytes)
await self.ws.send(task_request)
def parse_response(self, response):
"""解析响应"""
if len(response) < 4:
return None
protocol_version = response[0] >> 4
header_size = response[0] & 0x0f
message_type = response[1] >> 4
flags = response[1] & 0x0f
payload_start = header_size * 4
payload = response[payload_start:]
result = {
'protocol_version': protocol_version,
'header_size': header_size,
'message_type': message_type,
'flags': flags,
'payload': payload,
'payload_size': len(payload)
}
# 解析payload
if len(payload) >= 4:
result['event'] = int.from_bytes(payload[:4], 'big')
if len(payload) >= 8:
session_id_len = int.from_bytes(payload[4:8], 'big')
if len(payload) >= 8 + session_id_len:
result['session_id'] = payload[8:8+session_id_len].decode()
if len(payload) >= 12 + session_id_len:
data_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big')
result['data_size'] = data_size
result['data'] = payload[12+session_id_len:12+session_id_len+data_size]
# 尝试解析JSON数据
try:
result['json_data'] = json.loads(result['data'].decode('utf-8'))
except:
pass
return result
async def process_audio_file(self, input_file: str, output_file: str) -> bool:
"""处理音频文件上传并获得TTS响应"""
print(f"开始处理音频文件: {input_file}")
try:
# 读取输入音频文件
with wave.open(input_file, 'rb') as wf:
audio_data = wf.readframes(wf.getnframes())
print(f"读取音频数据: {len(audio_data)} 字节")
print(f"音频参数: {wf.getframerate()}Hz, {wf.getnchannels()}通道, {wf.getsampwidth()*8}-bit")
# 发送音频数据
print("发送音频数据...")
await self.task_request(audio_data)
print("音频数据发送成功")
# 接收响应序列
print("开始接收响应...")
audio_chunks = []
response_count = 0
max_responses = 20
while response_count < max_responses:
try:
response = await asyncio.wait_for(self.ws.recv(), timeout=30.0)
response_count += 1
parsed = self.parse_response(response)
if not parsed:
continue
print(f"响应 {response_count}: message_type={parsed['message_type']}, event={parsed.get('event', 'N/A')}, size={parsed['payload_size']}")
# 处理不同类型的响应
if parsed['message_type'] == 11: # SERVER_ACK - 可能包含音频
if 'data' in parsed and parsed['data_size'] > 0:
audio_chunks.append(parsed['data'])
print(f"收集到音频块: {parsed['data_size']} 字节")
elif parsed['message_type'] == 9: # SERVER_FULL_RESPONSE
event = parsed.get('event', 0)
if event == 350: # TTS开始
print("TTS音频生成开始")
elif event == 359: # TTS结束
print("TTS音频生成结束")
break
elif event == 451: # ASR结果
if 'json_data' in parsed and 'results' in parsed['json_data']:
text = parsed['json_data']['results'][0].get('text', '')
print(f"语音识别结果: {text}")
elif event == 550: # TTS音频数据
if 'data' in parsed and parsed['data_size'] > 0:
# 检查是否是JSON音频元数据还是实际音频数据
try:
json.loads(parsed['data'].decode('utf-8'))
print("收到TTS音频元数据")
except:
# 不是JSON可能是音频数据
audio_chunks.append(parsed['data'])
print(f"收集到TTS音频块: {parsed['data_size']} 字节")
except asyncio.TimeoutError:
print(f"等待响应 {response_count + 1} 超时")
break
except websockets.exceptions.ConnectionClosed:
print("连接已关闭")
break
print(f"共收到 {response_count} 个响应,收集到 {len(audio_chunks)} 个音频块")
# 合并音频数据
if audio_chunks:
combined_audio = b''.join(audio_chunks)
print(f"合并后的音频数据: {len(combined_audio)} 字节")
# 检查是否是GZIP压缩数据
try:
decompressed = gzip.decompress(combined_audio)
print(f"解压缩后音频数据: {len(decompressed)} 字节")
audio_to_write = decompressed
except:
print("音频数据不是GZIP压缩格式直接使用原始数据")
audio_to_write = combined_audio
# 创建输出WAV文件
try:
# 豆包返回的音频是32位浮点格式需要转换为16位整数
import struct
# 检查音频数据长度是否是4的倍数32位浮点
if len(audio_to_write) % 4 != 0:
print(f"警告:音频数据长度 {len(audio_to_write)} 不是4的倍数截断到最近的倍数")
audio_to_write = audio_to_write[:len(audio_to_write) // 4 * 4]
# 将32位浮点转换为16位整数
float_count = len(audio_to_write) // 4
int16_data = bytearray(float_count * 2)
for i in range(float_count):
# 读取32位浮点数小端序
float_value = struct.unpack('<f', audio_to_write[i*4:i*4+4])[0]
# 将浮点数限制在[-1.0, 1.0]范围内
float_value = max(-1.0, min(1.0, float_value))
# 转换为16位整数
int16_value = int(float_value * 32767)
# 写入16位整数小端序
int16_data[i*2:i*2+2] = struct.pack('<h', int16_value)
# 创建WAV文件
with wave.open(output_file, 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(24000)
wav_file.writeframes(int16_data)
print(f"成功创建输出文件: {output_file}")
print(f"音频转换: {float_count} 个32位浮点样本 -> {len(int16_data)//2} 个16位整数样本")
# 显示文件信息
if os.path.exists(output_file):
file_size = os.path.getsize(output_file)
duration = file_size / (24000 * 1 * 2)
print(f"输出文件大小: {file_size} 字节,时长: {duration:.2f}")
return True
except Exception as e:
print(f"创建WAV文件失败: {e}")
# 保存原始数据
with open(output_file + '.raw', 'wb') as f:
f.write(audio_to_write)
print(f"原始音频数据已保存到: {output_file}.raw")
return False
else:
print("未收到音频数据")
return False
except Exception as e:
print(f"处理音频文件失败: {e}")
import traceback
traceback.print_exc()
return False
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
await self.ws.close()
except:
pass
print("连接已关闭")
async def main():
"""主函数"""
client = DoubaoClient()
try:
await client.connect()
# 处理录音文件
input_file = "recording_20250920_135137.wav"
output_file = "tts_output.wav"
success = await client.process_audio_file(input_file, output_file)
if success:
print("音频处理成功!")
else:
print("音频处理失败")
except Exception as e:
print(f"程序失败: {e}")
import traceback
traceback.print_exc()
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(main())