This commit is contained in:
朱潮 2025-09-20 15:44:46 +08:00
parent 43879961a2
commit dbdeeeefcb
10 changed files with 0 additions and 3460 deletions

470
doubao.py
View File

@ -1,470 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块
简化版WebSocket API支持音频文件上传和返回音频处理
"""
import asyncio
import gzip
import json
import uuid
import wave
import struct
import time
import os
from typing import Dict, Any, Optional
import websockets
class DoubaoConfig:
"""豆包配置"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
class DoubaoProtocol:
"""豆包协议处理"""
# 协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
MSG_WITH_EVENT = 0b0100
JSON = 0b0001
NO_SERIALIZATION = 0b0000
GZIP = 0b0001
@classmethod
def generate_header(cls, message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON, compression_type=GZIP) -> bytes:
"""生成协议头"""
header = bytearray()
header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(0x00) # reserved
return bytes(header)
@classmethod
def parse_response(cls, res: bytes) -> Dict[str, Any]:
"""解析响应"""
if isinstance(res, str):
return {}
protocol_version = res[0] >> 4
header_size = res[0] & 0x0f
message_type = res[1] >> 4
message_type_specific_flags = res[1] & 0x0f
serialization_method = res[2] >> 4
message_compression = res[2] & 0x0f
payload = res[header_size * 4:]
result = {}
if message_type == cls.SERVER_FULL_RESPONSE or message_type == cls.SERVER_ACK:
result['message_type'] = 'SERVER_FULL_RESPONSE'
if message_type == cls.SERVER_ACK:
result['message_type'] = 'SERVER_ACK'
start = 0
if message_type_specific_flags & cls.MSG_WITH_EVENT:
result['event'] = int.from_bytes(payload[:4], "big", signed=False)
start += 4
payload = payload[start:]
session_id_size = int.from_bytes(payload[:4], "big", signed=True)
session_id = payload[4:session_id_size+4]
result['session_id'] = str(session_id)
payload = payload[4 + session_id_size:]
payload_size = int.from_bytes(payload[:4], "big", signed=False)
payload_msg = payload[4:]
result['payload_size'] = payload_size
if payload_msg:
if message_compression == cls.GZIP:
payload_msg = gzip.decompress(payload_msg)
if serialization_method == cls.JSON:
payload_msg = json.loads(str(payload_msg, "utf-8"))
result['payload_msg'] = payload_msg
elif message_type == cls.SERVER_ERROR_RESPONSE:
code = int.from_bytes(payload[:4], "big", signed=False)
result['code'] = code
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
result['payload_size'] = payload_size
return result
class AudioProcessor:
"""音频处理器"""
@staticmethod
def read_wav_file(file_path: str) -> tuple:
"""读取WAV文件返回音频数据和参数"""
with wave.open(file_path, 'rb') as wf:
# 获取音频参数
channels = wf.getnchannels()
sampwidth = wf.getsampwidth()
framerate = wf.getframerate()
nframes = wf.getnframes()
# 读取音频数据
audio_data = wf.readframes(nframes)
return audio_data, {
'channels': channels,
'sampwidth': sampwidth,
'framerate': framerate,
'nframes': nframes
}
@staticmethod
def create_wav_file(audio_data: bytes, output_path: str,
sample_rate: int = 24000, channels: int = 1,
sampwidth: int = 2) -> None:
"""创建WAV文件适配树莓派播放"""
with wave.open(output_path, 'wb') as wf:
wf.setnchannels(channels)
wf.setsampwidth(sampwidth)
wf.setframerate(sample_rate)
wf.writeframes(audio_data)
class DoubaoClient:
"""豆包客户端"""
def __init__(self, config: DoubaoConfig):
self.config = config
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"连接豆包服务器: {self.config.base_url}")
self.ws = await websockets.connect(
self.config.base_url,
additional_headers=self.config.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
parsed_response = DoubaoProtocol.parse_response(response)
print(f"StartConnection响应: {parsed_response}")
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
session_config = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 10,
"input_mod": "audio_file", # 使用音频文件模式
},
},
}
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
parsed_response = DoubaoProtocol.parse_response(response)
print(f"StartSession响应: {parsed_response}")
async def send_audio_file(self, file_path: str) -> bytes:
"""发送音频文件并返回响应音频"""
print(f"处理音频文件: {file_path}")
# 读取音频文件
audio_data, audio_info = AudioProcessor.read_wav_file(file_path)
print(f"音频参数: {audio_info}")
# 计算分块大小200ms
chunk_size = int(audio_info['framerate'] * audio_info['channels'] *
audio_info['sampwidth'] * 0.2)
# 分块发送音频数据
total_chunks = (len(audio_data) + chunk_size - 1) // chunk_size
print(f"开始分块发送音频,共 {total_chunks}")
received_audio = b""
error_count = 0
for i in range(0, len(audio_data), chunk_size):
chunk = audio_data[i:i + chunk_size]
is_last = (i + chunk_size >= len(audio_data))
# 发送音频块
await self._send_audio_chunk(chunk, is_last)
# 接收响应
try:
response = await asyncio.wait_for(self.ws.recv(), timeout=2.0)
parsed_response = DoubaoProtocol.parse_response(response)
# 检查是否是错误响应
if 'code' in parsed_response and parsed_response['code'] != 0:
print(f"服务器返回错误: {parsed_response}")
error_count += 1
if error_count > 3:
raise Exception(f"服务器连续返回错误: {parsed_response}")
continue
# 处理音频响应
if (parsed_response.get('message_type') == 'SERVER_ACK' and
isinstance(parsed_response.get('payload_msg'), bytes)):
audio_chunk = parsed_response['payload_msg']
received_audio += audio_chunk
print(f"接收到音频数据块,大小: {len(audio_chunk)} 字节")
# 检查会话状态
event = parsed_response.get('event')
if event in [359, 152, 153]: # 这些事件表示会话相关状态
print(f"会话事件: {event}")
if event in [152, 153]: # 会话结束
print("检测到会话结束事件")
break
except asyncio.TimeoutError:
print("等待响应超时,继续发送")
# 模拟实时发送的延迟
await asyncio.sleep(0.05)
print("音频文件发送完成")
return received_audio
async def _send_audio_chunk(self, audio_data: bytes, is_last: bool = False) -> None:
"""发送音频块"""
request = bytearray(
DoubaoProtocol.generate_header(
message_type=DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST,
message_type_specific_flags=DoubaoProtocol.NO_SEQUENCE,
serial_method=DoubaoProtocol.NO_SERIALIZATION, # 音频数据不需要序列化
compression_type=DoubaoProtocol.GZIP
)
)
request.extend(int(200).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
# 压缩音频数据
compressed_audio = gzip.compress(audio_data)
request.extend(len(compressed_audio).to_bytes(4, 'big')) # payload size(4 bytes)
request.extend(compressed_audio)
await self.ws.send(request)
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
# 发送FinishSession
await self._send_finish_session()
# 发送FinishConnection
await self._send_finish_connection()
except Exception as e:
print(f"关闭会话时出错: {e}")
finally:
# 确保WebSocket连接关闭
try:
await self.ws.close()
except:
pass
print("连接已关闭")
async def _send_finish_session(self) -> None:
"""发送FinishSession请求"""
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(102).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
async def _send_finish_connection(self) -> None:
"""发送FinishConnection请求"""
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(2).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
parsed_response = DoubaoProtocol.parse_response(response)
print(f"FinishConnection响应: {parsed_response}")
class DoubaoProcessor:
"""豆包音频处理器"""
def __init__(self):
self.config = DoubaoConfig()
self.client = DoubaoClient(self.config)
async def process_audio_file(self, input_file: str, output_file: str = None) -> str:
"""处理音频文件
Args:
input_file: 输入音频文件路径
output_file: 输出音频文件路径如果为None则自动生成
Returns:
输出音频文件路径
"""
if not os.path.exists(input_file):
raise FileNotFoundError(f"音频文件不存在: {input_file}")
# 生成输出文件名
if output_file is None:
timestamp = time.strftime("%Y%m%d_%H%M%S")
output_file = f"doubao_response_{timestamp}.wav"
try:
# 连接豆包服务器
await self.client.connect()
# 等待一会确保会话建立
await asyncio.sleep(0.5)
# 发送音频文件并获取响应
received_audio = await self.client.send_audio_file(input_file)
if received_audio:
print(f"总共接收到音频数据: {len(received_audio)} 字节")
# 转换为WAV格式保存适配树莓派播放
AudioProcessor.create_wav_file(
received_audio,
output_file,
sample_rate=24000, # 豆包返回的音频采样率
channels=1,
sampwidth=2 # 16-bit
)
print(f"响应音频已保存到: {output_file}")
# 显示文件信息
file_size = os.path.getsize(output_file)
print(f"输出文件大小: {file_size} 字节")
else:
print("警告: 未接收到音频响应")
return output_file
except Exception as e:
print(f"处理音频文件时出错: {e}")
import traceback
traceback.print_exc()
raise
finally:
await self.client.close()
async def main():
"""测试函数"""
import argparse
parser = argparse.ArgumentParser(description="豆包音频处理测试")
parser.add_argument("--input", type=str, required=True, help="输入音频文件路径")
parser.add_argument("--output", type=str, help="输出音频文件路径")
args = parser.parse_args()
processor = DoubaoProcessor()
try:
output_file = await processor.process_audio_file(args.input, args.output)
print(f"处理完成,输出文件: {output_file}")
except Exception as e:
print(f"处理失败: {e}")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,540 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块 - 调试版本
添加更多调试信息和错误处理
"""
import asyncio
import gzip
import json
import uuid
import wave
import struct
import time
import os
from typing import Dict, Any, Optional
import websockets
class DoubaoConfig:
"""豆包配置"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
class DoubaoProtocol:
"""豆包协议处理"""
# 协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
MSG_WITH_EVENT = 0b0100
JSON = 0b0001
NO_SERIALIZATION = 0b0000
GZIP = 0b0001
@classmethod
def generate_header(cls, message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON, compression_type=GZIP) -> bytes:
"""生成协议头"""
header = bytearray()
header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(0x00) # reserved
return bytes(header)
@classmethod
def parse_response(cls, res: bytes) -> Dict[str, Any]:
"""解析响应"""
if isinstance(res, str):
return {}
try:
protocol_version = res[0] >> 4
header_size = res[0] & 0x0f
message_type = res[1] >> 4
message_type_specific_flags = res[1] & 0x0f
serialization_method = res[2] >> 4
message_compression = res[2] & 0x0f
payload = res[header_size * 4:]
result = {}
if message_type == cls.SERVER_FULL_RESPONSE or message_type == cls.SERVER_ACK:
result['message_type'] = 'SERVER_FULL_RESPONSE'
if message_type == cls.SERVER_ACK:
result['message_type'] = 'SERVER_ACK'
start = 0
if message_type_specific_flags & cls.MSG_WITH_EVENT:
result['event'] = int.from_bytes(payload[:4], "big", signed=False)
start += 4
payload = payload[start:]
if len(payload) < 4:
result['error'] = 'Payload too short for session_id'
return result
session_id_size = int.from_bytes(payload[:4], "big", signed=True)
if session_id_size < 0 or session_id_size > len(payload) - 4:
result['error'] = f'Invalid session_id size: {session_id_size}'
return result
session_id = payload[4:session_id_size+4]
result['session_id'] = str(session_id)
payload = payload[4 + session_id_size:]
if len(payload) < 4:
result['error'] = 'Payload too short for payload_size'
return result
payload_size = int.from_bytes(payload[:4], "big", signed=False)
result['payload_size'] = payload_size
if len(payload) >= 4 + payload_size:
payload_msg = payload[4:4 + payload_size]
if payload_msg:
if message_compression == cls.GZIP:
try:
payload_msg = gzip.decompress(payload_msg)
except Exception as e:
result['decompress_error'] = str(e)
return result
if serialization_method == cls.JSON:
try:
payload_msg = json.loads(str(payload_msg, "utf-8"))
except Exception as e:
result['json_error'] = str(e)
payload_msg = str(payload_msg, "utf-8")
elif serialization_method != cls.NO_SERIALIZATION:
payload_msg = str(payload_msg, "utf-8")
result['payload_msg'] = payload_msg
elif message_type == cls.SERVER_ERROR_RESPONSE:
if len(payload) >= 8:
code = int.from_bytes(payload[:4], "big", signed=False)
result['code'] = code
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
result['payload_size'] = payload_size
if len(payload) >= 8 + payload_size:
payload_msg = payload[8:8 + payload_size]
if payload_msg and message_compression == cls.GZIP:
try:
payload_msg = gzip.decompress(payload_msg)
except:
pass
result['payload_msg'] = payload_msg
except Exception as e:
result['parse_error'] = str(e)
return result
class AudioProcessor:
"""音频处理器"""
@staticmethod
def read_wav_file(file_path: str) -> tuple:
"""读取WAV文件返回音频数据和参数"""
with wave.open(file_path, 'rb') as wf:
# 获取音频参数
channels = wf.getnchannels()
sampwidth = wf.getsampwidth()
framerate = wf.getframerate()
nframes = wf.getnframes()
# 读取音频数据
audio_data = wf.readframes(nframes)
return audio_data, {
'channels': channels,
'sampwidth': sampwidth,
'framerate': framerate,
'nframes': nframes
}
@staticmethod
def create_wav_file(audio_data: bytes, output_path: str,
sample_rate: int = 24000, channels: int = 1,
sampwidth: int = 2) -> None:
"""创建WAV文件适配树莓派播放"""
with wave.open(output_path, 'wb') as wf:
wf.setnchannels(channels)
wf.setsampwidth(sampwidth)
wf.setframerate(sample_rate)
wf.writeframes(audio_data)
class DoubaoClient:
"""豆包客户端"""
def __init__(self, config: DoubaoConfig):
self.config = config
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"连接豆包服务器: {self.config.base_url}")
try:
self.ws = await websockets.connect(
self.config.base_url,
additional_headers=self.config.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
except Exception as e:
print(f"连接失败: {e}")
raise
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
print("发送StartConnection请求...")
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
parsed_response = DoubaoProtocol.parse_response(response)
print(f"StartConnection响应: {parsed_response}")
# 检查是否有错误
if 'error' in parsed_response:
raise Exception(f"StartConnection解析错误: {parsed_response['error']}")
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
print("发送StartSession请求...")
session_config = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 30, # 增加超时时间
"input_mod": "audio_file", # 使用音频文件模式
},
},
}
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
parsed_response = DoubaoProtocol.parse_response(response)
print(f"StartSession响应: {parsed_response}")
# 检查是否有错误
if 'error' in parsed_response:
raise Exception(f"StartSession解析错误: {parsed_response['error']}")
# 等待一会确保会话完全建立
await asyncio.sleep(1.0)
async def send_audio_file(self, file_path: str) -> bytes:
"""发送音频文件并返回响应音频"""
print(f"处理音频文件: {file_path}")
# 读取音频文件
audio_data, audio_info = AudioProcessor.read_wav_file(file_path)
print(f"音频参数: {audio_info}")
# 计算分块大小减小到50ms避免数据块过大
chunk_size = int(audio_info['framerate'] * audio_info['channels'] *
audio_info['sampwidth'] * 0.05) # 50ms
# 分块发送音频数据
total_chunks = (len(audio_data) + chunk_size - 1) // chunk_size
print(f"开始分块发送音频,共 {total_chunks}")
received_audio = b""
error_count = 0
session_active = True
for i in range(0, len(audio_data), chunk_size):
if not session_active:
print("会话已结束,停止发送")
break
chunk = audio_data[i:i + chunk_size]
is_last = (i + chunk_size >= len(audio_data))
# 发送音频块
await self._send_audio_chunk(chunk, is_last)
# 接收响应
try:
response = await asyncio.wait_for(self.ws.recv(), timeout=3.0)
parsed_response = DoubaoProtocol.parse_response(response)
print(f"响应 {i//chunk_size + 1}/{total_chunks}: {parsed_response}")
# 检查是否是错误响应
if 'code' in parsed_response and parsed_response['code'] != 0:
print(f"服务器返回错误: {parsed_response}")
error_count += 1
if error_count > 3:
raise Exception(f"服务器连续返回错误: {parsed_response}")
continue
# 处理音频响应
if (parsed_response.get('message_type') == 'SERVER_ACK' and
isinstance(parsed_response.get('payload_msg'), bytes)):
audio_chunk = parsed_response['payload_msg']
received_audio += audio_chunk
print(f"接收到音频数据块,大小: {len(audio_chunk)} 字节")
# 检查会话状态
event = parsed_response.get('event')
if event in [359, 152, 153]: # 这些事件表示会话相关状态
print(f"会话事件: {event}")
if event in [152, 153]: # 会话结束
print("检测到会话结束事件")
session_active = False
break
except asyncio.TimeoutError:
print("等待响应超时,继续发送")
# 模拟实时发送的延迟
await asyncio.sleep(0.1)
print("音频文件发送完成")
return received_audio
async def _send_audio_chunk(self, audio_data: bytes, is_last: bool = False) -> None:
"""发送音频块"""
request = bytearray(
DoubaoProtocol.generate_header(
message_type=DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST,
message_type_specific_flags=DoubaoProtocol.NO_SEQUENCE,
serial_method=DoubaoProtocol.NO_SERIALIZATION, # 音频数据不需要序列化
compression_type=DoubaoProtocol.GZIP
)
)
request.extend(int(200).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
# 压缩音频数据
compressed_audio = gzip.compress(audio_data)
payload_size = len(compressed_audio)
request.extend(payload_size.to_bytes(4, 'big')) # payload size(4 bytes)
request.extend(compressed_audio)
print(f"发送音频块 - 原始大小: {len(audio_data)}, 压缩后大小: {payload_size}, 总请求数据大小: {len(request)}")
await self.ws.send(request)
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
# 发送FinishSession
await self._send_finish_session()
# 发送FinishConnection
await self._send_finish_connection()
except Exception as e:
print(f"关闭会话时出错: {e}")
finally:
# 确保WebSocket连接关闭
try:
await self.ws.close()
except:
pass
print("连接已关闭")
async def _send_finish_session(self) -> None:
"""发送FinishSession请求"""
print("发送FinishSession请求...")
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(102).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
async def _send_finish_connection(self) -> None:
"""发送FinishConnection请求"""
print("发送FinishConnection请求...")
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(2).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
try:
response = await asyncio.wait_for(self.ws.recv(), timeout=5.0)
parsed_response = DoubaoProtocol.parse_response(response)
print(f"FinishConnection响应: {parsed_response}")
except asyncio.TimeoutError:
print("FinishConnection响应超时")
class DoubaoProcessor:
"""豆包音频处理器"""
def __init__(self):
self.config = DoubaoConfig()
self.client = DoubaoClient(self.config)
async def process_audio_file(self, input_file: str, output_file: str = None) -> str:
"""处理音频文件
Args:
input_file: 输入音频文件路径
output_file: 输出音频文件路径如果为None则自动生成
Returns:
输出音频文件路径
"""
if not os.path.exists(input_file):
raise FileNotFoundError(f"音频文件不存在: {input_file}")
# 生成输出文件名
if output_file is None:
timestamp = time.strftime("%Y%m%d_%H%M%S")
output_file = f"doubao_response_{timestamp}.wav"
try:
# 连接豆包服务器
await self.client.connect()
# 发送音频文件并获取响应
received_audio = await self.client.send_audio_file(input_file)
if received_audio:
print(f"总共接收到音频数据: {len(received_audio)} 字节")
# 转换为WAV格式保存适配树莓派播放
AudioProcessor.create_wav_file(
received_audio,
output_file,
sample_rate=24000, # 豆包返回的音频采样率
channels=1,
sampwidth=2 # 16-bit
)
print(f"响应音频已保存到: {output_file}")
# 显示文件信息
file_size = os.path.getsize(output_file)
print(f"输出文件大小: {file_size} 字节")
else:
print("警告: 未接收到音频响应")
return output_file
except Exception as e:
print(f"处理音频文件时出错: {e}")
import traceback
traceback.print_exc()
raise
finally:
await self.client.close()
async def main():
"""测试函数"""
import argparse
parser = argparse.ArgumentParser(description="豆包音频处理测试")
parser.add_argument("--input", type=str, required=True, help="输入音频文件路径")
parser.add_argument("--output", type=str, help="输出音频文件路径")
args = parser.parse_args()
processor = DoubaoProcessor()
try:
output_file = await processor.process_audio_file(args.input, args.output)
print(f"处理完成,输出文件: {output_file}")
except Exception as e:
print(f"处理失败: {e}")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,425 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块 - 简化测试版本
专门测试完整的音频上传和TTS音频下载流程
"""
import asyncio
import gzip
import json
import uuid
import wave
import struct
import time
import os
from typing import Dict, Any, Optional
import websockets
# 直接复制原始豆包代码的协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
MSG_WITH_EVENT = 0b0100
NO_SERIALIZATION = 0b0000
JSON = 0b0001
GZIP = 0b0001
def generate_header(
version=PROTOCOL_VERSION,
message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON,
compression_type=GZIP,
reserved_data=0x00,
extension_header=bytes()
):
"""直接复制原始豆包代码的generate_header函数"""
header = bytearray()
header_size = int(len(extension_header) / 4) + 1
header.append((version << 4) | header_size)
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(reserved_data)
header.extend(extension_header)
return header
class DoubaoConfig:
"""豆包配置"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
class DoubaoClient:
"""豆包客户端 - 基于原始代码"""
def __init__(self, config: DoubaoConfig):
self.config = config
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"连接豆包服务器: {self.config.base_url}")
try:
self.ws = await websockets.connect(
self.config.base_url,
additional_headers=self.config.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
except Exception as e:
print(f"连接失败: {e}")
raise
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
print("发送StartConnection请求...")
request = bytearray(generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartConnection响应长度: {len(response)}")
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
print("发送StartSession请求...")
session_config = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 30,
"input_mod": "audio",
},
},
}
request = bytearray(generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartSession响应长度: {len(response)}")
# 等待一会确保会话完全建立
await asyncio.sleep(1.0)
async def task_request(self, audio: bytes) -> None:
"""直接复制原始豆包代码的task_request方法"""
task_request = bytearray(
generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
serial_method=NO_SERIALIZATION))
task_request.extend(int(200).to_bytes(4, 'big'))
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
task_request.extend(str.encode(self.session_id))
payload_bytes = gzip.compress(audio)
task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
task_request.extend(payload_bytes)
await self.ws.send(task_request)
async def test_full_dialog(self) -> None:
"""测试完整对话流程"""
print("开始完整对话测试...")
# 读取真实的录音文件
try:
import wave
with wave.open("recording_20250920_135137.wav", 'rb') as wf:
# 读取前5秒的音频数据
total_frames = wf.getnframes()
frames_to_read = min(total_frames, 80000) # 5秒
audio_data = wf.readframes(frames_to_read)
print(f"读取真实音频数据: {len(audio_data)} 字节")
print(f"音频参数: 采样率={wf.getframerate()}, 通道数={wf.getnchannels()}, 采样宽度={wf.getsampwidth()}")
except Exception as e:
print(f"读取音频文件失败: {e}")
return
print(f"音频数据大小: {len(audio_data)}")
try:
# 发送音频数据
print("发送音频数据...")
await self.task_request(audio_data)
print("音频数据发送成功")
# 等待语音识别响应
print("等待语音识别响应...")
response = await asyncio.wait_for(self.ws.recv(), timeout=15.0)
print(f"收到ASR响应长度: {len(response)}")
# 解析ASR响应
if len(response) >= 4:
protocol_version = response[0] >> 4
header_size = response[0] & 0x0f
message_type = response[1] >> 4
flags = response[1] & 0x0f
print(f"ASR响应协议: version={protocol_version}, header_size={header_size}, message_type={message_type}, flags={flags}")
if message_type == 9: # SERVER_FULL_RESPONSE
payload_start = header_size * 4
payload = response[payload_start:]
if len(payload) >= 4:
event = int.from_bytes(payload[:4], 'big')
print(f"ASR Event: {event}")
if len(payload) >= 8:
session_id_len = int.from_bytes(payload[4:8], 'big')
if len(payload) >= 8 + session_id_len:
session_id = payload[8:8+session_id_len].decode()
print(f"Session ID: {session_id}")
if len(payload) >= 12 + session_id_len:
payload_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big')
payload_data = payload[12+session_id_len:12+session_id_len+payload_size]
print(f"Payload size: {payload_size}")
# 解析ASR结果
try:
asr_result = json.loads(payload_data.decode('utf-8'))
print(f"ASR结果: {asr_result}")
# 如果有识别结果,提取文本
if 'results' in asr_result and asr_result['results']:
text = asr_result['results'][0].get('text', '')
print(f"识别文本: {text}")
except Exception as e:
print(f"解析ASR结果失败: {e}")
# 持续等待TTS音频响应
print("开始持续等待TTS音频响应...")
response_count = 0
max_responses = 10
while response_count < max_responses:
try:
print(f"等待第 {response_count + 1} 个响应...")
tts_response = await asyncio.wait_for(self.ws.recv(), timeout=30.0)
print(f"收到响应 {response_count + 1},长度: {len(tts_response)}")
# 解析响应
if len(tts_response) >= 4:
tts_version = tts_response[0] >> 4
tts_header_size = tts_response[0] & 0x0f
tts_message_type = tts_response[1] >> 4
tts_flags = tts_response[1] & 0x0f
print(f"响应协议: version={tts_version}, header_size={tts_header_size}, message_type={tts_message_type}, flags={tts_flags}")
if tts_message_type == 11: # SERVER_ACK (包含TTS音频)
tts_payload_start = tts_header_size * 4
tts_payload = tts_response[tts_payload_start:]
if len(tts_payload) >= 12:
tts_event = int.from_bytes(tts_payload[:4], 'big')
tts_session_len = int.from_bytes(tts_payload[4:8], 'big')
tts_session = tts_payload[8:8+tts_session_len].decode()
tts_audio_size = int.from_bytes(tts_payload[8+tts_session_len:12+tts_session_len], 'big')
tts_audio_data = tts_payload[12+tts_session_len:12+tts_session_len+tts_audio_size]
print(f"Event: {tts_event}")
print(f"音频数据大小: {tts_audio_size}")
if tts_audio_size > 0:
print("找到TTS音频数据")
# 尝试解压缩TTS音频
try:
decompressed_tts = gzip.decompress(tts_audio_data)
print(f"解压缩后TTS音频大小: {len(decompressed_tts)}")
# 创建WAV文件
sample_rate = 24000
channels = 1
sampwidth = 2
with wave.open(f'tts_response_{response_count}.wav', 'wb') as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(sampwidth)
wav_file.setframerate(sample_rate)
wav_file.writeframes(decompressed_tts)
print(f"成功创建TTS WAV文件: tts_response_{response_count}.wav")
print(f"音频参数: {sample_rate}Hz, {channels}通道, {sampwidth*8}-bit")
# 显示文件信息
if os.path.exists(f'tts_response_{response_count}.wav'):
file_size = os.path.getsize(f'tts_response_{response_count}.wav')
duration = file_size / (sample_rate * channels * sampwidth)
print(f"WAV文件大小: {file_size} 字节")
print(f"音频时长: {duration:.2f}")
# 成功获取音频,退出循环
break
except Exception as tts_e:
print(f"TTS音频解压缩失败: {tts_e}")
# 保存原始数据
with open(f'tts_response_audio_{response_count}.raw', 'wb') as f:
f.write(tts_audio_data)
print(f"原始TTS音频数据已保存到 tts_response_audio_{response_count}.raw")
elif tts_message_type == 9: # SERVER_FULL_RESPONSE
tts_payload_start = tts_header_size * 4
tts_payload = tts_response[tts_payload_start:]
if len(tts_payload) >= 4:
event = int.from_bytes(tts_payload[:4], 'big')
print(f"Event: {event}")
if event in [451, 359]: # ASR结果或TTS结束
# 解析payload
if len(tts_payload) >= 8:
session_id_len = int.from_bytes(tts_payload[4:8], 'big')
if len(tts_payload) >= 8 + session_id_len:
session_id = tts_payload[8:8+session_id_len].decode()
if len(tts_payload) >= 12 + session_id_len:
payload_size = int.from_bytes(tts_payload[8+session_id_len:12+session_id_len], 'big')
payload_data = tts_payload[12+session_id_len:12+session_id_len+payload_size]
try:
json_data = json.loads(payload_data.decode('utf-8'))
print(f"JSON数据: {json_data}")
# 如果是ASR结果
if 'results' in json_data:
text = json_data['results'][0].get('text', '')
print(f"识别文本: {text}")
# 如果是TTS结束标记
if event == 359:
print("TTS响应结束")
break
except Exception as e:
print(f"解析JSON失败: {e}")
# 保存原始数据
with open(f'tts_response_{response_count}.raw', 'wb') as f:
f.write(payload_data)
# 保存完整响应用于调试
with open(f'tts_response_full_{response_count}.raw', 'wb') as f:
f.write(tts_response)
print(f"完整响应已保存到 tts_response_full_{response_count}.raw")
response_count += 1
except asyncio.TimeoutError:
print(f"等待第 {response_count + 1} 个响应超时")
break
except websockets.exceptions.ConnectionClosed:
print("连接已关闭")
break
print(f"共收到 {response_count} 个响应")
except asyncio.TimeoutError:
print("等待响应超时")
except websockets.exceptions.ConnectionClosed as e:
print(f"连接关闭: {e}")
except Exception as e:
print(f"测试失败: {e}")
import traceback
traceback.print_exc()
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
await self.ws.close()
except:
pass
print("连接已关闭")
async def main():
"""测试函数"""
config = DoubaoConfig()
client = DoubaoClient(config)
try:
await client.connect()
await client.test_full_dialog()
except Exception as e:
print(f"测试失败: {e}")
import traceback
traceback.print_exc()
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,412 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块 - 基于原始代码的测试版本
直接使用原始豆包代码的核心逻辑
"""
import asyncio
import gzip
import json
import uuid
import wave
import struct
import time
import os
from typing import Dict, Any, Optional
import websockets
# 直接复制原始豆包代码的协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
MSG_WITH_EVENT = 0b0100
NO_SERIALIZATION = 0b0000
JSON = 0b0001
GZIP = 0b0001
def generate_header(
version=PROTOCOL_VERSION,
message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON,
compression_type=GZIP,
reserved_data=0x00,
extension_header=bytes()
):
"""直接复制原始豆包代码的generate_header函数"""
header = bytearray()
header_size = int(len(extension_header) / 4) + 1
header.append((version << 4) | header_size)
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(reserved_data)
header.extend(extension_header)
return header
class DoubaoConfig:
"""豆包配置"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
class DoubaoClient:
"""豆包客户端 - 基于原始代码"""
def __init__(self, config: DoubaoConfig):
self.config = config
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"连接豆包服务器: {self.config.base_url}")
try:
self.ws = await websockets.connect(
self.config.base_url,
additional_headers=self.config.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
except Exception as e:
print(f"连接失败: {e}")
raise
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
print("发送StartConnection请求...")
request = bytearray(generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartConnection响应长度: {len(response)}")
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
print("发送StartSession请求...")
session_config = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 30,
"input_mod": "audio",
},
},
}
request = bytearray(generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartSession响应长度: {len(response)}")
# 等待一会确保会话完全建立
await asyncio.sleep(1.0)
async def task_request(self, audio: bytes) -> None:
"""直接复制原始豆包代码的task_request方法"""
task_request = bytearray(
generate_header(message_type=CLIENT_AUDIO_ONLY_REQUEST,
serial_method=NO_SERIALIZATION))
task_request.extend(int(200).to_bytes(4, 'big'))
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
task_request.extend(str.encode(self.session_id))
payload_bytes = gzip.compress(audio)
task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
task_request.extend(payload_bytes)
await self.ws.send(task_request)
async def test_audio_request(self) -> None:
"""测试音频请求"""
print("测试音频请求...")
# 读取真实的录音文件
try:
import wave
with wave.open("recording_20250920_135137.wav", 'rb') as wf:
# 读取前10秒的音频数据16000采样率 * 10秒 = 160000帧
total_frames = wf.getnframes()
frames_to_read = min(total_frames, 160000) # 最多10秒
small_audio = wf.readframes(frames_to_read)
print(f"读取真实音频数据: {len(small_audio)} 字节")
print(f"音频参数: 采样率={wf.getframerate()}, 通道数={wf.getnchannels()}, 采样宽度={wf.getsampwidth()}")
print(f"总帧数: {total_frames}, 读取帧数: {frames_to_read}")
except Exception as e:
print(f"读取音频文件失败: {e}")
# 如果读取失败,使用静音数据
small_audio = b'\x00' * 3200
print(f"音频数据大小: {len(small_audio)}")
try:
# 发送完整的音频数据块
print(f"发送完整的音频数据块...")
await self.task_request(small_audio)
print(f"音频数据块发送成功")
print("等待语音识别响应...")
# 等待更长时间的响应(语音识别可能需要更长时间)
response = await asyncio.wait_for(self.ws.recv(), timeout=15.0)
print(f"收到响应,长度: {len(response)}")
# 解析响应
try:
if len(response) >= 4:
protocol_version = response[0] >> 4
header_size = response[0] & 0x0f
message_type = response[1] >> 4
message_type_specific_flags = response[1] & 0x0f
print(f"响应协议信息: version={protocol_version}, header_size={header_size}, message_type={message_type}, flags={message_type_specific_flags}")
# 解析payload
payload_start = header_size * 4
payload = response[payload_start:]
if message_type == 9: # SERVER_FULL_RESPONSE
print("收到SERVER_FULL_RESPONSE")
if len(payload) >= 4:
# 解析event
event = int.from_bytes(payload[:4], 'big')
print(f"Event: {event}")
# 解析session_id
if len(payload) >= 8:
session_id_len = int.from_bytes(payload[4:8], 'big')
if len(payload) >= 8 + session_id_len:
session_id = payload[8:8+session_id_len].decode()
print(f"Session ID: {session_id}")
# 解析payload size和data
if len(payload) >= 12 + session_id_len:
payload_size = int.from_bytes(payload[8+session_id_len:12+session_id_len], 'big')
payload_data = payload[12+session_id_len:12+session_id_len+payload_size]
print(f"Payload size: {payload_size}")
# 如果包含音频数据,保存到文件
if len(payload_data) > 0:
print(f"收到数据: {len(payload_data)} 字节")
# 保存原始音频数据
with open('response_audio.raw', 'wb') as f:
f.write(payload_data)
print("音频数据已保存到 response_audio.raw")
# 尝试解析JSON数据
try:
import json
json_data = json.loads(payload_data.decode('utf-8'))
print(f"JSON数据: {json_data}")
# 如果是语音识别任务开始,继续等待音频响应
if 'asr_task_id' in json_data:
print("语音识别任务开始,继续等待音频响应...")
try:
# 等待音频响应
audio_response = await asyncio.wait_for(self.ws.recv(), timeout=20.0)
print(f"收到音频响应,长度: {len(audio_response)}")
# 解析音频响应
if len(audio_response) >= 4:
audio_version = audio_response[0] >> 4
audio_header_size = audio_response[0] & 0x0f
audio_message_type = audio_response[1] >> 4
audio_flags = audio_response[1] & 0x0f
print(f"音频响应协议信息: version={audio_version}, header_size={audio_header_size}, message_type={audio_message_type}, flags={audio_flags}")
if audio_message_type == 9: # SERVER_FULL_RESPONSE (包含TTS音频)
audio_payload_start = audio_header_size * 4
audio_payload = audio_response[audio_payload_start:]
if len(audio_payload) >= 12:
# 解析event和session_id
audio_event = int.from_bytes(audio_payload[:4], 'big')
audio_session_len = int.from_bytes(audio_payload[4:8], 'big')
audio_session = audio_payload[8:8+audio_session_len].decode()
audio_data_size = int.from_bytes(audio_payload[8+audio_session_len:12+audio_session_len], 'big')
audio_data = audio_payload[12+audio_session_len:12+audio_session_len+audio_data_size]
print(f"音频Event: {audio_event}")
print(f"音频数据大小: {audio_data_size}")
if audio_data_size > 0:
# 保存原始音频数据
with open('tts_response_audio.raw', 'wb') as f:
f.write(audio_data)
print(f"TTS音频数据已保存到 tts_response_audio.raw")
# 尝试解析音频数据可能是JSON或GZIP压缩的音频
try:
# 首先尝试解压缩
import gzip
decompressed_audio = gzip.decompress(audio_data)
print(f"解压缩后音频数据大小: {len(decompressed_audio)}")
with open('tts_response_audio_decompressed.raw', 'wb') as f:
f.write(decompressed_audio)
print("解压缩的音频数据已保存")
# 创建WAV文件供树莓派播放
import wave
import struct
# 豆包返回的音频是24000Hz, 16-bit, 单声道
sample_rate = 24000
channels = 1
sampwidth = 2 # 16-bit = 2 bytes
with wave.open('tts_response.wav', 'wb') as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(sampwidth)
wav_file.setframerate(sample_rate)
wav_file.writeframes(decompressed_audio)
print("已创建WAV文件: tts_response.wav")
print(f"音频参数: {sample_rate}Hz, {channels}通道, {sampwidth*8}-bit")
except Exception as audio_e:
print(f"音频数据处理失败: {audio_e}")
# 如果解压缩失败,直接保存原始数据
with open('tts_response_audio_original.raw', 'wb') as f:
f.write(audio_data)
elif audio_message_type == 11: # SERVER_ACK
print("收到SERVER_ACK音频响应")
# 处理SERVER_ACK格式的音频响应
audio_payload_start = audio_header_size * 4
audio_payload = audio_response[audio_payload_start:]
print(f"音频payload长度: {len(audio_payload)}")
with open('tts_response_ack.raw', 'wb') as f:
f.write(audio_payload)
except asyncio.TimeoutError:
print("等待音频响应超时")
except Exception as json_e:
print(f"解析JSON失败: {json_e}")
# 如果不是JSON可能是音频数据直接保存
with open('response_audio.raw', 'wb') as f:
f.write(payload_data)
elif message_type == 11: # SERVER_ACK
print("收到SERVER_ACK响应")
elif message_type == 15: # SERVER_ERROR_RESPONSE
print("收到错误响应")
if len(response) > 8:
error_code = int.from_bytes(response[4:8], 'big')
print(f"错误代码: {error_code}")
except Exception as e:
print(f"解析响应失败: {e}")
import traceback
traceback.print_exc()
except asyncio.TimeoutError:
print("等待响应超时")
except websockets.exceptions.ConnectionClosed as e:
print(f"连接关闭: {e}")
except Exception as e:
print(f"发送音频请求失败: {e}")
raise
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
await self.ws.close()
except:
pass
print("连接已关闭")
async def main():
"""测试函数"""
config = DoubaoConfig()
client = DoubaoClient(config)
try:
await client.connect()
await client.test_audio_request()
except Exception as e:
print(f"测试失败: {e}")
import traceback
traceback.print_exc()
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,303 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块 - 协议测试版本
专门测试协议格式问题
"""
import asyncio
import gzip
import json
import uuid
import wave
import struct
import time
import os
from typing import Dict, Any, Optional
import websockets
class DoubaoConfig:
"""豆包配置"""
def __init__(self):
self.base_url = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
self.app_id = "8718217928"
self.access_key = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc"
self.app_key = "PlgvMymc7f3tQnJ6"
self.resource_id = "volc.speech.dialog"
def get_headers(self) -> Dict[str, str]:
"""获取请求头"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_key,
"X-Api-Resource-Id": self.resource_id,
"X-Api-App-Key": self.app_key,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
class DoubaoProtocol:
"""豆包协议处理"""
# 协议常量
PROTOCOL_VERSION = 0b0001
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
MSG_WITH_EVENT = 0b0100
JSON = 0b0001
NO_SERIALIZATION = 0b0000
GZIP = 0b0001
@classmethod
def generate_header(cls, message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=MSG_WITH_EVENT,
serial_method=JSON, compression_type=GZIP) -> bytes:
"""生成协议头"""
header = bytearray()
header.append((cls.PROTOCOL_VERSION << 4) | 1) # version + header_size
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(0x00) # reserved
return bytes(header)
class DoubaoClient:
"""豆包客户端"""
def __init__(self, config: DoubaoConfig):
self.config = config
self.session_id = str(uuid.uuid4())
self.ws = None
self.log_id = ""
async def connect(self) -> None:
"""建立WebSocket连接"""
print(f"连接豆包服务器: {self.config.base_url}")
try:
self.ws = await websockets.connect(
self.config.base_url,
additional_headers=self.config.get_headers(),
ping_interval=None
)
# 获取log_id
if hasattr(self.ws, 'response_headers'):
self.log_id = self.ws.response_headers.get("X-Tt-Logid")
elif hasattr(self.ws, 'headers'):
self.log_id = self.ws.headers.get("X-Tt-Logid")
print(f"连接成功, log_id: {self.log_id}")
# 发送StartConnection请求
await self._send_start_connection()
# 发送StartSession请求
await self._send_start_session()
except Exception as e:
print(f"连接失败: {e}")
raise
async def _send_start_connection(self) -> None:
"""发送StartConnection请求"""
print("发送StartConnection请求...")
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(1).to_bytes(4, 'big'))
payload_bytes = b"{}"
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartConnection响应长度: {len(response)}")
async def _send_start_session(self) -> None:
"""发送StartSession请求"""
print("发送StartSession请求...")
session_config = {
"asr": {
"extra": {
"end_smooth_window_ms": 1500,
},
},
"tts": {
"speaker": "zh_female_vv_jupiter_bigtts",
"audio_config": {
"channel": 1,
"format": "pcm",
"sample_rate": 24000
},
},
"dialog": {
"bot_name": "豆包",
"system_role": "你使用活泼灵动的女声,性格开朗,热爱生活。",
"speaking_style": "你的说话风格简洁明了,语速适中,语调自然。",
"location": {"city": "北京"},
"extra": {
"strict_audit": False,
"audit_response": "支持客户自定义安全审核回复话术。",
"recv_timeout": 30,
"input_mod": "audio",
},
},
}
request = bytearray(DoubaoProtocol.generate_header())
request.extend(int(100).to_bytes(4, 'big'))
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
payload_bytes = json.dumps(session_config).encode()
payload_bytes = gzip.compress(payload_bytes)
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
await self.ws.send(request)
response = await self.ws.recv()
print(f"StartSession响应长度: {len(response)}")
# 等待一会确保会话完全建立
await asyncio.sleep(1.0)
async def test_audio_request(self) -> None:
"""测试音频请求格式"""
print("测试音频请求格式...")
# 创建音频数据(静音)- 使用原始豆包代码的chunk大小
small_audio = b'\x00' * 3200 # 原始豆包代码中的chunk大小
# 完全按照原始豆包代码的格式构建请求,不进行任何填充
header = bytearray()
header.append((DoubaoProtocol.PROTOCOL_VERSION << 4) | 1) # version + header_size
header.append((DoubaoProtocol.CLIENT_AUDIO_ONLY_REQUEST << 4) | DoubaoProtocol.NO_SEQUENCE)
header.append((DoubaoProtocol.NO_SERIALIZATION << 4) | DoubaoProtocol.GZIP)
header.append(0x00) # reserved
request = bytearray(header)
# 添加消息类型 (200 = task request)
request.extend(int(200).to_bytes(4, 'big'))
# 添加session_id
request.extend(len(self.session_id).to_bytes(4, 'big'))
request.extend(self.session_id.encode())
# 压缩音频数据
compressed_audio = gzip.compress(small_audio)
# 添加payload size
request.extend(len(compressed_audio).to_bytes(4, 'big'))
# 添加压缩后的音频数据
request.extend(compressed_audio)
print(f"测试请求详细信息:")
print(f" - 音频原始大小: {len(small_audio)}")
print(f" - 音频压缩后大小: {len(compressed_audio)}")
print(f" - Session ID: {self.session_id} (长度: {len(self.session_id)})")
print(f" - 总请求大小: {len(request)}")
print(f" - 头部字节: {request[:4].hex()}")
print(f" - 消息类型: {int.from_bytes(request[4:8], 'big')}")
print(f" - Session ID长度: {int.from_bytes(request[8:12], 'big')}")
print(f" - Payload size: {int.from_bytes(request[12+len(self.session_id):16+len(self.session_id)], 'big')}")
try:
await self.ws.send(request)
print("请求发送成功")
# 等待响应
response = await asyncio.wait_for(self.ws.recv(), timeout=3.0)
print(f"收到响应,长度: {len(response)}")
# 尝试解析响应
try:
protocol_version = response[0] >> 4
header_size = response[0] & 0x0f
message_type = response[1] >> 4
message_type_specific_flags = response[1] & 0x0f
serialization_method = response[2] >> 4
message_compression = response[2] & 0x0f
print(f"响应协议信息:")
print(f" - version={protocol_version}")
print(f" - header_size={header_size}")
print(f" - message_type={message_type} (15=SERVER_ERROR_RESPONSE)")
print(f" - message_type_specific_flags={message_type_specific_flags}")
print(f" - serialization_method={serialization_method}")
print(f" - message_compression={message_compression}")
# 解析payload
payload = response[header_size * 4:]
if message_type == 15: # SERVER_ERROR_RESPONSE
if len(payload) >= 8:
code = int.from_bytes(payload[:4], "big", signed=False)
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
print(f" - 错误代码: {code}")
print(f" - payload大小: {payload_size}")
if len(payload) >= 8 + payload_size:
payload_msg = payload[8:8 + payload_size]
print(f" - payload长度: {len(payload_msg)}")
if message_compression == 1: # GZIP
try:
payload_msg = gzip.decompress(payload_msg)
print(f" - 解压缩后长度: {len(payload_msg)}")
except:
pass
try:
error_msg = json.loads(payload_msg.decode('utf-8'))
print(f" - 错误信息: {error_msg}")
except:
print(f" - 原始payload: {payload_msg}")
except Exception as e:
print(f"解析响应失败: {e}")
import traceback
traceback.print_exc()
except Exception as e:
print(f"发送测试请求失败: {e}")
raise
async def close(self) -> None:
"""关闭连接"""
if self.ws:
try:
await self.ws.close()
except:
pass
print("连接已关闭")
async def main():
"""测试函数"""
config = DoubaoConfig()
client = DoubaoClient(config)
try:
await client.connect()
await client.test_audio_request()
except Exception as e:
print(f"测试失败: {e}")
import traceback
traceback.print_exc()
finally:
await client.close()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,127 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
语音识别使用示例
演示如何使用 speech_recognizer 模块
"""
import os
import asyncio
from speech_recognizer import SpeechRecognizer
async def example_recognize_file():
"""示例:识别单个音频文件"""
print("=== 示例1识别单个音频文件 ===")
# 初始化识别器
recognizer = SpeechRecognizer(
app_key="your_app_key", # 请替换为实际的app_key
access_key="your_access_key" # 请替换为实际的access_key
)
# 假设有一个录音文件
audio_file = "recording_20240101_120000.wav"
if not os.path.exists(audio_file):
print(f"音频文件不存在: {audio_file}")
print("请先运行 enhanced_wake_and_record.py 录制一个音频文件")
return
try:
# 识别音频文件
results = await recognizer.recognize_file(audio_file)
print(f"识别结果(共{len(results)}个):")
for i, result in enumerate(results):
print(f"结果 {i+1}:")
print(f" 文本: {result.text}")
print(f" 置信度: {result.confidence}")
print(f" 最终结果: {result.is_final}")
print("-" * 40)
except Exception as e:
print(f"识别失败: {e}")
async def example_recognize_latest():
"""示例:识别最新的录音文件"""
print("\n=== 示例2识别最新的录音文件 ===")
# 初始化识别器
recognizer = SpeechRecognizer(
app_key="your_app_key", # 请替换为实际的app_key
access_key="your_access_key" # 请替换为实际的access_key
)
try:
# 识别最新的录音文件
result = await recognizer.recognize_latest_recording()
if result:
print("识别结果:")
print(f" 文本: {result.text}")
print(f" 置信度: {result.confidence}")
print(f" 最终结果: {result.is_final}")
else:
print("未找到录音文件或识别失败")
except Exception as e:
print(f"识别失败: {e}")
async def example_batch_recognition():
"""示例:批量识别多个录音文件"""
print("\n=== 示例3批量识别录音文件 ===")
# 初始化识别器
recognizer = SpeechRecognizer(
app_key="your_app_key", # 请替换为实际的app_key
access_key="your_access_key" # 请替换为实际的access_key
)
# 获取所有录音文件
recording_files = [f for f in os.listdir(".") if f.startswith('recording_') and f.endswith('.wav')]
if not recording_files:
print("未找到录音文件")
return
print(f"找到 {len(recording_files)} 个录音文件")
for filename in recording_files[:5]: # 只处理前5个文件
print(f"\n处理文件: {filename}")
try:
results = await recognizer.recognize_file(filename)
if results:
final_result = results[-1] # 取最后一个结果
print(f"识别结果: {final_result.text}")
else:
print("识别失败")
except Exception as e:
print(f"处理失败: {e}")
# 添加延迟,避免请求过于频繁
await asyncio.sleep(1)
async def main():
"""主函数"""
print("🚀 语音识别使用示例")
print("=" * 50)
# 请先设置环境变量或在代码中填入实际的API密钥
if not os.getenv("SAUC_APP_KEY") and "your_app_key" in "your_app_key":
print("⚠️ 请先设置 SAUC_APP_KEY 和 SAUC_ACCESS_KEY 环境变量")
print("或者在代码中填入实际的 app_key 和 access_key")
print("示例:")
print("export SAUC_APP_KEY='your_app_key'")
print("export SAUC_ACCESS_KEY='your_access_key'")
return
# 运行示例
await example_recognize_file()
await example_recognize_latest()
await example_batch_recognition()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,15 +0,0 @@
# README
**asr tob 相关client demo**
# Notice
python version: python 3.x
替换代码中的key为真实数据:
"app_key": "xxxxxxx",
"access_key": "xxxxxxxxxxxxxxxx"
使用示例:
python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav

View File

@ -1,523 +0,0 @@
import asyncio
import aiohttp
import json
import struct
import gzip
import uuid
import logging
import os
import subprocess
from typing import Optional, List, Dict, Any, Tuple, AsyncGenerator
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('run.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# 常量定义
DEFAULT_SAMPLE_RATE = 16000
class ProtocolVersion:
V1 = 0b0001
class MessageType:
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ERROR_RESPONSE = 0b1111
class MessageTypeSpecificFlags:
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
NEG_SEQUENCE = 0b0010
NEG_WITH_SEQUENCE = 0b0011
class SerializationType:
NO_SERIALIZATION = 0b0000
JSON = 0b0001
class CompressionType:
GZIP = 0b0001
class Config:
def __init__(self):
# 填入控制台获取的app id和access token
self.auth = {
"app_key": "xxxxxxx",
"access_key": "xxxxxxxxxxxx"
}
@property
def app_key(self) -> str:
return self.auth["app_key"]
@property
def access_key(self) -> str:
return self.auth["access_key"]
config = Config()
class CommonUtils:
@staticmethod
def gzip_compress(data: bytes) -> bytes:
return gzip.compress(data)
@staticmethod
def gzip_decompress(data: bytes) -> bytes:
return gzip.decompress(data)
@staticmethod
def judge_wav(data: bytes) -> bool:
if len(data) < 44:
return False
return data[:4] == b'RIFF' and data[8:12] == b'WAVE'
@staticmethod
def convert_wav_with_path(audio_path: str, sample_rate: int = DEFAULT_SAMPLE_RATE) -> bytes:
try:
cmd = [
"ffmpeg", "-v", "quiet", "-y", "-i", audio_path,
"-acodec", "pcm_s16le", "-ac", "1", "-ar", str(sample_rate),
"-f", "wav", "-"
]
result = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# 尝试删除原始文件
try:
os.remove(audio_path)
except OSError as e:
logger.warning(f"Failed to remove original file: {e}")
return result.stdout
except subprocess.CalledProcessError as e:
logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}")
raise RuntimeError(f"Audio conversion failed: {e.stderr.decode()}")
@staticmethod
def read_wav_info(data: bytes) -> Tuple[int, int, int, int, bytes]:
if len(data) < 44:
raise ValueError("Invalid WAV file: too short")
# 解析WAV头
chunk_id = data[:4]
if chunk_id != b'RIFF':
raise ValueError("Invalid WAV file: not RIFF format")
format_ = data[8:12]
if format_ != b'WAVE':
raise ValueError("Invalid WAV file: not WAVE format")
# 解析fmt子块
audio_format = struct.unpack('<H', data[20:22])[0]
num_channels = struct.unpack('<H', data[22:24])[0]
sample_rate = struct.unpack('<I', data[24:28])[0]
bits_per_sample = struct.unpack('<H', data[34:36])[0]
# 查找data子块
pos = 36
while pos < len(data) - 8:
subchunk_id = data[pos:pos+4]
subchunk_size = struct.unpack('<I', data[pos+4:pos+8])[0]
if subchunk_id == b'data':
wave_data = data[pos+8:pos+8+subchunk_size]
return (
num_channels,
bits_per_sample // 8,
sample_rate,
subchunk_size // (num_channels * (bits_per_sample // 8)),
wave_data
)
pos += 8 + subchunk_size
raise ValueError("Invalid WAV file: no data subchunk found")
class AsrRequestHeader:
def __init__(self):
self.message_type = MessageType.CLIENT_FULL_REQUEST
self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE
self.serialization_type = SerializationType.JSON
self.compression_type = CompressionType.GZIP
self.reserved_data = bytes([0x00])
def with_message_type(self, message_type: int) -> 'AsrRequestHeader':
self.message_type = message_type
return self
def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader':
self.message_type_specific_flags = flags
return self
def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader':
self.serialization_type = serialization_type
return self
def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader':
self.compression_type = compression_type
return self
def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader':
self.reserved_data = reserved_data
return self
def to_bytes(self) -> bytes:
header = bytearray()
header.append((ProtocolVersion.V1 << 4) | 1)
header.append((self.message_type << 4) | self.message_type_specific_flags)
header.append((self.serialization_type << 4) | self.compression_type)
header.extend(self.reserved_data)
return bytes(header)
@staticmethod
def default_header() -> 'AsrRequestHeader':
return AsrRequestHeader()
class RequestBuilder:
@staticmethod
def new_auth_headers() -> Dict[str, str]:
reqid = str(uuid.uuid4())
return {
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
"X-Api-Request-Id": reqid,
"X-Api-Access-Key": config.access_key,
"X-Api-App-Key": config.app_key
}
@staticmethod
def new_full_client_request(seq: int) -> bytes: # 添加seq参数
header = AsrRequestHeader.default_header() \
.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
payload = {
"user": {
"uid": "demo_uid"
},
"audio": {
"format": "wav",
"codec": "raw",
"rate": 16000,
"bits": 16,
"channel": 1
},
"request": {
"model_name": "bigmodel",
"enable_itn": True,
"enable_punc": True,
"enable_ddc": True,
"show_utterances": True,
"enable_nonstream": False
}
}
payload_bytes = json.dumps(payload).encode('utf-8')
compressed_payload = CommonUtils.gzip_compress(payload_bytes)
payload_size = len(compressed_payload)
request = bytearray()
request.extend(header.to_bytes())
request.extend(struct.pack('>i', seq)) # 使用传入的seq
request.extend(struct.pack('>I', payload_size))
request.extend(compressed_payload)
return bytes(request)
@staticmethod
def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes:
header = AsrRequestHeader.default_header()
if is_last: # 最后一个包特殊处理
header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE)
seq = -seq # 设为负值
else:
header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST)
request = bytearray()
request.extend(header.to_bytes())
request.extend(struct.pack('>i', seq))
compressed_segment = CommonUtils.gzip_compress(segment)
request.extend(struct.pack('>I', len(compressed_segment)))
request.extend(compressed_segment)
return bytes(request)
class AsrResponse:
def __init__(self):
self.code = 0
self.event = 0
self.is_last_package = False
self.payload_sequence = 0
self.payload_size = 0
self.payload_msg = None
def to_dict(self) -> Dict[str, Any]:
return {
"code": self.code,
"event": self.event,
"is_last_package": self.is_last_package,
"payload_sequence": self.payload_sequence,
"payload_size": self.payload_size,
"payload_msg": self.payload_msg
}
class ResponseParser:
@staticmethod
def parse_response(msg: bytes) -> AsrResponse:
response = AsrResponse()
header_size = msg[0] & 0x0f
message_type = msg[1] >> 4
message_type_specific_flags = msg[1] & 0x0f
serialization_method = msg[2] >> 4
message_compression = msg[2] & 0x0f
payload = msg[header_size*4:]
# 解析message_type_specific_flags
if message_type_specific_flags & 0x01:
response.payload_sequence = struct.unpack('>i', payload[:4])[0]
payload = payload[4:]
if message_type_specific_flags & 0x02:
response.is_last_package = True
if message_type_specific_flags & 0x04:
response.event = struct.unpack('>i', payload[:4])[0]
payload = payload[4:]
# 解析message_type
if message_type == MessageType.SERVER_FULL_RESPONSE:
response.payload_size = struct.unpack('>I', payload[:4])[0]
payload = payload[4:]
elif message_type == MessageType.SERVER_ERROR_RESPONSE:
response.code = struct.unpack('>i', payload[:4])[0]
response.payload_size = struct.unpack('>I', payload[4:8])[0]
payload = payload[8:]
if not payload:
return response
# 解压缩
if message_compression == CompressionType.GZIP:
try:
payload = CommonUtils.gzip_decompress(payload)
except Exception as e:
logger.error(f"Failed to decompress payload: {e}")
return response
# 解析payload
try:
if serialization_method == SerializationType.JSON:
response.payload_msg = json.loads(payload.decode('utf-8'))
except Exception as e:
logger.error(f"Failed to parse payload: {e}")
return response
class AsrWsClient:
def __init__(self, url: str, segment_duration: int = 200):
self.seq = 1
self.url = url
self.segment_duration = segment_duration
self.conn = None
self.session = None # 添加session引用
async def __aenter__(self):
self.session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc, tb):
if self.conn and not self.conn.closed:
await self.conn.close()
if self.session and not self.session.closed:
await self.session.close()
async def read_audio_data(self, file_path: str) -> bytes:
try:
with open(file_path, 'rb') as f:
content = f.read()
if not CommonUtils.judge_wav(content):
logger.info("Converting audio to WAV format...")
content = CommonUtils.convert_wav_with_path(file_path, DEFAULT_SAMPLE_RATE)
return content
except Exception as e:
logger.error(f"Failed to read audio data: {e}")
raise
def get_segment_size(self, content: bytes) -> int:
try:
channel_num, samp_width, frame_rate, _, _ = CommonUtils.read_wav_info(content)[:5]
size_per_sec = channel_num * samp_width * frame_rate
segment_size = size_per_sec * self.segment_duration // 1000
return segment_size
except Exception as e:
logger.error(f"Failed to calculate segment size: {e}")
raise
async def create_connection(self) -> None:
headers = RequestBuilder.new_auth_headers()
try:
self.conn = await self.session.ws_connect( # 使用self.session
self.url,
headers=headers
)
logger.info(f"Connected to {self.url}")
except Exception as e:
logger.error(f"Failed to connect to WebSocket: {e}")
raise
async def send_full_client_request(self) -> None:
request = RequestBuilder.new_full_client_request(self.seq)
self.seq += 1 # 发送后递增
try:
await self.conn.send_bytes(request)
logger.info(f"Sent full client request with seq: {self.seq-1}")
msg = await self.conn.receive()
if msg.type == aiohttp.WSMsgType.BINARY:
response = ResponseParser.parse_response(msg.data)
logger.info(f"Received response: {response.to_dict()}")
else:
logger.error(f"Unexpected message type: {msg.type}")
except Exception as e:
logger.error(f"Failed to send full client request: {e}")
raise
async def send_messages(self, segment_size: int, content: bytes) -> AsyncGenerator[None, None]:
audio_segments = self.split_audio(content, segment_size)
total_segments = len(audio_segments)
for i, segment in enumerate(audio_segments):
is_last = (i == total_segments - 1)
request = RequestBuilder.new_audio_only_request(
self.seq,
segment,
is_last=is_last
)
await self.conn.send_bytes(request)
logger.info(f"Sent audio segment with seq: {self.seq} (last: {is_last})")
if not is_last:
self.seq += 1
await asyncio.sleep(self.segment_duration / 1000) # 逐个发送,间隔时间模拟实时流
# 让出控制权,允许接受消息
yield
async def recv_messages(self) -> AsyncGenerator[AsrResponse, None]:
try:
async for msg in self.conn:
if msg.type == aiohttp.WSMsgType.BINARY:
response = ResponseParser.parse_response(msg.data)
yield response
if response.is_last_package or response.code != 0:
break
elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error(f"WebSocket error: {msg.data}")
break
elif msg.type == aiohttp.WSMsgType.CLOSED:
logger.info("WebSocket connection closed")
break
except Exception as e:
logger.error(f"Error receiving messages: {e}")
raise
async def start_audio_stream(self, segment_size: int, content: bytes) -> AsyncGenerator[AsrResponse, None]:
async def sender():
async for _ in self.send_messages(segment_size, content):
pass
# 启动发送和接收任务
sender_task = asyncio.create_task(sender())
try:
async for response in self.recv_messages():
yield response
finally:
sender_task.cancel()
try:
await sender_task
except asyncio.CancelledError:
pass
@staticmethod
def split_audio(data: bytes, segment_size: int) -> List[bytes]:
if segment_size <= 0:
return []
segments = []
for i in range(0, len(data), segment_size):
end = i + segment_size
if end > len(data):
end = len(data)
segments.append(data[i:end])
return segments
async def execute(self, file_path: str) -> AsyncGenerator[AsrResponse, None]:
if not file_path:
raise ValueError("File path is empty")
if not self.url:
raise ValueError("URL is empty")
self.seq = 1
try:
# 1. 读取音频文件
content = await self.read_audio_data(file_path)
# 2. 计算分段大小
segment_size = self.get_segment_size(content)
# 3. 创建WebSocket连接
await self.create_connection()
# 4. 发送完整客户端请求
await self.send_full_client_request()
# 5. 启动音频流处理
async for response in self.start_audio_stream(segment_size, content):
yield response
except Exception as e:
logger.error(f"Error in ASR execution: {e}")
raise
finally:
if self.conn:
await self.conn.close()
async def main():
import argparse
parser = argparse.ArgumentParser(description="ASR WebSocket Client")
parser.add_argument("--file", type=str, required=True, help="Audio file path")
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async
#wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream
parser.add_argument("--url", type=str, default="wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream",
help="WebSocket URL")
parser.add_argument("--seg-duration", type=int, default=200,
help="Audio duration(ms) per packet, default:200")
args = parser.parse_args()
async with AsrWsClient(args.url, args.seg_duration) as client: # 使用async with
try:
async for response in client.execute(args.file):
logger.info(f"Received response: {json.dumps(response.to_dict(), indent=2, ensure_ascii=False)}")
except Exception as e:
logger.error(f"ASR processing failed: {e}")
if __name__ == "__main__":
asyncio.run(main())
# 用法:
# python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav

View File

@ -1,532 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
语音识别模块
基于 SAUC API 为录音文件提供语音识别功能
"""
import os
import json
import time
import logging
import asyncio
import aiohttp
import struct
import gzip
import uuid
from typing import Optional, List, Dict, Any, AsyncGenerator
from dataclasses import dataclass
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 常量定义
DEFAULT_SAMPLE_RATE = 16000
class ProtocolVersion:
V1 = 0b0001
class MessageType:
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ERROR_RESPONSE = 0b1111
class MessageTypeSpecificFlags:
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
NEG_SEQUENCE = 0b0010
NEG_WITH_SEQUENCE = 0b0011
class SerializationType:
NO_SERIALIZATION = 0b0000
JSON = 0b0001
class CompressionType:
GZIP = 0b0001
@dataclass
class RecognitionResult:
"""语音识别结果"""
text: str
confidence: float
is_final: bool
start_time: Optional[float] = None
end_time: Optional[float] = None
class AudioUtils:
"""音频处理工具类"""
@staticmethod
def gzip_compress(data: bytes) -> bytes:
"""GZIP压缩"""
return gzip.compress(data)
@staticmethod
def gzip_decompress(data: bytes) -> bytes:
"""GZIP解压缩"""
return gzip.decompress(data)
@staticmethod
def is_wav_file(data: bytes) -> bool:
"""检查是否为WAV文件"""
if len(data) < 44:
return False
return data[:4] == b'RIFF' and data[8:12] == b'WAVE'
@staticmethod
def read_wav_info(data: bytes) -> tuple:
"""读取WAV文件信息"""
if len(data) < 44:
raise ValueError("Invalid WAV file: too short")
# 解析WAV头
chunk_id = data[:4]
if chunk_id != b'RIFF':
raise ValueError("Invalid WAV file: not RIFF format")
format_ = data[8:12]
if format_ != b'WAVE':
raise ValueError("Invalid WAV file: not WAVE format")
# 解析fmt子块
audio_format = struct.unpack('<H', data[20:22])[0]
num_channels = struct.unpack('<H', data[22:24])[0]
sample_rate = struct.unpack('<I', data[24:28])[0]
bits_per_sample = struct.unpack('<H', data[34:36])[0]
# 查找data子块
pos = 36
while pos < len(data) - 8:
subchunk_id = data[pos:pos+4]
subchunk_size = struct.unpack('<I', data[pos+4:pos+8])[0]
if subchunk_id == b'data':
wave_data = data[pos+8:pos+8+subchunk_size]
return (
num_channels,
bits_per_sample // 8,
sample_rate,
subchunk_size // (num_channels * (bits_per_sample // 8)),
wave_data
)
pos += 8 + subchunk_size
raise ValueError("Invalid WAV file: no data subchunk found")
class AsrConfig:
"""ASR配置"""
def __init__(self, app_key: str = None, access_key: str = None):
self.auth = {
"app_key": app_key or os.getenv("SAUC_APP_KEY", "your_app_key"),
"access_key": access_key or os.getenv("SAUC_ACCESS_KEY", "your_access_key")
}
@property
def app_key(self) -> str:
return self.auth["app_key"]
@property
def access_key(self) -> str:
return self.auth["access_key"]
class AsrRequestHeader:
"""ASR请求头"""
def __init__(self):
self.message_type = MessageType.CLIENT_FULL_REQUEST
self.message_type_specific_flags = MessageTypeSpecificFlags.POS_SEQUENCE
self.serialization_type = SerializationType.JSON
self.compression_type = CompressionType.GZIP
self.reserved_data = bytes([0x00])
def with_message_type(self, message_type: int) -> 'AsrRequestHeader':
self.message_type = message_type
return self
def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader':
self.message_type_specific_flags = flags
return self
def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader':
self.serialization_type = serialization_type
return self
def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader':
self.compression_type = compression_type
return self
def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader':
self.reserved_data = reserved_data
return self
def to_bytes(self) -> bytes:
header = bytearray()
header.append((ProtocolVersion.V1 << 4) | 1)
header.append((self.message_type << 4) | self.message_type_specific_flags)
header.append((self.serialization_type << 4) | self.compression_type)
header.extend(self.reserved_data)
return bytes(header)
@staticmethod
def default_header() -> 'AsrRequestHeader':
return AsrRequestHeader()
class RequestBuilder:
"""请求构建器"""
@staticmethod
def new_auth_headers(config: AsrConfig) -> Dict[str, str]:
"""创建认证头"""
reqid = str(uuid.uuid4())
return {
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
"X-Api-Request-Id": reqid,
"X-Api-Access-Key": config.access_key,
"X-Api-App-Key": config.app_key
}
@staticmethod
def new_full_client_request(seq: int) -> bytes:
"""创建完整客户端请求"""
header = AsrRequestHeader.default_header() \
.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
payload = {
"user": {
"uid": "local_voice_user"
},
"audio": {
"format": "wav",
"codec": "raw",
"rate": 16000,
"bits": 16,
"channel": 1
},
"request": {
"model_name": "bigmodel",
"enable_itn": True,
"enable_punc": True,
"enable_ddc": True,
"show_utterances": True,
"enable_nonstream": False
}
}
payload_bytes = json.dumps(payload).encode('utf-8')
compressed_payload = AudioUtils.gzip_compress(payload_bytes)
payload_size = len(compressed_payload)
request = bytearray()
request.extend(header.to_bytes())
request.extend(struct.pack('>i', seq))
request.extend(struct.pack('>U', payload_size))
request.extend(compressed_payload)
return bytes(request)
@staticmethod
def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes:
"""创建纯音频请求"""
header = AsrRequestHeader.default_header()
if is_last:
header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE)
seq = -seq
else:
header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE)
header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST)
request = bytearray()
request.extend(header.to_bytes())
request.extend(struct.pack('>i', seq))
compressed_segment = AudioUtils.gzip_compress(segment)
request.extend(struct.pack('>U', len(compressed_segment)))
request.extend(compressed_segment)
return bytes(request)
class AsrResponse:
"""ASR响应"""
def __init__(self):
self.code = 0
self.event = 0
self.is_last_package = False
self.payload_sequence = 0
self.payload_size = 0
self.payload_msg = None
def to_dict(self) -> Dict[str, Any]:
return {
"code": self.code,
"event": self.event,
"is_last_package": self.is_last_package,
"payload_sequence": self.payload_sequence,
"payload_size": self.payload_size,
"payload_msg": self.payload_msg
}
class ResponseParser:
"""响应解析器"""
@staticmethod
def parse_response(msg: bytes) -> AsrResponse:
"""解析响应"""
response = AsrResponse()
header_size = msg[0] & 0x0f
message_type = msg[1] >> 4
message_type_specific_flags = msg[1] & 0x0f
serialization_method = msg[2] >> 4
message_compression = msg[2] & 0x0f
payload = msg[header_size*4:]
# 解析message_type_specific_flags
if message_type_specific_flags & 0x01:
response.payload_sequence = struct.unpack('>i', payload[:4])[0]
payload = payload[4:]
if message_type_specific_flags & 0x02:
response.is_last_package = True
if message_type_specific_flags & 0x04:
response.event = struct.unpack('>i', payload[:4])[0]
payload = payload[4:]
# 解析message_type
if message_type == MessageType.SERVER_FULL_RESPONSE:
response.payload_size = struct.unpack('>U', payload[:4])[0]
payload = payload[4:]
elif message_type == MessageType.SERVER_ERROR_RESPONSE:
response.code = struct.unpack('>i', payload[:4])[0]
response.payload_size = struct.unpack('>U', payload[4:8])[0]
payload = payload[8:]
if not payload:
return response
# 解压缩
if message_compression == CompressionType.GZIP:
try:
payload = AudioUtils.gzip_decompress(payload)
except Exception as e:
logger.error(f"Failed to decompress payload: {e}")
return response
# 解析payload
try:
if serialization_method == SerializationType.JSON:
response.payload_msg = json.loads(payload.decode('utf-8'))
except Exception as e:
logger.error(f"Failed to parse payload: {e}")
return response
class SpeechRecognizer:
"""语音识别器"""
def __init__(self, app_key: str = None, access_key: str = None,
url: str = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream"):
self.config = AsrConfig(app_key, access_key)
self.url = url
self.seq = 1
async def recognize_file(self, file_path: str) -> List[RecognitionResult]:
"""识别音频文件"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"Audio file not found: {file_path}")
results = []
try:
async with aiohttp.ClientSession() as session:
# 读取音频文件
with open(file_path, 'rb') as f:
content = f.read()
if not AudioUtils.is_wav_file(content):
raise ValueError("Audio file must be in WAV format")
# 获取音频信息
try:
_, _, sample_rate, _, audio_data = AudioUtils.read_wav_info(content)
if sample_rate != DEFAULT_SAMPLE_RATE:
logger.warning(f"Sample rate {sample_rate} != {DEFAULT_SAMPLE_RATE}, may affect recognition accuracy")
except Exception as e:
logger.error(f"Failed to read audio info: {e}")
raise
# 计算分段大小 (200ms per segment)
segment_size = 1 * 2 * DEFAULT_SAMPLE_RATE * 200 // 1000 # channel * bytes_per_sample * sample_rate * duration_ms / 1000
# 创建WebSocket连接
headers = RequestBuilder.new_auth_headers(self.config)
async with session.ws_connect(self.url, headers=headers) as ws:
# 发送完整客户端请求
request = RequestBuilder.new_full_client_request(self.seq)
self.seq += 1
await ws.send_bytes(request)
# 接收初始响应
msg = await ws.receive()
if msg.type == aiohttp.WSMsgType.BINARY:
response = ResponseParser.parse_response(msg.data)
logger.info(f"Initial response: {response.to_dict()}")
# 分段发送音频数据
audio_segments = self._split_audio(audio_data, segment_size)
total_segments = len(audio_segments)
for i, segment in enumerate(audio_segments):
is_last = (i == total_segments - 1)
request = RequestBuilder.new_audio_only_request(
self.seq,
segment,
is_last=is_last
)
await ws.send_bytes(request)
logger.info(f"Sent audio segment {i+1}/{total_segments}")
if not is_last:
self.seq += 1
# 短暂延迟模拟实时流
await asyncio.sleep(0.1)
# 接收识别结果
final_text = ""
while True:
msg = await ws.receive()
if msg.type == aiohttp.WSMsgType.BINARY:
response = ResponseParser.parse_response(msg.data)
if response.payload_msg and 'text' in response.payload_msg:
text = response.payload_msg['text']
if text:
final_text += text
result = RecognitionResult(
text=text,
confidence=0.9, # 默认置信度
is_final=response.is_last_package
)
results.append(result)
logger.info(f"Recognized: {text}")
if response.is_last_package or response.code != 0:
break
elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error(f"WebSocket error: {msg.data}")
break
elif msg.type == aiohttp.WSMsgType.CLOSED:
logger.info("WebSocket connection closed")
break
# 如果没有获得最终结果,创建一个包含所有文本的结果
if final_text and not any(r.is_final for r in results):
final_result = RecognitionResult(
text=final_text,
confidence=0.9,
is_final=True
)
results.append(final_result)
return results
except Exception as e:
logger.error(f"Speech recognition failed: {e}")
raise
def _split_audio(self, data: bytes, segment_size: int) -> List[bytes]:
"""分割音频数据"""
if segment_size <= 0:
return []
segments = []
for i in range(0, len(data), segment_size):
end = i + segment_size
if end > len(data):
end = len(data)
segments.append(data[i:end])
return segments
async def recognize_latest_recording(self, directory: str = ".") -> Optional[RecognitionResult]:
"""识别最新的录音文件"""
# 查找最新的录音文件
recording_files = [f for f in os.listdir(directory) if f.startswith('recording_') and f.endswith('.wav')]
if not recording_files:
logger.warning("No recording files found")
return None
# 按文件名排序(包含时间戳)
recording_files.sort(reverse=True)
latest_file = recording_files[0]
latest_path = os.path.join(directory, latest_file)
logger.info(f"Recognizing latest recording: {latest_file}")
try:
results = await self.recognize_file(latest_path)
if results:
# 返回最终的识别结果
final_results = [r for r in results if r.is_final]
if final_results:
return final_results[-1]
else:
# 如果没有标记为final的结果返回最后一个
return results[-1]
except Exception as e:
logger.error(f"Failed to recognize latest recording: {e}")
return None
async def main():
"""测试函数"""
import argparse
parser = argparse.ArgumentParser(description="语音识别测试")
parser.add_argument("--file", type=str, help="音频文件路径")
parser.add_argument("--latest", action="store_true", help="识别最新的录音文件")
parser.add_argument("--app-key", type=str, help="SAUC App Key")
parser.add_argument("--access-key", type=str, help="SAUC Access Key")
args = parser.parse_args()
recognizer = SpeechRecognizer(
app_key=args.app_key,
access_key=args.access_key
)
try:
if args.latest:
result = await recognizer.recognize_latest_recording()
if result:
print(f"识别结果: {result.text}")
print(f"置信度: {result.confidence}")
print(f"最终结果: {result.is_final}")
else:
print("未能识别到语音内容")
elif args.file:
results = await recognizer.recognize_file(args.file)
for i, result in enumerate(results):
print(f"结果 {i+1}: {result.text}")
print(f"置信度: {result.confidence}")
print(f"最终结果: {result.is_final}")
print("-" * 40)
else:
print("请指定 --file 或 --latest 参数")
except Exception as e:
print(f"识别失败: {e}")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,113 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
豆包音频处理模块 - 验证脚本
验证完整的音频处理流程
"""
import asyncio
import subprocess
import os
from doubao_simple import DoubaoClient
async def test_complete_workflow():
"""测试完整的工作流程"""
print("=== 豆包音频处理模块验证 ===")
# 检查输入文件
input_file = "recording_20250920_135137.wav"
if not os.path.exists(input_file):
print(f"❌ 输入文件不存在: {input_file}")
return False
print(f"✅ 输入文件存在: {input_file}")
# 检查文件信息
try:
result = subprocess.run(['file', input_file], capture_output=True, text=True)
print(f"📁 输入文件格式: {result.stdout.strip()}")
except:
pass
# 初始化客户端
client = DoubaoClient()
try:
# 连接服务器
print("🔌 连接豆包服务器...")
await client.connect()
print("✅ 连接成功")
# 处理音频文件
output_file = "tts_output.wav"
print(f"🎵 处理音频文件: {input_file} -> {output_file}")
success = await client.process_audio_file(input_file, output_file)
if success:
print("✅ 音频处理成功!")
# 检查输出文件
if os.path.exists(output_file):
result = subprocess.run(['file', output_file], capture_output=True, text=True)
print(f"📁 输出文件格式: {result.stdout.strip()}")
# 获取文件大小
file_size = os.path.getsize(output_file)
print(f"📊 输出文件大小: {file_size:,} 字节")
# 测试播放
print("🔊 测试播放输出文件...")
try:
subprocess.run(['aplay', output_file], timeout=10, check=True)
print("✅ 播放成功")
except subprocess.TimeoutExpired:
print("✅ 播放完成(超时是正常的)")
except subprocess.CalledProcessError as e:
print(f"⚠️ 播放出现问题: {e}")
except FileNotFoundError:
print("⚠️ aplay命令不存在跳过播放测试")
return True
else:
print("❌ 输出文件未生成")
return False
else:
print("❌ 音频处理失败")
return False
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return False
finally:
try:
await client.close()
except:
pass
def main():
"""主函数"""
print("开始验证豆包音频处理模块...")
success = asyncio.run(test_complete_workflow())
if success:
print("\n🎉 验证完成!豆包音频处理模块工作正常。")
print("\n📋 功能总结:")
print(" ✅ WebSocket连接建立")
print(" ✅ 音频文件上传")
print(" ✅ 语音识别")
print(" ✅ TTS音频生成")
print(" ✅ 音频格式转换Float32 -> Int16")
print(" ✅ WAV文件生成")
print(" ✅ 树莓派兼容播放")
else:
print("\n❌ 验证失败,请检查错误信息。")
return success
if __name__ == "__main__":
main()