#coding=utf-8 """ requires Python 3.6 or later pip install asyncio pip install websockets """ import asyncio import base64 import gzip import hmac import json import logging import os import time import uuid import wave from enum import Enum from hashlib import sha256 from io import BytesIO from typing import List from urllib.parse import urlparse import websockets appid = "8718217928" # 项目的 appid token = "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc" # 项目的 token cluster = "volcengine_input_common" # 请求的集群 audio_path = "recording_20250920_161438.wav" # 本地音频路径 audio_format = "wav" # wav 或者 mp3,根据实际音频格式设置 PROTOCOL_VERSION = 0b0001 DEFAULT_HEADER_SIZE = 0b0001 PROTOCOL_VERSION_BITS = 4 HEADER_BITS = 4 MESSAGE_TYPE_BITS = 4 MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4 MESSAGE_SERIALIZATION_BITS = 4 MESSAGE_COMPRESSION_BITS = 4 RESERVED_BITS = 8 # Message Type: CLIENT_FULL_REQUEST = 0b0001 CLIENT_AUDIO_ONLY_REQUEST = 0b0010 SERVER_FULL_RESPONSE = 0b1001 SERVER_ACK = 0b1011 SERVER_ERROR_RESPONSE = 0b1111 # Message Type Specific Flags NO_SEQUENCE = 0b0000 # no check sequence POS_SEQUENCE = 0b0001 NEG_SEQUENCE = 0b0010 NEG_SEQUENCE_1 = 0b0011 # Message Serialization NO_SERIALIZATION = 0b0000 JSON = 0b0001 THRIFT = 0b0011 CUSTOM_TYPE = 0b1111 # Message Compression NO_COMPRESSION = 0b0000 GZIP = 0b0001 CUSTOM_COMPRESSION = 0b1111 def generate_header( version=PROTOCOL_VERSION, message_type=CLIENT_FULL_REQUEST, message_type_specific_flags=NO_SEQUENCE, serial_method=JSON, compression_type=GZIP, reserved_data=0x00, extension_header=bytes() ): """ protocol_version(4 bits), header_size(4 bits), message_type(4 bits), message_type_specific_flags(4 bits) serialization_method(4 bits) message_compression(4 bits) reserved (8bits) 保留字段 header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) """ header = bytearray() header_size = int(len(extension_header) / 4) + 1 header.append((version << 4) | header_size) header.append((message_type << 4) | message_type_specific_flags) header.append((serial_method << 4) | compression_type) header.append(reserved_data) header.extend(extension_header) return header def generate_full_default_header(): return generate_header() def generate_audio_default_header(): return generate_header( message_type=CLIENT_AUDIO_ONLY_REQUEST ) def generate_last_audio_default_header(): return generate_header( message_type=CLIENT_AUDIO_ONLY_REQUEST, message_type_specific_flags=NEG_SEQUENCE ) def parse_response(res): """ protocol_version(4 bits), header_size(4 bits), message_type(4 bits), message_type_specific_flags(4 bits) serialization_method(4 bits) message_compression(4 bits) reserved (8bits) 保留字段 header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) payload 类似与http 请求体 """ protocol_version = res[0] >> 4 header_size = res[0] & 0x0f message_type = res[1] >> 4 message_type_specific_flags = res[1] & 0x0f serialization_method = res[2] >> 4 message_compression = res[2] & 0x0f reserved = res[3] header_extensions = res[4:header_size * 4] payload = res[header_size * 4:] result = {} payload_msg = None payload_size = 0 if message_type == SERVER_FULL_RESPONSE: payload_size = int.from_bytes(payload[:4], "big", signed=True) payload_msg = payload[4:] elif message_type == SERVER_ACK: seq = int.from_bytes(payload[:4], "big", signed=True) result['seq'] = seq if len(payload) >= 8: payload_size = int.from_bytes(payload[4:8], "big", signed=False) payload_msg = payload[8:] elif message_type == SERVER_ERROR_RESPONSE: code = int.from_bytes(payload[:4], "big", signed=False) result['code'] = code payload_size = int.from_bytes(payload[4:8], "big", signed=False) payload_msg = payload[8:] if payload_msg is None: return result if message_compression == GZIP: payload_msg = gzip.decompress(payload_msg) if serialization_method == JSON: payload_msg = json.loads(str(payload_msg, "utf-8")) elif serialization_method != NO_SERIALIZATION: payload_msg = str(payload_msg, "utf-8") result['payload_msg'] = payload_msg result['payload_size'] = payload_size return result def read_wav_info(data: bytes = None) -> (int, int, int, int, int): with BytesIO(data) as _f: wave_fp = wave.open(_f, 'rb') nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4] wave_bytes = wave_fp.readframes(nframes) return nchannels, sampwidth, framerate, nframes, len(wave_bytes) class AudioType(Enum): LOCAL = 1 # 使用本地音频文件 class AsrWsClient: def __init__(self, audio_path, cluster, **kwargs): """ :param config: config """ self.audio_path = audio_path self.cluster = cluster self.success_code = 1000 # success code, default is 1000 self.seg_duration = int(kwargs.get("seg_duration", 15000)) self.nbest = int(kwargs.get("nbest", 1)) self.appid = kwargs.get("appid", "") self.token = kwargs.get("token", "") self.ws_url = kwargs.get("ws_url", "wss://openspeech.bytedance.com/api/v2/asr") self.uid = kwargs.get("uid", "streaming_asr_demo") self.workflow = kwargs.get("workflow", "audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate") self.show_language = kwargs.get("show_language", False) self.show_utterances = kwargs.get("show_utterances", False) self.result_type = kwargs.get("result_type", "full") self.format = kwargs.get("format", "wav") self.rate = kwargs.get("sample_rate", 16000) self.language = kwargs.get("language", "zh-CN") self.bits = kwargs.get("bits", 16) self.channel = kwargs.get("channel", 1) self.codec = kwargs.get("codec", "raw") self.audio_type = kwargs.get("audio_type", AudioType.LOCAL) self.secret = kwargs.get("secret", "access_secret") self.auth_method = kwargs.get("auth_method", "token") self.mp3_seg_size = int(kwargs.get("mp3_seg_size", 10000)) def construct_request(self, reqid): req = { 'app': { 'appid': self.appid, 'cluster': self.cluster, 'token': self.token, }, 'user': { 'uid': self.uid }, 'request': { 'reqid': reqid, 'nbest': self.nbest, 'workflow': self.workflow, 'show_language': self.show_language, 'show_utterances': self.show_utterances, 'result_type': self.result_type, "sequence": 1 }, 'audio': { 'format': self.format, 'rate': self.rate, 'language': self.language, 'bits': self.bits, 'channel': self.channel, 'codec': self.codec } } return req @staticmethod def slice_data(data: bytes, chunk_size: int) -> (list, bool): """ slice data :param data: wav data :param chunk_size: the segment size in one request :return: segment data, last flag """ data_len = len(data) offset = 0 while offset + chunk_size < data_len: yield data[offset: offset + chunk_size], False offset += chunk_size else: yield data[offset: data_len], True def _real_processor(self, request_params: dict) -> dict: pass def token_auth(self): return {'Authorization': 'Bearer; {}'.format(self.token)} def signature_auth(self, data): header_dicts = { 'Custom': 'auth_custom', } url_parse = urlparse(self.ws_url) input_str = 'GET {} HTTP/1.1\n'.format(url_parse.path) auth_headers = 'Custom' for header in auth_headers.split(','): input_str += '{}\n'.format(header_dicts[header]) input_data = bytearray(input_str, 'utf-8') input_data += data mac = base64.urlsafe_b64encode( hmac.new(self.secret.encode('utf-8'), input_data, digestmod=sha256).digest()) header_dicts['Authorization'] = 'HMAC256; access_token="{}"; mac="{}"; h="{}"'.format(self.token, str(mac, 'utf-8'), auth_headers) return header_dicts async def segment_data_processor(self, wav_data: bytes, segment_size: int): reqid = str(uuid.uuid4()) # 构建 full client request,并序列化压缩 request_params = self.construct_request(reqid) payload_bytes = str.encode(json.dumps(request_params)) payload_bytes = gzip.compress(payload_bytes) full_client_request = bytearray(generate_full_default_header()) full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) full_client_request.extend(payload_bytes) # payload additional_headers = None if self.auth_method == "token": additional_headers = self.token_auth() elif self.auth_method == "signature": additional_headers = self.signature_auth(full_client_request) connection_kwargs = {"max_size": 1000000000} if additional_headers: connection_kwargs["additional_headers"] = additional_headers async with websockets.connect(self.ws_url, **connection_kwargs) as ws: # 发送 full client request await ws.send(full_client_request) res = await ws.recv() result = parse_response(res) if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code: return result for seq, (chunk, last) in enumerate(AsrWsClient.slice_data(wav_data, segment_size), 1): # if no compression, comment this line payload_bytes = gzip.compress(chunk) audio_only_request = bytearray(generate_audio_default_header()) if last: audio_only_request = bytearray(generate_last_audio_default_header()) audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) audio_only_request.extend(payload_bytes) # payload # 发送 audio-only client request await ws.send(audio_only_request) res = await ws.recv() result = parse_response(res) if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code: return result return result async def execute(self): with open(self.audio_path, mode="rb") as _f: data = _f.read() audio_data = bytes(data) if self.format == "mp3": segment_size = self.mp3_seg_size return await self.segment_data_processor(audio_data, segment_size) if self.format != "wav": raise Exception("format should in wav or mp3") nchannels, sampwidth, framerate, nframes, wav_len = read_wav_info( audio_data) size_per_sec = nchannels * sampwidth * framerate segment_size = int(size_per_sec * self.seg_duration / 1000) return await self.segment_data_processor(audio_data, segment_size) def execute_one(audio_item, cluster, **kwargs): """ :param audio_item: {"id": xxx, "path": "xxx"} :param cluster:集群名称 :return: """ assert 'id' in audio_item assert 'path' in audio_item audio_id = audio_item['id'] audio_path = audio_item['path'] audio_type = AudioType.LOCAL asr_http_client = AsrWsClient( audio_path=audio_path, cluster=cluster, audio_type=audio_type, **kwargs ) result = asyncio.run(asr_http_client.execute()) return {"id": audio_id, "path": audio_path, "result": result} def test_one(): result = execute_one( { 'id': 1, 'path': audio_path }, cluster=cluster, appid=appid, token=token, format=audio_format, ) print(result) if __name__ == '__main__': test_one()