207 lines
7.9 KiB
Python
207 lines
7.9 KiB
Python
import gzip
|
|
import json
|
|
import uuid
|
|
import logging
|
|
from typing import Dict, Any, Optional
|
|
|
|
import websockets
|
|
|
|
from services import realtime_voice_protocol as protocol
|
|
from utils.settings import (
|
|
VOLCENGINE_REALTIME_URL,
|
|
VOLCENGINE_APP_ID,
|
|
VOLCENGINE_ACCESS_KEY,
|
|
VOLCENGINE_RESOURCE_ID,
|
|
VOLCENGINE_APP_KEY,
|
|
VOLCENGINE_DEFAULT_SPEAKER,
|
|
VOLCENGINE_TTS_SAMPLE_RATE,
|
|
)
|
|
|
|
logger = logging.getLogger('app')
|
|
|
|
|
|
class RealtimeDialogClient:
|
|
"""Fire Mountain Engine Realtime Dialogue API WebSocket Client"""
|
|
|
|
def __init__(
|
|
self,
|
|
session_id: str,
|
|
speaker: Optional[str] = None,
|
|
system_role: Optional[str] = None,
|
|
speaking_style: Optional[str] = None,
|
|
bot_name: Optional[str] = None,
|
|
recv_timeout: int = 60,
|
|
input_mod: str = "audio",
|
|
) -> None:
|
|
self.session_id = session_id
|
|
self.speaker = speaker or VOLCENGINE_DEFAULT_SPEAKER
|
|
self.system_role = system_role or ""
|
|
self.speaking_style = speaking_style or ""
|
|
self.bot_name = bot_name or ""
|
|
self.recv_timeout = recv_timeout
|
|
self.input_mod = input_mod
|
|
self.logid = ""
|
|
self.ws = None
|
|
self._connect_id = str(uuid.uuid4())
|
|
|
|
def _build_headers(self) -> Dict[str, str]:
|
|
return {
|
|
"X-Api-App-ID": VOLCENGINE_APP_ID,
|
|
"X-Api-Access-Key": VOLCENGINE_ACCESS_KEY,
|
|
"X-Api-Resource-Id": VOLCENGINE_RESOURCE_ID,
|
|
"X-Api-App-Key": VOLCENGINE_APP_KEY,
|
|
"X-Api-Connect-Id": self._connect_id,
|
|
}
|
|
|
|
def _build_session_params(self) -> Dict[str, Any]:
|
|
return {
|
|
"asr": {
|
|
"extra": {
|
|
"end_smooth_window_ms": 1500,
|
|
},
|
|
},
|
|
"tts": {
|
|
"speaker": self.speaker,
|
|
"audio_config": {
|
|
"channel": 1,
|
|
"format": "pcm",
|
|
"sample_rate": VOLCENGINE_TTS_SAMPLE_RATE,
|
|
},
|
|
},
|
|
"dialog": {
|
|
"bot_name": self.bot_name,
|
|
"system_role": self.system_role,
|
|
"speaking_style": self.speaking_style,
|
|
"extra": {
|
|
"strict_audit": False,
|
|
"recv_timeout": self.recv_timeout,
|
|
"input_mod": self.input_mod,
|
|
"enable_volc_websearch": False,
|
|
"enable_music": False,
|
|
"model": "1.2.1.1"
|
|
},
|
|
}
|
|
}
|
|
|
|
def _build_event_request(self, event_id: int, payload: dict, with_session: bool = True) -> bytearray:
|
|
request = bytearray(protocol.generate_header())
|
|
request.extend(int(event_id).to_bytes(4, 'big'))
|
|
payload_bytes = str.encode(json.dumps(payload))
|
|
payload_bytes = gzip.compress(payload_bytes)
|
|
if with_session:
|
|
request.extend((len(self.session_id)).to_bytes(4, 'big'))
|
|
request.extend(str.encode(self.session_id))
|
|
request.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
|
request.extend(payload_bytes)
|
|
return request
|
|
|
|
async def connect(self) -> None:
|
|
logger.info(f"Connecting to Volcengine Realtime API: {VOLCENGINE_REALTIME_URL}")
|
|
headers = self._build_headers()
|
|
try:
|
|
self.ws = await websockets.connect(
|
|
VOLCENGINE_REALTIME_URL,
|
|
additional_headers=headers,
|
|
ping_interval=None,
|
|
proxy=None,
|
|
)
|
|
except websockets.exceptions.InvalidStatus as e:
|
|
body = ""
|
|
if e.response and e.response.body:
|
|
body = e.response.body.decode("utf-8", errors="replace")
|
|
raise ConnectionError(
|
|
f"Volcengine Realtime API rejected connection: HTTP {e.response.status_code} - {body}"
|
|
) from e
|
|
self.logid = self.ws.response.headers.get("X-Tt-Logid", "")
|
|
logger.info(f"Volcengine Realtime connected, logid: {self.logid}")
|
|
|
|
# StartConnection (event 1)
|
|
start_conn = bytearray(protocol.generate_header())
|
|
start_conn.extend(int(1).to_bytes(4, 'big'))
|
|
payload_bytes = gzip.compress(str.encode("{}"))
|
|
start_conn.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
|
start_conn.extend(payload_bytes)
|
|
await self.ws.send(start_conn)
|
|
response = await self.ws.recv()
|
|
logger.info(f"StartConnection response: {protocol.parse_response(response)}")
|
|
|
|
# StartSession (event 100)
|
|
session_params = self._build_session_params()
|
|
payload_bytes = gzip.compress(str.encode(json.dumps(session_params)))
|
|
start_session = bytearray(protocol.generate_header())
|
|
start_session.extend(int(100).to_bytes(4, 'big'))
|
|
start_session.extend((len(self.session_id)).to_bytes(4, 'big'))
|
|
start_session.extend(str.encode(self.session_id))
|
|
start_session.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
|
start_session.extend(payload_bytes)
|
|
await self.ws.send(start_session)
|
|
response = await self.ws.recv()
|
|
logger.info(f"StartSession response: {protocol.parse_response(response)}")
|
|
|
|
async def send_audio(self, audio: bytes) -> None:
|
|
"""Send audio data (event 200)"""
|
|
task_request = bytearray(
|
|
protocol.generate_header(
|
|
message_type=protocol.CLIENT_AUDIO_ONLY_REQUEST,
|
|
serial_method=protocol.NO_SERIALIZATION,
|
|
)
|
|
)
|
|
task_request.extend(int(200).to_bytes(4, 'big'))
|
|
task_request.extend((len(self.session_id)).to_bytes(4, 'big'))
|
|
task_request.extend(str.encode(self.session_id))
|
|
payload_bytes = gzip.compress(audio)
|
|
task_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
|
task_request.extend(payload_bytes)
|
|
await self.ws.send(task_request)
|
|
|
|
async def chat_text_query(self, content: str) -> None:
|
|
"""Send text query (event 501)"""
|
|
await self.ws.send(self._build_event_request(501, {"content": content}))
|
|
|
|
async def chat_tts_text(self, content: str, start: bool = True, end: bool = True) -> None:
|
|
"""Send TTS text for synthesis (event 500)"""
|
|
await self.ws.send(
|
|
self._build_event_request(500, {"start": start, "end": end, "content": content})
|
|
)
|
|
|
|
async def chat_rag_text(self, external_rag: str) -> None:
|
|
"""Inject external RAG result (event 502)"""
|
|
await self.ws.send(
|
|
self._build_event_request(502, {"external_rag": external_rag})
|
|
)
|
|
|
|
async def receive_response(self) -> Dict[str, Any]:
|
|
try:
|
|
response = await self.ws.recv()
|
|
return protocol.parse_response(response)
|
|
except Exception as e:
|
|
raise Exception(f"Failed to receive message: {e}")
|
|
|
|
async def finish_session(self) -> None:
|
|
"""Finish session (event 102)"""
|
|
finish_req = bytearray(protocol.generate_header())
|
|
finish_req.extend(int(102).to_bytes(4, 'big'))
|
|
payload_bytes = gzip.compress(str.encode("{}"))
|
|
finish_req.extend((len(self.session_id)).to_bytes(4, 'big'))
|
|
finish_req.extend(str.encode(self.session_id))
|
|
finish_req.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
|
finish_req.extend(payload_bytes)
|
|
await self.ws.send(finish_req)
|
|
|
|
async def finish_connection(self) -> None:
|
|
"""Finish connection (event 2)"""
|
|
finish_req = bytearray(protocol.generate_header())
|
|
finish_req.extend(int(2).to_bytes(4, 'big'))
|
|
payload_bytes = gzip.compress(str.encode("{}"))
|
|
finish_req.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
|
finish_req.extend(payload_bytes)
|
|
await self.ws.send(finish_req)
|
|
response = await self.ws.recv()
|
|
logger.info(f"FinishConnection response: {protocol.parse_response(response)}")
|
|
|
|
async def close(self) -> None:
|
|
if self.ws:
|
|
logger.info("Closing Volcengine Realtime WebSocket connection")
|
|
await self.ws.close()
|
|
self.ws = None
|