diff --git a/fastapi_app.py b/fastapi_app.py index 756207d..c02f8d4 100644 --- a/fastapi_app.py +++ b/fastapi_app.py @@ -81,7 +81,7 @@ from utils.log_util.logger import init_with_fastapi logger = logging.getLogger('app') # Import route modules -from routes import chat, files, projects, system, skill_manager, database, memory, bot_manager, knowledge_base, payment +from routes import chat, files, projects, system, skill_manager, database, memory, bot_manager, knowledge_base, payment, voice from routes.webdav import wsgidav_app @@ -204,6 +204,9 @@ app.include_router(bot_manager.router) app.include_router(payment.router) app.include_router(memory.router) +# 注册语音对话路由 +app.include_router(voice.router) + # 注册文件管理API路由 app.include_router(file_manager_router) diff --git a/poetry.lock b/poetry.lock index 1f501f4..2111ea3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6983,4 +6983,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.12,<3.15" -content-hash = "50bb88e7ae6df1bee01b170d6303bd8065ed3e7f899bf6b5e068784e954a40e6" +content-hash = "1461514ed1f9639f41f43ebb28f2a3fcd2d5a5dde954cd509c0ea7bf181e9bb6" diff --git a/pyproject.toml b/pyproject.toml index 08aa50a..8c563d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,8 @@ dependencies = [ "tiktoken (>=0.5.0,<1.0.0)", "ragflow-sdk (>=0.23.0,<0.24.0)", "httpx (>=0.28.1,<0.29.0)", - "wsgidav (>=4.3.3,<5.0.0)" + "wsgidav (>=4.3.3,<5.0.0)", + "websockets (>=15.0.0,<16.0.0)" ] [tool.poetry.requires-plugins] diff --git a/routes/bot_manager.py b/routes/bot_manager.py index fe2c26c..433e119 100644 --- a/routes/bot_manager.py +++ b/routes/bot_manager.py @@ -656,6 +656,9 @@ class BotSettingsUpdate(BaseModel): skills: Optional[str] = None is_published: Optional[bool] = None # 是否发布到广场 shell_env: Optional[dict] = None # 自定义 shell 环境变量 + voice_speaker: Optional[str] = None # 语音音色 + voice_system_role: Optional[str] = None # 语音对话系统角色 + voice_speaking_style: Optional[str] = None # 语音说话风格 class ModelInfo(BaseModel): @@ -697,6 +700,9 @@ class BotSettingsResponse(BaseModel): is_published: bool = False # 是否发布到广场 is_owner: bool = True # 是否是所有者 copied_from: Optional[str] = None # 复制来源的bot id + voice_speaker: Optional[str] = None # 语音音色 + voice_system_role: Optional[str] = None # 语音对话系统角色 + voice_speaking_style: Optional[str] = None # 语音说话风格 updated_at: str @@ -1869,6 +1875,9 @@ async def get_bot_settings(bot_uuid: str, authorization: Optional[str] = Header( is_published=is_published if is_published else False, is_owner=is_owner, copied_from=str(copied_from) if copied_from else None, + voice_speaker=settings.get('voice_speaker'), + voice_system_role=settings.get('voice_system_role'), + voice_speaking_style=settings.get('voice_speaking_style'), updated_at=datetime_to_str(updated_at) ) @@ -1943,6 +1952,12 @@ async def update_bot_settings( update_json['skills'] = request.skills if request.shell_env is not None: update_json['shell_env'] = request.shell_env + if request.voice_speaker is not None: + update_json['voice_speaker'] = request.voice_speaker + if request.voice_system_role is not None: + update_json['voice_system_role'] = request.voice_system_role + if request.voice_speaking_style is not None: + update_json['voice_speaking_style'] = request.voice_speaking_style # is_published 是表字段,不在 settings JSON 中 need_update_published = request.is_published is not None diff --git a/routes/voice.py b/routes/voice.py new file mode 100644 index 0000000..7031328 --- /dev/null +++ b/routes/voice.py @@ -0,0 +1,140 @@ +import asyncio +import base64 +import json +import logging +from typing import Optional + +from fastapi import APIRouter, WebSocket, WebSocketDisconnect + +from services.voice_session_manager import VoiceSession + +logger = logging.getLogger('app') + +router = APIRouter() + + +@router.websocket("/api/v3/voice/realtime") +async def voice_realtime(websocket: WebSocket): + """ + WebSocket endpoint for voice realtime dialogue. + + Client sends: + - {"type": "start", "bot_id": "xxx", "session_id": "xxx", "user_identifier": "xxx"} + - {"type": "audio", "data": ""} + - {"type": "text", "content": "text input"} + - {"type": "stop"} + + Server sends: + - {"type": "audio", "data": ""} + - {"type": "asr_text", "text": "recognized text"} + - {"type": "agent_result", "text": "agent answer"} + - {"type": "llm_text", "text": "polished answer"} + - {"type": "status", "status": "ready|listening|thinking|speaking|idle"} + - {"type": "error", "message": "..."} + """ + await websocket.accept() + + session: Optional[VoiceSession] = None + + async def send_json(data: dict): + try: + await websocket.send_text(json.dumps(data, ensure_ascii=False)) + except Exception: + pass + + async def on_audio(audio_data: bytes): + """Forward TTS audio to frontend""" + try: + encoded = base64.b64encode(audio_data).decode('ascii') + await send_json({"type": "audio", "data": encoded}) + except Exception as e: + logger.error(f"Error sending audio to client: {e}") + + async def on_asr_text(text: str): + await send_json({"type": "asr_text", "text": text}) + + async def on_agent_result(text: str): + await send_json({"type": "agent_result", "text": text}) + + async def on_llm_text(text: str): + await send_json({"type": "llm_text", "text": text}) + + async def on_status(status: str): + await send_json({"type": "status", "status": status}) + + async def on_error(message: str): + await send_json({"type": "error", "message": message}) + + try: + while True: + raw = await websocket.receive_text() + try: + msg = json.loads(raw) + except json.JSONDecodeError: + await send_json({"type": "error", "message": "Invalid JSON"}) + continue + + msg_type = msg.get("type", "") + + if msg_type == "start": + # Initialize voice session + if session: + await session.stop() + + bot_id = msg.get("bot_id", "") + if not bot_id: + await send_json({"type": "error", "message": "bot_id is required"}) + continue + + session = VoiceSession( + bot_id=bot_id, + session_id=msg.get("session_id"), + user_identifier=msg.get("user_identifier"), + on_audio=on_audio, + on_asr_text=on_asr_text, + on_agent_result=on_agent_result, + on_llm_text=on_llm_text, + on_status=on_status, + on_error=on_error, + ) + + try: + await session.start() + except Exception as e: + logger.error(f"Failed to start voice session: {e}", exc_info=True) + await send_json({"type": "error", "message": f"Failed to connect: {str(e)}"}) + session = None + + elif msg_type == "audio": + if not session: + await send_json({"type": "error", "message": "Session not started"}) + continue + audio_b64 = msg.get("data", "") + if audio_b64: + audio_bytes = base64.b64decode(audio_b64) + await session.handle_audio(audio_bytes) + + elif msg_type == "text": + if not session: + await send_json({"type": "error", "message": "Session not started"}) + continue + content = msg.get("content", "") + if content: + await session.handle_text(content) + + elif msg_type == "stop": + if session: + await session.stop() + session = None + break + + except WebSocketDisconnect: + logger.info("Voice WebSocket disconnected") + except Exception as e: + logger.error(f"Voice WebSocket error: {e}", exc_info=True) + finally: + if session: + try: + await session.stop() + except Exception: + pass diff --git a/services/realtime_voice_client.py b/services/realtime_voice_client.py new file mode 100644 index 0000000..b0a7419 --- /dev/null +++ b/services/realtime_voice_client.py @@ -0,0 +1,206 @@ +import gzip +import json +import uuid +import logging +from typing import Dict, Any, Optional + +import websockets + +from services import realtime_voice_protocol as protocol +from utils.settings import ( + VOLCENGINE_REALTIME_URL, + VOLCENGINE_APP_ID, + VOLCENGINE_ACCESS_KEY, + VOLCENGINE_RESOURCE_ID, + VOLCENGINE_APP_KEY, + VOLCENGINE_DEFAULT_SPEAKER, + VOLCENGINE_TTS_SAMPLE_RATE, +) + +logger = logging.getLogger('app') + + +class RealtimeDialogClient: + """Fire Mountain Engine Realtime Dialogue API WebSocket Client""" + + def __init__( + self, + session_id: str, + speaker: Optional[str] = None, + system_role: Optional[str] = None, + speaking_style: Optional[str] = None, + bot_name: Optional[str] = None, + recv_timeout: int = 60, + input_mod: str = "audio", + ) -> None: + self.session_id = session_id + self.speaker = speaker or VOLCENGINE_DEFAULT_SPEAKER + self.system_role = system_role or "" + self.speaking_style = speaking_style or "" + self.bot_name = bot_name or "" + self.recv_timeout = recv_timeout + self.input_mod = input_mod + self.logid = "" + self.ws = None + self._connect_id = str(uuid.uuid4()) + + def _build_headers(self) -> Dict[str, str]: + return { + "X-Api-App-ID": VOLCENGINE_APP_ID, + "X-Api-Access-Key": VOLCENGINE_ACCESS_KEY, + "X-Api-Resource-Id": VOLCENGINE_RESOURCE_ID, + "X-Api-App-Key": VOLCENGINE_APP_KEY, + "X-Api-Connect-Id": self._connect_id, + } + + def _build_session_params(self) -> Dict[str, Any]: + return { + "asr": { + "extra": { + "end_smooth_window_ms": 1500, + }, + }, + "tts": { + "speaker": self.speaker, + "audio_config": { + "channel": 1, + "format": "pcm", + "sample_rate": VOLCENGINE_TTS_SAMPLE_RATE, + }, + }, + "dialog": { + "bot_name": self.bot_name, + "system_role": self.system_role, + "speaking_style": self.speaking_style, + "extra": { + "strict_audit": False, + "recv_timeout": self.recv_timeout, + "input_mod": self.input_mod, + "enable_volc_websearch": False, + "enable_music": False, + "model": "1.2.1.1" + }, + } + } + + def _build_event_request(self, event_id: int, payload: dict, with_session: bool = True) -> bytearray: + request = bytearray(protocol.generate_header()) + request.extend(int(event_id).to_bytes(4, 'big')) + payload_bytes = str.encode(json.dumps(payload)) + payload_bytes = gzip.compress(payload_bytes) + if with_session: + request.extend((len(self.session_id)).to_bytes(4, 'big')) + request.extend(str.encode(self.session_id)) + request.extend((len(payload_bytes)).to_bytes(4, 'big')) + request.extend(payload_bytes) + return request + + async def connect(self) -> None: + logger.info(f"Connecting to Volcengine Realtime API: {VOLCENGINE_REALTIME_URL}") + headers = self._build_headers() + try: + self.ws = await websockets.connect( + VOLCENGINE_REALTIME_URL, + additional_headers=headers, + ping_interval=None, + proxy=None, + ) + except websockets.exceptions.InvalidStatus as e: + body = "" + if e.response and e.response.body: + body = e.response.body.decode("utf-8", errors="replace") + raise ConnectionError( + f"Volcengine Realtime API rejected connection: HTTP {e.response.status_code} - {body}" + ) from e + self.logid = self.ws.response.headers.get("X-Tt-Logid", "") + logger.info(f"Volcengine Realtime connected, logid: {self.logid}") + + # StartConnection (event 1) + start_conn = bytearray(protocol.generate_header()) + start_conn.extend(int(1).to_bytes(4, 'big')) + payload_bytes = gzip.compress(str.encode("{}")) + start_conn.extend((len(payload_bytes)).to_bytes(4, 'big')) + start_conn.extend(payload_bytes) + await self.ws.send(start_conn) + response = await self.ws.recv() + logger.info(f"StartConnection response: {protocol.parse_response(response)}") + + # StartSession (event 100) + session_params = self._build_session_params() + payload_bytes = gzip.compress(str.encode(json.dumps(session_params))) + start_session = bytearray(protocol.generate_header()) + start_session.extend(int(100).to_bytes(4, 'big')) + start_session.extend((len(self.session_id)).to_bytes(4, 'big')) + start_session.extend(str.encode(self.session_id)) + start_session.extend((len(payload_bytes)).to_bytes(4, 'big')) + start_session.extend(payload_bytes) + await self.ws.send(start_session) + response = await self.ws.recv() + logger.info(f"StartSession response: {protocol.parse_response(response)}") + + async def send_audio(self, audio: bytes) -> None: + """Send audio data (event 200)""" + task_request = bytearray( + protocol.generate_header( + message_type=protocol.CLIENT_AUDIO_ONLY_REQUEST, + serial_method=protocol.NO_SERIALIZATION, + ) + ) + task_request.extend(int(200).to_bytes(4, 'big')) + task_request.extend((len(self.session_id)).to_bytes(4, 'big')) + task_request.extend(str.encode(self.session_id)) + payload_bytes = gzip.compress(audio) + task_request.extend((len(payload_bytes)).to_bytes(4, 'big')) + task_request.extend(payload_bytes) + await self.ws.send(task_request) + + async def chat_text_query(self, content: str) -> None: + """Send text query (event 501)""" + await self.ws.send(self._build_event_request(501, {"content": content})) + + async def chat_tts_text(self, content: str, start: bool = True, end: bool = True) -> None: + """Send TTS text for synthesis (event 500)""" + await self.ws.send( + self._build_event_request(500, {"start": start, "end": end, "content": content}) + ) + + async def chat_rag_text(self, external_rag: str) -> None: + """Inject external RAG result (event 502)""" + await self.ws.send( + self._build_event_request(502, {"external_rag": external_rag}) + ) + + async def receive_response(self) -> Dict[str, Any]: + try: + response = await self.ws.recv() + return protocol.parse_response(response) + except Exception as e: + raise Exception(f"Failed to receive message: {e}") + + async def finish_session(self) -> None: + """Finish session (event 102)""" + finish_req = bytearray(protocol.generate_header()) + finish_req.extend(int(102).to_bytes(4, 'big')) + payload_bytes = gzip.compress(str.encode("{}")) + finish_req.extend((len(self.session_id)).to_bytes(4, 'big')) + finish_req.extend(str.encode(self.session_id)) + finish_req.extend((len(payload_bytes)).to_bytes(4, 'big')) + finish_req.extend(payload_bytes) + await self.ws.send(finish_req) + + async def finish_connection(self) -> None: + """Finish connection (event 2)""" + finish_req = bytearray(protocol.generate_header()) + finish_req.extend(int(2).to_bytes(4, 'big')) + payload_bytes = gzip.compress(str.encode("{}")) + finish_req.extend((len(payload_bytes)).to_bytes(4, 'big')) + finish_req.extend(payload_bytes) + await self.ws.send(finish_req) + response = await self.ws.recv() + logger.info(f"FinishConnection response: {protocol.parse_response(response)}") + + async def close(self) -> None: + if self.ws: + logger.info("Closing Volcengine Realtime WebSocket connection") + await self.ws.close() + self.ws = None diff --git a/services/realtime_voice_protocol.py b/services/realtime_voice_protocol.py new file mode 100644 index 0000000..78ab96d --- /dev/null +++ b/services/realtime_voice_protocol.py @@ -0,0 +1,110 @@ +import gzip +import json + +PROTOCOL_VERSION = 0b0001 +DEFAULT_HEADER_SIZE = 0b0001 + +PROTOCOL_VERSION_BITS = 4 +HEADER_BITS = 4 +MESSAGE_TYPE_BITS = 4 +MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4 +MESSAGE_SERIALIZATION_BITS = 4 +MESSAGE_COMPRESSION_BITS = 4 +RESERVED_BITS = 8 + +# Message Type: +CLIENT_FULL_REQUEST = 0b0001 +CLIENT_AUDIO_ONLY_REQUEST = 0b0010 + +SERVER_FULL_RESPONSE = 0b1001 +SERVER_ACK = 0b1011 +SERVER_ERROR_RESPONSE = 0b1111 + +# Message Type Specific Flags +NO_SEQUENCE = 0b0000 +POS_SEQUENCE = 0b0001 +NEG_SEQUENCE = 0b0010 +NEG_SEQUENCE_1 = 0b0011 + +MSG_WITH_EVENT = 0b0100 + +# Message Serialization +NO_SERIALIZATION = 0b0000 +JSON = 0b0001 +THRIFT = 0b0011 +CUSTOM_TYPE = 0b1111 + +# Message Compression +NO_COMPRESSION = 0b0000 +GZIP = 0b0001 +CUSTOM_COMPRESSION = 0b1111 + + +def generate_header( + version=PROTOCOL_VERSION, + message_type=CLIENT_FULL_REQUEST, + message_type_specific_flags=MSG_WITH_EVENT, + serial_method=JSON, + compression_type=GZIP, + reserved_data=0x00, + extension_header=bytes() +): + header = bytearray() + header_size = int(len(extension_header) / 4) + 1 + header.append((version << 4) | header_size) + header.append((message_type << 4) | message_type_specific_flags) + header.append((serial_method << 4) | compression_type) + header.append(reserved_data) + header.extend(extension_header) + return header + + +def parse_response(res): + if isinstance(res, str): + return {} + protocol_version = res[0] >> 4 + header_size = res[0] & 0x0f + message_type = res[1] >> 4 + message_type_specific_flags = res[1] & 0x0f + serialization_method = res[2] >> 4 + message_compression = res[2] & 0x0f + reserved = res[3] + header_extensions = res[4:header_size * 4] + payload = res[header_size * 4:] + result = {} + payload_msg = None + payload_size = 0 + start = 0 + if message_type == SERVER_FULL_RESPONSE or message_type == SERVER_ACK: + result['message_type'] = 'SERVER_FULL_RESPONSE' + if message_type == SERVER_ACK: + result['message_type'] = 'SERVER_ACK' + if message_type_specific_flags & NEG_SEQUENCE > 0: + result['seq'] = int.from_bytes(payload[:4], "big", signed=False) + start += 4 + if message_type_specific_flags & MSG_WITH_EVENT > 0: + result['event'] = int.from_bytes(payload[:4], "big", signed=False) + start += 4 + payload = payload[start:] + session_id_size = int.from_bytes(payload[:4], "big", signed=True) + session_id = payload[4:session_id_size+4] + result['session_id'] = str(session_id) + payload = payload[4 + session_id_size:] + payload_size = int.from_bytes(payload[:4], "big", signed=False) + payload_msg = payload[4:] + elif message_type == SERVER_ERROR_RESPONSE: + code = int.from_bytes(payload[:4], "big", signed=False) + result['code'] = code + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + payload_msg = payload[8:] + if payload_msg is None: + return result + if message_compression == GZIP: + payload_msg = gzip.decompress(payload_msg) + if serialization_method == JSON: + payload_msg = json.loads(str(payload_msg, "utf-8")) + elif serialization_method != NO_SERIALIZATION: + payload_msg = str(payload_msg, "utf-8") + result['payload_msg'] = payload_msg + result['payload_size'] = payload_size + return result diff --git a/services/voice_session_manager.py b/services/voice_session_manager.py new file mode 100644 index 0000000..0a2ac0c --- /dev/null +++ b/services/voice_session_manager.py @@ -0,0 +1,305 @@ +import asyncio +import json +import logging +import uuid +from typing import Optional, Callable, Awaitable + +from services.realtime_voice_client import RealtimeDialogClient + +logger = logging.getLogger('app') + + +class VoiceSession: + """Manages a single voice dialogue session lifecycle""" + + def __init__( + self, + bot_id: str, + session_id: Optional[str] = None, + user_identifier: Optional[str] = None, + on_audio: Optional[Callable[[bytes], Awaitable[None]]] = None, + on_asr_text: Optional[Callable[[str], Awaitable[None]]] = None, + on_agent_result: Optional[Callable[[str], Awaitable[None]]] = None, + on_llm_text: Optional[Callable[[str], Awaitable[None]]] = None, + on_status: Optional[Callable[[str], Awaitable[None]]] = None, + on_error: Optional[Callable[[str], Awaitable[None]]] = None, + ): + self.bot_id = bot_id + self.session_id = session_id or str(uuid.uuid4()) + self.user_identifier = user_identifier or "" + + self.realtime_client: Optional[RealtimeDialogClient] = None + self._bot_config: dict = {} + + # Callbacks + self._on_audio = on_audio + self._on_asr_text = on_asr_text + self._on_agent_result = on_agent_result + self._on_llm_text = on_llm_text + self._on_status = on_status + self._on_error = on_error + + self._running = False + self._is_user_querying = False + self._current_asr_text = "" + # When True, discard TTS audio from SERVER_ACK (comfort speech period) + self._is_sending_chat_tts_text = False + self._receive_task: Optional[asyncio.Task] = None + + async def start(self) -> None: + """Fetch bot config, connect to Volcengine and start receiving responses""" + from utils.fastapi_utils import fetch_bot_config_from_db + + self._bot_config = await fetch_bot_config_from_db(self.bot_id, self.user_identifier) + + self.realtime_client = RealtimeDialogClient( + session_id=self.session_id, + speaker=self._bot_config.get("voice_speaker"), + system_role=self._bot_config.get("voice_system_role"), + speaking_style=self._bot_config.get("voice_speaking_style"), + bot_name=self._bot_config.get("name", ""), + ) + + await self.realtime_client.connect() + self._running = True + self._receive_task = asyncio.create_task(self._receive_loop()) + await self._emit_status("ready") + + async def stop(self) -> None: + """Gracefully stop the session""" + self._running = False + try: + await self.realtime_client.finish_session() + await asyncio.sleep(0.5) + await self.realtime_client.finish_connection() + except Exception as e: + logger.warning(f"Error during session cleanup: {e}") + finally: + if self._receive_task and not self._receive_task.done(): + self._receive_task.cancel() + await self.realtime_client.close() + + async def handle_audio(self, audio_data: bytes) -> None: + """Forward user audio to Volcengine""" + if self._running and self.realtime_client.ws: + await self.realtime_client.send_audio(audio_data) + + async def handle_text(self, text: str) -> None: + """Handle text input - send as text query""" + if self._running and self.realtime_client.ws: + await self.realtime_client.chat_text_query(text) + + async def _receive_loop(self) -> None: + """Continuously receive and dispatch Volcengine responses""" + try: + while self._running: + response = await self.realtime_client.receive_response() + if not response: + continue + await self._handle_response(response) + except asyncio.CancelledError: + logger.info(f"Voice session receive loop cancelled: {self.session_id}") + except Exception as e: + logger.error(f"Voice session receive loop error: {e}") + await self._emit_error(f"Connection error: {str(e)}") + finally: + self._running = False + + async def _handle_response(self, response: dict) -> None: + msg_type = response.get('message_type', '') + event = response.get('event') + payload_msg = response.get('payload_msg', {}) + + if msg_type == 'SERVER_ACK' and isinstance(payload_msg, bytes): + # TTS audio data — discard during comfort speech period + if self._is_sending_chat_tts_text: + return + if self._on_audio: + await self._on_audio(payload_msg) + + elif msg_type == 'SERVER_FULL_RESPONSE': + logger.info(f"[Voice] event={event}, payload_msg={payload_msg if not isinstance(payload_msg, bytes) else f'<{len(payload_msg)} bytes>'}") + + if event == 450: + # User started speaking — clear audio, set querying flag, reset ASR accumulator + self._is_user_querying = True + self._current_asr_text = "" + await self._emit_status("listening") + + elif event == 451: + # Streaming ASR result — accumulate recognized text + if isinstance(payload_msg, dict): + results = payload_msg.get("results", []) + if results and isinstance(results, list) and len(results) > 0: + text = results[0].get("text", "") + if text: + self._current_asr_text = text + logger.debug(f"[Voice] ASR streaming (451): '{text}'") + + elif event == 459: + # ASR completed — use accumulated text from event 451 + self._is_user_querying = False + asr_text = self._current_asr_text + + logger.info(f"[Voice] ASR result: '{asr_text}'") + + if self._on_asr_text and asr_text: + await self._on_asr_text(asr_text) + await self._emit_status("thinking") + + # Trigger comfort TTS + agent call + self._is_sending_chat_tts_text = True + asyncio.create_task(self._on_asr_text_received(asr_text)) + + elif event == 350: + # TTS segment completed + tts_type = "" + if isinstance(payload_msg, dict): + tts_type = payload_msg.get("tts_type", "") + logger.info(f"[Voice] TTS segment done, type={tts_type}, is_sending={self._is_sending_chat_tts_text}") + + # When comfort TTS or RAG TTS finishes, stop discarding audio + if self._is_sending_chat_tts_text and tts_type in ("chat_tts_text", "external_rag"): + self._is_sending_chat_tts_text = False + logger.info(f"[Voice] Comfort/RAG TTS done, resuming audio forwarding") + + elif event == 359: + # TTS fully completed (all segments done) + logger.info(f"[Voice] TTS fully completed") + await self._emit_status("idle") + + elif event in (152, 153): + logger.info(f"[Voice] Session finished event: {event}") + self._running = False + + elif msg_type == 'SERVER_ERROR': + error_msg = str(payload_msg) if payload_msg else "Unknown server error" + logger.error(f"[Voice] Server error: {error_msg}") + await self._emit_error(error_msg) + + async def _on_asr_text_received(self, text: str) -> None: + """Called when ASR text is received — send comfort TTS, call agent, inject RAG""" + if not text.strip(): + self._is_sending_chat_tts_text = False + return + + try: + # 1. Send comfort TTS (real Chinese text, not "...") + logger.info(f"[Voice] Sending comfort TTS...") + await self.realtime_client.chat_tts_text( + content="请稍等,让我查一下。", + start=True, + end=False, + ) + await self.realtime_client.chat_tts_text( + content="", + start=False, + end=True, + ) + + # 2. Call v3 agent (this may take a while) + logger.info(f"[Voice] Calling v3 agent with text: '{text}'") + agent_result = await self._call_v3_agent(text) + logger.info(f"[Voice] Agent result ({len(agent_result)} chars): {agent_result[:200]}") + + if self._on_agent_result and agent_result: + await self._on_agent_result(agent_result) + + # 3. Inject RAG result so the built-in LLM can polish and TTS it + if agent_result: + clean_result = self._extract_answer(agent_result) + logger.info(f"[Voice] Injecting RAG text ({len(clean_result)} chars): {clean_result[:200]}") + await self.realtime_client.chat_rag_text(clean_result) + else: + logger.warning(f"[Voice] Agent returned empty result, skipping RAG injection") + self._is_sending_chat_tts_text = False + + except Exception as e: + logger.error(f"[Voice] Error in ASR text callback: {e}", exc_info=True) + self._is_sending_chat_tts_text = False + await self._emit_error(f"Agent call failed: {str(e)}") + + async def _call_v3_agent(self, user_text: str) -> str: + """Call v3 agent API internally (stream=false) to get full reasoning result""" + try: + from utils.api_models import ChatRequestV3, Message + from utils.fastapi_utils import ( + process_messages, + create_project_directory, + ) + from agent.agent_config import AgentConfig + from routes.chat import create_agent_and_generate_response + + bot_config = self._bot_config + language = bot_config.get("language", "zh") + + messages_raw = [{"role": "user", "content": user_text}] + messages_obj = [Message(role="user", content=user_text)] + + request = ChatRequestV3( + messages=messages_obj, + bot_id=self.bot_id, + stream=False, + session_id=self.session_id, + user_identifier=self.user_identifier, + ) + + project_dir = create_project_directory( + bot_config.get("dataset_ids", []), + self.bot_id, + bot_config.get("skills", []), + ) + + processed_messages = process_messages(messages_obj, language) + + config = await AgentConfig.from_v3_request( + request, + bot_config, + project_dir, + processed_messages, + language, + ) + config.stream = False + + result = await create_agent_and_generate_response(config) + + if hasattr(result, 'choices'): + choices = result.choices + if choices and len(choices) > 0: + return choices[0].get("message", {}).get("content", "") + return "" + + except Exception as e: + logger.error(f"[Voice] Error calling v3 agent: {e}", exc_info=True) + return "" + + @staticmethod + def _extract_answer(agent_result: str) -> str: + """Extract the answer portion from agent result, stripping tags like [ANSWER], [THINK] etc.""" + lines = agent_result.split('\n') + answer_lines = [] + in_answer = False + for line in lines: + if line.strip().startswith('[ANSWER]'): + in_answer = True + rest = line.strip()[len('[ANSWER]'):].strip() + if rest: + answer_lines.append(rest) + continue + if line.strip().startswith('[') and not line.strip().startswith('[ANSWER]'): + in_answer = False + continue + if in_answer: + answer_lines.append(line) + + if answer_lines: + return '\n'.join(answer_lines).strip() + return agent_result.strip() + + async def _emit_status(self, status: str) -> None: + if self._on_status: + await self._on_status(status) + + async def _emit_error(self, message: str) -> None: + if self._on_error: + await self._on_error(message) diff --git a/utils/fastapi_utils.py b/utils/fastapi_utils.py index bd43a4f..2520156 100644 --- a/utils/fastapi_utils.py +++ b/utils/fastapi_utils.py @@ -518,6 +518,7 @@ async def fetch_bot_config_from_db(bot_user_id: str, user_identifier: Optional[s model_server = NEW_API_BASE_URL.rstrip('/') + "/v1" if NEW_API_BASE_URL else "" config = { + "name": bot_name, "model": model_name, "api_key": api_key, "model_server": model_server, @@ -532,6 +533,9 @@ async def fetch_bot_config_from_db(bot_user_id: str, user_identifier: Optional[s "description": settings_data.get("description", ""), "suggestions": settings_data.get("suggestions", []), "shell_env": settings_data.get("shell_env") or {}, + "voice_speaker": settings_data.get("voice_speaker", ""), + "voice_system_role": settings_data.get("voice_system_role", ""), + "voice_speaking_style": settings_data.get("voice_speaking_style", ""), } # 处理 dataset_ids diff --git a/utils/settings.py b/utils/settings.py index 9d51b28..2f1dacc 100644 --- a/utils/settings.py +++ b/utils/settings.py @@ -118,6 +118,22 @@ NEW_API_TIMEOUT = int(os.getenv("NEW_API_TIMEOUT", "30")) # New API 管理员密钥(用于同步用户等管理操作,可选) NEW_API_ADMIN_KEY = os.getenv("NEW_API_ADMIN_KEY", "") +# ============================================================ +# Volcengine Realtime Dialogue Configuration +# ============================================================ +VOLCENGINE_REALTIME_URL = os.getenv( + "VOLCENGINE_REALTIME_URL", + "wss://openspeech.bytedance.com/api/v3/realtime/dialogue" +) +VOLCENGINE_APP_ID = os.getenv("VOLCENGINE_APP_ID", "8718217928") +VOLCENGINE_ACCESS_KEY = os.getenv("VOLCENGINE_ACCESS_KEY", "ynJMX-5ix1FsJvswC9KTNlGUdubcchqc") +VOLCENGINE_RESOURCE_ID = os.getenv("VOLCENGINE_RESOURCE_ID", "volc.speech.dialog") +VOLCENGINE_APP_KEY = os.getenv("VOLCENGINE_APP_KEY", "PlgvMymc7f3tQnJ6") +VOLCENGINE_DEFAULT_SPEAKER = os.getenv( + "VOLCENGINE_DEFAULT_SPEAKER", "zh_male_yunzhou_jupiter_bigtts" +) +VOLCENGINE_TTS_SAMPLE_RATE = int(os.getenv("VOLCENGINE_TTS_SAMPLE_RATE", "24000")) + # ============================================================ # Single Agent Mode Configuration # ============================================================