import json import logging import os import sqlite3 from typing import Any, Dict, Optional, List from langchain.chat_models import init_chat_model # from deepagents import create_deep_agent from langchain.agents import create_agent from langchain.agents.middleware import SummarizationMiddleware from langchain_mcp_adapters.client import MultiServerMCPClient from langgraph.checkpoint.memory import MemorySaver from utils.fastapi_utils import detect_provider from .guideline_middleware import GuidelineMiddleware from .tool_output_length_middleware import ToolOutputLengthMiddleware from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH from agent.agent_config import AgentConfig # Utility functions def read_system_prompt(): """读取通用的无状态系统prompt""" with open("./prompt/system_prompt_default.md", "r", encoding="utf-8") as f: return f.read().strip() def read_mcp_settings(): """读取MCP工具配置""" with open("./mcp/mcp_settings.json", "r") as f: mcp_settings_json = json.load(f) return mcp_settings_json async def get_tools_from_mcp(mcp): """从MCP配置中提取工具""" # 防御式处理:确保 mcp 是列表且长度大于 0,且包含 mcpServers if not isinstance(mcp, list) or len(mcp) == 0 or "mcpServers" not in mcp[0]: return [] # 修改 mcp[0]["mcpServers"] 列表,把 type 字段改成 transport # 如果没有 transport,则根据是否存在 url 默认 transport 为 http 或 stdio for cfg in mcp[0]["mcpServers"].values(): if "type" in cfg: cfg.pop("type") if "transport" not in cfg: cfg["transport"] = "http" if "url" in cfg else "stdio" # 确保 mcp[0]["mcpServers"] 是字典类型 if not isinstance(mcp[0]["mcpServers"], dict): return [] try: mcp_client = MultiServerMCPClient(mcp[0]["mcpServers"]) mcp_tools = await mcp_client.get_tools() return mcp_tools except Exception: # 发生异常时返回空列表,避免上层调用报错 return [] async def init_agent(config: AgentConfig): """ 初始化 Agent,支持持久化内存和对话摘要 Args: config: AgentConfig 对象,包含所有初始化参数 mcp: MCP配置(如果为None则使用配置中的mcp_settings) """ # 如果没有提供mcp,使用config中的mcp_settings mcp_settings = config.mcp_settings if config.mcp_settings else read_mcp_settings() system_prompt = config.system_prompt if config.system_prompt else read_system_prompt() mcp_tools = await get_tools_from_mcp(mcp_settings) # 检测或使用指定的提供商 model_provider,base_url = detect_provider(config.model_name, config.model_server) # 构建模型参数 model_kwargs = { "model": config.model_name, "model_provider": model_provider, "temperature": 0.8, "base_url": base_url, "api_key": config.api_key } if config.generate_cfg: model_kwargs.update(config.generate_cfg) llm_instance = init_chat_model(**model_kwargs) # 构建中间件列表 middleware = [] # 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware if config.enable_thinking: middleware.append(GuidelineMiddleware(llm_instance, config, system_prompt)) # 添加工具输出长度控制中间件 tool_output_middleware = ToolOutputLengthMiddleware( max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH, truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart', tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None, # 可配置特定工具 exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [], # 排除的工具 preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True, preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True ) middleware.append(tool_output_middleware) # 初始化 checkpointer 和中间件 checkpointer = None if config.session_id: checkpointer = MemorySaver() summarization_middleware = SummarizationMiddleware( model=llm_instance, max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS, messages_to_keep=20, # 摘要后保留最近 20 条消息 summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。" ) middleware.append(summarization_middleware) agent = create_agent( model=llm_instance, system_prompt=system_prompt, tools=mcp_tools, middleware=middleware, checkpointer=checkpointer # 传入 checkpointer 以启用持久化 ) return agent