config
This commit is contained in:
parent
43879961a2
commit
dbdeeeefcb
470
doubao.py
470
doubao.py
@ -1,470 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
豆包音频处理模块
|
||||
简化版WebSocket API,支持音频文件上传和返回音频处理
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import gzip
|
||||
import json
|
||||
import uuid
|
||||
import wave
|
||||
import struct
|
||||
import time
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
import websockets
|
||||
|
||||
|
||||
class DoubaoConfig:
|
||||
"""豆包配置"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
|
||||
self.app_id = "8718217928"
|
||||
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
|
||||
self.app_key = "PlgvMymc7f3tQnJ6"
|
||||
self.resource_id = "volc.speech.dialog"
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
"""获取请求头"""
|
||||
return {
|
||||
"X-Api-App-ID": self.app_id,
|
||||
"X-Api-Access-Key": self.access_key,
|
||||
"X-Api-Resource-Id": self.resource_id,
|
||||
"X-Api-App-Key": self.app_key,
|
||||
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
|
||||
class DoubaoProtocol:
|
||||
"""豆包协议处理"""
|
||||
|
||||
# 协议常量
|
||||
PROTOCOL_VERSION = 0b0001
|
||||
CLIENT_FULL_REQUEST = 0b0001
|
||||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||
SERVER_FULL_RESPONSE = 0b1001
|
||||
SERVER_ACK = 0b1011
|
||||
SERVER_ERROR_RESPONSE = 0b1111
|
||||
|
||||
NO_SEQUENCE = 0b0000
|
||||
POS_SEQUENCE = 0b0001
|
||||
MSG_WITH_EVENT = 0b0100
|
||||
JSON = 0b0001
|
||||
NO_SERIALIZATION = 0b0000
|
||||
GZIP = 0b0001
|
||||
|
||||
@classmethod
|
||||
def generate_header(cls, message_type=CLIENT_FULL_REQUEST,
|
||||
message_type_specific_flags=MSG_WITH_EVENT,
|
||||
serial_method=JSON, compression_type=GZIP) -> bytes:
|
||||
"""生成协议头"""
|
||||
header = bytearray()
|
||||
header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size
|
||||
header.append((message_type << 4) | message_type_specific_flags)
|
||||
header.append((serial_method << 4) | compression_type)
|
||||
header.append(0x00) # reserved
|
||||
return bytes(header)
|
||||
|
||||
@classmethod
|
||||
def parse_response(cls, res: bytes) -> Dict[str, Any]:
|
||||
"""解析响应"""
|
||||
if isinstance(res, str):
|
||||
return {}
|
||||
|
||||
protocol_version = res[0] >> 4
|
||||
header_size = res[0] & 0x0f
|
||||
message_type = res[1] >> 4
|
||||
message_type_specific_flags = res[1] & 0x0f
|
||||
serialization_method = res[2] >> 4
|
||||
message_compression = res[2] & 0x0f
|
||||
|
||||
payload = res[header_size * 4:]
|
||||
result = {}
|
||||
|
||||
if message_type == cls.SERVER_FULL_RESPONSE or message_type == cls.SERVER_ACK:
|
||||
result['message_type'] = 'SERVER_FULL_RESPONSE'
|
||||
if message_type == cls.SERVER_ACK:
|
||||
result['message_type'] = 'SERVER_ACK'
|
||||
|
||||
start = 0
|
||||
if message_type_specific_flags & cls.MSG_WITH_EVENT:
|
||||
result['event'] = int.from_bytes(payload[:4], "big", signed=False)
|
||||
start += 4
|
||||
|
||||
payload = payload[start:]
|
||||
session_id_size = int.from_bytes(payload[:4], "big", signed=True)
|
||||
session_id = payload[4:session_id_size+4]
|
||||
result['session_id'] = str(session_id)
|
||||
payload = payload[4 + session_id_size:]
|
||||
|
||||
payload_size = int.from_bytes(payload[:4], "big", signed=False)
|
||||
payload_msg = payload[4:]
|
||||
result['payload_size'] = payload_size
|
||||
|
||||
if payload_msg:
|
||||
if message_compression == cls.GZIP:
|
||||
payload_msg = gzip.decompress(payload_msg)
|
||||
if serialization_method == cls.JSON:
|
||||
payload_msg = json.loads(str(payload_msg, "utf-8"))
|
||||
result['payload_msg'] = payload_msg
|
||||
|
||||
elif message_type == cls.SERVER_ERROR_RESPONSE:
|
||||
code = int.from_bytes(payload[:4], "big", signed=False)
|
||||
result['code'] = code
|
||||
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
result['payload_size'] = payload_size
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class AudioProcessor:
|
||||
"""音频处理器"""
|
||||
|
||||
@staticmethod
|
||||
def read_wav_file(file_path: str) -> tuple:
|
||||
"""读取WAV文件,返回音频数据和参数"""
|
||||
with wave.open(file_path, 'rb') as wf:
|
||||
# 获取音频参数
|
||||
channels = wf.getnchannels()
|
||||
sampwidth = wf.getsampwidth()
|
||||
framerate = wf.getframerate()
|
||||
nframes = wf.getnframes()
|
||||
|
||||
# 读取音频数据
|
||||
audio_data = wf.readframes(nframes)
|
||||
|
||||
return audio_data, {
|
||||
'channels': channels,
|
||||
'sampwidth': sampwidth,
|
||||
'framerate': framerate,
|
||||
'nframes': nframes
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_wav_file(audio_data: bytes, output_path: str,
|
||||
sample_rate: int = 24000, channels: int = 1,
|
||||
sampwidth: int = 2) -> None:
|
||||
"""创建WAV文件,适配树莓派播放"""
|
||||
with wave.open(output_path, 'wb') as wf:
|
||||
wf.setnchannels(channels)
|
||||
wf.setsampwidth(sampwidth)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio_data)
|
||||
|
||||
|
||||
class DoubaoClient:
|
||||
"""豆包客户端"""
|
||||
|
||||
def __init__(self, config: DoubaoConfig):
|
||||
self.config = config
|
||||
self.session_id = str(uuid.uuid4())
|
||||
self.ws = None
|
||||
self.log_id = ""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""建立WebSocket连接"""
|
||||
print(f"连接豆包服务器: {self.config.base_url}")
|
||||
|
||||
self.ws = await websockets.connect(
|
||||
self.config.base_url,
|
||||
additional_headers=self.config.get_headers(),
|
||||
ping_interval=None
|
||||
)
|
||||
|
||||
# 获取log_id
|
||||
if hasattr(self.ws, 'response_headers'):
|
||||
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
|
||||
elif hasattr(self.ws, 'headers'):
|
||||
self.log_id = self.ws.headers.get("X-Tt-Logid")
|
||||
|
||||
print(f"连接成功, log_id: {self.log_id}")
|
||||
|
||||
# 发送StartConnection请求
|
||||
await self._send_start_connection()
|
||||
|
||||
# 发送StartSession请求
|
||||
await self._send_start_session()
|
||||
|
||||
async def _send_start_connection(self) -> None:
|
||||
"""发送StartConnection请求"""
|
||||
request = bytearray(DoubaoProtocol.generate_header())
|
||||
request.extend(int(1).to_bytes(4, 'big'))
|
||||
|
||||
payload_bytes = b"{}"
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
request.extend(payload_bytes)
|
||||
|
||||
await self.ws.send(request)
|
||||
response = await self.ws.recv()
|
||||
parsed_response = DoubaoProtocol.parse_response(response)
|
||||
print(f"StartConnection响应: {parsed_response}")
|
||||
|
||||
async def _send_start_session(self) -> None:
|
||||
"""发送StartSession请求"""
|
||||
session_config = {
|
||||
"asr": {
|
||||
"extra": {
|
||||
"end_smooth_window_ms": 1500,
|
||||
},
|
||||
},
|
||||
"tts": {
|
||||
"speaker": "zh_female_vv_jupiter_bigtts",
|
||||
"audio_config": {
|
||||
"channel": 1,
|
||||
"format": "pcm",
|
||||
"sample_rate": 24000
|
||||
},
|
||||
},
|
||||
"dialog": {
|
||||
"bot_name": "豆包",
|
||||
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
|
||||
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
|
||||
"location": {"city": "北京"},
|
||||
"extra": {
|
||||
"strict_audit": False,
|
||||
"audit_response": "支持客户自定义安全审核回复话术。",
|
||||
"recv_timeout": 10,
|
||||
"input_mod": "audio_file", # 使用音频文件模式
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
request = bytearray(DoubaoProtocol.generate_header())
|
||||
request.extend(int(100).to_bytes(4, 'big'))
|
||||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||||
request.extend(self.session_id.encode())
|
||||
|
||||
payload_bytes = json.dumps(session_config).encode()
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
request.extend(payload_bytes)
|
||||
|
||||
await self.ws.send(request)
|
||||
response = await self.ws.recv()
|
||||
parsed_response = DoubaoProtocol.parse_response(response)
|
||||
print(f"StartSession响应: {parsed_response}")
|
||||
|
||||
async def send_audio_file(self, file_path: str) -> bytes:
|
||||
"""发送音频文件并返回响应音频"""
|
||||
print(f"处理音频文件: {file_path}")
|
||||
|
||||
# 读取音频文件
|
||||
audio_data, audio_info = AudioProcessor.read_wav_file(file_path)
|
||||
print(f"音频参数: {audio_info}")
|
||||
|
||||
# 计算分块大小(200ms)
|
||||
chunk_size = int(audio_info['framerate'] * audio_info['channels'] *
|
||||
audio_info['sampwidth'] * 0.2)
|
||||
|
||||
# 分块发送音频数据
|
||||
total_chunks = (len(audio_data) + chunk_size - 1) // chunk_size
|
||||
print(f"开始分块发送音频,共 {total_chunks} 块")
|
||||
|
||||
received_audio = b""
|
||||
error_count = 0
|
||||
|
||||
for i in range(0, len(audio_data), chunk_size):
|
||||
chunk = audio_data[i:i + chunk_size]
|
||||
is_last = (i + chunk_size >= len(audio_data))
|
||||
|
||||
# 发送音频块
|
||||
await self._send_audio_chunk(chunk, is_last)
|
||||
|
||||
# 接收响应
|
||||
try:
|
||||
response = await asyncio.wait_for(self.ws.recv(), timeout=2.0)
|
||||
parsed_response = DoubaoProtocol.parse_response(response)
|
||||
|
||||
# 检查是否是错误响应
|
||||
if 'code' in parsed_response and parsed_response['code'] != 0:
|
||||
print(f"服务器返回错误: {parsed_response}")
|
||||
error_count += 1
|
||||
if error_count > 3:
|
||||
raise Exception(f"服务器连续返回错误: {parsed_response}")
|
||||
continue
|
||||
|
||||
# 处理音频响应
|
||||
if (parsed_response.get('message_type') == 'SERVER_ACK' and
|
||||
isinstance(parsed_response.get('payload_msg'), bytes)):
|
||||
audio_chunk = parsed_response['payload_msg']
|
||||
received_audio += audio_chunk
|
||||
print(f"接收到音频数据块,大小: {len(audio_chunk)} 字节")
|
||||
|
||||
# 检查会话状态
|
||||
event = parsed_response.get('event')
|
||||
if event in [359, 152, 153]: # 这些事件表示会话相关状态
|
||||
print(f"会话事件: {event}")
|
||||
if event in [152, 153]: # 会话结束
|
||||
print("检测到会话结束事件")
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print("等待响应超时,继续发送")
|
||||
|
||||
# 模拟实时发送的延迟
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
print("音频文件发送完成")
|
||||
return received_audio
|
||||
|
||||
async def _send_audio_chunk(self, audio_data: bytes, is_last: bool = False) -> None:
|
||||
"""发送音频块"""
|
||||
request = bytearray(
|
||||
DoubaoProtocol.generate_header(
|
||||
message_type=DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST,
|
||||
message_type_specific_flags=DoubaoProtocol.NO_SEQUENCE,
|
||||
serial_method=DoubaoProtocol.NO_SERIALIZATION, # 音频数据不需要序列化
|
||||
compression_type=DoubaoProtocol.GZIP
|
||||
)
|
||||
)
|
||||
request.extend(int(200).to_bytes(4, 'big'))
|
||||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||||
request.extend(self.session_id.encode())
|
||||
|
||||
# 压缩音频数据
|
||||
compressed_audio = gzip.compress(audio_data)
|
||||
request.extend(len(compressed_audio).to_bytes(4, 'big')) # payload size(4 bytes)
|
||||
request.extend(compressed_audio)
|
||||
|
||||
await self.ws.send(request)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭连接"""
|
||||
if self.ws:
|
||||
try:
|
||||
# 发送FinishSession
|
||||
await self._send_finish_session()
|
||||
|
||||
# 发送FinishConnection
|
||||
await self._send_finish_connection()
|
||||
|
||||
except Exception as e:
|
||||
print(f"关闭会话时出错: {e}")
|
||||
finally:
|
||||
# 确保WebSocket连接关闭
|
||||
try:
|
||||
await self.ws.close()
|
||||
except:
|
||||
pass
|
||||
print("连接已关闭")
|
||||
|
||||
async def _send_finish_session(self) -> None:
|
||||
"""发送FinishSession请求"""
|
||||
request = bytearray(DoubaoProtocol.generate_header())
|
||||
request.extend(int(102).to_bytes(4, 'big'))
|
||||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||||
request.extend(self.session_id.encode())
|
||||
|
||||
payload_bytes = b"{}"
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
request.extend(payload_bytes)
|
||||
|
||||
await self.ws.send(request)
|
||||
|
||||
async def _send_finish_connection(self) -> None:
|
||||
"""发送FinishConnection请求"""
|
||||
request = bytearray(DoubaoProtocol.generate_header())
|
||||
request.extend(int(2).to_bytes(4, 'big'))
|
||||
|
||||
payload_bytes = b"{}"
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
request.extend(payload_bytes)
|
||||
|
||||
await self.ws.send(request)
|
||||
response = await self.ws.recv()
|
||||
parsed_response = DoubaoProtocol.parse_response(response)
|
||||
print(f"FinishConnection响应: {parsed_response}")
|
||||
|
||||
|
||||
class DoubaoProcessor:
|
||||
"""豆包音频处理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = DoubaoConfig()
|
||||
self.client = DoubaoClient(self.config)
|
||||
|
||||
async def process_audio_file(self, input_file: str, output_file: str = None) -> str:
|
||||
"""处理音频文件
|
||||
|
||||
Args:
|
||||
input_file: 输入音频文件路径
|
||||
output_file: 输出音频文件路径,如果为None则自动生成
|
||||
|
||||
Returns:
|
||||
输出音频文件路径
|
||||
"""
|
||||
if not os.path.exists(input_file):
|
||||
raise FileNotFoundError(f"音频文件不存在: {input_file}")
|
||||
|
||||
# 生成输出文件名
|
||||
if output_file is None:
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
output_file = f"doubao_response_{timestamp}.wav"
|
||||
|
||||
try:
|
||||
# 连接豆包服务器
|
||||
await self.client.connect()
|
||||
|
||||
# 等待一会确保会话建立
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# 发送音频文件并获取响应
|
||||
received_audio = await self.client.send_audio_file(input_file)
|
||||
|
||||
if received_audio:
|
||||
print(f"总共接收到音频数据: {len(received_audio)} 字节")
|
||||
# 转换为WAV格式保存(适配树莓派播放)
|
||||
AudioProcessor.create_wav_file(
|
||||
received_audio,
|
||||
output_file,
|
||||
sample_rate=24000, # 豆包返回的音频采样率
|
||||
channels=1,
|
||||
sampwidth=2 # 16-bit
|
||||
)
|
||||
print(f"响应音频已保存到: {output_file}")
|
||||
|
||||
# 显示文件信息
|
||||
file_size = os.path.getsize(output_file)
|
||||
print(f"输出文件大小: {file_size} 字节")
|
||||
|
||||
else:
|
||||
print("警告: 未接收到音频响应")
|
||||
|
||||
return output_file
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理音频文件时出错: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
await self.client.close()
|
||||
|
||||
|
||||
async def main():
|
||||
"""测试函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="豆包音频处理测试")
|
||||
parser.add_argument("--input", type=str, required=True, help="输入音频文件路径")
|
||||
parser.add_argument("--output", type=str, help="输出音频文件路径")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
processor = DoubaoProcessor()
|
||||
try:
|
||||
output_file = await processor.process_audio_file(args.input, args.output)
|
||||
print(f"处理完成,输出文件: {output_file}")
|
||||
except Exception as e:
|
||||
print(f"处理失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
540
doubao_debug.py
540
doubao_debug.py
@ -1,540 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
豆包音频处理模块 - 调试版本
|
||||
添加更多调试信息和错误处理
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import gzip
|
||||
import json
|
||||
import uuid
|
||||
import wave
|
||||
import struct
|
||||
import time
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
import websockets
|
||||
|
||||
|
||||
class DoubaoConfig:
|
||||
"""豆包配置"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
|
||||
self.app_id = "8718217928"
|
||||
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
|
||||
self.app_key = "PlgvMymc7f3tQnJ6"
|
||||
self.resource_id = "volc.speech.dialog"
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
"""获取请求头"""
|
||||
return {
|
||||
"X-Api-App-ID": self.app_id,
|
||||
"X-Api-Access-Key": self.access_key,
|
||||
"X-Api-Resource-Id": self.resource_id,
|
||||
"X-Api-App-Key": self.app_key,
|
||||
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
|
||||
class DoubaoProtocol:
|
||||
"""豆包协议处理"""
|
||||
|
||||
# 协议常量
|
||||
PROTOCOL_VERSION = 0b0001
|
||||
CLIENT_FULL_REQUEST = 0b0001
|
||||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||
SERVER_FULL_RESPONSE = 0b1001
|
||||
SERVER_ACK = 0b1011
|
||||
SERVER_ERROR_RESPONSE = 0b1111
|
||||
|
||||
NO_SEQUENCE = 0b0000
|
||||
POS_SEQUENCE = 0b0001
|
||||
MSG_WITH_EVENT = 0b0100
|
||||
JSON = 0b0001
|
||||
NO_SERIALIZATION = 0b0000
|
||||
GZIP = 0b0001
|
||||
|
||||
@classmethod
|
||||
def generate_header(cls, message_type=CLIENT_FULL_REQUEST,
|
||||
message_type_specific_flags=MSG_WITH_EVENT,
|
||||
serial_method=JSON, compression_type=GZIP) -> bytes:
|
||||
"""生成协议头"""
|
||||
header = bytearray()
|
||||
header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size
|
||||
header.append((message_type << 4) | message_type_specific_flags)
|
||||
header.append((serial_method << 4) | compression_type)
|
||||
header.append(0x00) # reserved
|
||||
return bytes(header)
|
||||
|
||||
@classmethod
|
||||
def parse_response(cls, res: bytes) -> Dict[str, Any]:
|
||||
"""解析响应"""
|
||||
if isinstance(res, str):
|
||||
return {}
|
||||
|
||||
try:
|
||||
protocol_version = res[0] >> 4
|
||||
header_size = res[0] & 0x0f
|
||||
message_type = res[1] >> 4
|
||||
message_type_specific_flags = res[1] & 0x0f
|
||||
serialization_method = res[2] >> 4
|
||||
message_compression = res[2] & 0x0f
|
||||
|
||||
payload = res[header_size * 4:]
|
||||
result = {}
|
||||
|
||||
if message_type == cls.SERVER_FULL_RESPONSE or message_type == cls.SERVER_ACK:
|
||||
result['message_type'] = 'SERVER_FULL_RESPONSE'
|
||||
if message_type == cls.SERVER_ACK:
|
||||
result['message_type'] = 'SERVER_ACK'
|
||||
|
||||
start = 0
|
||||
if message_type_specific_flags & cls.MSG_WITH_EVENT:
|
||||
result['event'] = int.from_bytes(payload[:4], "big", signed=False)
|
||||
start += 4
|
||||
|
||||
payload = payload[start:]
|
||||
if len(payload) < 4:
|
||||
result['error'] = 'Payload too short for session_id'
|
||||
return result
|
||||
|
||||
session_id_size = int.from_bytes(payload[:4], "big", signed=True)
|
||||
if session_id_size < 0 or session_id_size > len(payload) - 4:
|
||||
result['error'] = f'Invalid session_id size: {session_id_size}'
|
||||
return result
|
||||
|
||||
session_id = payload[4:session_id_size+4]
|
||||
result['session_id'] = str(session_id)
|
||||
payload = payload[4 + session_id_size:]
|
||||
|
||||
if len(payload) < 4:
|
||||
result['error'] = 'Payload too short for payload_size'
|
||||
return result
|
||||
|
||||
payload_size = int.from_bytes(payload[:4], "big", signed=False)
|
||||
result['payload_size'] = payload_size
|
||||
|
||||
if len(payload) >= 4 + payload_size:
|
||||
payload_msg = payload[4:4 + payload_size]
|
||||
if payload_msg:
|
||||
if message_compression == cls.GZIP:
|
||||
try:
|
||||
payload_msg = gzip.decompress(payload_msg)
|
||||
except Exception as e:
|
||||
result['decompress_error'] = str(e)
|
||||
return result
|
||||
|
||||
if serialization_method == cls.JSON:
|
||||
try:
|
||||
payload_msg = json.loads(str(payload_msg, "utf-8"))
|
||||
except Exception as e:
|
||||
result['json_error'] = str(e)
|
||||
payload_msg = str(payload_msg, "utf-8")
|
||||
elif serialization_method != cls.NO_SERIALIZATION:
|
||||
payload_msg = str(payload_msg, "utf-8")
|
||||
result['payload_msg'] = payload_msg
|
||||
|
||||
elif message_type == cls.SERVER_ERROR_RESPONSE:
|
||||
if len(payload) >= 8:
|
||||
code = int.from_bytes(payload[:4], "big", signed=False)
|
||||
result['code'] = code
|
||||
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
result['payload_size'] = payload_size
|
||||
|
||||
if len(payload) >= 8 + payload_size:
|
||||
payload_msg = payload[8:8 + payload_size]
|
||||
if payload_msg and message_compression == cls.GZIP:
|
||||
try:
|
||||
payload_msg = gzip.decompress(payload_msg)
|
||||
except:
|
||||
pass
|
||||
result['payload_msg'] = payload_msg
|
||||
|
||||
except Exception as e:
|
||||
result['parse_error'] = str(e)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class AudioProcessor:
|
||||
"""音频处理器"""
|
||||
|
||||
@staticmethod
|
||||
def read_wav_file(file_path: str) -> tuple:
|
||||
"""读取WAV文件,返回音频数据和参数"""
|
||||
with wave.open(file_path, 'rb') as wf:
|
||||
# 获取音频参数
|
||||
channels = wf.getnchannels()
|
||||
sampwidth = wf.getsampwidth()
|
||||
framerate = wf.getframerate()
|
||||
nframes = wf.getnframes()
|
||||
|
||||
# 读取音频数据
|
||||
audio_data = wf.readframes(nframes)
|
||||
|
||||
return audio_data, {
|
||||
'channels': channels,
|
||||
'sampwidth': sampwidth,
|
||||
'framerate': framerate,
|
||||
'nframes': nframes
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_wav_file(audio_data: bytes, output_path: str,
|
||||
sample_rate: int = 24000, channels: int = 1,
|
||||
sampwidth: int = 2) -> None:
|
||||
"""创建WAV文件,适配树莓派播放"""
|
||||
with wave.open(output_path, 'wb') as wf:
|
||||
wf.setnchannels(channels)
|
||||
wf.setsampwidth(sampwidth)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio_data)
|
||||
|
||||
|
||||
class DoubaoClient:
|
||||
"""豆包客户端"""
|
||||
|
||||
def __init__(self, config: DoubaoConfig):
|
||||
self.config = config
|
||||
self.session_id = str(uuid.uuid4())
|
||||
self.ws = None
|
||||
self.log_id = ""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""建立WebSocket连接"""
|
||||
print(f"连接豆包服务器: {self.config.base_url}")
|
||||
|
||||
try:
|
||||
self.ws = await websockets.connect(
|
||||
self.config.base_url,
|
||||
additional_headers=self.config.get_headers(),
|
||||
ping_interval=None
|
||||
)
|
||||
|
||||
# 获取log_id
|
||||
if hasattr(self.ws, 'response_headers'):
|
||||
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
|
||||
elif hasattr(self.ws, 'headers'):
|
||||
self.log_id = self.ws.headers.get("X-Tt-Logid")
|
||||
|
||||
print(f"连接成功, log_id: {self.log_id}")
|
||||
|
||||
# 发送StartConnection请求
|
||||
await self._send_start_connection()
|
||||
|
||||
# 发送StartSession请求
|
||||
await self._send_start_session()
|
||||
|
||||
except Exception as e:
|
||||
print(f"连接失败: {e}")
|
||||
raise
|
||||
|
||||
async def _send_start_connection(self) -> None:
|
||||
"""发送StartConnection请求"""
|
||||
print("发送StartConnection请求...")
|
||||
request = bytearray(DoubaoProtocol.generate_header())
|
||||
request.extend(int(1).to_bytes(4, 'big'))
|
||||
|
||||
payload_bytes = b"{}"
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
request.extend(payload_bytes)
|
||||
|
||||
await self.ws.send(request)
|
||||
response = await self.ws.recv()
|
||||
parsed_response = DoubaoProtocol.parse_response(response)
|
||||
print(f"StartConnection响应: {parsed_response}")
|
||||
|
||||
# 检查是否有错误
|
||||
if 'error' in parsed_response:
|
||||
raise Exception(f"StartConnection解析错误: {parsed_response['error']}")
|
||||
|
||||
async def _send_start_session(self) -> None:
|
||||
"""发送StartSession请求"""
|
||||
print("发送StartSession请求...")
|
||||
session_config = {
|
||||
"asr": {
|
||||
"extra": {
|
||||
"end_smooth_window_ms": 1500,
|
||||
},
|
||||
},
|
||||
"tts": {
|
||||
"speaker": "zh_female_vv_jupiter_bigtts",
|
||||
"audio_config": {
|
||||
"channel": 1,
|
||||
"format": "pcm",
|
||||
"sample_rate": 24000
|
||||
},
|
||||
},
|
||||
"dialog": {
|
||||
"bot_name": "豆包",
|
||||
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
|
||||
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
|
||||
"location": {"city": "北京"},
|
||||
"extra": {
|
||||
"strict_audit": False,
|
||||
"audit_response": "支持客户自定义安全审核回复话术。",
|
||||
"recv_timeout": 30, # 增加超时时间
|
||||
"input_mod": "audio_file", # 使用音频文件模式
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
request = bytearray(DoubaoProtocol.generate_header())
|
||||
request.extend(int(100).to_bytes(4, 'big'))
|
||||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||||
request.extend(self.session_id.encode())
|
||||
|
||||
payload_bytes = json.dumps(session_config).encode()
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
request.extend(payload_bytes)
|
||||
|
||||
await self.ws.send(request)
|
||||
response = await self.ws.recv()
|
||||
parsed_response = DoubaoProtocol.parse_response(response)
|
||||
print(f"StartSession响应: {parsed_response}")
|
||||
|
||||
# 检查是否有错误
|
||||
if 'error' in parsed_response:
|
||||
raise Exception(f"StartSession解析错误: {parsed_response['error']}")
|
||||
|
||||
# 等待一会确保会话完全建立
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
async def send_audio_file(self, file_path: str) -> bytes:
|
||||
"""发送音频文件并返回响应音频"""
|
||||
print(f"处理音频文件: {file_path}")
|
||||
|
||||
# 读取音频文件
|
||||
audio_data, audio_info = AudioProcessor.read_wav_file(file_path)
|
||||
print(f"音频参数: {audio_info}")
|
||||
|
||||
# 计算分块大小(减小到50ms,避免数据块过大)
|
||||
chunk_size = int(audio_info['framerate'] * audio_info['channels'] *
|
||||
audio_info['sampwidth'] * 0.05) # 50ms
|
||||
|
||||
# 分块发送音频数据
|
||||
total_chunks = (len(audio_data) + chunk_size - 1) // chunk_size
|
||||
print(f"开始分块发送音频,共 {total_chunks} 块")
|
||||
|
||||
received_audio = b""
|
||||
error_count = 0
|
||||
session_active = True
|
||||
|
||||
for i in range(0, len(audio_data), chunk_size):
|
||||
if not session_active:
|
||||
print("会话已结束,停止发送")
|
||||
break
|
||||
|
||||
chunk = audio_data[i:i + chunk_size]
|
||||
is_last = (i + chunk_size >= len(audio_data))
|
||||
|
||||
# 发送音频块
|
||||
await self._send_audio_chunk(chunk, is_last)
|
||||
|
||||
# 接收响应
|
||||
try:
|
||||
response = await asyncio.wait_for(self.ws.recv(), timeout=3.0)
|
||||
parsed_response = DoubaoProtocol.parse_response(response)
|
||||
|
||||
print(f"响应 {i//chunk_size + 1}/{total_chunks}: {parsed_response}")
|
||||
|
||||
# 检查是否是错误响应
|
||||
if 'code' in parsed_response and parsed_response['code'] != 0:
|
||||
print(f"服务器返回错误: {parsed_response}")
|
||||
error_count += 1
|
||||
if error_count > 3:
|
||||
raise Exception(f"服务器连续返回错误: {parsed_response}")
|
||||
continue
|
||||
|
||||
# 处理音频响应
|
||||
if (parsed_response.get('message_type') == 'SERVER_ACK' and
|
||||
isinstance(parsed_response.get('payload_msg'), bytes)):
|
||||
audio_chunk = parsed_response['payload_msg']
|
||||
received_audio += audio_chunk
|
||||
print(f"接收到音频数据块,大小: {len(audio_chunk)} 字节")
|
||||
|
||||
# 检查会话状态
|
||||
event = parsed_response.get('event')
|
||||
if event in [359, 152, 153]: # 这些事件表示会话相关状态
|
||||
print(f"会话事件: {event}")
|
||||
if event in [152, 153]: # 会话结束
|
||||
print("检测到会话结束事件")
|
||||
session_active = False
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print("等待响应超时,继续发送")
|
||||
|
||||
# 模拟实时发送的延迟
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
print("音频文件发送完成")
|
||||
return received_audio
|
||||
|
||||
async def _send_audio_chunk(self, audio_data: bytes, is_last: bool = False) -> None:
|
||||
"""发送音频块"""
|
||||
request = bytearray(
|
||||
DoubaoProtocol.generate_header(
|
||||
message_type=DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST,
|
||||
message_type_specific_flags=DoubaoProtocol.NO_SEQUENCE,
|
||||
serial_method=DoubaoProtocol.NO_SERIALIZATION, # 音频数据不需要序列化
|
||||
compression_type=DoubaoProtocol.GZIP
|
||||
)
|
||||
)
|
||||
request.extend(int(200).to_bytes(4, 'big'))
|
||||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||||
request.extend(self.session_id.encode())
|
||||
|
||||
# 压缩音频数据
|
||||
compressed_audio = gzip.compress(audio_data)
|
||||
payload_size = len(compressed_audio)
|
||||
request.extend(payload_size.to_bytes(4, 'big')) # payload size(4 bytes)
|
||||
request.extend(compressed_audio)
|
||||
|
||||
print(f"发送音频块 - 原始大小: {len(audio_data)}, 压缩后大小: {payload_size}, 总请求数据大小: {len(request)}")
|
||||
|
||||
await self.ws.send(request)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭连接"""
|
||||
if self.ws:
|
||||
try:
|
||||
# 发送FinishSession
|
||||
await self._send_finish_session()
|
||||
|
||||
# 发送FinishConnection
|
||||
await self._send_finish_connection()
|
||||
|
||||
except Exception as e:
|
||||
print(f"关闭会话时出错: {e}")
|
||||
finally:
|
||||
# 确保WebSocket连接关闭
|
||||
try:
|
||||
await self.ws.close()
|
||||
except:
|
||||
pass
|
||||
print("连接已关闭")
|
||||
|
||||
async def _send_finish_session(self) -> None:
|
||||
"""发送FinishSession请求"""
|
||||
print("发送FinishSession请求...")
|
||||
request = bytearray(DoubaoProtocol.generate_header())
|
||||
request.extend(int(102).to_bytes(4, 'big'))
|
||||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||||
request.extend(self.session_id.encode())
|
||||
|
||||
payload_bytes = b"{}"
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
request.extend(payload_bytes)
|
||||
|
||||
await self.ws.send(request)
|
||||
|
||||
async def _send_finish_connection(self) -> None:
|
||||
"""发送FinishConnection请求"""
|
||||
print("发送FinishConnection请求...")
|
||||
request = bytearray(DoubaoProtocol.generate_header())
|
||||
request.extend(int(2).to_bytes(4, 'big'))
|
||||
|
||||
payload_bytes = b"{}"
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
request.extend(payload_bytes)
|
||||
|
||||
await self.ws.send(request)
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(self.ws.recv(), timeout=5.0)
|
||||
parsed_response = DoubaoProtocol.parse_response(response)
|
||||
print(f"FinishConnection响应: {parsed_response}")
|
||||
except asyncio.TimeoutError:
|
||||
print("FinishConnection响应超时")
|
||||
|
||||
|
||||
class DoubaoProcessor:
|
||||
"""豆包音频处理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = DoubaoConfig()
|
||||
self.client = DoubaoClient(self.config)
|
||||
|
||||
async def process_audio_file(self, input_file: str, output_file: str = None) -> str:
|
||||
"""处理音频文件
|
||||
|
||||
Args:
|
||||
input_file: 输入音频文件路径
|
||||
output_file: 输出音频文件路径,如果为None则自动生成
|
||||
|
||||
Returns:
|
||||
输出音频文件路径
|
||||
"""
|
||||
if not os.path.exists(input_file):
|
||||
raise FileNotFoundError(f"音频文件不存在: {input_file}")
|
||||
|
||||
# 生成输出文件名
|
||||
if output_file is None:
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
output_file = f"doubao_response_{timestamp}.wav"
|
||||
|
||||
try:
|
||||
# 连接豆包服务器
|
||||
await self.client.connect()
|
||||
|
||||
# 发送音频文件并获取响应
|
||||
received_audio = await self.client.send_audio_file(input_file)
|
||||
|
||||
if received_audio:
|
||||
print(f"总共接收到音频数据: {len(received_audio)} 字节")
|
||||
# 转换为WAV格式保存(适配树莓派播放)
|
||||
AudioProcessor.create_wav_file(
|
||||
received_audio,
|
||||
output_file,
|
||||
sample_rate=24000, # 豆包返回的音频采样率
|
||||
channels=1,
|
||||
sampwidth=2 # 16-bit
|
||||
)
|
||||
print(f"响应音频已保存到: {output_file}")
|
||||
|
||||
# 显示文件信息
|
||||
file_size = os.path.getsize(output_file)
|
||||
print(f"输出文件大小: {file_size} 字节")
|
||||
|
||||
else:
|
||||
print("警告: 未接收到音频响应")
|
||||
|
||||
return output_file
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理音频文件时出错: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
await self.client.close()
|
||||
|
||||
|
||||
async def main():
|
||||
"""测试函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="豆包音频处理测试")
|
||||
parser.add_argument("--input", type=str, required=True, help="输入音频文件路径")
|
||||
parser.add_argument("--output", type=str, help="输出音频文件路径")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
processor = DoubaoProcessor()
|
||||
try:
|
||||
output_file = await processor.process_audio_file(args.input, args.output)
|
||||
print(f"处理完成,输出文件: {output_file}")
|
||||
except Exception as e:
|
||||
print(f"处理失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,425 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
豆包音频处理模块 - 简化测试版本
|
||||
专门测试完整的音频上传和TTS音频下载流程
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import gzip
|
||||
import json
|
||||
import uuid
|
||||
import wave
|
||||
import struct
|
||||
import time
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
import websockets
|
||||
|
||||
|
||||
# 直接复制原始豆包代码的协议常量
|
||||
PROTOCOL_VERSION = 0b0001
|
||||
CLIENT_FULL_REQUEST = 0b0001
|
||||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||
SERVER_FULL_RESPONSE = 0b1001
|
||||
SERVER_ACK = 0b1011
|
||||
SERVER_ERROR_RESPONSE = 0b1111
|
||||
|
||||
NO_SEQUENCE = 0b0000
|
||||
POS_SEQUENCE = 0b0001
|
||||
MSG_WITH_EVENT = 0b0100
|
||||
|
||||
NO_SERIALIZATION = 0b0000
|
||||
JSON = 0b0001
|
||||
GZIP = 0b0001
|
||||
|
||||
|
||||
def generate_header(
|
||||
version=PROTOCOL_VERSION,
|
||||
message_type=CLIENT_FULL_REQUEST,
|
||||
message_type_specific_flags=MSG_WITH_EVENT,
|
||||
serial_method=JSON,
|
||||
compression_type=GZIP,
|
||||
reserved_data=0x00,
|
||||
extension_header=bytes()
|
||||
):
|
||||
"""直接复制原始豆包代码的generate_header函数"""
|
||||
header = bytearray()
|
||||
header_size = int(len(extension_header) / 4) + 1
|
||||
header.append((version << 4) | header_size)
|
||||
header.append((message_type << 4) | message_type_specific_flags)
|
||||
header.append((serial_method << 4) | compression_type)
|
||||
header.append(reserved_data)
|
||||
header.extend(extension_header)
|
||||
return header
|
||||
|
||||
|
||||
class DoubaoConfig:
|
||||
"""豆包配置"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
|
||||
self.app_id = "8718217928"
|
||||
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
|
||||
self.app_key = "PlgvMymc7f3tQnJ6"
|
||||
self.resource_id = "volc.speech.dialog"
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
"""获取请求头"""
|
||||
return {
|
||||
"X-Api-App-ID": self.app_id,
|
||||
"X-Api-Access-Key": self.access_key,
|
||||
"X-Api-Resource-Id": self.resource_id,
|
||||
"X-Api-App-Key": self.app_key,
|
||||
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
|
||||
class DoubaoClient:
|
||||
"""豆包客户端 - 基于原始代码"""
|
||||
|
||||
def __init__(self, config: DoubaoConfig):
|
||||
self.config = config
|
||||
self.session_id = str(uuid.uuid4())
|
||||
self.ws = None
|
||||
self.log_id = ""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""建立WebSocket连接"""
|
||||
print(f"连接豆包服务器: {self.config.base_url}")
|
||||
|
||||
try:
|
||||
self.ws = await websockets.connect(
|
||||
self.config.base_url,
|
||||
additional_headers=self.config.get_headers(),
|
||||
ping_interval=None
|
||||
)
|
||||
|
||||
# 获取log_id
|
||||
if hasattr(self.ws, 'response_headers'):
|
||||
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
|
||||
elif hasattr(self.ws, 'headers'):
|
||||
self.log_id = self.ws.headers.get("X-Tt-Logid")
|
||||
|
||||
print(f"连接成功, log_id: {self.log_id}")
|
||||
|
||||
# 发送StartConnection请求
|
||||
await self._send_start_connection()
|
||||
|
||||
# 发送StartSession请求
|
||||
await self._send_start_session()
|
||||
|
||||
except Exception as e:
|
||||
print(f"连接失败: {e}")
|
||||
raise
|
||||
|
||||
async def _send_start_connection(self) -> None:
|
||||
"""发送StartConnection请求"""
|
||||
print("发送StartConnection请求...")
|
||||
request = bytearray(generate_header())
|
||||
request.extend(int(1).to_bytes(4, 'big'))
|
||||
|
||||
payload_bytes = b"{}"
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
request.extend(payload_bytes)
|
||||
|
||||
await self.ws.send(request)
|
||||
response = await self.ws.recv()
|
||||
print(f"StartConnection响应长度: {len(response)}")
|
||||
|
||||
async def _send_start_session(self) -> None:
|
||||
"""发送StartSession请求"""
|
||||
print("发送StartSession请求...")
|
||||
session_config = {
|
||||
"asr": {
|
||||
"extra": {
|
||||
"end_smooth_window_ms": 1500,
|
||||
},
|
||||
},
|
||||
"tts": {
|
||||
"speaker": "zh_female_vv_jupiter_bigtts",
|
||||
"audio_config": {
|
||||
"channel": 1,
|
||||
"format": "pcm",
|
||||
"sample_rate": 24000
|
||||
},
|
||||
},
|
||||
"dialog": {
|
||||
"bot_name": "豆包",
|
||||
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
|
||||
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
|
||||
"location": {"city": "北京"},
|
||||
"extra": {
|
||||
"strict_audit": False,
|
||||
"audit_response": "支持客户自定义安全审核回复话术。",
|
||||
"recv_timeout": 30,
|
||||
"input_mod": "audio",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
request = bytearray(generate_header())
|
||||
request.extend(int(100).to_bytes(4, 'big'))
|
||||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||||
request.extend(self.session_id.encode())
|
||||
|
||||
payload_bytes = json.dumps(session_config).encode()
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
request.extend(payload_bytes)
|
||||
|
||||
await self.ws.send(request)
|
||||
response = await self.ws.recv()
|
||||
print(f"StartSession响应长度: {len(response)}")
|
||||
|
||||
# 等待一会确保会话完全建立
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
async def task_request(self, audio: bytes) -> None:
|
||||
"""直接复制原始豆包代码的task_request方法"""
|
||||
task_request = bytearray(
|
||||
generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
|
||||
serial_method=NO_SERIALIZATION))
|
||||
task_request.extend(int(200).to_bytes(4, 'big'))
|
||||
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
|
||||
task_request.extend(str.encode(self.session_id))
|
||||
payload_bytes = gzip.compress(audio)
|
||||
task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
|
||||
task_request.extend(payload_bytes)
|
||||
await self.ws.send(task_request)
|
||||
|
||||
async def test_full_dialog(self) -> None:
|
||||
"""测试完整对话流程"""
|
||||
print("开始完整对话测试...")
|
||||
|
||||
# 读取真实的录音文件
|
||||
try:
|
||||
import wave
|
||||
with wave.open("recording_20250920_135137.wav", 'rb') as wf:
|
||||
# 读取前5秒的音频数据
|
||||
total_frames = wf.getnframes()
|
||||
frames_to_read = min(total_frames, 80000) # 5秒
|
||||
audio_data = wf.readframes(frames_to_read)
|
||||
print(f"读取真实音频数据: {len(audio_data)} 字节")
|
||||
print(f"音频参数: 采样率={wf.getframerate()}, 通道数={wf.getnchannels()}, 采样宽度={wf.getsampwidth()}")
|
||||
except Exception as e:
|
||||
print(f"读取音频文件失败: {e}")
|
||||
return
|
||||
|
||||
print(f"音频数据大小: {len(audio_data)}")
|
||||
|
||||
try:
|
||||
# 发送音频数据
|
||||
print("发送音频数据...")
|
||||
await self.task_request(audio_data)
|
||||
print("音频数据发送成功")
|
||||
|
||||
# 等待语音识别响应
|
||||
print("等待语音识别响应...")
|
||||
response = await asyncio.wait_for(self.ws.recv(), timeout=15.0)
|
||||
print(f"收到ASR响应,长度: {len(response)}")
|
||||
|
||||
# 解析ASR响应
|
||||
if len(response) >= 4:
|
||||
protocol_version = response[0] >> 4
|
||||
header_size = response[0] & 0x0f
|
||||
message_type = response[1] >> 4
|
||||
flags = response[1] & 0x0f
|
||||
print(f"ASR响应协议: version={protocol_version}, header_size={header_size}, message_type={message_type}, flags={flags}")
|
||||
|
||||
if message_type == 9: # SERVER_FULL_RESPONSE
|
||||
payload_start = header_size * 4
|
||||
payload = response[payload_start:]
|
||||
|
||||
if len(payload) >= 4:
|
||||
event = int.from_bytes(payload[:4], 'big')
|
||||
print(f"ASR Event: {event}")
|
||||
|
||||
if len(payload) >= 8:
|
||||
session_id_len = int.from_bytes(payload[4:8], 'big')
|
||||
if len(payload) >= 8 + session_id_len:
|
||||
session_id = payload[8:8+session_id_len].decode()
|
||||
print(f"Session ID: {session_id}")
|
||||
|
||||
if len(payload) >= 12 + session_id_len:
|
||||
payload_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big')
|
||||
payload_data = payload[12+session_id_len:12+session_id_len+payload_size]
|
||||
print(f"Payload size: {payload_size}")
|
||||
|
||||
# 解析ASR结果
|
||||
try:
|
||||
asr_result = json.loads(payload_data.decode('utf-8'))
|
||||
print(f"ASR结果: {asr_result}")
|
||||
|
||||
# 如果有识别结果,提取文本
|
||||
if 'results' in asr_result and asr_result['results']:
|
||||
text = asr_result['results'][0].get('text', '')
|
||||
print(f"识别文本: {text}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析ASR结果失败: {e}")
|
||||
|
||||
# 持续等待TTS音频响应
|
||||
print("开始持续等待TTS音频响应...")
|
||||
response_count = 0
|
||||
max_responses = 10
|
||||
|
||||
while response_count < max_responses:
|
||||
try:
|
||||
print(f"等待第 {response_count + 1} 个响应...")
|
||||
tts_response = await asyncio.wait_for(self.ws.recv(), timeout=30.0)
|
||||
print(f"收到响应 {response_count + 1},长度: {len(tts_response)}")
|
||||
|
||||
# 解析响应
|
||||
if len(tts_response) >= 4:
|
||||
tts_version = tts_response[0] >> 4
|
||||
tts_header_size = tts_response[0] & 0x0f
|
||||
tts_message_type = tts_response[1] >> 4
|
||||
tts_flags = tts_response[1] & 0x0f
|
||||
print(f"响应协议: version={tts_version}, header_size={tts_header_size}, message_type={tts_message_type}, flags={tts_flags}")
|
||||
|
||||
if tts_message_type == 11: # SERVER_ACK (包含TTS音频)
|
||||
tts_payload_start = tts_header_size * 4
|
||||
tts_payload = tts_response[tts_payload_start:]
|
||||
|
||||
if len(tts_payload) >= 12:
|
||||
tts_event = int.from_bytes(tts_payload[:4], 'big')
|
||||
tts_session_len = int.from_bytes(tts_payload[4:8], 'big')
|
||||
tts_session = tts_payload[8:8+tts_session_len].decode()
|
||||
tts_audio_size = int.from_bytes(tts_payload[8+tts_session_len:12+tts_session_len], 'big')
|
||||
tts_audio_data = tts_payload[12+tts_session_len:12+tts_session_len+tts_audio_size]
|
||||
|
||||
print(f"Event: {tts_event}")
|
||||
print(f"音频数据大小: {tts_audio_size}")
|
||||
|
||||
if tts_audio_size > 0:
|
||||
print("找到TTS音频数据!")
|
||||
# 尝试解压缩TTS音频
|
||||
try:
|
||||
decompressed_tts = gzip.decompress(tts_audio_data)
|
||||
print(f"解压缩后TTS音频大小: {len(decompressed_tts)}")
|
||||
|
||||
# 创建WAV文件
|
||||
sample_rate = 24000
|
||||
channels = 1
|
||||
sampwidth = 2
|
||||
|
||||
with wave.open(f'tts_response_{response_count}.wav', 'wb') as wav_file:
|
||||
wav_file.setnchannels(channels)
|
||||
wav_file.setsampwidth(sampwidth)
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(decompressed_tts)
|
||||
|
||||
print(f"成功创建TTS WAV文件: tts_response_{response_count}.wav")
|
||||
print(f"音频参数: {sample_rate}Hz, {channels}通道, {sampwidth*8}-bit")
|
||||
|
||||
# 显示文件信息
|
||||
if os.path.exists(f'tts_response_{response_count}.wav'):
|
||||
file_size = os.path.getsize(f'tts_response_{response_count}.wav')
|
||||
duration = file_size / (sample_rate * channels * sampwidth)
|
||||
print(f"WAV文件大小: {file_size} 字节")
|
||||
print(f"音频时长: {duration:.2f} 秒")
|
||||
|
||||
# 成功获取音频,退出循环
|
||||
break
|
||||
|
||||
except Exception as tts_e:
|
||||
print(f"TTS音频解压缩失败: {tts_e}")
|
||||
# 保存原始数据
|
||||
with open(f'tts_response_audio_{response_count}.raw', 'wb') as f:
|
||||
f.write(tts_audio_data)
|
||||
print(f"原始TTS音频数据已保存到 tts_response_audio_{response_count}.raw")
|
||||
|
||||
elif tts_message_type == 9: # SERVER_FULL_RESPONSE
|
||||
tts_payload_start = tts_header_size * 4
|
||||
tts_payload = tts_response[tts_payload_start:]
|
||||
|
||||
if len(tts_payload) >= 4:
|
||||
event = int.from_bytes(tts_payload[:4], 'big')
|
||||
print(f"Event: {event}")
|
||||
|
||||
if event in [451, 359]: # ASR结果或TTS结束
|
||||
# 解析payload
|
||||
if len(tts_payload) >= 8:
|
||||
session_id_len = int.from_bytes(tts_payload[4:8], 'big')
|
||||
if len(tts_payload) >= 8 + session_id_len:
|
||||
session_id = tts_payload[8:8+session_id_len].decode()
|
||||
if len(tts_payload) >= 12 + session_id_len:
|
||||
payload_size = int.from_bytes(tts_payload[8+session_id_len:12+session_id_len], 'big')
|
||||
payload_data = tts_payload[12+session_id_len:12+session_id_len+payload_size]
|
||||
|
||||
try:
|
||||
json_data = json.loads(payload_data.decode('utf-8'))
|
||||
print(f"JSON数据: {json_data}")
|
||||
|
||||
# 如果是ASR结果
|
||||
if 'results' in json_data:
|
||||
text = json_data['results'][0].get('text', '')
|
||||
print(f"识别文本: {text}")
|
||||
|
||||
# 如果是TTS结束标记
|
||||
if event == 359:
|
||||
print("TTS响应结束")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析JSON失败: {e}")
|
||||
# 保存原始数据
|
||||
with open(f'tts_response_{response_count}.raw', 'wb') as f:
|
||||
f.write(payload_data)
|
||||
|
||||
# 保存完整响应用于调试
|
||||
with open(f'tts_response_full_{response_count}.raw', 'wb') as f:
|
||||
f.write(tts_response)
|
||||
print(f"完整响应已保存到 tts_response_full_{response_count}.raw")
|
||||
|
||||
response_count += 1
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print(f"等待第 {response_count + 1} 个响应超时")
|
||||
break
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
print("连接已关闭")
|
||||
break
|
||||
|
||||
print(f"共收到 {response_count} 个响应")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print("等待响应超时")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
print(f"连接关闭: {e}")
|
||||
except Exception as e:
|
||||
print(f"测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭连接"""
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.close()
|
||||
except:
|
||||
pass
|
||||
print("连接已关闭")
|
||||
|
||||
|
||||
async def main():
|
||||
"""测试函数"""
|
||||
config = DoubaoConfig()
|
||||
client = DoubaoClient(config)
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
await client.test_full_dialog()
|
||||
except Exception as e:
|
||||
print(f"测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,412 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
豆包音频处理模块 - 基于原始代码的测试版本
|
||||
直接使用原始豆包代码的核心逻辑
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import gzip
|
||||
import json
|
||||
import uuid
|
||||
import wave
|
||||
import struct
|
||||
import time
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
import websockets
|
||||
|
||||
|
||||
# 直接复制原始豆包代码的协议常量
|
||||
PROTOCOL_VERSION = 0b0001
|
||||
CLIENT_FULL_REQUEST = 0b0001
|
||||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||
SERVER_FULL_RESPONSE = 0b1001
|
||||
SERVER_ACK = 0b1011
|
||||
SERVER_ERROR_RESPONSE = 0b1111
|
||||
|
||||
NO_SEQUENCE = 0b0000
|
||||
POS_SEQUENCE = 0b0001
|
||||
MSG_WITH_EVENT = 0b0100
|
||||
|
||||
NO_SERIALIZATION = 0b0000
|
||||
JSON = 0b0001
|
||||
GZIP = 0b0001
|
||||
|
||||
|
||||
def generate_header(
|
||||
version=PROTOCOL_VERSION,
|
||||
message_type=CLIENT_FULL_REQUEST,
|
||||
message_type_specific_flags=MSG_WITH_EVENT,
|
||||
serial_method=JSON,
|
||||
compression_type=GZIP,
|
||||
reserved_data=0x00,
|
||||
extension_header=bytes()
|
||||
):
|
||||
"""直接复制原始豆包代码的generate_header函数"""
|
||||
header = bytearray()
|
||||
header_size = int(len(extension_header) / 4) + 1
|
||||
header.append((version << 4) | header_size)
|
||||
header.append((message_type << 4) | message_type_specific_flags)
|
||||
header.append((serial_method << 4) | compression_type)
|
||||
header.append(reserved_data)
|
||||
header.extend(extension_header)
|
||||
return header
|
||||
|
||||
|
||||
class DoubaoConfig:
|
||||
"""豆包配置"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
|
||||
self.app_id = "8718217928"
|
||||
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
|
||||
self.app_key = "PlgvMymc7f3tQnJ6"
|
||||
self.resource_id = "volc.speech.dialog"
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
"""获取请求头"""
|
||||
return {
|
||||
"X-Api-App-ID": self.app_id,
|
||||
"X-Api-Access-Key": self.access_key,
|
||||
"X-Api-Resource-Id": self.resource_id,
|
||||
"X-Api-App-Key": self.app_key,
|
||||
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
|
||||
class DoubaoClient:
|
||||
"""豆包客户端 - 基于原始代码"""
|
||||
|
||||
def __init__(self, config: DoubaoConfig):
|
||||
self.config = config
|
||||
self.session_id = str(uuid.uuid4())
|
||||
self.ws = None
|
||||
self.log_id = ""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""建立WebSocket连接"""
|
||||
print(f"连接豆包服务器: {self.config.base_url}")
|
||||
|
||||
try:
|
||||
self.ws = await websockets.connect(
|
||||
self.config.base_url,
|
||||
additional_headers=self.config.get_headers(),
|
||||
ping_interval=None
|
||||
)
|
||||
|
||||
# 获取log_id
|
||||
if hasattr(self.ws, 'response_headers'):
|
||||
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
|
||||
elif hasattr(self.ws, 'headers'):
|
||||
self.log_id = self.ws.headers.get("X-Tt-Logid")
|
||||
|
||||
print(f"连接成功, log_id: {self.log_id}")
|
||||
|
||||
# 发送StartConnection请求
|
||||
await self._send_start_connection()
|
||||
|
||||
# 发送StartSession请求
|
||||
await self._send_start_session()
|
||||
|
||||
except Exception as e:
|
||||
print(f"连接失败: {e}")
|
||||
raise
|
||||
|
||||
async def _send_start_connection(self) -> None:
|
||||
"""发送StartConnection请求"""
|
||||
print("发送StartConnection请求...")
|
||||
request = bytearray(generate_header())
|
||||
request.extend(int(1).to_bytes(4, 'big'))
|
||||
|
||||
payload_bytes = b"{}"
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
request.extend(payload_bytes)
|
||||
|
||||
await self.ws.send(request)
|
||||
response = await self.ws.recv()
|
||||
print(f"StartConnection响应长度: {len(response)}")
|
||||
|
||||
async def _send_start_session(self) -> None:
|
||||
"""发送StartSession请求"""
|
||||
print("发送StartSession请求...")
|
||||
session_config = {
|
||||
"asr": {
|
||||
"extra": {
|
||||
"end_smooth_window_ms": 1500,
|
||||
},
|
||||
},
|
||||
"tts": {
|
||||
"speaker": "zh_female_vv_jupiter_bigtts",
|
||||
"audio_config": {
|
||||
"channel": 1,
|
||||
"format": "pcm",
|
||||
"sample_rate": 24000
|
||||
},
|
||||
},
|
||||
"dialog": {
|
||||
"bot_name": "豆包",
|
||||
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
|
||||
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
|
||||
"location": {"city": "北京"},
|
||||
"extra": {
|
||||
"strict_audit": False,
|
||||
"audit_response": "支持客户自定义安全审核回复话术。",
|
||||
"recv_timeout": 30,
|
||||
"input_mod": "audio",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
request = bytearray(generate_header())
|
||||
request.extend(int(100).to_bytes(4, 'big'))
|
||||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||||
request.extend(self.session_id.encode())
|
||||
|
||||
payload_bytes = json.dumps(session_config).encode()
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
request.extend(payload_bytes)
|
||||
|
||||
await self.ws.send(request)
|
||||
response = await self.ws.recv()
|
||||
print(f"StartSession响应长度: {len(response)}")
|
||||
|
||||
# 等待一会确保会话完全建立
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
async def task_request(self, audio: bytes) -> None:
|
||||
"""直接复制原始豆包代码的task_request方法"""
|
||||
task_request = bytearray(
|
||||
generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
|
||||
serial_method=NO_SERIALIZATION))
|
||||
task_request.extend(int(200).to_bytes(4, 'big'))
|
||||
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
|
||||
task_request.extend(str.encode(self.session_id))
|
||||
payload_bytes = gzip.compress(audio)
|
||||
task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
|
||||
task_request.extend(payload_bytes)
|
||||
await self.ws.send(task_request)
|
||||
|
||||
async def test_audio_request(self) -> None:
|
||||
"""测试音频请求"""
|
||||
print("测试音频请求...")
|
||||
|
||||
# 读取真实的录音文件
|
||||
try:
|
||||
import wave
|
||||
with wave.open("recording_20250920_135137.wav", 'rb') as wf:
|
||||
# 读取前10秒的音频数据(16000采样率 * 10秒 = 160000帧)
|
||||
total_frames = wf.getnframes()
|
||||
frames_to_read = min(total_frames, 160000) # 最多10秒
|
||||
small_audio = wf.readframes(frames_to_read)
|
||||
print(f"读取真实音频数据: {len(small_audio)} 字节")
|
||||
print(f"音频参数: 采样率={wf.getframerate()}, 通道数={wf.getnchannels()}, 采样宽度={wf.getsampwidth()}")
|
||||
print(f"总帧数: {total_frames}, 读取帧数: {frames_to_read}")
|
||||
except Exception as e:
|
||||
print(f"读取音频文件失败: {e}")
|
||||
# 如果读取失败,使用静音数据
|
||||
small_audio = b'\x00' * 3200
|
||||
|
||||
print(f"音频数据大小: {len(small_audio)}")
|
||||
|
||||
try:
|
||||
# 发送完整的音频数据块
|
||||
print(f"发送完整的音频数据块...")
|
||||
await self.task_request(small_audio)
|
||||
print(f"音频数据块发送成功")
|
||||
|
||||
print("等待语音识别响应...")
|
||||
|
||||
# 等待更长时间的响应(语音识别可能需要更长时间)
|
||||
response = await asyncio.wait_for(self.ws.recv(), timeout=15.0)
|
||||
print(f"收到响应,长度: {len(response)}")
|
||||
|
||||
# 解析响应
|
||||
try:
|
||||
if len(response) >= 4:
|
||||
protocol_version = response[0] >> 4
|
||||
header_size = response[0] & 0x0f
|
||||
message_type = response[1] >> 4
|
||||
message_type_specific_flags = response[1] & 0x0f
|
||||
print(f"响应协议信息: version={protocol_version}, header_size={header_size}, message_type={message_type}, flags={message_type_specific_flags}")
|
||||
|
||||
# 解析payload
|
||||
payload_start = header_size * 4
|
||||
payload = response[payload_start:]
|
||||
|
||||
if message_type == 9: # SERVER_FULL_RESPONSE
|
||||
print("收到SERVER_FULL_RESPONSE!")
|
||||
if len(payload) >= 4:
|
||||
# 解析event
|
||||
event = int.from_bytes(payload[:4], 'big')
|
||||
print(f"Event: {event}")
|
||||
|
||||
# 解析session_id
|
||||
if len(payload) >= 8:
|
||||
session_id_len = int.from_bytes(payload[4:8], 'big')
|
||||
if len(payload) >= 8 + session_id_len:
|
||||
session_id = payload[8:8+session_id_len].decode()
|
||||
print(f"Session ID: {session_id}")
|
||||
|
||||
# 解析payload size和data
|
||||
if len(payload) >= 12 + session_id_len:
|
||||
payload_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big')
|
||||
payload_data = payload[12+session_id_len:12+session_id_len+payload_size]
|
||||
print(f"Payload size: {payload_size}")
|
||||
|
||||
# 如果包含音频数据,保存到文件
|
||||
if len(payload_data) > 0:
|
||||
print(f"收到数据: {len(payload_data)} 字节")
|
||||
# 保存原始音频数据
|
||||
with open('response_audio.raw', 'wb') as f:
|
||||
f.write(payload_data)
|
||||
print("音频数据已保存到 response_audio.raw")
|
||||
|
||||
# 尝试解析JSON数据
|
||||
try:
|
||||
import json
|
||||
json_data = json.loads(payload_data.decode('utf-8'))
|
||||
print(f"JSON数据: {json_data}")
|
||||
|
||||
# 如果是语音识别任务开始,继续等待音频响应
|
||||
if 'asr_task_id' in json_data:
|
||||
print("语音识别任务开始,继续等待音频响应...")
|
||||
try:
|
||||
# 等待音频响应
|
||||
audio_response = await asyncio.wait_for(self.ws.recv(), timeout=20.0)
|
||||
print(f"收到音频响应,长度: {len(audio_response)}")
|
||||
|
||||
# 解析音频响应
|
||||
if len(audio_response) >= 4:
|
||||
audio_version = audio_response[0] >> 4
|
||||
audio_header_size = audio_response[0] & 0x0f
|
||||
audio_message_type = audio_response[1] >> 4
|
||||
audio_flags = audio_response[1] & 0x0f
|
||||
print(f"音频响应协议信息: version={audio_version}, header_size={audio_header_size}, message_type={audio_message_type}, flags={audio_flags}")
|
||||
|
||||
if audio_message_type == 9: # SERVER_FULL_RESPONSE (包含TTS音频)
|
||||
audio_payload_start = audio_header_size * 4
|
||||
audio_payload = audio_response[audio_payload_start:]
|
||||
|
||||
if len(audio_payload) >= 12:
|
||||
# 解析event和session_id
|
||||
audio_event = int.from_bytes(audio_payload[:4], 'big')
|
||||
audio_session_len = int.from_bytes(audio_payload[4:8], 'big')
|
||||
audio_session = audio_payload[8:8+audio_session_len].decode()
|
||||
audio_data_size = int.from_bytes(audio_payload[8+audio_session_len:12+audio_session_len], 'big')
|
||||
audio_data = audio_payload[12+audio_session_len:12+audio_session_len+audio_data_size]
|
||||
|
||||
print(f"音频Event: {audio_event}")
|
||||
print(f"音频数据大小: {audio_data_size}")
|
||||
|
||||
if audio_data_size > 0:
|
||||
# 保存原始音频数据
|
||||
with open('tts_response_audio.raw', 'wb') as f:
|
||||
f.write(audio_data)
|
||||
print(f"TTS音频数据已保存到 tts_response_audio.raw")
|
||||
|
||||
# 尝试解析音频数据(可能是JSON或GZIP压缩的音频)
|
||||
try:
|
||||
# 首先尝试解压缩
|
||||
import gzip
|
||||
decompressed_audio = gzip.decompress(audio_data)
|
||||
print(f"解压缩后音频数据大小: {len(decompressed_audio)}")
|
||||
with open('tts_response_audio_decompressed.raw', 'wb') as f:
|
||||
f.write(decompressed_audio)
|
||||
print("解压缩的音频数据已保存")
|
||||
|
||||
# 创建WAV文件供树莓派播放
|
||||
import wave
|
||||
import struct
|
||||
|
||||
# 豆包返回的音频是24000Hz, 16-bit, 单声道
|
||||
sample_rate = 24000
|
||||
channels = 1
|
||||
sampwidth = 2 # 16-bit = 2 bytes
|
||||
|
||||
with wave.open('tts_response.wav', 'wb') as wav_file:
|
||||
wav_file.setnchannels(channels)
|
||||
wav_file.setsampwidth(sampwidth)
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(decompressed_audio)
|
||||
|
||||
print("已创建WAV文件: tts_response.wav")
|
||||
print(f"音频参数: {sample_rate}Hz, {channels}通道, {sampwidth*8}-bit")
|
||||
|
||||
except Exception as audio_e:
|
||||
print(f"音频数据处理失败: {audio_e}")
|
||||
# 如果解压缩失败,直接保存原始数据
|
||||
with open('tts_response_audio_original.raw', 'wb') as f:
|
||||
f.write(audio_data)
|
||||
|
||||
elif audio_message_type == 11: # SERVER_ACK
|
||||
print("收到SERVER_ACK音频响应")
|
||||
# 处理SERVER_ACK格式的音频响应
|
||||
audio_payload_start = audio_header_size * 4
|
||||
audio_payload = audio_response[audio_payload_start:]
|
||||
print(f"音频payload长度: {len(audio_payload)}")
|
||||
with open('tts_response_ack.raw', 'wb') as f:
|
||||
f.write(audio_payload)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print("等待音频响应超时")
|
||||
|
||||
except Exception as json_e:
|
||||
print(f"解析JSON失败: {json_e}")
|
||||
# 如果不是JSON,可能是音频数据,直接保存
|
||||
with open('response_audio.raw', 'wb') as f:
|
||||
f.write(payload_data)
|
||||
|
||||
elif message_type == 11: # SERVER_ACK
|
||||
print("收到SERVER_ACK响应!")
|
||||
|
||||
elif message_type == 15: # SERVER_ERROR_RESPONSE
|
||||
print("收到错误响应")
|
||||
if len(response) > 8:
|
||||
error_code = int.from_bytes(response[4:8], 'big')
|
||||
print(f"错误代码: {error_code}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析响应失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print("等待响应超时")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
print(f"连接关闭: {e}")
|
||||
except Exception as e:
|
||||
print(f"发送音频请求失败: {e}")
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭连接"""
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.close()
|
||||
except:
|
||||
pass
|
||||
print("连接已关闭")
|
||||
|
||||
|
||||
async def main():
|
||||
"""测试函数"""
|
||||
config = DoubaoConfig()
|
||||
client = DoubaoClient(config)
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
await client.test_audio_request()
|
||||
except Exception as e:
|
||||
print(f"测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
303
doubao_test.py
303
doubao_test.py
@ -1,303 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
豆包音频处理模块 - 协议测试版本
|
||||
专门测试协议格式问题
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import gzip
|
||||
import json
|
||||
import uuid
|
||||
import wave
|
||||
import struct
|
||||
import time
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
import websockets
|
||||
|
||||
|
||||
class DoubaoConfig:
|
||||
"""豆包配置"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
|
||||
self.app_id = "8718217928"
|
||||
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
|
||||
self.app_key = "PlgvMymc7f3tQnJ6"
|
||||
self.resource_id = "volc.speech.dialog"
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
"""获取请求头"""
|
||||
return {
|
||||
"X-Api-App-ID": self.app_id,
|
||||
"X-Api-Access-Key": self.access_key,
|
||||
"X-Api-Resource-Id": self.resource_id,
|
||||
"X-Api-App-Key": self.app_key,
|
||||
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
|
||||
class DoubaoProtocol:
|
||||
"""豆包协议处理"""
|
||||
|
||||
# 协议常量
|
||||
PROTOCOL_VERSION = 0b0001
|
||||
CLIENT_FULL_REQUEST = 0b0001
|
||||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||
SERVER_FULL_RESPONSE = 0b1001
|
||||
SERVER_ACK = 0b1011
|
||||
SERVER_ERROR_RESPONSE = 0b1111
|
||||
|
||||
NO_SEQUENCE = 0b0000
|
||||
POS_SEQUENCE = 0b0001
|
||||
MSG_WITH_EVENT = 0b0100
|
||||
JSON = 0b0001
|
||||
NO_SERIALIZATION = 0b0000
|
||||
GZIP = 0b0001
|
||||
|
||||
@classmethod
|
||||
def generate_header(cls, message_type=CLIENT_FULL_REQUEST,
|
||||
message_type_specific_flags=MSG_WITH_EVENT,
|
||||
serial_method=JSON, compression_type=GZIP) -> bytes:
|
||||
"""生成协议头"""
|
||||
header = bytearray()
|
||||
header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size
|
||||
header.append((message_type << 4) | message_type_specific_flags)
|
||||
header.append((serial_method << 4) | compression_type)
|
||||
header.append(0x00) # reserved
|
||||
return bytes(header)
|
||||
|
||||
|
||||
class DoubaoClient:
|
||||
"""豆包客户端"""
|
||||
|
||||
def __init__(self, config: DoubaoConfig):
|
||||
self.config = config
|
||||
self.session_id = str(uuid.uuid4())
|
||||
self.ws = None
|
||||
self.log_id = ""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""建立WebSocket连接"""
|
||||
print(f"连接豆包服务器: {self.config.base_url}")
|
||||
|
||||
try:
|
||||
self.ws = await websockets.connect(
|
||||
self.config.base_url,
|
||||
additional_headers=self.config.get_headers(),
|
||||
ping_interval=None
|
||||
)
|
||||
|
||||
# 获取log_id
|
||||
if hasattr(self.ws, 'response_headers'):
|
||||
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
|
||||
elif hasattr(self.ws, 'headers'):
|
||||
self.log_id = self.ws.headers.get("X-Tt-Logid")
|
||||
|
||||
print(f"连接成功, log_id: {self.log_id}")
|
||||
|
||||
# 发送StartConnection请求
|
||||
await self._send_start_connection()
|
||||
|
||||
# 发送StartSession请求
|
||||
await self._send_start_session()
|
||||
|
||||
except Exception as e:
|
||||
print(f"连接失败: {e}")
|
||||
raise
|
||||
|
||||
async def _send_start_connection(self) -> None:
|
||||
"""发送StartConnection请求"""
|
||||
print("发送StartConnection请求...")
|
||||
request = bytearray(DoubaoProtocol.generate_header())
|
||||
request.extend(int(1).to_bytes(4, 'big'))
|
||||
|
||||
payload_bytes = b"{}"
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
request.extend(payload_bytes)
|
||||
|
||||
await self.ws.send(request)
|
||||
response = await self.ws.recv()
|
||||
print(f"StartConnection响应长度: {len(response)}")
|
||||
|
||||
async def _send_start_session(self) -> None:
|
||||
"""发送StartSession请求"""
|
||||
print("发送StartSession请求...")
|
||||
session_config = {
|
||||
"asr": {
|
||||
"extra": {
|
||||
"end_smooth_window_ms": 1500,
|
||||
},
|
||||
},
|
||||
"tts": {
|
||||
"speaker": "zh_female_vv_jupiter_bigtts",
|
||||
"audio_config": {
|
||||
"channel": 1,
|
||||
"format": "pcm",
|
||||
"sample_rate": 24000
|
||||
},
|
||||
},
|
||||
"dialog": {
|
||||
"bot_name": "豆包",
|
||||
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
|
||||
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
|
||||
"location": {"city": "北京"},
|
||||
"extra": {
|
||||
"strict_audit": False,
|
||||
"audit_response": "支持客户自定义安全审核回复话术。",
|
||||
"recv_timeout": 30,
|
||||
"input_mod": "audio",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
request = bytearray(DoubaoProtocol.generate_header())
|
||||
request.extend(int(100).to_bytes(4, 'big'))
|
||||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||||
request.extend(self.session_id.encode())
|
||||
|
||||
payload_bytes = json.dumps(session_config).encode()
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
request.extend(payload_bytes)
|
||||
|
||||
await self.ws.send(request)
|
||||
response = await self.ws.recv()
|
||||
print(f"StartSession响应长度: {len(response)}")
|
||||
|
||||
# 等待一会确保会话完全建立
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
async def test_audio_request(self) -> None:
|
||||
"""测试音频请求格式"""
|
||||
print("测试音频请求格式...")
|
||||
|
||||
# 创建音频数据(静音)- 使用原始豆包代码的chunk大小
|
||||
small_audio = b'\x00' * 3200 # 原始豆包代码中的chunk大小
|
||||
|
||||
# 完全按照原始豆包代码的格式构建请求,不进行任何填充
|
||||
header = bytearray()
|
||||
header.append((DoubaoProtocol.PROTOCOL_VERSION << 4) | 1) # version + header_size
|
||||
header.append((DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST << 4) | DoubaoProtocol.NO_SEQUENCE)
|
||||
header.append((DoubaoProtocol.NO_SERIALIZATION << 4) | DoubaoProtocol.GZIP)
|
||||
header.append(0x00) # reserved
|
||||
|
||||
request = bytearray(header)
|
||||
|
||||
# 添加消息类型 (200 = task request)
|
||||
request.extend(int(200).to_bytes(4, 'big'))
|
||||
|
||||
# 添加session_id
|
||||
request.extend(len(self.session_id).to_bytes(4, 'big'))
|
||||
request.extend(self.session_id.encode())
|
||||
|
||||
# 压缩音频数据
|
||||
compressed_audio = gzip.compress(small_audio)
|
||||
|
||||
# 添加payload size
|
||||
request.extend(len(compressed_audio).to_bytes(4, 'big'))
|
||||
|
||||
# 添加压缩后的音频数据
|
||||
request.extend(compressed_audio)
|
||||
|
||||
print(f"测试请求详细信息:")
|
||||
print(f" - 音频原始大小: {len(small_audio)}")
|
||||
print(f" - 音频压缩后大小: {len(compressed_audio)}")
|
||||
print(f" - Session ID: {self.session_id} (长度: {len(self.session_id)})")
|
||||
print(f" - 总请求大小: {len(request)}")
|
||||
print(f" - 头部字节: {request[:4].hex()}")
|
||||
print(f" - 消息类型: {int.from_bytes(request[4:8], 'big')}")
|
||||
print(f" - Session ID长度: {int.from_bytes(request[8:12], 'big')}")
|
||||
print(f" - Payload size: {int.from_bytes(request[12+len(self.session_id):16+len(self.session_id)], 'big')}")
|
||||
|
||||
try:
|
||||
await self.ws.send(request)
|
||||
print("请求发送成功")
|
||||
|
||||
# 等待响应
|
||||
response = await asyncio.wait_for(self.ws.recv(), timeout=3.0)
|
||||
print(f"收到响应,长度: {len(response)}")
|
||||
|
||||
# 尝试解析响应
|
||||
try:
|
||||
protocol_version = response[0] >> 4
|
||||
header_size = response[0] & 0x0f
|
||||
message_type = response[1] >> 4
|
||||
message_type_specific_flags = response[1] & 0x0f
|
||||
serialization_method = response[2] >> 4
|
||||
message_compression = response[2] & 0x0f
|
||||
|
||||
print(f"响应协议信息:")
|
||||
print(f" - version={protocol_version}")
|
||||
print(f" - header_size={header_size}")
|
||||
print(f" - message_type={message_type} (15=SERVER_ERROR_RESPONSE)")
|
||||
print(f" - message_type_specific_flags={message_type_specific_flags}")
|
||||
print(f" - serialization_method={serialization_method}")
|
||||
print(f" - message_compression={message_compression}")
|
||||
|
||||
# 解析payload
|
||||
payload = response[header_size * 4:]
|
||||
if message_type == 15: # SERVER_ERROR_RESPONSE
|
||||
if len(payload) >= 8:
|
||||
code = int.from_bytes(payload[:4], "big", signed=False)
|
||||
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
print(f" - 错误代码: {code}")
|
||||
print(f" - payload大小: {payload_size}")
|
||||
|
||||
if len(payload) >= 8 + payload_size:
|
||||
payload_msg = payload[8:8 + payload_size]
|
||||
print(f" - payload长度: {len(payload_msg)}")
|
||||
|
||||
if message_compression == 1: # GZIP
|
||||
try:
|
||||
payload_msg = gzip.decompress(payload_msg)
|
||||
print(f" - 解压缩后长度: {len(payload_msg)}")
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
error_msg = json.loads(payload_msg.decode('utf-8'))
|
||||
print(f" - 错误信息: {error_msg}")
|
||||
except:
|
||||
print(f" - 原始payload: {payload_msg}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析响应失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
except Exception as e:
|
||||
print(f"发送测试请求失败: {e}")
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭连接"""
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.close()
|
||||
except:
|
||||
pass
|
||||
print("连接已关闭")
|
||||
|
||||
|
||||
async def main():
|
||||
"""测试函数"""
|
||||
config = DoubaoConfig()
|
||||
client = DoubaoClient(config)
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
await client.test_audio_request()
|
||||
except Exception as e:
|
||||
print(f"测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,127 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
语音识别使用示例
|
||||
演示如何使用 speech_recognizer 模块
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from speech_recognizer import SpeechRecognizer
|
||||
|
||||
async def example_recognize_file():
|
||||
"""示例:识别单个音频文件"""
|
||||
print("=== 示例1:识别单个音频文件 ===")
|
||||
|
||||
# 初始化识别器
|
||||
recognizer = SpeechRecognizer(
|
||||
app_key="your_app_key", # 请替换为实际的app_key
|
||||
access_key="your_access_key" # 请替换为实际的access_key
|
||||
)
|
||||
|
||||
# 假设有一个录音文件
|
||||
audio_file = "recording_20240101_120000.wav"
|
||||
|
||||
if not os.path.exists(audio_file):
|
||||
print(f"音频文件不存在: {audio_file}")
|
||||
print("请先运行 enhanced_wake_and_record.py 录制一个音频文件")
|
||||
return
|
||||
|
||||
try:
|
||||
# 识别音频文件
|
||||
results = await recognizer.recognize_file(audio_file)
|
||||
|
||||
print(f"识别结果(共{len(results)}个):")
|
||||
for i, result in enumerate(results):
|
||||
print(f"结果 {i+1}:")
|
||||
print(f" 文本: {result.text}")
|
||||
print(f" 置信度: {result.confidence}")
|
||||
print(f" 最终结果: {result.is_final}")
|
||||
print("-" * 40)
|
||||
|
||||
except Exception as e:
|
||||
print(f"识别失败: {e}")
|
||||
|
||||
async def example_recognize_latest():
|
||||
"""示例:识别最新的录音文件"""
|
||||
print("\n=== 示例2:识别最新的录音文件 ===")
|
||||
|
||||
# 初始化识别器
|
||||
recognizer = SpeechRecognizer(
|
||||
app_key="your_app_key", # 请替换为实际的app_key
|
||||
access_key="your_access_key" # 请替换为实际的access_key
|
||||
)
|
||||
|
||||
try:
|
||||
# 识别最新的录音文件
|
||||
result = await recognizer.recognize_latest_recording()
|
||||
|
||||
if result:
|
||||
print("识别结果:")
|
||||
print(f" 文本: {result.text}")
|
||||
print(f" 置信度: {result.confidence}")
|
||||
print(f" 最终结果: {result.is_final}")
|
||||
else:
|
||||
print("未找到录音文件或识别失败")
|
||||
|
||||
except Exception as e:
|
||||
print(f"识别失败: {e}")
|
||||
|
||||
async def example_batch_recognition():
|
||||
"""示例:批量识别多个录音文件"""
|
||||
print("\n=== 示例3:批量识别录音文件 ===")
|
||||
|
||||
# 初始化识别器
|
||||
recognizer = SpeechRecognizer(
|
||||
app_key="your_app_key", # 请替换为实际的app_key
|
||||
access_key="your_access_key" # 请替换为实际的access_key
|
||||
)
|
||||
|
||||
# 获取所有录音文件
|
||||
recording_files = [f for f in os.listdir(".") if f.startswith('recording_') and f.endswith('.wav')]
|
||||
|
||||
if not recording_files:
|
||||
print("未找到录音文件")
|
||||
return
|
||||
|
||||
print(f"找到 {len(recording_files)} 个录音文件")
|
||||
|
||||
for filename in recording_files[:5]: # 只处理前5个文件
|
||||
print(f"\n处理文件: {filename}")
|
||||
try:
|
||||
results = await recognizer.recognize_file(filename)
|
||||
|
||||
if results:
|
||||
final_result = results[-1] # 取最后一个结果
|
||||
print(f"识别结果: {final_result.text}")
|
||||
else:
|
||||
print("识别失败")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理失败: {e}")
|
||||
|
||||
# 添加延迟,避免请求过于频繁
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
print("🚀 语音识别使用示例")
|
||||
print("=" * 50)
|
||||
|
||||
# 请先设置环境变量或在代码中填入实际的API密钥
|
||||
if not os.getenv("SAUC_APP_KEY") and "your_app_key" in "your_app_key":
|
||||
print("⚠️ 请先设置 SAUC_APP_KEY 和 SAUC_ACCESS_KEY 环境变量")
|
||||
print("或者在代码中填入实际的 app_key 和 access_key")
|
||||
print("示例:")
|
||||
print("export SAUC_APP_KEY='your_app_key'")
|
||||
print("export SAUC_ACCESS_KEY='your_access_key'")
|
||||
return
|
||||
|
||||
# 运行示例
|
||||
await example_recognize_file()
|
||||
await example_recognize_latest()
|
||||
await example_batch_recognition()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,15 +0,0 @@
|
||||
# README
|
||||
|
||||
**asr tob 相关client demo**
|
||||
|
||||
# Notice
|
||||
python version: python 3.x
|
||||
|
||||
替换代码中的key为真实数据:
|
||||
"app_key": "xxxxxxx",
|
||||
"access_key": "xxxxxxxxxxxxxxxx"
|
||||
使用示例:
|
||||
python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav
|
||||
|
||||
|
||||
|
||||
@ -1,523 +0,0 @@
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import struct
|
||||
import gzip
|
||||
import uuid
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Optional, List, Dict, Any, Tuple, AsyncGenerator
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('run.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 常量定义
|
||||
DEFAULT_SAMPLE_RATE = 16000
|
||||
|
||||
class ProtocolVersion:
|
||||
V1 = 0b0001
|
||||
|
||||
class MessageType:
|
||||
CLIENT_FULL_REQUEST = 0b0001
|
||||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||
SERVER_FULL_RESPONSE = 0b1001
|
||||
SERVER_ERROR_RESPONSE = 0b1111
|
||||
|
||||
class MessageTypeSpecificFlags:
|
||||
NO_SEQUENCE = 0b0000
|
||||
POS_SEQUENCE = 0b0001
|
||||
NEG_SEQUENCE = 0b0010
|
||||
NEG_WITH_SEQUENCE = 0b0011
|
||||
|
||||
class SerializationType:
|
||||
NO_SERIALIZATION = 0b0000
|
||||
JSON = 0b0001
|
||||
|
||||
class CompressionType:
|
||||
GZIP = 0b0001
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
# 填入控制台获取的app id和access token
|
||||
self.auth = {
|
||||
"app_key": "xxxxxxx",
|
||||
"access_key": "xxxxxxxxxxxx"
|
||||
}
|
||||
|
||||
@property
|
||||
def app_key(self) -> str:
|
||||
return self.auth["app_key"]
|
||||
|
||||
@property
|
||||
def access_key(self) -> str:
|
||||
return self.auth["access_key"]
|
||||
|
||||
config = Config()
|
||||
|
||||
class CommonUtils:
|
||||
@staticmethod
|
||||
def gzip_compress(data: bytes) -> bytes:
|
||||
return gzip.compress(data)
|
||||
|
||||
@staticmethod
|
||||
def gzip_decompress(data: bytes) -> bytes:
|
||||
return gzip.decompress(data)
|
||||
|
||||
@staticmethod
|
||||
def judge_wav(data: bytes) -> bool:
|
||||
if len(data) < 44:
|
||||
return False
|
||||
return data[:4] == b'RIFF' and data[8:12] == b'WAVE'
|
||||
|
||||
@staticmethod
|
||||
def convert_wav_with_path(audio_path: str, sample_rate: int = DEFAULT_SAMPLE_RATE) -> bytes:
|
||||
try:
|
||||
cmd = [
|
||||
"ffmpeg", "-v", "quiet", "-y", "-i", audio_path,
|
||||
"-acodec", "pcm_s16le", "-ac", "1", "-ar", str(sample_rate),
|
||||
"-f", "wav", "-"
|
||||
]
|
||||
result = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
# 尝试删除原始文件
|
||||
try:
|
||||
os.remove(audio_path)
|
||||
except OSError as e:
|
||||
logger.warning(f"Failed to remove original file: {e}")
|
||||
|
||||
return result.stdout
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}")
|
||||
raise RuntimeError(f"Audio conversion failed: {e.stderr.decode()}")
|
||||
|
||||
@staticmethod
|
||||
def read_wav_info(data: bytes) -> Tuple[int, int, int, int, bytes]:
|
||||
if len(data) < 44:
|
||||
raise ValueError("Invalid WAV file: too short")
|
||||
|
||||
# 解析WAV头
|
||||
chunk_id = data[:4]
|
||||
if chunk_id != b'RIFF':
|
||||
raise ValueError("Invalid WAV file: not RIFF format")
|
||||
|
||||
format_ = data[8:12]
|
||||
if format_ != b'WAVE':
|
||||
raise ValueError("Invalid WAV file: not WAVE format")
|
||||
|
||||
# 解析fmt子块
|
||||
audio_format = struct.unpack('<H', data[20:22])[0]
|
||||
num_channels = struct.unpack('<H', data[22:24])[0]
|
||||
sample_rate = struct.unpack('<I', data[24:28])[0]
|
||||
bits_per_sample = struct.unpack('<H', data[34:36])[0]
|
||||
|
||||
# 查找data子块
|
||||
pos = 36
|
||||
while pos < len(data) - 8:
|
||||
subchunk_id = data[pos:pos+4]
|
||||
subchunk_size = struct.unpack('<I', data[pos+4:pos+8])[0]
|
||||
if subchunk_id == b'data':
|
||||
wave_data = data[pos+8:pos+8+subchunk_size]
|
||||
return (
|
||||
num_channels,
|
||||
bits_per_sample // 8,
|
||||
sample_rate,
|
||||
subchunk_size // (num_channels * (bits_per_sample // 8)),
|
||||
wave_data
|
||||
)
|
||||
pos += 8 + subchunk_size
|
||||
|
||||
raise ValueError("Invalid WAV file: no data subchunk found")
|
||||
|
||||
class AsrRequestHeader:
|
||||
def __init__(self):
|
||||
self.message_type = MessageType.CLIENT_FULL_REQUEST
|
||||
self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE
|
||||
self.serialization_type = SerializationType.JSON
|
||||
self.compression_type = CompressionType.GZIP
|
||||
self.reserved_data = bytes([0x00])
|
||||
|
||||
def with_message_type(self, message_type: int) -> 'AsrRequestHeader':
|
||||
self.message_type = message_type
|
||||
return self
|
||||
|
||||
def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader':
|
||||
self.message_type_specific_flags = flags
|
||||
return self
|
||||
|
||||
def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader':
|
||||
self.serialization_type = serialization_type
|
||||
return self
|
||||
|
||||
def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader':
|
||||
self.compression_type = compression_type
|
||||
return self
|
||||
|
||||
def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader':
|
||||
self.reserved_data = reserved_data
|
||||
return self
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
header = bytearray()
|
||||
header.append((ProtocolVersion.V1 << 4) | 1)
|
||||
header.append((self.message_type << 4) | self.message_type_specific_flags)
|
||||
header.append((self.serialization_type << 4) | self.compression_type)
|
||||
header.extend(self.reserved_data)
|
||||
return bytes(header)
|
||||
|
||||
@staticmethod
|
||||
def default_header() -> 'AsrRequestHeader':
|
||||
return AsrRequestHeader()
|
||||
|
||||
class RequestBuilder:
|
||||
@staticmethod
|
||||
def new_auth_headers() -> Dict[str, str]:
|
||||
reqid = str(uuid.uuid4())
|
||||
return {
|
||||
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
|
||||
"X-Api-Request-Id": reqid,
|
||||
"X-Api-Access-Key": config.access_key,
|
||||
"X-Api-App-Key": config.app_key
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def new_full_client_request(seq: int) -> bytes: # 添加seq参数
|
||||
header = AsrRequestHeader.default_header() \
|
||||
.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
|
||||
|
||||
payload = {
|
||||
"user": {
|
||||
"uid": "demo_uid"
|
||||
},
|
||||
"audio": {
|
||||
"format": "wav",
|
||||
"codec": "raw",
|
||||
"rate": 16000,
|
||||
"bits": 16,
|
||||
"channel": 1
|
||||
},
|
||||
"request": {
|
||||
"model_name": "bigmodel",
|
||||
"enable_itn": True,
|
||||
"enable_punc": True,
|
||||
"enable_ddc": True,
|
||||
"show_utterances": True,
|
||||
"enable_nonstream": False
|
||||
}
|
||||
}
|
||||
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
compressed_payload = CommonUtils.gzip_compress(payload_bytes)
|
||||
payload_size = len(compressed_payload)
|
||||
|
||||
request = bytearray()
|
||||
request.extend(header.to_bytes())
|
||||
request.extend(struct.pack('>i', seq)) # 使用传入的seq
|
||||
request.extend(struct.pack('>I', payload_size))
|
||||
request.extend(compressed_payload)
|
||||
|
||||
return bytes(request)
|
||||
|
||||
@staticmethod
|
||||
def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes:
|
||||
header = AsrRequestHeader.default_header()
|
||||
if is_last: # 最后一个包特殊处理
|
||||
header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE)
|
||||
seq = -seq # 设为负值
|
||||
else:
|
||||
header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
|
||||
header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST)
|
||||
|
||||
request = bytearray()
|
||||
request.extend(header.to_bytes())
|
||||
request.extend(struct.pack('>i', seq))
|
||||
|
||||
compressed_segment = CommonUtils.gzip_compress(segment)
|
||||
request.extend(struct.pack('>I', len(compressed_segment)))
|
||||
request.extend(compressed_segment)
|
||||
|
||||
return bytes(request)
|
||||
|
||||
class AsrResponse:
|
||||
def __init__(self):
|
||||
self.code = 0
|
||||
self.event = 0
|
||||
self.is_last_package = False
|
||||
self.payload_sequence = 0
|
||||
self.payload_size = 0
|
||||
self.payload_msg = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"code": self.code,
|
||||
"event": self.event,
|
||||
"is_last_package": self.is_last_package,
|
||||
"payload_sequence": self.payload_sequence,
|
||||
"payload_size": self.payload_size,
|
||||
"payload_msg": self.payload_msg
|
||||
}
|
||||
|
||||
class ResponseParser:
|
||||
@staticmethod
|
||||
def parse_response(msg: bytes) -> AsrResponse:
|
||||
response = AsrResponse()
|
||||
|
||||
header_size = msg[0] & 0x0f
|
||||
message_type = msg[1] >> 4
|
||||
message_type_specific_flags = msg[1] & 0x0f
|
||||
serialization_method = msg[2] >> 4
|
||||
message_compression = msg[2] & 0x0f
|
||||
|
||||
payload = msg[header_size*4:]
|
||||
|
||||
# 解析message_type_specific_flags
|
||||
if message_type_specific_flags & 0x01:
|
||||
response.payload_sequence = struct.unpack('>i', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
if message_type_specific_flags & 0x02:
|
||||
response.is_last_package = True
|
||||
if message_type_specific_flags & 0x04:
|
||||
response.event = struct.unpack('>i', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
|
||||
# 解析message_type
|
||||
if message_type == MessageType.SERVER_FULL_RESPONSE:
|
||||
response.payload_size = struct.unpack('>I', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
elif message_type == MessageType.SERVER_ERROR_RESPONSE:
|
||||
response.code = struct.unpack('>i', payload[:4])[0]
|
||||
response.payload_size = struct.unpack('>I', payload[4:8])[0]
|
||||
payload = payload[8:]
|
||||
|
||||
if not payload:
|
||||
return response
|
||||
|
||||
# 解压缩
|
||||
if message_compression == CompressionType.GZIP:
|
||||
try:
|
||||
payload = CommonUtils.gzip_decompress(payload)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decompress payload: {e}")
|
||||
return response
|
||||
|
||||
# 解析payload
|
||||
try:
|
||||
if serialization_method == SerializationType.JSON:
|
||||
response.payload_msg = json.loads(payload.decode('utf-8'))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse payload: {e}")
|
||||
|
||||
return response
|
||||
|
||||
class AsrWsClient:
|
||||
def __init__(self, url: str, segment_duration: int = 200):
|
||||
self.seq = 1
|
||||
self.url = url
|
||||
self.segment_duration = segment_duration
|
||||
self.conn = None
|
||||
self.session = None # 添加session引用
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = aiohttp.ClientSession()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
if self.conn and not self.conn.closed:
|
||||
await self.conn.close()
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
|
||||
async def read_audio_data(self, file_path: str) -> bytes:
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
|
||||
if not CommonUtils.judge_wav(content):
|
||||
logger.info("Converting audio to WAV format...")
|
||||
content = CommonUtils.convert_wav_with_path(file_path, DEFAULT_SAMPLE_RATE)
|
||||
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read audio data: {e}")
|
||||
raise
|
||||
|
||||
def get_segment_size(self, content: bytes) -> int:
|
||||
try:
|
||||
channel_num, samp_width, frame_rate, _, _ = CommonUtils.read_wav_info(content)[:5]
|
||||
size_per_sec = channel_num * samp_width * frame_rate
|
||||
segment_size = size_per_sec * self.segment_duration // 1000
|
||||
return segment_size
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to calculate segment size: {e}")
|
||||
raise
|
||||
|
||||
async def create_connection(self) -> None:
|
||||
headers = RequestBuilder.new_auth_headers()
|
||||
try:
|
||||
self.conn = await self.session.ws_connect( # 使用self.session
|
||||
self.url,
|
||||
headers=headers
|
||||
)
|
||||
logger.info(f"Connected to {self.url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to WebSocket: {e}")
|
||||
raise
|
||||
|
||||
async def send_full_client_request(self) -> None:
|
||||
request = RequestBuilder.new_full_client_request(self.seq)
|
||||
self.seq += 1 # 发送后递增
|
||||
try:
|
||||
await self.conn.send_bytes(request)
|
||||
logger.info(f"Sent full client request with seq: {self.seq-1}")
|
||||
|
||||
msg = await self.conn.receive()
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
response = ResponseParser.parse_response(msg.data)
|
||||
logger.info(f"Received response: {response.to_dict()}")
|
||||
else:
|
||||
logger.error(f"Unexpected message type: {msg.type}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send full client request: {e}")
|
||||
raise
|
||||
|
||||
async def send_messages(self, segment_size: int, content: bytes) -> AsyncGenerator[None, None]:
|
||||
audio_segments = self.split_audio(content, segment_size)
|
||||
total_segments = len(audio_segments)
|
||||
|
||||
for i, segment in enumerate(audio_segments):
|
||||
is_last = (i == total_segments - 1)
|
||||
request = RequestBuilder.new_audio_only_request(
|
||||
self.seq,
|
||||
segment,
|
||||
is_last=is_last
|
||||
)
|
||||
await self.conn.send_bytes(request)
|
||||
logger.info(f"Sent audio segment with seq: {self.seq} (last: {is_last})")
|
||||
|
||||
if not is_last:
|
||||
self.seq += 1
|
||||
|
||||
await asyncio.sleep(self.segment_duration / 1000) # 逐个发送,间隔时间模拟实时流
|
||||
# 让出控制权,允许接受消息
|
||||
yield
|
||||
|
||||
async def recv_messages(self) -> AsyncGenerator[AsrResponse, None]:
|
||||
try:
|
||||
async for msg in self.conn:
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
response = ResponseParser.parse_response(msg.data)
|
||||
yield response
|
||||
|
||||
if response.is_last_package or response.code != 0:
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error: {msg.data}")
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
logger.info("WebSocket connection closed")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving messages: {e}")
|
||||
raise
|
||||
|
||||
async def start_audio_stream(self, segment_size: int, content: bytes) -> AsyncGenerator[AsrResponse, None]:
|
||||
async def sender():
|
||||
async for _ in self.send_messages(segment_size, content):
|
||||
pass
|
||||
|
||||
# 启动发送和接收任务
|
||||
sender_task = asyncio.create_task(sender())
|
||||
|
||||
try:
|
||||
async for response in self.recv_messages():
|
||||
yield response
|
||||
finally:
|
||||
sender_task.cancel()
|
||||
try:
|
||||
await sender_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def split_audio(data: bytes, segment_size: int) -> List[bytes]:
|
||||
if segment_size <= 0:
|
||||
return []
|
||||
|
||||
segments = []
|
||||
for i in range(0, len(data), segment_size):
|
||||
end = i + segment_size
|
||||
if end > len(data):
|
||||
end = len(data)
|
||||
segments.append(data[i:end])
|
||||
return segments
|
||||
|
||||
async def execute(self, file_path: str) -> AsyncGenerator[AsrResponse, None]:
|
||||
if not file_path:
|
||||
raise ValueError("File path is empty")
|
||||
|
||||
if not self.url:
|
||||
raise ValueError("URL is empty")
|
||||
|
||||
self.seq = 1
|
||||
|
||||
try:
|
||||
# 1. 读取音频文件
|
||||
content = await self.read_audio_data(file_path)
|
||||
|
||||
# 2. 计算分段大小
|
||||
segment_size = self.get_segment_size(content)
|
||||
|
||||
# 3. 创建WebSocket连接
|
||||
await self.create_connection()
|
||||
|
||||
# 4. 发送完整客户端请求
|
||||
await self.send_full_client_request()
|
||||
|
||||
# 5. 启动音频流处理
|
||||
async for response in self.start_audio_stream(segment_size, content):
|
||||
yield response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ASR execution: {e}")
|
||||
raise
|
||||
finally:
|
||||
if self.conn:
|
||||
await self.conn.close()
|
||||
|
||||
async def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="ASR WebSocket Client")
|
||||
parser.add_argument("--file", type=str, required=True, help="Audio file path")
|
||||
|
||||
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel
|
||||
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async
|
||||
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream
|
||||
parser.add_argument("--url", type=str, default="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream",
|
||||
help="WebSocket URL")
|
||||
parser.add_argument("--seg-duration", type=int, default=200,
|
||||
help="Audio duration(ms) per packet, default:200")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
async with AsrWsClient(args.url, args.seg_duration) as client: # 使用async with
|
||||
try:
|
||||
async for response in client.execute(args.file):
|
||||
logger.info(f"Received response: {json.dumps(response.to_dict(), indent=2, ensure_ascii=False)}")
|
||||
except Exception as e:
|
||||
logger.error(f"ASR processing failed: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
# 用法:
|
||||
# python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav
|
||||
@ -1,532 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
语音识别模块
|
||||
基于 SAUC API 为录音文件提供语音识别功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import struct
|
||||
import gzip
|
||||
import uuid
|
||||
from typing import Optional, List, Dict, Any, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 常量定义
|
||||
DEFAULT_SAMPLE_RATE = 16000
|
||||
|
||||
class ProtocolVersion:
|
||||
V1 = 0b0001
|
||||
|
||||
class MessageType:
|
||||
CLIENT_FULL_REQUEST = 0b0001
|
||||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||
SERVER_FULL_RESPONSE = 0b1001
|
||||
SERVER_ERROR_RESPONSE = 0b1111
|
||||
|
||||
class MessageTypeSpecificFlags:
|
||||
NO_SEQUENCE = 0b0000
|
||||
POS_SEQUENCE = 0b0001
|
||||
NEG_SEQUENCE = 0b0010
|
||||
NEG_WITH_SEQUENCE = 0b0011
|
||||
|
||||
class SerializationType:
|
||||
NO_SERIALIZATION = 0b0000
|
||||
JSON = 0b0001
|
||||
|
||||
class CompressionType:
|
||||
GZIP = 0b0001
|
||||
|
||||
@dataclass
|
||||
class RecognitionResult:
|
||||
"""语音识别结果"""
|
||||
text: str
|
||||
confidence: float
|
||||
is_final: bool
|
||||
start_time: Optional[float] = None
|
||||
end_time: Optional[float] = None
|
||||
|
||||
class AudioUtils:
|
||||
"""音频处理工具类"""
|
||||
|
||||
@staticmethod
|
||||
def gzip_compress(data: bytes) -> bytes:
|
||||
"""GZIP压缩"""
|
||||
return gzip.compress(data)
|
||||
|
||||
@staticmethod
|
||||
def gzip_decompress(data: bytes) -> bytes:
|
||||
"""GZIP解压缩"""
|
||||
return gzip.decompress(data)
|
||||
|
||||
@staticmethod
|
||||
def is_wav_file(data: bytes) -> bool:
|
||||
"""检查是否为WAV文件"""
|
||||
if len(data) < 44:
|
||||
return False
|
||||
return data[:4] == b'RIFF' and data[8:12] == b'WAVE'
|
||||
|
||||
@staticmethod
|
||||
def read_wav_info(data: bytes) -> tuple:
|
||||
"""读取WAV文件信息"""
|
||||
if len(data) < 44:
|
||||
raise ValueError("Invalid WAV file: too short")
|
||||
|
||||
# 解析WAV头
|
||||
chunk_id = data[:4]
|
||||
if chunk_id != b'RIFF':
|
||||
raise ValueError("Invalid WAV file: not RIFF format")
|
||||
|
||||
format_ = data[8:12]
|
||||
if format_ != b'WAVE':
|
||||
raise ValueError("Invalid WAV file: not WAVE format")
|
||||
|
||||
# 解析fmt子块
|
||||
audio_format = struct.unpack('<H', data[20:22])[0]
|
||||
num_channels = struct.unpack('<H', data[22:24])[0]
|
||||
sample_rate = struct.unpack('<I', data[24:28])[0]
|
||||
bits_per_sample = struct.unpack('<H', data[34:36])[0]
|
||||
|
||||
# 查找data子块
|
||||
pos = 36
|
||||
while pos < len(data) - 8:
|
||||
subchunk_id = data[pos:pos+4]
|
||||
subchunk_size = struct.unpack('<I', data[pos+4:pos+8])[0]
|
||||
if subchunk_id == b'data':
|
||||
wave_data = data[pos+8:pos+8+subchunk_size]
|
||||
return (
|
||||
num_channels,
|
||||
bits_per_sample // 8,
|
||||
sample_rate,
|
||||
subchunk_size // (num_channels * (bits_per_sample // 8)),
|
||||
wave_data
|
||||
)
|
||||
pos += 8 + subchunk_size
|
||||
|
||||
raise ValueError("Invalid WAV file: no data subchunk found")
|
||||
|
||||
class AsrConfig:
|
||||
"""ASR配置"""
|
||||
|
||||
def __init__(self, app_key: str = None, access_key: str = None):
|
||||
self.auth = {
|
||||
"app_key": app_key or os.getenv("SAUC_APP_KEY", "your_app_key"),
|
||||
"access_key": access_key or os.getenv("SAUC_ACCESS_KEY", "your_access_key")
|
||||
}
|
||||
|
||||
@property
|
||||
def app_key(self) -> str:
|
||||
return self.auth["app_key"]
|
||||
|
||||
@property
|
||||
def access_key(self) -> str:
|
||||
return self.auth["access_key"]
|
||||
|
||||
class AsrRequestHeader:
|
||||
"""ASR请求头"""
|
||||
|
||||
def __init__(self):
|
||||
self.message_type = MessageType.CLIENT_FULL_REQUEST
|
||||
self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE
|
||||
self.serialization_type = SerializationType.JSON
|
||||
self.compression_type = CompressionType.GZIP
|
||||
self.reserved_data = bytes([0x00])
|
||||
|
||||
def with_message_type(self, message_type: int) -> 'AsrRequestHeader':
|
||||
self.message_type = message_type
|
||||
return self
|
||||
|
||||
def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader':
|
||||
self.message_type_specific_flags = flags
|
||||
return self
|
||||
|
||||
def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader':
|
||||
self.serialization_type = serialization_type
|
||||
return self
|
||||
|
||||
def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader':
|
||||
self.compression_type = compression_type
|
||||
return self
|
||||
|
||||
def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader':
|
||||
self.reserved_data = reserved_data
|
||||
return self
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
header = bytearray()
|
||||
header.append((ProtocolVersion.V1 << 4) | 1)
|
||||
header.append((self.message_type << 4) | self.message_type_specific_flags)
|
||||
header.append((self.serialization_type << 4) | self.compression_type)
|
||||
header.extend(self.reserved_data)
|
||||
return bytes(header)
|
||||
|
||||
@staticmethod
|
||||
def default_header() -> 'AsrRequestHeader':
|
||||
return AsrRequestHeader()
|
||||
|
||||
class RequestBuilder:
|
||||
"""请求构建器"""
|
||||
|
||||
@staticmethod
|
||||
def new_auth_headers(config: AsrConfig) -> Dict[str, str]:
|
||||
"""创建认证头"""
|
||||
reqid = str(uuid.uuid4())
|
||||
return {
|
||||
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
|
||||
"X-Api-Request-Id": reqid,
|
||||
"X-Api-Access-Key": config.access_key,
|
||||
"X-Api-App-Key": config.app_key
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def new_full_client_request(seq: int) -> bytes:
|
||||
"""创建完整客户端请求"""
|
||||
header = AsrRequestHeader.default_header() \
|
||||
.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
|
||||
|
||||
payload = {
|
||||
"user": {
|
||||
"uid": "local_voice_user"
|
||||
},
|
||||
"audio": {
|
||||
"format": "wav",
|
||||
"codec": "raw",
|
||||
"rate": 16000,
|
||||
"bits": 16,
|
||||
"channel": 1
|
||||
},
|
||||
"request": {
|
||||
"model_name": "bigmodel",
|
||||
"enable_itn": True,
|
||||
"enable_punc": True,
|
||||
"enable_ddc": True,
|
||||
"show_utterances": True,
|
||||
"enable_nonstream": False
|
||||
}
|
||||
}
|
||||
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
compressed_payload = AudioUtils.gzip_compress(payload_bytes)
|
||||
payload_size = len(compressed_payload)
|
||||
|
||||
request = bytearray()
|
||||
request.extend(header.to_bytes())
|
||||
request.extend(struct.pack('>i', seq))
|
||||
request.extend(struct.pack('>U', payload_size))
|
||||
request.extend(compressed_payload)
|
||||
|
||||
return bytes(request)
|
||||
|
||||
@staticmethod
|
||||
def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes:
|
||||
"""创建纯音频请求"""
|
||||
header = AsrRequestHeader.default_header()
|
||||
if is_last:
|
||||
header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE)
|
||||
seq = -seq
|
||||
else:
|
||||
header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
|
||||
header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST)
|
||||
|
||||
request = bytearray()
|
||||
request.extend(header.to_bytes())
|
||||
request.extend(struct.pack('>i', seq))
|
||||
|
||||
compressed_segment = AudioUtils.gzip_compress(segment)
|
||||
request.extend(struct.pack('>U', len(compressed_segment)))
|
||||
request.extend(compressed_segment)
|
||||
|
||||
return bytes(request)
|
||||
|
||||
class AsrResponse:
|
||||
"""ASR响应"""
|
||||
|
||||
def __init__(self):
|
||||
self.code = 0
|
||||
self.event = 0
|
||||
self.is_last_package = False
|
||||
self.payload_sequence = 0
|
||||
self.payload_size = 0
|
||||
self.payload_msg = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"code": self.code,
|
||||
"event": self.event,
|
||||
"is_last_package": self.is_last_package,
|
||||
"payload_sequence": self.payload_sequence,
|
||||
"payload_size": self.payload_size,
|
||||
"payload_msg": self.payload_msg
|
||||
}
|
||||
|
||||
class ResponseParser:
|
||||
"""响应解析器"""
|
||||
|
||||
@staticmethod
|
||||
def parse_response(msg: bytes) -> AsrResponse:
|
||||
"""解析响应"""
|
||||
response = AsrResponse()
|
||||
|
||||
header_size = msg[0] & 0x0f
|
||||
message_type = msg[1] >> 4
|
||||
message_type_specific_flags = msg[1] & 0x0f
|
||||
serialization_method = msg[2] >> 4
|
||||
message_compression = msg[2] & 0x0f
|
||||
|
||||
payload = msg[header_size*4:]
|
||||
|
||||
# 解析message_type_specific_flags
|
||||
if message_type_specific_flags & 0x01:
|
||||
response.payload_sequence = struct.unpack('>i', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
if message_type_specific_flags & 0x02:
|
||||
response.is_last_package = True
|
||||
if message_type_specific_flags & 0x04:
|
||||
response.event = struct.unpack('>i', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
|
||||
# 解析message_type
|
||||
if message_type == MessageType.SERVER_FULL_RESPONSE:
|
||||
response.payload_size = struct.unpack('>U', payload[:4])[0]
|
||||
payload = payload[4:]
|
||||
elif message_type == MessageType.SERVER_ERROR_RESPONSE:
|
||||
response.code = struct.unpack('>i', payload[:4])[0]
|
||||
response.payload_size = struct.unpack('>U', payload[4:8])[0]
|
||||
payload = payload[8:]
|
||||
|
||||
if not payload:
|
||||
return response
|
||||
|
||||
# 解压缩
|
||||
if message_compression == CompressionType.GZIP:
|
||||
try:
|
||||
payload = AudioUtils.gzip_decompress(payload)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decompress payload: {e}")
|
||||
return response
|
||||
|
||||
# 解析payload
|
||||
try:
|
||||
if serialization_method == SerializationType.JSON:
|
||||
response.payload_msg = json.loads(payload.decode('utf-8'))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse payload: {e}")
|
||||
|
||||
return response
|
||||
|
||||
class SpeechRecognizer:
|
||||
"""语音识别器"""
|
||||
|
||||
def __init__(self, app_key: str = None, access_key: str = None,
|
||||
url: str = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream"):
|
||||
self.config = AsrConfig(app_key, access_key)
|
||||
self.url = url
|
||||
self.seq = 1
|
||||
|
||||
async def recognize_file(self, file_path: str) -> List[RecognitionResult]:
|
||||
"""识别音频文件"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Audio file not found: {file_path}")
|
||||
|
||||
results = []
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# 读取音频文件
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
|
||||
if not AudioUtils.is_wav_file(content):
|
||||
raise ValueError("Audio file must be in WAV format")
|
||||
|
||||
# 获取音频信息
|
||||
try:
|
||||
_, _, sample_rate, _, audio_data = AudioUtils.read_wav_info(content)
|
||||
if sample_rate != DEFAULT_SAMPLE_RATE:
|
||||
logger.warning(f"Sample rate {sample_rate} != {DEFAULT_SAMPLE_RATE}, may affect recognition accuracy")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read audio info: {e}")
|
||||
raise
|
||||
|
||||
# 计算分段大小 (200ms per segment)
|
||||
segment_size = 1 * 2 * DEFAULT_SAMPLE_RATE * 200 // 1000 # channel * bytes_per_sample * sample_rate * duration_ms / 1000
|
||||
|
||||
# 创建WebSocket连接
|
||||
headers = RequestBuilder.new_auth_headers(self.config)
|
||||
async with session.ws_connect(self.url, headers=headers) as ws:
|
||||
|
||||
# 发送完整客户端请求
|
||||
request = RequestBuilder.new_full_client_request(self.seq)
|
||||
self.seq += 1
|
||||
await ws.send_bytes(request)
|
||||
|
||||
# 接收初始响应
|
||||
msg = await ws.receive()
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
response = ResponseParser.parse_response(msg.data)
|
||||
logger.info(f"Initial response: {response.to_dict()}")
|
||||
|
||||
# 分段发送音频数据
|
||||
audio_segments = self._split_audio(audio_data, segment_size)
|
||||
total_segments = len(audio_segments)
|
||||
|
||||
for i, segment in enumerate(audio_segments):
|
||||
is_last = (i == total_segments - 1)
|
||||
request = RequestBuilder.new_audio_only_request(
|
||||
self.seq,
|
||||
segment,
|
||||
is_last=is_last
|
||||
)
|
||||
await ws.send_bytes(request)
|
||||
logger.info(f"Sent audio segment {i+1}/{total_segments}")
|
||||
|
||||
if not is_last:
|
||||
self.seq += 1
|
||||
|
||||
# 短暂延迟模拟实时流
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 接收识别结果
|
||||
final_text = ""
|
||||
while True:
|
||||
msg = await ws.receive()
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
response = ResponseParser.parse_response(msg.data)
|
||||
|
||||
if response.payload_msg and 'text' in response.payload_msg:
|
||||
text = response.payload_msg['text']
|
||||
if text:
|
||||
final_text += text
|
||||
|
||||
result = RecognitionResult(
|
||||
text=text,
|
||||
confidence=0.9, # 默认置信度
|
||||
is_final=response.is_last_package
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
logger.info(f"Recognized: {text}")
|
||||
|
||||
if response.is_last_package or response.code != 0:
|
||||
break
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error: {msg.data}")
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
logger.info("WebSocket connection closed")
|
||||
break
|
||||
|
||||
# 如果没有获得最终结果,创建一个包含所有文本的结果
|
||||
if final_text and not any(r.is_final for r in results):
|
||||
final_result = RecognitionResult(
|
||||
text=final_text,
|
||||
confidence=0.9,
|
||||
is_final=True
|
||||
)
|
||||
results.append(final_result)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Speech recognition failed: {e}")
|
||||
raise
|
||||
|
||||
def _split_audio(self, data: bytes, segment_size: int) -> List[bytes]:
|
||||
"""分割音频数据"""
|
||||
if segment_size <= 0:
|
||||
return []
|
||||
|
||||
segments = []
|
||||
for i in range(0, len(data), segment_size):
|
||||
end = i + segment_size
|
||||
if end > len(data):
|
||||
end = len(data)
|
||||
segments.append(data[i:end])
|
||||
return segments
|
||||
|
||||
async def recognize_latest_recording(self, directory: str = ".") -> Optional[RecognitionResult]:
|
||||
"""识别最新的录音文件"""
|
||||
# 查找最新的录音文件
|
||||
recording_files = [f for f in os.listdir(directory) if f.startswith('recording_') and f.endswith('.wav')]
|
||||
|
||||
if not recording_files:
|
||||
logger.warning("No recording files found")
|
||||
return None
|
||||
|
||||
# 按文件名排序(包含时间戳)
|
||||
recording_files.sort(reverse=True)
|
||||
latest_file = recording_files[0]
|
||||
latest_path = os.path.join(directory, latest_file)
|
||||
|
||||
logger.info(f"Recognizing latest recording: {latest_file}")
|
||||
|
||||
try:
|
||||
results = await self.recognize_file(latest_path)
|
||||
if results:
|
||||
# 返回最终的识别结果
|
||||
final_results = [r for r in results if r.is_final]
|
||||
if final_results:
|
||||
return final_results[-1]
|
||||
else:
|
||||
# 如果没有标记为final的结果,返回最后一个
|
||||
return results[-1]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to recognize latest recording: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def main():
|
||||
"""测试函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="语音识别测试")
|
||||
parser.add_argument("--file", type=str, help="音频文件路径")
|
||||
parser.add_argument("--latest", action="store_true", help="识别最新的录音文件")
|
||||
parser.add_argument("--app-key", type=str, help="SAUC App Key")
|
||||
parser.add_argument("--access-key", type=str, help="SAUC Access Key")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
recognizer = SpeechRecognizer(
|
||||
app_key=args.app_key,
|
||||
access_key=args.access_key
|
||||
)
|
||||
|
||||
try:
|
||||
if args.latest:
|
||||
result = await recognizer.recognize_latest_recording()
|
||||
if result:
|
||||
print(f"识别结果: {result.text}")
|
||||
print(f"置信度: {result.confidence}")
|
||||
print(f"最终结果: {result.is_final}")
|
||||
else:
|
||||
print("未能识别到语音内容")
|
||||
elif args.file:
|
||||
results = await recognizer.recognize_file(args.file)
|
||||
for i, result in enumerate(results):
|
||||
print(f"结果 {i+1}: {result.text}")
|
||||
print(f"置信度: {result.confidence}")
|
||||
print(f"最终结果: {result.is_final}")
|
||||
print("-" * 40)
|
||||
else:
|
||||
print("请指定 --file 或 --latest 参数")
|
||||
|
||||
except Exception as e:
|
||||
print(f"识别失败: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
113
test_doubao.py
113
test_doubao.py
@ -1,113 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
豆包音频处理模块 - 验证脚本
|
||||
验证完整的音频处理流程
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
import os
|
||||
from doubao_simple import DoubaoClient
|
||||
|
||||
async def test_complete_workflow():
|
||||
"""测试完整的工作流程"""
|
||||
print("=== 豆包音频处理模块验证 ===")
|
||||
|
||||
# 检查输入文件
|
||||
input_file = "recording_20250920_135137.wav"
|
||||
if not os.path.exists(input_file):
|
||||
print(f"❌ 输入文件不存在: {input_file}")
|
||||
return False
|
||||
|
||||
print(f"✅ 输入文件存在: {input_file}")
|
||||
|
||||
# 检查文件信息
|
||||
try:
|
||||
result = subprocess.run(['file', input_file], capture_output=True, text=True)
|
||||
print(f"📁 输入文件格式: {result.stdout.strip()}")
|
||||
except:
|
||||
pass
|
||||
|
||||
# 初始化客户端
|
||||
client = DoubaoClient()
|
||||
|
||||
try:
|
||||
# 连接服务器
|
||||
print("🔌 连接豆包服务器...")
|
||||
await client.connect()
|
||||
print("✅ 连接成功")
|
||||
|
||||
# 处理音频文件
|
||||
output_file = "tts_output.wav"
|
||||
print(f"🎵 处理音频文件: {input_file} -> {output_file}")
|
||||
|
||||
success = await client.process_audio_file(input_file, output_file)
|
||||
|
||||
if success:
|
||||
print("✅ 音频处理成功!")
|
||||
|
||||
# 检查输出文件
|
||||
if os.path.exists(output_file):
|
||||
result = subprocess.run(['file', output_file], capture_output=True, text=True)
|
||||
print(f"📁 输出文件格式: {result.stdout.strip()}")
|
||||
|
||||
# 获取文件大小
|
||||
file_size = os.path.getsize(output_file)
|
||||
print(f"📊 输出文件大小: {file_size:,} 字节")
|
||||
|
||||
# 测试播放
|
||||
print("🔊 测试播放输出文件...")
|
||||
try:
|
||||
subprocess.run(['aplay', output_file], timeout=10, check=True)
|
||||
print("✅ 播放成功")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("✅ 播放完成(超时是正常的)")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"⚠️ 播放出现问题: {e}")
|
||||
except FileNotFoundError:
|
||||
print("⚠️ aplay命令不存在,跳过播放测试")
|
||||
|
||||
return True
|
||||
else:
|
||||
print("❌ 输出文件未生成")
|
||||
return False
|
||||
else:
|
||||
print("❌ 音频处理失败")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("开始验证豆包音频处理模块...")
|
||||
|
||||
success = asyncio.run(test_complete_workflow())
|
||||
|
||||
if success:
|
||||
print("\n🎉 验证完成!豆包音频处理模块工作正常。")
|
||||
print("\n📋 功能总结:")
|
||||
print(" ✅ WebSocket连接建立")
|
||||
print(" ✅ 音频文件上传")
|
||||
print(" ✅ 语音识别")
|
||||
print(" ✅ TTS音频生成")
|
||||
print(" ✅ 音频格式转换(Float32 -> Int16)")
|
||||
print(" ✅ WAV文件生成")
|
||||
print(" ✅ 树莓派兼容播放")
|
||||
else:
|
||||
print("\n❌ 验证失败,请检查错误信息。")
|
||||
|
||||
return success
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user