diff --git a/routes/chat.py b/routes/chat.py index fa848e2..a6e803b 100644 --- a/routes/chat.py +++ b/routes/chat.py @@ -4,7 +4,7 @@ import asyncio import shutil import time from typing import Union, Optional, Any, List, Dict -from fastapi import APIRouter, HTTPException, Header +from fastapi import APIRouter, HTTPException, Header, Body from fastapi.responses import StreamingResponse import logging @@ -39,11 +39,19 @@ async def enhanced_generate_stream_response( # 用于收集完整的响应内容,用于保存到数据库 full_response_content = [] + # 取消管理 + cancel_event = None + try: # 创建输出队列和控制事件 output_queue = asyncio.Queue() preamble_completed = asyncio.Event() + # 注册取消事件 + if config.session_id: + from utils.cancel_manager import register_cancel_event, unregister_cancel_event + cancel_event = register_cancel_event(config.session_id) + # 在流式开始前保存用户消息 if config.session_id: asyncio.create_task(_save_user_messages(config)) @@ -81,6 +89,11 @@ async def enhanced_generate_stream_response( message_tag = "" agent, checkpointer = await init_agent(config) async for msg, metadata in agent.astream({"messages": config.messages}, stream_mode="messages", config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS): + # 检查是否收到取消信号 + if cancel_event and cancel_event.is_set(): + logger.info(f"Agent stream cancelled for session_id={config.session_id}") + break + new_content = "" if isinstance(msg, AIMessageChunk): @@ -124,7 +137,8 @@ async def enhanced_generate_stream_response( await output_queue.put(("agent", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n")) # 发送最终chunk - final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", config.model_name, finish_reason="stop") + finish = "cancelled" if (cancel_event and cancel_event.is_set()) else "stop" + final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", config.model_name, finish_reason=finish) await output_queue.put(("agent", f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n")) # ============ 执行 PostAgent hooks ============ # 注意:这里在单独的异步任务中执行,不阻塞流式输出 @@ -190,6 +204,11 @@ async def enhanced_generate_stream_response( break except asyncio.TimeoutError: + # 检查是否收到取消信号 + if cancel_event and cancel_event.is_set(): + logger.info(f"Output loop cancelled for session_id={config.session_id}") + break + # 检查是否还有任务在运行 if all(task.done() for task in [preamble_task_handle, agent_task_handle]): # 所有任务都完成了,退出循环 @@ -203,8 +222,13 @@ async def enhanced_generate_stream_response( # 发送结束标记 yield "data: [DONE]\n\n" + # 清理取消事件 + if config.session_id: + from utils.cancel_manager import unregister_cancel_event + unregister_cancel_event(config.session_id) logger.info(f"Enhanced stream response completed") + # 流式结束后保存 AI 响应 if full_response_content and config.session_id: asyncio.create_task(_save_assistant_response(config, "".join(full_response_content))) @@ -213,6 +237,10 @@ async def enhanced_generate_stream_response( logger.error(f"Error in enhanced_generate_stream_response: {e}") yield f'data: {{"error": "{str(e)}"}}\n\n' yield "data: [DONE]\n\n" + # 清理取消事件 + if config.session_id: + from utils.cancel_manager import unregister_cancel_event + unregister_cancel_event(config.session_id) async def create_agent_and_generate_response( @@ -767,6 +795,25 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st logger.error(f"Full traceback: {error_details}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") +@router.post("/api/v1/chat/cancel") +async def cancel_chat(session_id: str = Body(..., embed=True)): + """ + 取消正在进行的 agent 推理 + + 请求体: {"session_id": "xxxxx"} + 响应: {"success": true/false, "message": "..."} + """ + from utils.cancel_manager import trigger_cancel + + if not session_id: + raise HTTPException(status_code=400, detail="session_id is required") + + found = trigger_cancel(session_id) + if found: + return {"success": True, "message": f"Cancel signal sent for session_id={session_id}"} + else: + return {"success": False, "message": f"No active inference found for session_id={session_id}"} + @router.post("/api/v3/chat/completions") async def chat_completions_v3(request: ChatRequestV3, authorization: Optional[str] = Header(None)): diff --git a/utils/cancel_manager.py b/utils/cancel_manager.py new file mode 100644 index 0000000..ada1aad --- /dev/null +++ b/utils/cancel_manager.py @@ -0,0 +1,33 @@ +import asyncio +import logging +from typing import Dict + +logger = logging.getLogger('app') + +# 全局取消注册表: session_id -> asyncio.Event +_cancel_registry: Dict[str, asyncio.Event] = {} + + +def register_cancel_event(session_id: str) -> asyncio.Event: + """注册一个取消事件""" + event = asyncio.Event() + _cancel_registry[session_id] = event + logger.debug(f"Cancel event registered for session_id={session_id}") + return event + + +def trigger_cancel(session_id: str) -> bool: + """触发取消事件""" + event = _cancel_registry.get(session_id) + if event: + event.set() + logger.info(f"Cancel triggered for session_id={session_id}") + return True + logger.warning(f"No active session found for session_id={session_id}") + return False + + +def unregister_cancel_event(session_id: str) -> None: + """清理取消事件""" + _cancel_registry.pop(session_id, None) + logger.debug(f"Cancel event unregistered for session_id={session_id}")