增加取消推理
This commit is contained in:
parent
32fd8c8656
commit
c27270588f
@ -203,7 +203,7 @@ class AgentConfig:
|
||||
enable_memori=enable_memori,
|
||||
memori_semantic_search_top_k=bot_config.get("memori_semantic_search_top_k", MEM0_SEMANTIC_SEARCH_TOP_K),
|
||||
trace_id=trace_id,
|
||||
shell_env=getattr(request, 'shell_env', None) or bot_config.get("shell_env") or {},
|
||||
shell_env=bot_config.get("shell_env") or {},
|
||||
)
|
||||
|
||||
# 在创建 config 时尽早准备 checkpoint 消息
|
||||
@ -218,7 +218,6 @@ class AgentConfig:
|
||||
|
||||
config.safe_print()
|
||||
return config
|
||||
|
||||
def invoke_config(self):
|
||||
"""返回Langchain需要的配置字典"""
|
||||
config = {}
|
||||
|
||||
@ -4,7 +4,7 @@ import asyncio
|
||||
import shutil
|
||||
import time
|
||||
from typing import Union, Optional, Any, List, Dict
|
||||
from fastapi import APIRouter, HTTPException, Header
|
||||
from fastapi import APIRouter, HTTPException, Header, Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
import logging
|
||||
|
||||
@ -39,11 +39,19 @@ async def enhanced_generate_stream_response(
|
||||
# 用于收集完整的响应内容,用于保存到数据库
|
||||
full_response_content = []
|
||||
|
||||
# 取消管理
|
||||
cancel_event = None
|
||||
|
||||
try:
|
||||
# 创建输出队列和控制事件
|
||||
output_queue = asyncio.Queue()
|
||||
preamble_completed = asyncio.Event()
|
||||
|
||||
# 注册取消事件
|
||||
if config.session_id:
|
||||
from utils.cancel_manager import register_cancel_event, unregister_cancel_event
|
||||
cancel_event = register_cancel_event(config.session_id)
|
||||
|
||||
# 在流式开始前保存用户消息
|
||||
if config.session_id:
|
||||
asyncio.create_task(_save_user_messages(config))
|
||||
@ -81,6 +89,11 @@ async def enhanced_generate_stream_response(
|
||||
message_tag = ""
|
||||
agent, checkpointer = await init_agent(config)
|
||||
async for msg, metadata in agent.astream({"messages": config.messages}, stream_mode="messages", config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS):
|
||||
# 检查是否收到取消信号
|
||||
if cancel_event and cancel_event.is_set():
|
||||
logger.info(f"Agent stream cancelled for session_id={config.session_id}")
|
||||
break
|
||||
|
||||
new_content = ""
|
||||
|
||||
if isinstance(msg, AIMessageChunk):
|
||||
@ -124,7 +137,8 @@ async def enhanced_generate_stream_response(
|
||||
await output_queue.put(("agent", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"))
|
||||
|
||||
# 发送最终chunk
|
||||
final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", config.model_name, finish_reason="stop")
|
||||
finish = "cancelled" if (cancel_event and cancel_event.is_set()) else "stop"
|
||||
final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", config.model_name, finish_reason=finish)
|
||||
await output_queue.put(("agent", f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"))
|
||||
# ============ 执行 PostAgent hooks ============
|
||||
# 注意:这里在单独的异步任务中执行,不阻塞流式输出
|
||||
@ -190,6 +204,11 @@ async def enhanced_generate_stream_response(
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 检查是否收到取消信号
|
||||
if cancel_event and cancel_event.is_set():
|
||||
logger.info(f"Output loop cancelled for session_id={config.session_id}")
|
||||
break
|
||||
|
||||
# 检查是否还有任务在运行
|
||||
if all(task.done() for task in [preamble_task_handle, agent_task_handle]):
|
||||
# 所有任务都完成了,退出循环
|
||||
@ -203,8 +222,13 @@ async def enhanced_generate_stream_response(
|
||||
|
||||
# 发送结束标记
|
||||
yield "data: [DONE]\n\n"
|
||||
# 清理取消事件
|
||||
if config.session_id:
|
||||
from utils.cancel_manager import unregister_cancel_event
|
||||
unregister_cancel_event(config.session_id)
|
||||
logger.info(f"Enhanced stream response completed")
|
||||
|
||||
|
||||
# 流式结束后保存 AI 响应
|
||||
if full_response_content and config.session_id:
|
||||
asyncio.create_task(_save_assistant_response(config, "".join(full_response_content)))
|
||||
@ -213,6 +237,10 @@ async def enhanced_generate_stream_response(
|
||||
logger.error(f"Error in enhanced_generate_stream_response: {e}")
|
||||
yield f'data: {{"error": "{str(e)}"}}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
# 清理取消事件
|
||||
if config.session_id:
|
||||
from utils.cancel_manager import unregister_cancel_event
|
||||
unregister_cancel_event(config.session_id)
|
||||
|
||||
|
||||
async def create_agent_and_generate_response(
|
||||
@ -767,6 +795,25 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
|
||||
logger.error(f"Full traceback: {error_details}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
@router.post("/api/v1/chat/cancel")
|
||||
async def cancel_chat(session_id: str = Body(..., embed=True)):
|
||||
"""
|
||||
取消正在进行的 agent 推理
|
||||
|
||||
请求体: {"session_id": "xxxxx"}
|
||||
响应: {"success": true/false, "message": "..."}
|
||||
"""
|
||||
from utils.cancel_manager import trigger_cancel
|
||||
|
||||
if not session_id:
|
||||
raise HTTPException(status_code=400, detail="session_id is required")
|
||||
|
||||
found = trigger_cancel(session_id)
|
||||
if found:
|
||||
return {"success": True, "message": f"Cancel signal sent for session_id={session_id}"}
|
||||
else:
|
||||
return {"success": False, "message": f"No active inference found for session_id={session_id}"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 聊天历史查询接口
|
||||
|
||||
@ -68,7 +68,6 @@ class ChatRequestV2(BaseModel):
|
||||
language: Optional[str] = "zh"
|
||||
user_identifier: Optional[str] = ""
|
||||
session_id: Optional[str] = None
|
||||
shell_env: Optional[Dict[str, str]] = None
|
||||
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
33
utils/cancel_manager.py
Normal file
33
utils/cancel_manager.py
Normal file
@ -0,0 +1,33 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
# 全局取消注册表: session_id -> asyncio.Event
|
||||
_cancel_registry: Dict[str, asyncio.Event] = {}
|
||||
|
||||
|
||||
def register_cancel_event(session_id: str) -> asyncio.Event:
|
||||
"""注册一个取消事件"""
|
||||
event = asyncio.Event()
|
||||
_cancel_registry[session_id] = event
|
||||
logger.debug(f"Cancel event registered for session_id={session_id}")
|
||||
return event
|
||||
|
||||
|
||||
def trigger_cancel(session_id: str) -> bool:
|
||||
"""触发取消事件"""
|
||||
event = _cancel_registry.get(session_id)
|
||||
if event:
|
||||
event.set()
|
||||
logger.info(f"Cancel triggered for session_id={session_id}")
|
||||
return True
|
||||
logger.warning(f"No active session found for session_id={session_id}")
|
||||
return False
|
||||
|
||||
|
||||
def unregister_cancel_event(session_id: str) -> None:
|
||||
"""清理取消事件"""
|
||||
_cancel_registry.pop(session_id, None)
|
||||
logger.debug(f"Cancel event unregistered for session_id={session_id}")
|
||||
Loading…
Reference in New Issue
Block a user