""" Memori Agent 中间件 实现记忆召回和存储的 AgentMiddleware """ import asyncio import logging from typing import Any, Dict, List, Optional from langchain.agents.middleware import AgentMiddleware, AgentState from langgraph.runtime import Runtime from .memori_config import MemoriConfig from .memori_manager import MemoriManager, get_memori_manager logger = logging.getLogger("app") class MemoriMiddleware(AgentMiddleware): """ Memori 记忆中间件 功能: 1. before_agent: 召回相关记忆并注入到上下文 2. after_agent: 后台异步提取和存储新记忆 """ def __init__( self, memori_manager: MemoriManager, config: MemoriConfig, ): """初始化 MemoriMiddleware Args: memori_manager: MemoriManager 实例 config: MemoriConfig 配置 """ self.memori_manager = memori_manager self.config = config def _extract_user_query(self, state: AgentState) -> str: """从状态中提取用户查询 Args: state: Agent 状态 Returns: 用户查询文本 """ messages = state.get("messages", []) if not messages: return "" # 获取最后一条消息 last_message = messages[-1] # 尝试获取内容 content = getattr(last_message, "content", None) if content is None: content = last_message.get("content", "") if isinstance(last_message, dict) else "" return str(content) if content else "" def _format_memories(self, memories: List[Dict[str, Any]]) -> str: """格式化记忆列表为文本 Args: memories: 记忆列表 Returns: 格式化的记忆文本 """ if not memories: return "" lines = [] for i, memory in enumerate(memories, 1): content = memory.get("content", "") similarity = memory.get("similarity", 0.0) fact_type = memory.get("fact_type", "fact") # 添加相似度分数(调试用) lines.append(f"{i}. [{fact_type}] {content}") return "\n".join(lines) def _inject_memory_context(self, state: AgentState, memory_text: str) -> AgentState: """将记忆上下文注入到状态中 Args: state: 原始状态 memory_text: 记忆文本 Returns: 更新后的状态 """ if not memory_text or not self.config.inject_memory_to_system_prompt: return state # 生成记忆提示 memory_prompt = self.config.get_memory_prompt([memory_text]) # 检查是否有系统消息 messages = state.get("messages", []) if not messages: return state # 在系统消息后添加记忆上下文 from langchain_core.messages import SystemMessage # 查找系统消息 system_message = None for msg in messages: if hasattr(msg, "type") and msg.type == "system": system_message = msg break elif isinstance(msg, dict) and msg.get("role") == "system": system_message = msg break if system_message: # 修改现有系统消息 if hasattr(system_message, "content"): original_content = system_message.content system_message.content = original_content + memory_prompt elif isinstance(system_message, dict): original_content = system_message.get("content", "") system_message["content"] = original_content + memory_prompt else: # 添加新的系统消息 new_messages = list(messages) new_messages.insert(0, SystemMessage(content=memory_prompt)) state = {**state, "messages": new_messages} return state def before_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None: """Agent 执行前:召回相关记忆(同步版本) Args: state: Agent 状态 runtime: 运行时上下文 Returns: 更新后的状态或 None """ if not self.config.is_enabled(): return None try: # 提取用户查询 query = self._extract_user_query(state) if not query: return None # 获取 attribution 参数 entity_id, process_id = self.config.get_attribution_tuple() session_id = self.config.session_id or runtime.config.get("configurable", {}).get("thread_id", "default") # 召回记忆(同步方式 - 在后台任务中执行) memories = asyncio.run(self._recall_memories_async(query, entity_id, process_id, session_id)) if memories: # 格式化记忆 memory_text = self._format_memories(memories) # 注入到状态 updated_state = self._inject_memory_context(state, memory_text) logger.info(f"Injected {len(memories)} memories into context") return updated_state return None except Exception as e: logger.error(f"Error in MemoriMiddleware.before_agent: {e}") return None async def abefore_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None: """Agent 执行前:召回相关记忆(异步版本) Args: state: Agent 状态 runtime: 运行时上下文 Returns: 更新后的状态或 None """ if not self.config.is_enabled(): return None try: # 提取用户查询 query = self._extract_user_query(state) if not query: logger.debug("No user query found, skipping memory recall") return None # 获取 attribution 参数 entity_id, process_id = self.config.get_attribution_tuple() session_id = self.config.session_id or runtime.config.get("configurable", {}).get("thread_id", "default") # 召回记忆 memories = await self._recall_memories_async(query, entity_id, process_id, session_id) if memories: # 格式化记忆 memory_text = self._format_memories(memories) # 注入到状态 updated_state = self._inject_memory_context(state, memory_text) logger.info(f"Injected {len(memories)} memories into context (similarity > {self.config.semantic_search_threshold})") return updated_state return None except Exception as e: logger.error(f"Error in MemoriMiddleware.abefore_agent: {e}") return None async def _recall_memories_async( self, query: str, entity_id: str, process_id: str, session_id: str ) -> List[Dict[str, Any]]: """异步召回记忆 Args: query: 查询文本 entity_id: 实体 ID process_id: 进程 ID session_id: 会话 ID Returns: 记忆列表 """ return await self.memori_manager.recall_memories( query=query, entity_id=entity_id, process_id=process_id, session_id=session_id, config=self.config, ) def after_agent(self, state: AgentState, runtime: Runtime) -> None: """Agent 执行后:触发记忆增强(同步版本) Args: state: Agent 状态 runtime: 运行时上下文 """ if not self.config.is_enabled() or not self.config.augmentation_enabled: return try: # 触发后台增强任务 asyncio.create_task(self._trigger_augmentation_async(state, runtime)) except Exception as e: logger.error(f"Error in MemoriMiddleware.after_agent: {e}") async def aafter_agent(self, state: AgentState, runtime: Runtime) -> None: """Agent 执行后:触发记忆增强(异步版本) 注意:Memori 的增强会自动在后台执行,这里主要是记录日志 Args: state: Agent 状态 runtime: 运行时上下文 """ if not self.config.is_enabled() or not self.config.augmentation_enabled: return try: # 如果配置了等待超时,则等待增强完成 if self.config.augmentation_wait_timeout is not None: entity_id, process_id = self.config.get_attribution_tuple() session_id = self.config.session_id or runtime.config.get("configurable", {}).get("thread_id", "default") await self.memori_manager.wait_for_augmentation( entity_id=entity_id, process_id=process_id, session_id=session_id, timeout=self.config.augmentation_wait_timeout, ) except Exception as e: logger.error(f"Error in MemoriMiddleware.aafter_agent: {e}") async def _trigger_augmentation_async(self, state: AgentState, runtime: Runtime) -> None: """触发记忆增强任务 注意:Memori 的 LLM 客户端注册后会自动捕获对话并进行增强, 这里不需要手动触发,只是确保会话正确设置 Args: state: Agent 状态 runtime: 运行时上下文 """ # Memori 的增强是自动的,这里主要是确保配置正确 # 如果需要手动触发,可以在这里实现 pass def create_memori_middleware( bot_id: str, user_identifier: str, session_id: str, enabled: bool = True, semantic_search_top_k: int = 5, semantic_search_threshold: float = 0.7, memori_manager: Optional[MemoriManager] = None, ) -> Optional[MemoriMiddleware]: """创建 MemoriMiddleware 的工厂函数 Args: bot_id: Bot ID user_identifier: 用户标识 session_id: 会话 ID enabled: 是否启用 semantic_search_top_k: 语义搜索返回数量 semantic_search_threshold: 语义搜索相似度阈值 memori_manager: MemoriManager 实例(如果为 None,使用全局实例) Returns: MemoriMiddleware 实例或 None """ if not enabled: return None # 获取或使用提供的 manager manager = memori_manager or get_memori_manager() # 创建配置 config = MemoriConfig( enabled=True, entity_id=user_identifier, process_id=bot_id, session_id=session_id, semantic_search_top_k=semantic_search_top_k, semantic_search_threshold=semantic_search_threshold, ) return MemoriMiddleware(memori_manager=manager, config=config)