Compare commits
2 Commits
023189e943
...
8e52b787f8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8e52b787f8 | ||
|
|
c27270588f |
@ -4,7 +4,7 @@ import asyncio
|
|||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
from typing import Union, Optional, Any, List, Dict
|
from typing import Union, Optional, Any, List, Dict
|
||||||
from fastapi import APIRouter, HTTPException, Header
|
from fastapi import APIRouter, HTTPException, Header, Body
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -39,11 +39,19 @@ async def enhanced_generate_stream_response(
|
|||||||
# 用于收集完整的响应内容,用于保存到数据库
|
# 用于收集完整的响应内容,用于保存到数据库
|
||||||
full_response_content = []
|
full_response_content = []
|
||||||
|
|
||||||
|
# 取消管理
|
||||||
|
cancel_event = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 创建输出队列和控制事件
|
# 创建输出队列和控制事件
|
||||||
output_queue = asyncio.Queue()
|
output_queue = asyncio.Queue()
|
||||||
preamble_completed = asyncio.Event()
|
preamble_completed = asyncio.Event()
|
||||||
|
|
||||||
|
# 注册取消事件
|
||||||
|
if config.session_id:
|
||||||
|
from utils.cancel_manager import register_cancel_event, unregister_cancel_event
|
||||||
|
cancel_event = register_cancel_event(config.session_id)
|
||||||
|
|
||||||
# 在流式开始前保存用户消息
|
# 在流式开始前保存用户消息
|
||||||
if config.session_id:
|
if config.session_id:
|
||||||
asyncio.create_task(_save_user_messages(config))
|
asyncio.create_task(_save_user_messages(config))
|
||||||
@ -81,6 +89,11 @@ async def enhanced_generate_stream_response(
|
|||||||
message_tag = ""
|
message_tag = ""
|
||||||
agent, checkpointer = await init_agent(config)
|
agent, checkpointer = await init_agent(config)
|
||||||
async for msg, metadata in agent.astream({"messages": config.messages}, stream_mode="messages", config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS):
|
async for msg, metadata in agent.astream({"messages": config.messages}, stream_mode="messages", config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS):
|
||||||
|
# 检查是否收到取消信号
|
||||||
|
if cancel_event and cancel_event.is_set():
|
||||||
|
logger.info(f"Agent stream cancelled for session_id={config.session_id}")
|
||||||
|
break
|
||||||
|
|
||||||
new_content = ""
|
new_content = ""
|
||||||
|
|
||||||
if isinstance(msg, AIMessageChunk):
|
if isinstance(msg, AIMessageChunk):
|
||||||
@ -124,7 +137,8 @@ async def enhanced_generate_stream_response(
|
|||||||
await output_queue.put(("agent", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"))
|
await output_queue.put(("agent", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"))
|
||||||
|
|
||||||
# 发送最终chunk
|
# 发送最终chunk
|
||||||
final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", config.model_name, finish_reason="stop")
|
finish = "cancelled" if (cancel_event and cancel_event.is_set()) else "stop"
|
||||||
|
final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", config.model_name, finish_reason=finish)
|
||||||
await output_queue.put(("agent", f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"))
|
await output_queue.put(("agent", f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"))
|
||||||
# ============ 执行 PostAgent hooks ============
|
# ============ 执行 PostAgent hooks ============
|
||||||
# 注意:这里在单独的异步任务中执行,不阻塞流式输出
|
# 注意:这里在单独的异步任务中执行,不阻塞流式输出
|
||||||
@ -190,6 +204,11 @@ async def enhanced_generate_stream_response(
|
|||||||
break
|
break
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
# 检查是否收到取消信号
|
||||||
|
if cancel_event and cancel_event.is_set():
|
||||||
|
logger.info(f"Output loop cancelled for session_id={config.session_id}")
|
||||||
|
break
|
||||||
|
|
||||||
# 检查是否还有任务在运行
|
# 检查是否还有任务在运行
|
||||||
if all(task.done() for task in [preamble_task_handle, agent_task_handle]):
|
if all(task.done() for task in [preamble_task_handle, agent_task_handle]):
|
||||||
# 所有任务都完成了,退出循环
|
# 所有任务都完成了,退出循环
|
||||||
@ -203,8 +222,13 @@ async def enhanced_generate_stream_response(
|
|||||||
|
|
||||||
# 发送结束标记
|
# 发送结束标记
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
# 清理取消事件
|
||||||
|
if config.session_id:
|
||||||
|
from utils.cancel_manager import unregister_cancel_event
|
||||||
|
unregister_cancel_event(config.session_id)
|
||||||
logger.info(f"Enhanced stream response completed")
|
logger.info(f"Enhanced stream response completed")
|
||||||
|
|
||||||
|
|
||||||
# 流式结束后保存 AI 响应
|
# 流式结束后保存 AI 响应
|
||||||
if full_response_content and config.session_id:
|
if full_response_content and config.session_id:
|
||||||
asyncio.create_task(_save_assistant_response(config, "".join(full_response_content)))
|
asyncio.create_task(_save_assistant_response(config, "".join(full_response_content)))
|
||||||
@ -213,6 +237,10 @@ async def enhanced_generate_stream_response(
|
|||||||
logger.error(f"Error in enhanced_generate_stream_response: {e}")
|
logger.error(f"Error in enhanced_generate_stream_response: {e}")
|
||||||
yield f'data: {{"error": "{str(e)}"}}\n\n'
|
yield f'data: {{"error": "{str(e)}"}}\n\n'
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
# 清理取消事件
|
||||||
|
if config.session_id:
|
||||||
|
from utils.cancel_manager import unregister_cancel_event
|
||||||
|
unregister_cancel_event(config.session_id)
|
||||||
|
|
||||||
|
|
||||||
async def create_agent_and_generate_response(
|
async def create_agent_and_generate_response(
|
||||||
@ -767,6 +795,25 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
|
|||||||
logger.error(f"Full traceback: {error_details}")
|
logger.error(f"Full traceback: {error_details}")
|
||||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||||
|
|
||||||
|
@router.post("/api/v1/chat/cancel")
|
||||||
|
async def cancel_chat(session_id: str = Body(..., embed=True)):
|
||||||
|
"""
|
||||||
|
取消正在进行的 agent 推理
|
||||||
|
|
||||||
|
请求体: {"session_id": "xxxxx"}
|
||||||
|
响应: {"success": true/false, "message": "..."}
|
||||||
|
"""
|
||||||
|
from utils.cancel_manager import trigger_cancel
|
||||||
|
|
||||||
|
if not session_id:
|
||||||
|
raise HTTPException(status_code=400, detail="session_id is required")
|
||||||
|
|
||||||
|
found = trigger_cancel(session_id)
|
||||||
|
if found:
|
||||||
|
return {"success": True, "message": f"Cancel signal sent for session_id={session_id}"}
|
||||||
|
else:
|
||||||
|
return {"success": False, "message": f"No active inference found for session_id={session_id}"}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/api/v3/chat/completions")
|
@router.post("/api/v3/chat/completions")
|
||||||
async def chat_completions_v3(request: ChatRequestV3, authorization: Optional[str] = Header(None)):
|
async def chat_completions_v3(request: ChatRequestV3, authorization: Optional[str] = Header(None)):
|
||||||
|
|||||||
33
utils/cancel_manager.py
Normal file
33
utils/cancel_manager.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
logger = logging.getLogger('app')
|
||||||
|
|
||||||
|
# 全局取消注册表: session_id -> asyncio.Event
|
||||||
|
_cancel_registry: Dict[str, asyncio.Event] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_cancel_event(session_id: str) -> asyncio.Event:
|
||||||
|
"""注册一个取消事件"""
|
||||||
|
event = asyncio.Event()
|
||||||
|
_cancel_registry[session_id] = event
|
||||||
|
logger.debug(f"Cancel event registered for session_id={session_id}")
|
||||||
|
return event
|
||||||
|
|
||||||
|
|
||||||
|
def trigger_cancel(session_id: str) -> bool:
|
||||||
|
"""触发取消事件"""
|
||||||
|
event = _cancel_registry.get(session_id)
|
||||||
|
if event:
|
||||||
|
event.set()
|
||||||
|
logger.info(f"Cancel triggered for session_id={session_id}")
|
||||||
|
return True
|
||||||
|
logger.warning(f"No active session found for session_id={session_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def unregister_cancel_event(session_id: str) -> None:
|
||||||
|
"""清理取消事件"""
|
||||||
|
_cancel_registry.pop(session_id, None)
|
||||||
|
logger.debug(f"Cancel event unregistered for session_id={session_id}")
|
||||||
Loading…
Reference in New Issue
Block a user