From 0d50cd8e9ff0e75a733a9dacc6937c08d48db18e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Mon, 15 Dec 2025 21:36:13 +0800 Subject: [PATCH] session_id --- agent/deep_assistant.py | 51 +++++++++++++++++++++++++++++++--- agent/sharded_agent_manager.py | 4 ++- routes/chat.py | 32 +++++++++++++++------ utils/api_models.py | 2 ++ 4 files changed, 75 insertions(+), 14 deletions(-) diff --git a/agent/deep_assistant.py b/agent/deep_assistant.py index b596755..32b0f6f 100644 --- a/agent/deep_assistant.py +++ b/agent/deep_assistant.py @@ -1,15 +1,22 @@ import json import logging +import os +import sqlite3 from typing import Any, Dict, Optional from langchain.chat_models import init_chat_model # from deepagents import create_deep_agent from langchain.agents import create_agent +from langchain.agents.middleware import SummarizationMiddleware from langchain_mcp_adapters.client import MultiServerMCPClient from langchain_core.callbacks import BaseCallbackHandler +from langgraph.checkpoint.memory import MemorySaver from utils.fastapi_utils import detect_provider from .guideline_middleware import GuidelineMiddleware +MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 65536)) +MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000)) +SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000 class LoggingCallbackHandler(BaseCallbackHandler): """自定义的 CallbackHandler,使用项目的 logger 来记录日志""" @@ -120,7 +127,24 @@ async def get_tools_from_mcp(mcp): async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None, model_server=None, generate_cfg=None, - system_prompt=None, mcp=None, robot_type=None, language="jp", user_identifier=None): + system_prompt=None, mcp=None, robot_type=None, language="jp", user_identifier=None, + session_id=None): + """ + 初始化 Agent,支持持久化内存和对话摘要 + + Args: + bot_id: Bot ID + model_name: 模型名称 + api_key: API密钥 + model_server: 模型服务器地址 + generate_cfg: 生成配置 + system_prompt: 系统提示 + mcp: MCP配置 + robot_type: 机器人类型 + language: 语言 + user_identifier: 用户标识 + session_id: 会话ID(如果为None,则不启用持久化内存) + """ system = system_prompt if system_prompt else read_system_prompt() mcp = mcp if mcp else read_mcp_settings() mcp_tools = await get_tools_from_mcp(mcp) @@ -143,14 +167,33 @@ async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None, # 创建自定义的日志处理器 logging_handler = LoggingCallbackHandler() + # 构建中间件列表 + middleware = [GuidelineMiddleware(bot_id, llm_instance, system, robot_type, language, user_identifier)] + + # 初始化 checkpointer 和中间件 + checkpointer = None + + if session_id: + checkpointer = MemorySaver() + summarization_middleware = SummarizationMiddleware( + model=llm_instance, + max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS, + messages_to_keep=20, # 摘要后保留最近 20 条消息 + summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。" + ) + middleware.append(summarization_middleware) + agent = create_agent( model=llm_instance, system_prompt=system, tools=mcp_tools, - middleware=[GuidelineMiddleware(bot_id, llm_instance, system, robot_type, language, user_identifier)] + middleware=middleware, + checkpointer=checkpointer # 传入 checkpointer 以启用持久化 ) - # 将 handler 存储在 agent 的属性中,方便在调用时使用 + # 将 handler 和 checkpointer 存储在 agent 的属性中,方便在调用时使用 agent.logging_handler = logging_handler - + agent.checkpointer = checkpointer + agent.bot_id = bot_id + agent.session_id = session_id return agent diff --git a/agent/sharded_agent_manager.py b/agent/sharded_agent_manager.py index d09b9f5..c71c7d7 100644 --- a/agent/sharded_agent_manager.py +++ b/agent/sharded_agent_manager.py @@ -127,7 +127,8 @@ class ShardedAgentManager: system_prompt: Optional[str] = None, mcp_settings: Optional[List[Dict]] = None, robot_type: Optional[str] = "general_agent", - user_identifier: Optional[str] = None): + user_identifier: Optional[str] = None, + session_id: Optional[str] = None): """获取或创建文件预加载的助手实例""" # 更新请求统计 @@ -201,6 +202,7 @@ class ShardedAgentManager: robot_type=robot_type, language=language, user_identifier=user_identifier, + session_id=session_id ) # 缓存实例 diff --git a/routes/chat.py b/routes/chat.py index 5a540ed..a584748 100644 --- a/routes/chat.py +++ b/routes/chat.py @@ -82,7 +82,8 @@ async def enhanced_generate_stream_response( robot_type: str, project_dir: Optional[str], generate_cfg: Optional[dict], - user_identifier: Optional[str] + user_identifier: Optional[str], + session_id: Optional[str] = None, ): """增强的渐进式流式响应生成器 - 并发优化版本""" try: @@ -133,7 +134,8 @@ async def enhanced_generate_stream_response( system_prompt=system_prompt, mcp_settings=mcp_settings, robot_type=robot_type, - user_identifier=user_identifier + user_identifier=user_identifier, + session_id=session_id, ) # 开始流式处理 @@ -141,7 +143,11 @@ async def enhanced_generate_stream_response( chunk_id = 0 message_tag = "" - config = {"callbacks": [agent.logging_handler]} if hasattr(agent, 'logging_handler') else {} + config = {} + if session_id: + config["configurable"] = {"thread_id": session_id} + if hasattr(agent, 'logging_handler'): + config["callbacks"] = [agent.logging_handler] async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config): new_content = "" @@ -265,7 +271,8 @@ async def create_agent_and_generate_response( robot_type: str, project_dir: Optional[str] = None, generate_cfg: Optional[dict] = None, - user_identifier: Optional[str] = None + user_identifier: Optional[str] = None, + session_id: Optional[str] = None ) -> Union[ChatResponse, StreamingResponse]: """创建agent并生成响应的公共逻辑""" if generate_cfg is None: @@ -288,7 +295,8 @@ async def create_agent_and_generate_response( robot_type=robot_type, project_dir=project_dir, generate_cfg=generate_cfg, - user_identifier=user_identifier + user_identifier=user_identifier, + session_id=session_id ), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} @@ -307,14 +315,19 @@ async def create_agent_and_generate_response( system_prompt=system_prompt, mcp_settings=mcp_settings, robot_type=robot_type, - user_identifier=user_identifier + user_identifier=user_identifier, + session_id=session_id, ) # 准备最终的消息 final_messages = messages.copy() # 非流式响应 - config = {"callbacks": [agent.logging_handler]} if hasattr(agent, 'logging_handler') else {} + config = {} + if session_id: + config["configurable"] = {"thread_id": session_id} + if hasattr(agent, 'logging_handler'): + config["callbacks"] = [agent.logging_handler] agent_responses = await agent.ainvoke({"messages": final_messages}, config=config) append_messages = agent_responses["messages"][len(final_messages):] response_text = "" @@ -396,7 +409,7 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] = project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type) # 收集额外参数作为 generate_cfg - exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier'} + exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id'} generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields} # 处理消息 @@ -417,7 +430,8 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] = robot_type=request.robot_type, project_dir=project_dir, generate_cfg=generate_cfg, - user_identifier=request.user_identifier + user_identifier=request.user_identifier, + session_id=request.session_id ) except Exception as e: diff --git a/utils/api_models.py b/utils/api_models.py index 2dc1c89..6656506 100644 --- a/utils/api_models.py +++ b/utils/api_models.py @@ -52,6 +52,7 @@ class ChatRequest(BaseModel): mcp_settings: Optional[List[Dict]] = None robot_type: Optional[str] = "general_agent" user_identifier: Optional[str] = "" + session_id: Optional[str] = None class ChatRequestV2(BaseModel): @@ -61,6 +62,7 @@ class ChatRequestV2(BaseModel): bot_id: str language: Optional[str] = "zh" user_identifier: Optional[str] = "" + session_id: Optional[str] = None class FileProcessRequest(BaseModel):