qwen_agent/services/streaming_asr_client.py
2026-03-21 23:50:51 +08:00

251 lines
7.9 KiB
Python

import gzip
import json
import struct
import uuid
import logging
from typing import AsyncGenerator, Tuple
import websockets
from utils.settings import (
VOLCENGINE_ACCESS_KEY,
VOLCENGINE_APP_ID,
)
VOLCENGINE_ASR_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async"
logger = logging.getLogger('app')
# Protocol constants (v3/sauc)
PROTOCOL_VERSION = 0b0001
HEADER_SIZE = 0b0001
# Message types
FULL_CLIENT_REQUEST = 0b0001
AUDIO_ONLY_REQUEST = 0b0010
FULL_SERVER_RESPONSE = 0b1001
SERVER_ERROR_RESPONSE = 0b1111
# Flags
POS_SEQUENCE = 0b0001
NEG_SEQUENCE = 0b0010
NEG_WITH_SEQUENCE = 0b0011
# Serialization / Compression
JSON_SERIAL = 0b0001
GZIP_COMPRESS = 0b0001
def _build_header(msg_type: int, flags: int = POS_SEQUENCE,
serial: int = JSON_SERIAL, compress: int = GZIP_COMPRESS) -> bytearray:
header = bytearray(4)
header[0] = (PROTOCOL_VERSION << 4) | HEADER_SIZE
header[1] = (msg_type << 4) | flags
header[2] = (serial << 4) | compress
header[3] = 0x00
return header
class StreamingASRClient:
"""Volcengine v3/sauc/bigmodel streaming ASR client."""
def __init__(self, uid: str = "voice_lite"):
self._uid = uid
self._ws = None
self._seq = 1
def _build_config(self) -> dict:
return {
"user": {
"uid": self._uid,
},
"audio": {
"format": "pcm",
"codec": "raw",
"rate": 24000,
"bits": 16,
"channel": 1,
},
"request": {
"model_name": "bigmodel",
"enable_itn": True,
"enable_punc": True,
"enable_ddc": True,
"show_utterances": True,
"enable_nonstream": False,
},
}
def _build_auth_headers(self) -> dict:
return {
"X-Api-Resource-Id": "volc.seedasr.sauc.duration",
"X-Api-Connect-Id": str(uuid.uuid4()),
"X-Api-Access-Key": VOLCENGINE_ACCESS_KEY,
"X-Api-App-Key": VOLCENGINE_APP_ID,
}
async def connect(self) -> None:
"""Connect to ASR WebSocket and send initial full_client_request."""
headers = self._build_auth_headers()
logger.info(f"[ASR] Connecting to {VOLCENGINE_ASR_URL} with headers: {headers}")
self._ws = await websockets.connect(
VOLCENGINE_ASR_URL,
additional_headers=headers,
ping_interval=None,
proxy=None
)
logger.info(f"[ASR] Connected to {VOLCENGINE_ASR_URL}")
# Send full_client_request with config (seq=1)
self._seq = 1
config = self._build_config()
config_bytes = gzip.compress(json.dumps(config).encode())
frame = bytearray(_build_header(FULL_CLIENT_REQUEST, POS_SEQUENCE, JSON_SERIAL, GZIP_COMPRESS))
frame.extend(struct.pack('>i', self._seq))
frame.extend(struct.pack('>I', len(config_bytes)))
frame.extend(config_bytes)
self._seq += 1
await self._ws.send(bytes(frame))
# Wait for server ack
resp = await self._ws.recv()
parsed = self._parse_response(resp)
if parsed and parsed.get("code", 0) != 0:
raise ConnectionError(f"[ASR] Server rejected config: {parsed}")
logger.info(f"[ASR] Config accepted, ready for audio")
async def send_audio(self, chunk: bytes) -> None:
"""Send an audio chunk to ASR with sequence number."""
if not self._ws:
return
compressed = gzip.compress(chunk)
frame = bytearray(_build_header(AUDIO_ONLY_REQUEST, POS_SEQUENCE, JSON_SERIAL, GZIP_COMPRESS))
frame.extend(struct.pack('>i', self._seq))
frame.extend(struct.pack('>I', len(compressed)))
frame.extend(compressed)
self._seq += 1
await self._ws.send(bytes(frame))
async def send_finish(self) -> None:
"""Send last audio frame with negative sequence to signal end."""
if not self._ws:
return
payload = gzip.compress(b'')
frame = bytearray(_build_header(AUDIO_ONLY_REQUEST, NEG_WITH_SEQUENCE, JSON_SERIAL, GZIP_COMPRESS))
frame.extend(struct.pack('>i', -self._seq))
frame.extend(struct.pack('>I', len(payload)))
frame.extend(payload)
await self._ws.send(bytes(frame))
async def receive_results(self) -> AsyncGenerator[Tuple[str, bool], None]:
"""Yield (text, is_last) tuples from ASR responses."""
if not self._ws:
return
try:
async for message in self._ws:
if isinstance(message, str):
logger.info(f"[ASR] Received text message: {message[:200]}")
continue
parsed = self._parse_response(message)
logger.info(f"[ASR] Received binary ({len(message)} bytes), parsed: {parsed}")
if parsed is None:
continue
code = parsed.get("code", 0)
if code != 0:
logger.warning(f"[ASR] Server error: {parsed}")
return
is_last = parsed.get("is_last", False)
payload_msg = parsed.get("payload_msg")
if payload_msg and isinstance(payload_msg, dict):
text = self._extract_text(payload_msg)
if text:
yield (text, is_last)
if is_last:
return
except websockets.exceptions.ConnectionClosed:
logger.info("[ASR] Connection closed")
@staticmethod
def _extract_text(payload: dict) -> str:
"""Extract recognized text from payload."""
result = payload.get("result")
if not result or not isinstance(result, dict):
return ""
# Try utterances first (show_utterances=True)
utterances = result.get("utterances", [])
if utterances:
parts = []
for utt in utterances:
text = utt.get("text", "")
if text:
parts.append(text)
return "".join(parts)
# Fallback to result.text
text = result.get("text", "")
if isinstance(text, str):
return text
return ""
def _parse_response(self, data: bytes) -> dict:
"""Parse binary ASR response into a dict."""
if len(data) < 4:
return None
msg_type = data[1] >> 4
msg_flags = data[1] & 0x0f
serial_method = data[2] >> 4
compression = data[2] & 0x0f
header_size = data[0] & 0x0f
payload = data[header_size * 4:]
result = {"code": 0, "is_last": False}
# Parse sequence and last flag
if msg_flags & 0x01: # has sequence
result["sequence"] = struct.unpack('>i', payload[:4])[0]
payload = payload[4:]
if msg_flags & 0x02: # is last package
result["is_last"] = True
if msg_type == SERVER_ERROR_RESPONSE:
result["code"] = struct.unpack('>i', payload[:4])[0]
payload_size = struct.unpack('>I', payload[4:8])[0]
payload = payload[8:]
elif msg_type == FULL_SERVER_RESPONSE:
payload_size = struct.unpack('>I', payload[:4])[0]
payload = payload[4:]
else:
return result
if not payload:
return result
if compression == GZIP_COMPRESS:
try:
payload = gzip.decompress(payload)
except Exception:
return result
if serial_method == JSON_SERIAL:
try:
result["payload_msg"] = json.loads(payload.decode('utf-8'))
except Exception:
pass
return result
async def close(self) -> None:
if self._ws:
logger.info("[ASR] Closing connection")
await self._ws.close()
self._ws = None