From c27270588fb42afbc7460e408d2e5870e96a6b9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Mon, 16 Mar 2026 22:22:39 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=8F=96=E6=B6=88=E6=8E=A8?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent/agent_config.py | 3 +-- routes/chat.py | 51 +++++++++++++++++++++++++++++++++++++++-- utils/api_models.py | 1 - utils/cancel_manager.py | 33 ++++++++++++++++++++++++++ 4 files changed, 83 insertions(+), 5 deletions(-) create mode 100644 utils/cancel_manager.py diff --git a/agent/agent_config.py b/agent/agent_config.py index d8693af..a1503a9 100644 --- a/agent/agent_config.py +++ b/agent/agent_config.py @@ -203,7 +203,7 @@ class AgentConfig: enable_memori=enable_memori, memori_semantic_search_top_k=bot_config.get("memori_semantic_search_top_k", MEM0_SEMANTIC_SEARCH_TOP_K), trace_id=trace_id, - shell_env=getattr(request, 'shell_env', None) or bot_config.get("shell_env") or {}, + shell_env=bot_config.get("shell_env") or {}, ) # 在创建 config 时尽早准备 checkpoint 消息 @@ -218,7 +218,6 @@ class AgentConfig: config.safe_print() return config - def invoke_config(self): """返回Langchain需要的配置字典""" config = {} diff --git a/routes/chat.py b/routes/chat.py index f59e33a..0be1aaa 100644 --- a/routes/chat.py +++ b/routes/chat.py @@ -4,7 +4,7 @@ import asyncio import shutil import time from typing import Union, Optional, Any, List, Dict -from fastapi import APIRouter, HTTPException, Header +from fastapi import APIRouter, HTTPException, Header, Body from fastapi.responses import StreamingResponse import logging @@ -39,11 +39,19 @@ async def enhanced_generate_stream_response( # 用于收集完整的响应内容,用于保存到数据库 full_response_content = [] + # 取消管理 + cancel_event = None + try: # 创建输出队列和控制事件 output_queue = asyncio.Queue() preamble_completed = asyncio.Event() + # 注册取消事件 + if config.session_id: + from utils.cancel_manager import register_cancel_event, unregister_cancel_event + cancel_event = register_cancel_event(config.session_id) + # 在流式开始前保存用户消息 if config.session_id: asyncio.create_task(_save_user_messages(config)) @@ -81,6 +89,11 @@ async def enhanced_generate_stream_response( message_tag = "" agent, checkpointer = await init_agent(config) async for msg, metadata in agent.astream({"messages": config.messages}, stream_mode="messages", config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS): + # 检查是否收到取消信号 + if cancel_event and cancel_event.is_set(): + logger.info(f"Agent stream cancelled for session_id={config.session_id}") + break + new_content = "" if isinstance(msg, AIMessageChunk): @@ -124,7 +137,8 @@ async def enhanced_generate_stream_response( await output_queue.put(("agent", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n")) # 发送最终chunk - final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", config.model_name, finish_reason="stop") + finish = "cancelled" if (cancel_event and cancel_event.is_set()) else "stop" + final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", config.model_name, finish_reason=finish) await output_queue.put(("agent", f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n")) # ============ 执行 PostAgent hooks ============ # 注意:这里在单独的异步任务中执行,不阻塞流式输出 @@ -190,6 +204,11 @@ async def enhanced_generate_stream_response( break except asyncio.TimeoutError: + # 检查是否收到取消信号 + if cancel_event and cancel_event.is_set(): + logger.info(f"Output loop cancelled for session_id={config.session_id}") + break + # 检查是否还有任务在运行 if all(task.done() for task in [preamble_task_handle, agent_task_handle]): # 所有任务都完成了,退出循环 @@ -203,8 +222,13 @@ async def enhanced_generate_stream_response( # 发送结束标记 yield "data: [DONE]\n\n" + # 清理取消事件 + if config.session_id: + from utils.cancel_manager import unregister_cancel_event + unregister_cancel_event(config.session_id) logger.info(f"Enhanced stream response completed") + # 流式结束后保存 AI 响应 if full_response_content and config.session_id: asyncio.create_task(_save_assistant_response(config, "".join(full_response_content))) @@ -213,6 +237,10 @@ async def enhanced_generate_stream_response( logger.error(f"Error in enhanced_generate_stream_response: {e}") yield f'data: {{"error": "{str(e)}"}}\n\n' yield "data: [DONE]\n\n" + # 清理取消事件 + if config.session_id: + from utils.cancel_manager import unregister_cancel_event + unregister_cancel_event(config.session_id) async def create_agent_and_generate_response( @@ -767,6 +795,25 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st logger.error(f"Full traceback: {error_details}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") +@router.post("/api/v1/chat/cancel") +async def cancel_chat(session_id: str = Body(..., embed=True)): + """ + 取消正在进行的 agent 推理 + + 请求体: {"session_id": "xxxxx"} + 响应: {"success": true/false, "message": "..."} + """ + from utils.cancel_manager import trigger_cancel + + if not session_id: + raise HTTPException(status_code=400, detail="session_id is required") + + found = trigger_cancel(session_id) + if found: + return {"success": True, "message": f"Cancel signal sent for session_id={session_id}"} + else: + return {"success": False, "message": f"No active inference found for session_id={session_id}"} + # ============================================================================ # 聊天历史查询接口 diff --git a/utils/api_models.py b/utils/api_models.py index 386c6f9..5985056 100644 --- a/utils/api_models.py +++ b/utils/api_models.py @@ -68,7 +68,6 @@ class ChatRequestV2(BaseModel): language: Optional[str] = "zh" user_identifier: Optional[str] = "" session_id: Optional[str] = None - shell_env: Optional[Dict[str, str]] = None model_config = ConfigDict(extra='allow') diff --git a/utils/cancel_manager.py b/utils/cancel_manager.py new file mode 100644 index 0000000..ada1aad --- /dev/null +++ b/utils/cancel_manager.py @@ -0,0 +1,33 @@ +import asyncio +import logging +from typing import Dict + +logger = logging.getLogger('app') + +# 全局取消注册表: session_id -> asyncio.Event +_cancel_registry: Dict[str, asyncio.Event] = {} + + +def register_cancel_event(session_id: str) -> asyncio.Event: + """注册一个取消事件""" + event = asyncio.Event() + _cancel_registry[session_id] = event + logger.debug(f"Cancel event registered for session_id={session_id}") + return event + + +def trigger_cancel(session_id: str) -> bool: + """触发取消事件""" + event = _cancel_registry.get(session_id) + if event: + event.set() + logger.info(f"Cancel triggered for session_id={session_id}") + return True + logger.warning(f"No active session found for session_id={session_id}") + return False + + +def unregister_cancel_event(session_id: str) -> None: + """清理取消事件""" + _cancel_registry.pop(session_id, None) + logger.debug(f"Cancel event unregistered for session_id={session_id}")