""" Mem0 agent middleware. Implements memory recall and storage for AgentMiddleware. """ import asyncio import logging import threading from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from langchain.agents.middleware import AgentMiddleware, AgentState, ModelRequest from langgraph.runtime import Runtime from .mem0_config import Mem0Config from .mem0_manager import Mem0Manager, get_mem0_manager logger = logging.getLogger("app") # Avoid circular imports if TYPE_CHECKING: from langchain_core.language_models import BaseChatModel class Mem0Middleware(AgentMiddleware): """ Mem0 memory middleware. Features: 1. before_agent: recall relevant memories and inject them into context 2. after_agent: asynchronously extract and store new memories in the background """ def __init__( self, mem0_manager: Mem0Manager, config: Mem0Config, agent_config: "AgentConfig", ): """Initialize Mem0Middleware. Args: mem0_manager: Mem0Manager instance config: Mem0Config configuration agent_config: AgentConfig instance used to pass data between middlewares """ self.mem0_manager = mem0_manager self.config = config self.agent_config = agent_config def _extract_user_query(self, state: AgentState) -> str: """Extract the user query from state, using the last HumanMessage. Args: state: Agent state Returns: User query text """ from langchain_core.messages import HumanMessage messages = state.get("messages", []) if not messages: return "" # Find the last HumanMessage for msg in reversed(messages): if isinstance(msg, HumanMessage): return str(msg.content) if msg.content else "" return "" def _extract_agent_response(self, state: AgentState) -> str: """Extract the agent response from state, using the last AIMessage. Args: state: Agent state Returns: Agent response text """ from langchain_core.messages import AIMessage messages = state.get("messages", []) if not messages: return "" # Find the last AIMessage for msg in reversed(messages): if isinstance(msg, AIMessage): return str(msg.content) if msg.content else "" return "" def _format_memories(self, memories: List[Dict[str, Any]]) -> str: """Format a list of memories as text. Args: memories: List of memories Returns: Formatted memory text """ if not memories: return "" lines = [] for i, memory in enumerate(memories, 1): content = memory.get("content", "") fact_type = memory.get("fact_type", "fact") lines.append(f"{i}. [{fact_type}] {content}") return "\n".join(lines) def before_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None: """Recall relevant memories before agent execution, synchronous version. Args: state: Agent state runtime: Runtime context Returns: Updated state or None """ if not self.config.is_enabled(): return None try: # Extract the user query query = self._extract_user_query(state) if not query: return None # Get attribution parameters user_id, agent_id = self.config.get_attribution_tuple() session_id = self.config.session_id or runtime.config.get("configurable", {}).get("thread_id", "default") # Recall memories synchronously by running the async method memories = asyncio.run(self._recall_memories_async(query, user_id, agent_id, session_id)) if memories: # Format memories and append them to the memory prompt memory_text = self._format_memories(memories) memory_prompt = self.config.get_memory_prompt([memory_text]) self.agent_config._mem0_context = memory_prompt logger.info(f"Recalled {len(memories)} memories for context") else: self.agent_config._mem0_context = None return state except Exception as e: logger.error(f"Error in Mem0Middleware.before_agent: {e}") return None async def abefore_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None: """Recall relevant memories before agent execution, asynchronous version. Args: state: Agent state runtime: Runtime context Returns: Updated state or None """ if not self.config.is_enabled(): return None try: # Extract the user query query = self._extract_user_query(state) if not query: logger.debug("No user query found, skipping memory recall") return None # Get attribution parameters user_id, agent_id = self.config.get_attribution_tuple() # Recall user-level memories across sessions memories = await self._recall_memories_async(query, user_id, agent_id) if memories: # Format memories and append them to the memory prompt memory_text = self._format_memories(memories) memory_prompt = self.config.get_memory_prompt([memory_text]) self.agent_config._mem0_context = memory_prompt logger.info(f"Recalled {len(memories)} memories for context") else: self.agent_config._mem0_context = None return state except Exception as e: logger.error(f"Error in Mem0Middleware.abefore_agent: {e}") return None async def _recall_memories_async( self, query: str, user_id: str, agent_id: str ) -> List[Dict[str, Any]]: """Recall memories asynchronously. Args: query: Query text user_id: User ID agent_id: Agent/Bot ID Returns: List of memories """ return await self.mem0_manager.recall_memories( query=query, user_id=user_id, agent_id=agent_id, config=self.config, ) def after_agent(self, state: AgentState, runtime: Runtime) -> None: """Trigger memory augmentation after agent execution, synchronous version. Runs in a background thread to avoid blocking the main flow. Args: state: Agent state runtime: Runtime context """ if not self.config.is_enabled(): return try: # Run in a background thread so the main flow is never blocked thread = threading.Thread( target=self._trigger_augmentation_sync, args=(state, runtime), daemon=True, ) thread.start() except Exception as e: logger.error(f"Error in Mem0Middleware.after_agent: {e}") async def aafter_agent(self, state: AgentState, runtime: Runtime) -> None: """Trigger memory augmentation after agent execution, asynchronous version. Runs in a background thread to avoid blocking the event loop. Args: state: Agent state runtime: Runtime context """ if not self.config.is_enabled(): return try: # Run in a background thread so the event loop is never blocked thread = threading.Thread( target=self._trigger_augmentation_sync, args=(state, runtime), daemon=True, ) thread.start() except Exception as e: logger.error(f"Error in Mem0Middleware.aafter_agent: {e}") def _trigger_augmentation_sync(self, state: AgentState, runtime: Runtime) -> None: """Trigger the memory augmentation task, synchronous version executed in a thread. Extracts information from the conversation and stores it in Mem0 at the user level across sessions. Args: state: Agent state runtime: Runtime context """ try: # Get attribution parameters user_id, agent_id = self.config.get_attribution_tuple() # Extract the user query and agent response user_query = self._extract_user_query(state) agent_response = self._extract_agent_response(state) # Store the conversation as user-level memory if user_query and agent_response: conversation_text = f"User: {user_query}\nAssistant: {agent_response}" # Run async code in a new event loop because this runs in a thread loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: loop.run_until_complete( self.mem0_manager.add_memory( text=conversation_text, user_id=user_id, agent_id=agent_id, metadata={"type": "conversation"}, config=self.config, ) ) logger.debug(f"Stored conversation as memory for user={user_id}, agent={agent_id}") finally: loop.close() except Exception as e: logger.error(f"Error in _trigger_augmentation_sync: {e}") async def _trigger_augmentation_async(self, state: AgentState, runtime: Runtime) -> None: """Trigger the memory augmentation task. Extracts information from the conversation and stores it in Mem0 at the user level across sessions. Args: state: Agent state runtime: Runtime context """ try: # Get attribution parameters user_id, agent_id = self.config.get_attribution_tuple() # Extract the user query and agent response user_query = self._extract_user_query(state) agent_response = self._extract_agent_response(state) # Store the conversation as user-level memory if user_query and agent_response: conversation_text = f"User: {user_query}\nAssistant: {agent_response}" await self.mem0_manager.add_memory( text=conversation_text, user_id=user_id, agent_id=agent_id, metadata={"type": "conversation"}, config=self.config, ) logger.debug(f"Stored conversation as memory for user={user_id}, agent={agent_id}") except Exception as e: logger.error(f"Error in _trigger_augmentation_async: {e}") def wrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], Any], ) -> Any: """Wrap model calls and inject memories into the system prompt, synchronous version. Args: request: Model request handler: Original handler Returns: Model response """ # Get the assembled memory prompt from agent_config memory_prompt = self.agent_config._mem0_context if not memory_prompt: return handler(request) # Get the current system prompt current_system_prompt = "" if request.system_message: content = request.system_message.content if hasattr(request.system_message, "content") else str(request.system_message) # content may be a list or a string; make sure it is converted to str if isinstance(content, list): current_system_prompt = "\n".join(str(item) for item in content) else: current_system_prompt = str(content) if content else "" # Make sure memory_prompt is also a string if isinstance(memory_prompt, list): memory_prompt = "\n".join(str(item) for item in memory_prompt) else: memory_prompt = str(memory_prompt) if memory_prompt else "" # Update the system prompt new_system_prompt = current_system_prompt + memory_prompt return handler(request.override(system_prompt=new_system_prompt)) async def awrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], Any], ) -> Any: """Wrap model calls and inject memories into the system prompt, asynchronous version. Args: request: Model request handler: Original handler Returns: Model response """ # Get the assembled memory prompt from agent_config memory_prompt = self.agent_config._mem0_context if not memory_prompt: return await handler(request) # Get the current system prompt current_system_prompt = "" if request.system_message: content = request.system_message.content if hasattr(request.system_message, "content") else str(request.system_message) # content may be a list or a string; make sure it is converted to str if isinstance(content, list): current_system_prompt = "\n".join(str(item) for item in content) else: current_system_prompt = str(content) if content else "" # Make sure memory_prompt is also a string if isinstance(memory_prompt, list): memory_prompt = "\n".join(str(item) for item in memory_prompt) else: memory_prompt = str(memory_prompt) if memory_prompt else "" # Update the system prompt new_system_prompt = current_system_prompt + memory_prompt return await handler(request.override(system_prompt=new_system_prompt)) def create_mem0_middleware( bot_id: str, user_identifier: str, session_id: str, agent_config: "AgentConfig", enabled: bool = True, semantic_search_top_k: int = 20, mem0_manager: Optional[Mem0Manager] = None, llm_instance: Optional["BaseChatModel"] = None, ) -> Optional[Mem0Middleware]: """Factory function for creating Mem0Middleware. Args: bot_id: Bot ID user_identifier: User identifier session_id: Session ID agent_config: AgentConfig instance used to pass data between middlewares enabled: Whether the middleware is enabled semantic_search_top_k: Number of semantic search results to return mem0_manager: Mem0Manager instance; uses the global instance if None llm_instance: LangChain LLM instance used for Mem0 memory extraction and augmentation Returns: Mem0Middleware instance or None """ if not enabled: return None # Use the provided manager or get the default one manager = mem0_manager or get_mem0_manager() # Create configuration config = Mem0Config( enabled=True, user_id=user_identifier, agent_id=bot_id, session_id=session_id, semantic_search_top_k=semantic_search_top_k, llm_instance=llm_instance, ) return Mem0Middleware(mem0_manager=manager, config=config, agent_config=agent_config)