import json import logging import os import sqlite3 from typing import Any, Dict, Optional from langchain.chat_models import init_chat_model # from deepagents import create_deep_agent from langchain.agents import create_agent from langchain.agents.middleware import SummarizationMiddleware from langchain_mcp_adapters.client import MultiServerMCPClient from langchain_core.callbacks import BaseCallbackHandler from langgraph.checkpoint.memory import MemorySaver from utils.fastapi_utils import detect_provider from .guideline_middleware import GuidelineMiddleware from utils.settings import SUMMARIZATION_MAX_TOKENS class LoggingCallbackHandler(BaseCallbackHandler): """自定义的 CallbackHandler,使用项目的 logger 来记录日志""" def __init__(self, logger_name: str = 'app'): self.logger = logging.getLogger(logger_name) # def on_llm_start( # self, serialized: Optional[Dict[str, Any]], prompts: list[str], **kwargs: Any # ) -> None: # """当 LLM 开始时调用""" # self.logger.info("🤖 LLM Start - Input Messages:") # if prompts: # for i, prompt in enumerate(prompts): # self.logger.info(f" Message {i+1}:\n{prompt}") # else: # self.logger.info(" No prompts") def on_llm_end(self, response, **kwargs: Any) -> None: """当 LLM 结束时调用""" self.logger.info("✅ LLM End - Output:") # 打印生成的文本 if hasattr(response, 'generations') and response.generations: for gen_idx, generation_list in enumerate(response.generations): for msg_idx, generation in enumerate(generation_list): if hasattr(generation, 'text'): output_list = generation.text.split("\n") for i, output in enumerate(output_list): if output.strip(): self.logger.info(f"{output}") elif hasattr(generation, 'message'): output_list = generation.message.split("\n") for i, output in enumerate(output_list): if output.strip(): self.logger.info(f"{output}") def on_llm_error( self, error: Exception, **kwargs: Any ) -> None: """当 LLM 出错时调用""" self.logger.error(f"❌ LLM Error: {error}") def on_tool_start( self, serialized: Optional[Dict[str, Any]], input_str: str, **kwargs: Any ) -> None: """当工具开始调用时调用""" if serialized is None: tool_name = 'unknown_tool' else: tool_name = serialized.get('name', 'unknown_tool') self.logger.info(f"🔧 Tool Start - {tool_name} with input: {str(input_str)[:100]}") def on_tool_end(self, output: str, **kwargs: Any) -> None: """当工具调用结束时调用""" self.logger.info(f"✅ Tool End Output: {output}") def on_tool_error( self, error: Exception, **kwargs: Any ) -> None: """当工具调用出错时调用""" self.logger.error(f"❌ Tool Error: {error}") def on_agent_action(self, action, **kwargs: Any) -> None: """当 Agent 执行动作时调用""" self.logger.info(f"🎯 Agent Action: {action.log}") # Utility functions def read_system_prompt(): """读取通用的无状态系统prompt""" with open("./prompt/system_prompt_default.md", "r", encoding="utf-8") as f: return f.read().strip() def read_mcp_settings(): """读取MCP工具配置""" with open("./mcp/mcp_settings.json", "r") as f: mcp_settings_json = json.load(f) return mcp_settings_json async def get_tools_from_mcp(mcp): """从MCP配置中提取工具""" # 防御式处理:确保 mcp 是列表且长度大于 0,且包含 mcpServers if not isinstance(mcp, list) or len(mcp) == 0 or "mcpServers" not in mcp[0]: return [] # 修改 mcp[0]["mcpServers"] 列表,把 type 字段改成 transport # 如果没有 transport,则根据是否存在 url 默认 transport 为 http 或 stdio for cfg in mcp[0]["mcpServers"].values(): if "type" in cfg: cfg.pop("type") if "transport" not in cfg: cfg["transport"] = "http" if "url" in cfg else "stdio" # 确保 mcp[0]["mcpServers"] 是字典类型 if not isinstance(mcp[0]["mcpServers"], dict): return [] try: mcp_client = MultiServerMCPClient(mcp[0]["mcpServers"]) mcp_tools = await mcp_client.get_tools() return mcp_tools except Exception: # 发生异常时返回空列表,避免上层调用报错 return [] async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None, model_server=None, generate_cfg=None, system_prompt=None, mcp=None, robot_type=None, language="jp", user_identifier=None, session_id=None): """ 初始化 Agent,支持持久化内存和对话摘要 Args: bot_id: Bot ID model_name: 模型名称 api_key: API密钥 model_server: 模型服务器地址 generate_cfg: 生成配置 system_prompt: 系统提示 mcp: MCP配置 robot_type: 机器人类型 language: 语言 user_identifier: 用户标识 session_id: 会话ID(如果为None,则不启用持久化内存) """ system = system_prompt if system_prompt else read_system_prompt() mcp = mcp if mcp else read_mcp_settings() mcp_tools = await get_tools_from_mcp(mcp) # 检测或使用指定的提供商 model_provider,base_url = detect_provider(model_name,model_server) # 构建模型参数 model_kwargs = { "model": model_name, "model_provider": model_provider, "temperature": 0.8, "base_url": base_url, "api_key": api_key } if generate_cfg: model_kwargs.update(generate_cfg) llm_instance = init_chat_model(**model_kwargs) # 创建自定义的日志处理器 logging_handler = LoggingCallbackHandler() # 构建中间件列表 middleware = [GuidelineMiddleware(bot_id, llm_instance, system, robot_type, language, user_identifier)] # 初始化 checkpointer 和中间件 checkpointer = None if session_id: checkpointer = MemorySaver() summarization_middleware = SummarizationMiddleware( model=llm_instance, max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS, messages_to_keep=20, # 摘要后保留最近 20 条消息 summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。" ) middleware.append(summarization_middleware) agent = create_agent( model=llm_instance, system_prompt=system, tools=mcp_tools, middleware=middleware, checkpointer=checkpointer # 传入 checkpointer 以启用持久化 ) # 将 handler 和 checkpointer 存储在 agent 的属性中,方便在调用时使用 agent.logging_handler = logging_handler agent.checkpointer = checkpointer agent.bot_id = bot_id agent.session_id = session_id return agent