192 lines
6.7 KiB
Python
192 lines
6.7 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
import logging
|
|
from typing import Optional
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
from pydantic import BaseModel
|
|
|
|
from services.voice_session_manager import VoiceSession
|
|
from utils.settings import VOICE_DEFAULT_MODE
|
|
|
|
logger = logging.getLogger('app')
|
|
|
|
router = APIRouter()
|
|
|
|
# Global message queue for broadcast feature
|
|
_pending_messages: dict[str, list[str]] = {}
|
|
|
|
|
|
def _get_queue_key(bot_id: str, user_identifier: str) -> str:
|
|
return f"{bot_id}_{user_identifier}"
|
|
|
|
|
|
class BroadcastRequest(BaseModel):
|
|
bot_id: str
|
|
user_identifier: str
|
|
message: str
|
|
|
|
|
|
@router.post("/api/v3/voice/broadcast")
|
|
async def voice_broadcast(req: BroadcastRequest):
|
|
"""Push a message to be spoken by an active voice session."""
|
|
key = _get_queue_key(req.bot_id, req.user_identifier)
|
|
_pending_messages.setdefault(key, []).append(req.message)
|
|
return {"success": True, "queued": True}
|
|
|
|
|
|
@router.websocket("/api/v3/voice/realtime")
|
|
async def voice_realtime(websocket: WebSocket):
|
|
"""
|
|
WebSocket endpoint for voice realtime dialogue.
|
|
|
|
Client sends:
|
|
- {"type": "start", "bot_id": "xxx", "session_id": "xxx", "user_identifier": "xxx"}
|
|
- {"type": "audio", "data": "<base64 pcm audio>"}
|
|
- {"type": "text", "content": "text input"}
|
|
- {"type": "stop"}
|
|
|
|
Server sends:
|
|
- {"type": "audio", "data": "<base64 pcm audio>"}
|
|
- {"type": "asr_text", "text": "recognized text"}
|
|
- {"type": "agent_stream", "text": "incremental text chunk"}
|
|
- {"type": "agent_result", "text": "agent answer"}
|
|
- {"type": "llm_text", "text": "polished answer"}
|
|
- {"type": "status", "status": "ready|listening|thinking|speaking|idle"}
|
|
- {"type": "error", "message": "..."}
|
|
"""
|
|
await websocket.accept()
|
|
|
|
session = None
|
|
|
|
async def send_json(data: dict):
|
|
try:
|
|
await websocket.send_text(json.dumps(data, ensure_ascii=False))
|
|
except Exception:
|
|
pass
|
|
|
|
async def on_audio(audio_data: bytes):
|
|
"""Forward TTS audio to frontend"""
|
|
try:
|
|
encoded = base64.b64encode(audio_data).decode('ascii')
|
|
await send_json({"type": "audio", "data": encoded})
|
|
except Exception as e:
|
|
logger.error(f"Error sending audio to client: {e}")
|
|
|
|
async def on_asr_text(text: str):
|
|
await send_json({"type": "asr_text", "text": text})
|
|
|
|
async def on_agent_result(text: str):
|
|
await send_json({"type": "agent_result", "text": text})
|
|
|
|
async def on_agent_stream(text: str):
|
|
"""Forward streaming agent text chunks to frontend"""
|
|
await send_json({"type": "agent_stream", "text": text})
|
|
|
|
async def on_llm_text(text: str):
|
|
await send_json({"type": "llm_text", "text": text})
|
|
|
|
async def on_status(status: str):
|
|
await send_json({"type": "status", "status": status})
|
|
|
|
async def on_error(message: str):
|
|
await send_json({"type": "error", "message": message})
|
|
|
|
try:
|
|
while True:
|
|
raw = await websocket.receive_text()
|
|
try:
|
|
msg = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
await send_json({"type": "error", "message": "Invalid JSON"})
|
|
continue
|
|
|
|
msg_type = msg.get("type", "")
|
|
|
|
if msg_type == "start":
|
|
# Initialize voice session
|
|
if session:
|
|
await session.stop()
|
|
|
|
bot_id = msg.get("bot_id", "")
|
|
if not bot_id:
|
|
await send_json({"type": "error", "message": "bot_id is required"})
|
|
continue
|
|
|
|
voice_mode = msg.get("voice_mode") or VOICE_DEFAULT_MODE
|
|
client_sample_rate = msg.get("sample_rate", 24000)
|
|
|
|
session_kwargs = dict(
|
|
bot_id=bot_id,
|
|
session_id=msg.get("session_id"),
|
|
user_identifier=msg.get("user_identifier"),
|
|
sample_rate=client_sample_rate,
|
|
on_audio=on_audio,
|
|
on_asr_text=on_asr_text,
|
|
on_agent_result=on_agent_result,
|
|
on_agent_stream=on_agent_stream,
|
|
on_llm_text=on_llm_text,
|
|
on_status=on_status,
|
|
on_error=on_error,
|
|
)
|
|
|
|
if voice_mode == "lite":
|
|
from services.voice_lite_session import VoiceLiteSession
|
|
# Create callback for broadcast messages
|
|
queue_key = _get_queue_key(bot_id, msg.get("user_identifier", ""))
|
|
|
|
async def get_pending_message() -> Optional[str]:
|
|
msgs = _pending_messages.get(queue_key, [])
|
|
return msgs.pop(0) if msgs else None
|
|
|
|
session_kwargs["get_pending_message"] = get_pending_message
|
|
session = VoiceLiteSession(**session_kwargs)
|
|
logger.info(f"[Voice] Using lite mode for bot_id={bot_id}")
|
|
else:
|
|
session = VoiceSession(**session_kwargs)
|
|
|
|
try:
|
|
await session.start()
|
|
# Clear old messages on new session connection
|
|
if voice_mode == "lite":
|
|
_pending_messages[queue_key] = []
|
|
except Exception as e:
|
|
logger.error(f"Failed to start voice session: {e}", exc_info=True)
|
|
await send_json({"type": "error", "message": f"Failed to connect: {str(e)}"})
|
|
session = None
|
|
|
|
elif msg_type == "audio":
|
|
if not session:
|
|
await send_json({"type": "error", "message": "Session not started"})
|
|
continue
|
|
audio_b64 = msg.get("data", "")
|
|
if audio_b64:
|
|
audio_bytes = base64.b64decode(audio_b64)
|
|
await session.handle_audio(audio_bytes)
|
|
|
|
elif msg_type == "text":
|
|
if not session:
|
|
await send_json({"type": "error", "message": "Session not started"})
|
|
continue
|
|
content = msg.get("content", "")
|
|
if content:
|
|
await session.handle_text(content)
|
|
|
|
elif msg_type == "stop":
|
|
if session:
|
|
await session.stop()
|
|
session = None
|
|
break
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info("Voice WebSocket disconnected")
|
|
except Exception as e:
|
|
logger.error(f"Voice WebSocket error: {e}", exc_info=True)
|
|
finally:
|
|
if session:
|
|
try:
|
|
await session.stop()
|
|
except Exception:
|
|
pass
|