251 lines
7.9 KiB
Python
251 lines
7.9 KiB
Python
import gzip
|
|
import json
|
|
import struct
|
|
import uuid
|
|
import logging
|
|
from typing import AsyncGenerator, Tuple
|
|
|
|
import websockets
|
|
|
|
from utils.settings import (
|
|
VOLCENGINE_ACCESS_KEY,
|
|
VOLCENGINE_APP_ID,
|
|
)
|
|
|
|
VOLCENGINE_ASR_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async"
|
|
logger = logging.getLogger('app')
|
|
|
|
# Protocol constants (v3/sauc)
|
|
PROTOCOL_VERSION = 0b0001
|
|
HEADER_SIZE = 0b0001
|
|
|
|
# Message types
|
|
FULL_CLIENT_REQUEST = 0b0001
|
|
AUDIO_ONLY_REQUEST = 0b0010
|
|
FULL_SERVER_RESPONSE = 0b1001
|
|
SERVER_ERROR_RESPONSE = 0b1111
|
|
|
|
# Flags
|
|
POS_SEQUENCE = 0b0001
|
|
NEG_SEQUENCE = 0b0010
|
|
NEG_WITH_SEQUENCE = 0b0011
|
|
|
|
# Serialization / Compression
|
|
JSON_SERIAL = 0b0001
|
|
GZIP_COMPRESS = 0b0001
|
|
|
|
|
|
def _build_header(msg_type: int, flags: int = POS_SEQUENCE,
|
|
serial: int = JSON_SERIAL, compress: int = GZIP_COMPRESS) -> bytearray:
|
|
header = bytearray(4)
|
|
header[0] = (PROTOCOL_VERSION << 4) | HEADER_SIZE
|
|
header[1] = (msg_type << 4) | flags
|
|
header[2] = (serial << 4) | compress
|
|
header[3] = 0x00
|
|
return header
|
|
|
|
|
|
class StreamingASRClient:
|
|
"""Volcengine v3/sauc/bigmodel streaming ASR client."""
|
|
|
|
def __init__(self, uid: str = "voice_lite"):
|
|
self._uid = uid
|
|
self._ws = None
|
|
self._seq = 1
|
|
|
|
def _build_config(self) -> dict:
|
|
return {
|
|
"user": {
|
|
"uid": self._uid,
|
|
},
|
|
"audio": {
|
|
"format": "pcm",
|
|
"codec": "raw",
|
|
"rate": 24000,
|
|
"bits": 16,
|
|
"channel": 1,
|
|
},
|
|
"request": {
|
|
"model_name": "bigmodel",
|
|
"enable_itn": True,
|
|
"enable_punc": True,
|
|
"enable_ddc": True,
|
|
"show_utterances": True,
|
|
"enable_nonstream": False,
|
|
},
|
|
}
|
|
|
|
def _build_auth_headers(self) -> dict:
|
|
return {
|
|
"X-Api-Resource-Id": "volc.seedasr.sauc.duration",
|
|
"X-Api-Connect-Id": str(uuid.uuid4()),
|
|
"X-Api-Access-Key": VOLCENGINE_ACCESS_KEY,
|
|
"X-Api-App-Key": VOLCENGINE_APP_ID,
|
|
}
|
|
|
|
async def connect(self) -> None:
|
|
"""Connect to ASR WebSocket and send initial full_client_request."""
|
|
headers = self._build_auth_headers()
|
|
logger.info(f"[ASR] Connecting to {VOLCENGINE_ASR_URL} with headers: {headers}")
|
|
self._ws = await websockets.connect(
|
|
VOLCENGINE_ASR_URL,
|
|
additional_headers=headers,
|
|
ping_interval=None,
|
|
proxy=None
|
|
)
|
|
logger.info(f"[ASR] Connected to {VOLCENGINE_ASR_URL}")
|
|
|
|
# Send full_client_request with config (seq=1)
|
|
self._seq = 1
|
|
config = self._build_config()
|
|
config_bytes = gzip.compress(json.dumps(config).encode())
|
|
|
|
frame = bytearray(_build_header(FULL_CLIENT_REQUEST, POS_SEQUENCE, JSON_SERIAL, GZIP_COMPRESS))
|
|
frame.extend(struct.pack('>i', self._seq))
|
|
frame.extend(struct.pack('>I', len(config_bytes)))
|
|
frame.extend(config_bytes)
|
|
self._seq += 1
|
|
|
|
await self._ws.send(bytes(frame))
|
|
|
|
# Wait for server ack
|
|
resp = await self._ws.recv()
|
|
parsed = self._parse_response(resp)
|
|
if parsed and parsed.get("code", 0) != 0:
|
|
raise ConnectionError(f"[ASR] Server rejected config: {parsed}")
|
|
logger.info(f"[ASR] Config accepted, ready for audio")
|
|
|
|
async def send_audio(self, chunk: bytes) -> None:
|
|
"""Send an audio chunk to ASR with sequence number."""
|
|
if not self._ws:
|
|
return
|
|
compressed = gzip.compress(chunk)
|
|
frame = bytearray(_build_header(AUDIO_ONLY_REQUEST, POS_SEQUENCE, JSON_SERIAL, GZIP_COMPRESS))
|
|
frame.extend(struct.pack('>i', self._seq))
|
|
frame.extend(struct.pack('>I', len(compressed)))
|
|
frame.extend(compressed)
|
|
self._seq += 1
|
|
await self._ws.send(bytes(frame))
|
|
|
|
async def send_finish(self) -> None:
|
|
"""Send last audio frame with negative sequence to signal end."""
|
|
if not self._ws:
|
|
return
|
|
payload = gzip.compress(b'')
|
|
frame = bytearray(_build_header(AUDIO_ONLY_REQUEST, NEG_WITH_SEQUENCE, JSON_SERIAL, GZIP_COMPRESS))
|
|
frame.extend(struct.pack('>i', -self._seq))
|
|
frame.extend(struct.pack('>I', len(payload)))
|
|
frame.extend(payload)
|
|
await self._ws.send(bytes(frame))
|
|
|
|
async def receive_results(self) -> AsyncGenerator[Tuple[str, bool], None]:
|
|
"""Yield (text, is_last) tuples from ASR responses."""
|
|
if not self._ws:
|
|
return
|
|
try:
|
|
async for message in self._ws:
|
|
if isinstance(message, str):
|
|
logger.info(f"[ASR] Received text message: {message[:200]}")
|
|
continue
|
|
parsed = self._parse_response(message)
|
|
logger.info(f"[ASR] Received binary ({len(message)} bytes), parsed: {parsed}")
|
|
if parsed is None:
|
|
continue
|
|
|
|
code = parsed.get("code", 0)
|
|
if code != 0:
|
|
logger.warning(f"[ASR] Server error: {parsed}")
|
|
return
|
|
|
|
is_last = parsed.get("is_last", False)
|
|
payload_msg = parsed.get("payload_msg")
|
|
|
|
if payload_msg and isinstance(payload_msg, dict):
|
|
text = self._extract_text(payload_msg)
|
|
if text:
|
|
yield (text, is_last)
|
|
|
|
if is_last:
|
|
return
|
|
|
|
except websockets.exceptions.ConnectionClosed:
|
|
logger.info("[ASR] Connection closed")
|
|
|
|
@staticmethod
|
|
def _extract_text(payload: dict) -> str:
|
|
"""Extract recognized text from payload."""
|
|
result = payload.get("result")
|
|
if not result or not isinstance(result, dict):
|
|
return ""
|
|
|
|
# Try utterances first (show_utterances=True)
|
|
utterances = result.get("utterances", [])
|
|
if utterances:
|
|
parts = []
|
|
for utt in utterances:
|
|
text = utt.get("text", "")
|
|
if text:
|
|
parts.append(text)
|
|
return "".join(parts)
|
|
|
|
# Fallback to result.text
|
|
text = result.get("text", "")
|
|
if isinstance(text, str):
|
|
return text
|
|
return ""
|
|
|
|
def _parse_response(self, data: bytes) -> dict:
|
|
"""Parse binary ASR response into a dict."""
|
|
if len(data) < 4:
|
|
return None
|
|
|
|
msg_type = data[1] >> 4
|
|
msg_flags = data[1] & 0x0f
|
|
serial_method = data[2] >> 4
|
|
compression = data[2] & 0x0f
|
|
|
|
header_size = data[0] & 0x0f
|
|
payload = data[header_size * 4:]
|
|
|
|
result = {"code": 0, "is_last": False}
|
|
|
|
# Parse sequence and last flag
|
|
if msg_flags & 0x01: # has sequence
|
|
result["sequence"] = struct.unpack('>i', payload[:4])[0]
|
|
payload = payload[4:]
|
|
if msg_flags & 0x02: # is last package
|
|
result["is_last"] = True
|
|
|
|
if msg_type == SERVER_ERROR_RESPONSE:
|
|
result["code"] = struct.unpack('>i', payload[:4])[0]
|
|
payload_size = struct.unpack('>I', payload[4:8])[0]
|
|
payload = payload[8:]
|
|
elif msg_type == FULL_SERVER_RESPONSE:
|
|
payload_size = struct.unpack('>I', payload[:4])[0]
|
|
payload = payload[4:]
|
|
else:
|
|
return result
|
|
|
|
if not payload:
|
|
return result
|
|
|
|
if compression == GZIP_COMPRESS:
|
|
try:
|
|
payload = gzip.decompress(payload)
|
|
except Exception:
|
|
return result
|
|
|
|
if serial_method == JSON_SERIAL:
|
|
try:
|
|
result["payload_msg"] = json.loads(payload.decode('utf-8'))
|
|
except Exception:
|
|
pass
|
|
|
|
return result
|
|
|
|
async def close(self) -> None:
|
|
if self._ws:
|
|
logger.info("[ASR] Closing connection")
|
|
await self._ws.close()
|
|
self._ws = None
|