This commit is contained in:
朱潮 2026-03-16 22:24:44 +08:00
commit 8e52b787f8
2 changed files with 82 additions and 2 deletions

View File

@ -4,7 +4,7 @@ import asyncio
import shutil
import time
from typing import Union, Optional, Any, List, Dict
from fastapi import APIRouter, HTTPException, Header
from fastapi import APIRouter, HTTPException, Header, Body
from fastapi.responses import StreamingResponse
import logging
@ -39,11 +39,19 @@ async def enhanced_generate_stream_response(
# 用于收集完整的响应内容,用于保存到数据库
full_response_content = []
# 取消管理
cancel_event = None
try:
# 创建输出队列和控制事件
output_queue = asyncio.Queue()
preamble_completed = asyncio.Event()
# 注册取消事件
if config.session_id:
from utils.cancel_manager import register_cancel_event, unregister_cancel_event
cancel_event = register_cancel_event(config.session_id)
# 在流式开始前保存用户消息
if config.session_id:
asyncio.create_task(_save_user_messages(config))
@ -81,6 +89,11 @@ async def enhanced_generate_stream_response(
message_tag = ""
agent, checkpointer = await init_agent(config)
async for msg, metadata in agent.astream({"messages": config.messages}, stream_mode="messages", config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS):
# 检查是否收到取消信号
if cancel_event and cancel_event.is_set():
logger.info(f"Agent stream cancelled for session_id={config.session_id}")
break
new_content = ""
if isinstance(msg, AIMessageChunk):
@ -124,7 +137,8 @@ async def enhanced_generate_stream_response(
await output_queue.put(("agent", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"))
# 发送最终chunk
final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", config.model_name, finish_reason="stop")
finish = "cancelled" if (cancel_event and cancel_event.is_set()) else "stop"
final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", config.model_name, finish_reason=finish)
await output_queue.put(("agent", f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"))
# ============ 执行 PostAgent hooks ============
# 注意:这里在单独的异步任务中执行,不阻塞流式输出
@ -190,6 +204,11 @@ async def enhanced_generate_stream_response(
break
except asyncio.TimeoutError:
# 检查是否收到取消信号
if cancel_event and cancel_event.is_set():
logger.info(f"Output loop cancelled for session_id={config.session_id}")
break
# 检查是否还有任务在运行
if all(task.done() for task in [preamble_task_handle, agent_task_handle]):
# 所有任务都完成了,退出循环
@ -203,8 +222,13 @@ async def enhanced_generate_stream_response(
# 发送结束标记
yield "data: [DONE]\n\n"
# 清理取消事件
if config.session_id:
from utils.cancel_manager import unregister_cancel_event
unregister_cancel_event(config.session_id)
logger.info(f"Enhanced stream response completed")
# 流式结束后保存 AI 响应
if full_response_content and config.session_id:
asyncio.create_task(_save_assistant_response(config, "".join(full_response_content)))
@ -213,6 +237,10 @@ async def enhanced_generate_stream_response(
logger.error(f"Error in enhanced_generate_stream_response: {e}")
yield f'data: {{"error": "{str(e)}"}}\n\n'
yield "data: [DONE]\n\n"
# 清理取消事件
if config.session_id:
from utils.cancel_manager import unregister_cancel_event
unregister_cancel_event(config.session_id)
async def create_agent_and_generate_response(
@ -767,6 +795,25 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
logger.error(f"Full traceback: {error_details}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@router.post("/api/v1/chat/cancel")
async def cancel_chat(session_id: str = Body(..., embed=True)):
"""
取消正在进行的 agent 推理
请求体: {"session_id": "xxxxx"}
响应: {"success": true/false, "message": "..."}
"""
from utils.cancel_manager import trigger_cancel
if not session_id:
raise HTTPException(status_code=400, detail="session_id is required")
found = trigger_cancel(session_id)
if found:
return {"success": True, "message": f"Cancel signal sent for session_id={session_id}"}
else:
return {"success": False, "message": f"No active inference found for session_id={session_id}"}
@router.post("/api/v3/chat/completions")
async def chat_completions_v3(request: ChatRequestV3, authorization: Optional[str] = Header(None)):

33
utils/cancel_manager.py Normal file
View File

@ -0,0 +1,33 @@
import asyncio
import logging
from typing import Dict
logger = logging.getLogger('app')
# 全局取消注册表: session_id -> asyncio.Event
_cancel_registry: Dict[str, asyncio.Event] = {}
def register_cancel_event(session_id: str) -> asyncio.Event:
"""注册一个取消事件"""
event = asyncio.Event()
_cancel_registry[session_id] = event
logger.debug(f"Cancel event registered for session_id={session_id}")
return event
def trigger_cancel(session_id: str) -> bool:
"""触发取消事件"""
event = _cancel_registry.get(session_id)
if event:
event.set()
logger.info(f"Cancel triggered for session_id={session_id}")
return True
logger.warning(f"No active session found for session_id={session_id}")
return False
def unregister_cancel_event(session_id: str) -> None:
"""清理取消事件"""
_cancel_registry.pop(session_id, None)
logger.debug(f"Cancel event unregistered for session_id={session_id}")