""" Mem0 Agent 中间件 实现记忆召回和存储的 AgentMiddleware """ import asyncio import logging import threading from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from langchain.agents.middleware import AgentMiddleware, AgentState, ModelRequest from langgraph.runtime import Runtime from .mem0_config import Mem0Config from .mem0_manager import Mem0Manager, get_mem0_manager logger = logging.getLogger("app") # 避免循环导入 if TYPE_CHECKING: from langchain_core.language_models import BaseChatModel class Mem0Middleware(AgentMiddleware): """ Mem0 记忆中间件 功能: 1. before_agent: 召回相关记忆并注入到上下文 2. after_agent: 后台异步提取和存储新记忆 """ def __init__( self, mem0_manager: Mem0Manager, config: Mem0Config, agent_config: "AgentConfig", ): """初始化 Mem0Middleware Args: mem0_manager: Mem0Manager 实例 config: Mem0Config 配置 agent_config: AgentConfig 实例,用于中间件间传递数据 """ self.mem0_manager = mem0_manager self.config = config self.agent_config = agent_config def _extract_user_query(self, state: AgentState) -> str: """从状态中提取用户查询(最后一条 HumanMessage) Args: state: Agent 状态 Returns: 用户查询文本 """ from langchain_core.messages import HumanMessage messages = state.get("messages", []) if not messages: return "" # 查找最后一条 HumanMessage for msg in reversed(messages): if isinstance(msg, HumanMessage): return str(msg.content) if msg.content else "" return "" def _extract_agent_response(self, state: AgentState) -> str: """从状态中提取 Agent 响应(最后一条 AIMessage) Args: state: Agent 状态 Returns: Agent 响应文本 """ from langchain_core.messages import AIMessage messages = state.get("messages", []) if not messages: return "" # 查找最后一条 AIMessage for msg in reversed(messages): if isinstance(msg, AIMessage): return str(msg.content) if msg.content else "" return "" def _format_memories(self, memories: List[Dict[str, Any]]) -> str: """格式化记忆列表为文本 Args: memories: 记忆列表 Returns: 格式化的记忆文本 """ if not memories: return "" lines = [] for i, memory in enumerate(memories, 1): content = memory.get("content", "") fact_type = memory.get("fact_type", "fact") lines.append(f"{i}. [{fact_type}] {content}") return "\n".join(lines) def before_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None: """Agent 执行前:召回相关记忆(同步版本) Args: state: Agent 状态 runtime: 运行时上下文 Returns: 更新后的状态或 None """ if not self.config.is_enabled(): return None try: # 提取用户查询 query = self._extract_user_query(state) if not query: return None # 获取 attribution 参数 user_id, agent_id = self.config.get_attribution_tuple() session_id = self.config.session_id or runtime.config.get("configurable", {}).get("thread_id", "default") # 召回记忆(同步方式 - 在后台任务中执行) memories = asyncio.run(self._recall_memories_async(query, user_id, agent_id, session_id)) if memories: # 格式化记忆并拼接 memory_prompt memory_text = self._format_memories(memories) memory_prompt = self.config.get_memory_prompt([memory_text]) self.agent_config._mem0_context = memory_prompt logger.info(f"Recalled {len(memories)} memories for context") else: self.agent_config._mem0_context = None return state except Exception as e: logger.error(f"Error in Mem0Middleware.before_agent: {e}") return None async def abefore_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None: """Agent 执行前:召回相关记忆(异步版本) Args: state: Agent 状态 runtime: 运行时上下文 Returns: 更新后的状态或 None """ if not self.config.is_enabled(): return None try: # 提取用户查询 query = self._extract_user_query(state) if not query: logger.debug("No user query found, skipping memory recall") return None # 获取 attribution 参数 user_id, agent_id = self.config.get_attribution_tuple() # 召回记忆(用户级别,跨会话) memories = await self._recall_memories_async(query, user_id, agent_id) if memories: # 格式化记忆并拼接 memory_prompt memory_text = self._format_memories(memories) memory_prompt = self.config.get_memory_prompt([memory_text]) self.agent_config._mem0_context = memory_prompt logger.info(f"Recalled {len(memories)} memories for context") else: self.agent_config._mem0_context = None return state except Exception as e: logger.error(f"Error in Mem0Middleware.abefore_agent: {e}") return None async def _recall_memories_async( self, query: str, user_id: str, agent_id: str ) -> List[Dict[str, Any]]: """异步召回记忆 Args: query: 查询文本 user_id: 用户 ID agent_id: Agent/Bot ID Returns: 记忆列表 """ return await self.mem0_manager.recall_memories( query=query, user_id=user_id, agent_id=agent_id, config=self.config, ) def after_agent(self, state: AgentState, runtime: Runtime) -> None: """Agent 执行后:触发记忆增强(同步版本) 使用后台线程执行,避免阻塞主流程 Args: state: Agent 状态 runtime: 运行时上下文 """ if not self.config.is_enabled(): return try: # 在后台线程中执行,完全不阻塞主流程 thread = threading.Thread( target=self._trigger_augmentation_sync, args=(state, runtime), daemon=True, ) thread.start() except Exception as e: logger.error(f"Error in Mem0Middleware.after_agent: {e}") async def aafter_agent(self, state: AgentState, runtime: Runtime) -> None: """Agent 执行后:触发记忆增强(异步版本) 使用后台线程执行,避免阻塞事件循环 Args: state: Agent 状态 runtime: 运行时上下文 """ if not self.config.is_enabled(): return try: # 在后台线程中执行,完全不阻塞事件循环 thread = threading.Thread( target=self._trigger_augmentation_sync, args=(state, runtime), daemon=True, ) thread.start() except Exception as e: logger.error(f"Error in Mem0Middleware.aafter_agent: {e}") def _trigger_augmentation_sync(self, state: AgentState, runtime: Runtime) -> None: """触发记忆增强任务(同步版本,在线程中执行) 从对话中提取信息并存储到 Mem0(用户级别,跨会话) Args: state: Agent 状态 runtime: 运行时上下文 """ try: # 获取 attribution 参数 user_id, agent_id = self.config.get_attribution_tuple() # 提取用户查询和 Agent 响应 user_query = self._extract_user_query(state) agent_response = self._extract_agent_response(state) # 将对话作为记忆存储(用户级别) if user_query and agent_response: conversation_text = f"User: {user_query}\nAssistant: {agent_response}" # 在新的事件循环中运行异步代码(因为在线程中) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: loop.run_until_complete( self.mem0_manager.add_memory( text=conversation_text, user_id=user_id, agent_id=agent_id, metadata={"type": "conversation"}, config=self.config, ) ) logger.debug(f"Stored conversation as memory for user={user_id}, agent={agent_id}") finally: loop.close() except Exception as e: logger.error(f"Error in _trigger_augmentation_sync: {e}") async def _trigger_augmentation_async(self, state: AgentState, runtime: Runtime) -> None: """触发记忆增强任务 从对话中提取信息并存储到 Mem0(用户级别,跨会话) Args: state: Agent 状态 runtime: 运行时上下文 """ try: # 获取 attribution 参数 user_id, agent_id = self.config.get_attribution_tuple() # 提取用户查询和 Agent 响应 user_query = self._extract_user_query(state) agent_response = self._extract_agent_response(state) # 将对话作为记忆存储(用户级别) if user_query and agent_response: conversation_text = f"User: {user_query}\nAssistant: {agent_response}" await self.mem0_manager.add_memory( text=conversation_text, user_id=user_id, agent_id=agent_id, metadata={"type": "conversation"}, config=self.config, ) logger.debug(f"Stored conversation as memory for user={user_id}, agent={agent_id}") except Exception as e: logger.error(f"Error in _trigger_augmentation_async: {e}") def wrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], Any], ) -> Any: """包装模型调用,注入记忆到系统提示词(同步版本) Args: request: 模型请求 handler: 原始处理器 Returns: 模型响应 """ # 从 agent_config 获取已拼接好的记忆 prompt memory_prompt = self.agent_config._mem0_context if not memory_prompt: return handler(request) # 获取当前系统提示词 current_system_prompt = "" if request.system_message: current_system_prompt = request.system_message.content if hasattr(request.system_message, "content") else str(request.system_message) # 修改系统提示词 new_system_prompt = current_system_prompt + memory_prompt return handler(request.override(system_prompt=new_system_prompt)) async def awrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], Any], ) -> Any: """包装模型调用,注入记忆到系统提示词(异步版本) Args: request: 模型请求 handler: 原始处理器 Returns: 模型响应 """ # 从 agent_config 获取已拼接好的记忆 prompt memory_prompt = self.agent_config._mem0_context if not memory_prompt: return await handler(request) # 获取当前系统提示词 current_system_prompt = "" if request.system_message: current_system_prompt = request.system_message.content if hasattr(request.system_message, "content") else str(request.system_message) # 修改系统提示词 new_system_prompt = current_system_prompt + memory_prompt return await handler(request.override(system_prompt=new_system_prompt)) def create_mem0_middleware( bot_id: str, user_identifier: str, session_id: str, agent_config: "AgentConfig", enabled: bool = True, semantic_search_top_k: int = 20, mem0_manager: Optional[Mem0Manager] = None, llm_instance: Optional["BaseChatModel"] = None, ) -> Optional[Mem0Middleware]: """创建 Mem0Middleware 的工厂函数 Args: bot_id: Bot ID user_identifier: 用户标识 session_id: 会话 ID agent_config: AgentConfig 实例,用于中间件间传递数据 enabled: 是否启用 semantic_search_top_k: 语义搜索返回数量 mem0_manager: Mem0Manager 实例(如果为 None,使用全局实例) llm_instance: LangChain LLM 实例(用于 Mem0 的记忆提取和增强) Returns: Mem0Middleware 实例或 None """ if not enabled: return None # 获取或使用提供的 manager manager = mem0_manager or get_mem0_manager() # 创建配置 config = Mem0Config( enabled=True, user_id=user_identifier, agent_id=bot_id, session_id=session_id, semantic_search_top_k=semantic_search_top_k, llm_instance=llm_instance, ) return Mem0Middleware(mem0_manager=manager, config=config, agent_config=agent_config)