""" Mem0 connection and instance manager. Responsible for creating, caching, and managing the lifecycle of Mem0 client instances. """ import logging import asyncio import threading import concurrent.futures from typing import Any, Dict, List, Optional, Literal from collections import OrderedDict from embedding.manager import GlobalModelManager, get_model_manager import json_repair from psycopg2 import pool from utils.settings import ( MEM0_POOL_SIZE ) from .mem0_config import Mem0Config logger = logging.getLogger("app") # ============================================================================ # Custom embedding class that uses the project's existing GlobalModelManager # Avoid loading the same model more than once # ============================================================================ class CustomMem0Embedding: """ Custom Mem0 embedding class that directly uses the project's existing GlobalModelManager This prevents Mem0 from loading the same model again and saves memory """ _model = None # Class variable caching the model instance _lock = threading.Lock() # Thread-safe lock _executor = None # Thread pool executor def __init__(self, config: Optional[Any] = None): """Initialize the custom embedding.""" # Create a simple config object compatible with Mem0 telemetry code if config is None: config = type('Config', (), {'embedding_dims': 384})() self.config = config @property def embedding_dims(self): """Get the embedding dimension.""" return 384 # Dimension of gte-tiny def _get_model_sync(self): """Synchronously get the model without using asyncio.run().""" # First try to get an already-loaded model from the manager manager = get_model_manager() model = manager.get_model_sync() if model is not None: # Cache the model CustomMem0Embedding._model = model return model # If the model is not loaded, run async initialization in a thread pool if CustomMem0Embedding._executor is None: CustomMem0Embedding._executor = concurrent.futures.ThreadPoolExecutor( max_workers=1, thread_name_prefix="mem0_embed" ) # Run async code in a dedicated thread def run_async_in_thread(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: result = loop.run_until_complete(manager.get_model()) return result finally: loop.close() future = CustomMem0Embedding._executor.submit(run_async_in_thread) model = future.result(timeout=30) # 30-second timeout # Cache the model CustomMem0Embedding._model = model return model def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None): """ Get the embedding vector for text (synchronous method for Mem0) Args: text: Text to embed (string or list) memory_action: Memory operation type (add/search/update), currently unused Returns: list: Embedding vector """ # Retrieve the model in a thread-safe manner if CustomMem0Embedding._model is None: with CustomMem0Embedding._lock: if CustomMem0Embedding._model is None: self._get_model_sync() model = CustomMem0Embedding._model embeddings = model.encode(text, convert_to_numpy=True) return embeddings.tolist() # Monkey patch: replace mem0's remove_code_blocks with json_repair def _remove_code_blocks_with_repair(content: str) -> str: """ Replace mem0's remove_code_blocks function with json_repair json_repair.loads automatically handles: - Removing code block markers (```json, ``` and similar) - Repairing malformed JSON (such as trailing commas, comments, and single quotes) """ import re content_stripped = content.strip() try: # json_repair.loads automatically strips code blocks and repairs JSON result = json_repair.loads(content_stripped) if isinstance(result, (dict, list)): import json return json.dumps(result, ensure_ascii=False) # If an empty string is returned for non-JSON input, fall back to the original content if result == "" and content_stripped != "": # Try simple code block removal as a fallback pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$" match = re.match(pattern, content_stripped) if match: return match.group(1).strip() return content_stripped return str(result) except Exception: # If parsing fails, try simple code block removal as a fallback pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$" match = re.match(pattern, content_stripped) if match: return match.group(1).strip() return content_stripped # Apply the monkey patch before or after importing mem0 try: import sys import mem0.memory.utils as mem0_utils mem0_utils.remove_code_blocks = _remove_code_blocks_with_repair # If mem0.memory.main has already been imported, patch its local reference as well if 'mem0.memory.main' in sys.modules: import mem0.memory.main mem0.memory.main.remove_code_blocks = _remove_code_blocks_with_repair logger.info("Successfully patched mem0.memory.main.remove_code_blocks with json_repair") else: logger.info("Successfully patched mem0.memory.utils.remove_code_blocks with json_repair") except ImportError: # mem0 has not been imported yet; the patch will take effect on first import pass except Exception as e: logger.warning(f"Failed to patch mem0 remove_code_blocks: {e}") class Mem0Manager: """ Mem0 connection and instance manager Main responsibilities: 1. Manage creation and caching of Mem0 instances 2. Support multi-tenant isolation (user_id + agent_id) 3. Use a shared synchronous connection pool (provided by DBPoolManager) 4. Provide memory recall and storage interfaces """ def __init__( self, sync_pool: Optional[pool.SimpleConnectionPool] = None, ): """Initialize Mem0Manager. Args: sync_pool: Shared PostgreSQL synchronous connection pool from DBPoolManager """ self._sync_pool = sync_pool # Use OrderedDict to implement an LRU cache, keeping at most 50 instances self._instances: OrderedDict[str, Any] = OrderedDict() self._max_instances = MEM0_POOL_SIZE/2 # Maximum number of cached instances self._initialized = False # Limit concurrent Mem0 operations to avoid exhausting the connection pool self._semaphore = asyncio.Semaphore(max(MEM0_POOL_SIZE - 2, 1)) async def initialize(self) -> None: """Initialize Mem0Manager. Create the database schema if needed. """ if self._initialized: return logger.info("Initializing Mem0Manager...") try: # Mem0 creates the schema automatically; only connection validation is needed here if self._sync_pool: logger.info("Mem0Manager initialized successfully") else: logger.warning("No database configuration provided for Mem0") self._initialized = True except Exception as e: logger.error(f"Failed to initialize Mem0Manager: {e}") # Do not raise; allow the system to run without Mem0 def _get_connection_pool(self) -> Optional[pool.SimpleConnectionPool]: """Get the synchronous database connection pool required by Mem0. Returns: psycopg2.pool connection pool """ return self._sync_pool def _cleanup_mem0_instance(self, mem0_instance: Any) -> None: """Clean up a Mem0 instance and release its database connection. Mem0's PGVector implementation acquires and holds a connection during initialization, only returning it in __del__. Python GC does not guarantee __del__ will run immediately, which can exhaust the connection pool. This method releases the connection explicitly. Args: mem0_instance: Mem0 Memory instance """ try: # Mem0 Memory instances have a vector_store attribute of type PGVector if hasattr(mem0_instance, 'vector_store'): vector_store = mem0_instance.vector_store # PGVector has conn and connection_pool attributes if hasattr(vector_store, 'conn') and hasattr(vector_store, 'connection_pool'): if vector_store.conn is not None and vector_store.connection_pool is not None: try: # Close the cursor first if hasattr(vector_store, 'cur') and vector_store.cur: vector_store.cur.close() vector_store.cur = None # Return the connection to the pool vector_store.connection_pool.putconn(vector_store.conn) # Mark as cleaned up to prevent __del__ from releasing it again vector_store.conn = None logger.debug("Successfully released Mem0 database connection back to pool") except Exception as e: logger.warning(f"Error releasing Mem0 connection: {e}") except Exception as e: logger.warning(f"Error cleaning up Mem0 instance: {e}") def _ensure_connection(self, mem0_instance: Any) -> None: """Ensure a Mem0 instance has a database connection before use. If the connection was released by _release_connection, reacquire it from the pool. Args: mem0_instance: Mem0 Memory instance """ try: if hasattr(mem0_instance, 'vector_store'): vs = mem0_instance.vector_store if hasattr(vs, 'conn') and vs.conn is None and self._sync_pool: vs.conn = self._sync_pool.getconn() vs.cur = vs.conn.cursor() # Ensure the connection_pool reference exists for later return if hasattr(vs, 'connection_pool') and vs.connection_pool is None: vs.connection_pool = self._sync_pool logger.debug("Re-acquired Mem0 database connection from pool") except Exception as e: logger.warning(f"Error ensuring Mem0 connection: {e}") raise def _release_connection(self, mem0_instance: Any) -> None: """Release the connection back to the pool after use. Unlike _cleanup_mem0_instance, this keeps the connection_pool reference so _ensure_connection can reacquire a connection later. Args: mem0_instance: Mem0 Memory instance """ try: if hasattr(mem0_instance, 'vector_store'): vs = mem0_instance.vector_store if hasattr(vs, 'conn') and vs.conn is not None: if hasattr(vs, 'cur') and vs.cur: vs.cur.close() vs.cur = None if hasattr(vs, 'connection_pool') and vs.connection_pool is not None: vs.connection_pool.putconn(vs.conn) vs.conn = None logger.debug("Released Mem0 database connection back to pool") except Exception as e: logger.warning(f"Error releasing Mem0 connection: {e}") async def get_mem0( self, user_id: str, agent_id: str, session_id: str, config: Optional[Mem0Config] = None, ) -> Any: """Get or create a Mem0 instance. Args: user_id: User ID (mapped to entity_id) agent_id: Agent/Bot ID (mapped to process_id) session_id: Session ID config: Mem0 configuration Returns: Mem0 instance """ # The cache key includes the LLM instance ID so different LLMs use different instances llm_suffix = "" if config and config.llm_instance is not None: llm_suffix = f":{id(config.llm_instance)}" cache_key = f"{user_id}:{agent_id}{llm_suffix}" # Check the cache and move hits to the end to mark recent use if cache_key in self._instances: self._instances.move_to_end(cache_key) return self._instances[cache_key] # Check cache size and evict the oldest entry if needed if len(self._instances) >= self._max_instances: removed_key, removed_instance = self._instances.popitem(last=False) # Explicitly release the connection instead of waiting for GC and risking pool exhaustion self._cleanup_mem0_instance(removed_instance) logger.debug(f"Mem0 instance cache full, removed and cleaned: {removed_key}") # Create a new instance mem0_instance = await self._create_mem0_instance( user_id=user_id, agent_id=agent_id, session_id=session_id, config=config, ) # Cache the instance; new entries are automatically appended to the end self._instances[cache_key] = mem0_instance return mem0_instance async def _create_mem0_instance( self, user_id: str, agent_id: str, session_id: str, config: Optional[Mem0Config] = None, ) -> Any: """Create a new Mem0 instance. Args: user_id: User ID agent_id: Agent/Bot ID session_id: Session ID config: Mem0 configuration, including the LLM instance Returns: Mem0 Memory instance """ try: from mem0 import Memory except ImportError: logger.error("mem0 package not installed") raise RuntimeError("mem0 package is required but not installed") # Get the synchronous connection pool connection_pool = self._get_connection_pool() if not connection_pool: raise ValueError("Database connection pool not available") # Create a custom embedder that reuses the shared model to avoid duplicate loading custom_embedder = CustomMem0Embedding() # Configure Mem0 to use Pgvector # Note: use huggingface_base_url here to bypass local model loading # Set a dummy base_url so HuggingFaceEmbedding does not load SentenceTransformer config_dict = { "vector_store": { "provider": "pgvector", "config": { "connection_pool": connection_pool, "collection_name": f"mem0_{agent_id}".replace("-", "_")[:50], # Isolate by agent_id "embedding_model_dims": 384, # Dimension of paraphrase-multilingual-MiniLM-L12-v2 } }, # Use huggingface_base_url to bypass model loading; it will later be replaced with the custom embedder "embedder": { "provider": "huggingface", "config": { "huggingface_base_url": "http://dummy-url-that-will-be-replaced", "api_key": "dummy-key" # Placeholder to prevent OpenAI client validation failure } } } # Add a custom memory extraction prompt if config is provided if config is not None: config_dict["custom_fact_extraction_prompt"] = await config.get_custom_fact_extraction_prompt_async() # Add LangChain LLM configuration if provided if config and config.llm_instance is not None: config_dict["llm"] = { "provider": "langchain", "config": {"model": config.llm_instance} } logger.info( f"Configured LangChain LLM for Mem0: {type(config.llm_instance).__name__}" ) else: # If no LLM is provided, use the default OpenAI configuration # Mem0's LLM is used to extract memory facts from utils.settings import MASTERKEY, BACKEND_HOST import os llm_api_key = os.environ.get("OPENAI_API_KEY", "") or MASTERKEY config_dict["llm"] = { "provider": "openai", "config": { "model": "gpt-4o-mini", "api_key": llm_api_key, "openai_base_url": BACKEND_HOST # Use the custom backend } } # Create the Mem0 instance mem = Memory.from_config(config_dict) logger.debug(f"Original embedder type: {type(mem.embedding_model).__name__}") logger.debug(f"Original embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}") # Replace with the custom embedder to reuse the model already loaded by the project # This prevents Mem0 from loading the model again mem.embedding_model = custom_embedder logger.debug(f"Replaced embedder type: {type(mem.embedding_model).__name__}") logger.debug(f"Replaced embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}") logger.info("Replaced Mem0 embedder with CustomMem0Embedding (reusing existing model)") logger.info( f"Created Mem0 instance: user={user_id}, agent={agent_id}" ) # PGVector calls getconn during creation; release immediately to avoid holding the connection long-term self._release_connection(mem) return mem async def recall_memories( self, query: str, user_id: str, agent_id: str, config: Optional[Mem0Config] = None, ) -> List[Dict[str, Any]]: """Recall relevant memories at the user level, shared across sessions. Args: query: Query text user_id: User ID agent_id: Agent/Bot ID config: Mem0 configuration Returns: List of memories, each containing fields such as content and similarity """ try: async with self._semaphore: mem = await self.get_mem0(user_id, agent_id, "default", config) self._ensure_connection(mem) try: # Call search for semantic retrieval, filtering with the agent_id parameter limit = config.semantic_search_top_k if config else 20 results = mem.search( query=query, limit=limit, user_id=user_id, agent_id=agent_id, ) finally: self._release_connection(mem) # Convert to a normalized format memories = [] for result in results["results"]: # Mem0 may return results as strings or dictionaries content = result.get("memory", "") score = result.get("score", 0.0) result_metadata = result.get("metadata", {}) memory = { "content": content, "similarity": score, "metadata": result_metadata, "fact_type": result_metadata.get("category", "fact"), } memories.append(memory) logger.info(f"Recalled {len(memories)} memories for user={user_id}, query: {query[:50]}...") return memories except Exception as e: logger.error(f"Failed to recall memories: {e}") return [] async def add_memory( self, text: str, user_id: str, agent_id: str, metadata: Optional[Dict[str, Any]] = None, config: Optional[Mem0Config] = None, ) -> Dict[str, Any]: """Add a new memory at the user level, shared across sessions. Args: text: Memory text user_id: User ID agent_id: Agent/Bot ID metadata: Additional metadata config: Mem0 configuration, including the LLM instance for memory extraction Returns: Result of the added memory """ try: async with self._semaphore: mem = await self.get_mem0(user_id, agent_id, "default", config) self._ensure_connection(mem) try: # Add the memory using the agent_id parameter result = mem.add( text, user_id=user_id, agent_id=agent_id, metadata=metadata or {} ) finally: self._release_connection(mem) logger.info(f"Added memory for user={user_id}, agent={agent_id}: {result}") return result except Exception as e: logger.error(f"Failed to add memory: {e}") return {} def _extract_memories_from_response(self, response: Any) -> List[Dict[str, Any]]: """Extract the memory list from a Mem0 get_all() response. Mem0 get_all() may return one of two formats: 1. New version: {"results": [...]} 2. Old version: a list directly Args: response: Response returned by Mem0 get_all() Returns: List of memories """ if isinstance(response, dict) and "results" in response: return response["results"] elif isinstance(response, list): return response else: logger.warning(f"Unexpected response format from mem.get_all(): {type(response)}") return [] def _check_agent_id_match(self, memory: Dict[str, Any], agent_id: str) -> bool: """Check whether a memory belongs to the specified agent. In Mem0's memory structure, agent_id may appear in two places: 1. Top level: memory["agent_id"] 2. Inside metadata: memory["metadata"]["agent_id"] Args: memory: Memory dictionary agent_id: Agent ID to match Returns: Whether it matches """ if not isinstance(memory, dict): return False # First check the top-level agent_id field (new format) if memory.get("agent_id") == agent_id: return True # Then check agent_id inside metadata (old format) metadata = memory.get("metadata", {}) if isinstance(metadata, dict) and metadata.get("agent_id") == agent_id: return True return False async def get_all_memories( self, user_id: str, agent_id: str, ) -> List[Dict[str, Any]]: """Get all memories for a user at the user level. Args: user_id: User ID agent_id: Agent/Bot ID Returns: List of memories """ try: async with self._semaphore: mem = await self.get_mem0(user_id, agent_id, "default") self._ensure_connection(mem) try: # Get all memories response = mem.get_all(user_id=user_id) finally: self._release_connection(mem) # Extract the memory list from the response memories = self._extract_memories_from_response(response) # Filter by agent_id, which is stored at the top level rather than in metadata filtered_memories = [ m for m in memories if self._check_agent_id_match(m, agent_id) ] return filtered_memories except Exception as e: logger.error(f"Failed to get all memories: {e}") return [] async def delete_memory( self, memory_id: str, user_id: str, agent_id: str, ) -> bool: """Delete a single memory. Args: memory_id: Memory ID user_id: User ID agent_id: Agent/Bot ID Returns: Whether deletion succeeded """ try: async with self._semaphore: mem = await self.get_mem0(user_id, agent_id, "default") self._ensure_connection(mem) try: # First fetch memories to verify ownership response = mem.get_all(user_id=user_id) memories = self._extract_memories_from_response(response) target_memory = None for m in memories: if isinstance(m, dict) and m.get("id") == memory_id: # Verify the agent_id match if self._check_agent_id_match(m, agent_id): target_memory = m break if not target_memory: logger.warning(f"Memory {memory_id} not found or access denied for user={user_id}, agent={agent_id}") return False # Delete the memory mem.delete(memory_id=memory_id) finally: self._release_connection(mem) logger.info(f"Deleted memory {memory_id} for user={user_id}, agent={agent_id}") return True except Exception as e: logger.error(f"Failed to delete memory {memory_id}: {e}") return False async def delete_all_memories( self, user_id: str, agent_id: str, ) -> int: """Delete all memories for a user under the specified agent. Args: user_id: User ID agent_id: Agent/Bot ID Returns: Number of deleted memories """ try: async with self._semaphore: mem = await self.get_mem0(user_id, agent_id, "default") self._ensure_connection(mem) try: # Get all memories response = mem.get_all(user_id=user_id) memories = self._extract_memories_from_response(response) # Filter by agent_id and delete matching memories deleted_count = 0 for m in memories: if isinstance(m, dict) and self._check_agent_id_match(m, agent_id): memory_id = m.get("id") if memory_id: try: mem.delete(memory_id=memory_id) deleted_count += 1 except Exception as e: logger.warning(f"Failed to delete memory {memory_id}: {e}") finally: self._release_connection(mem) logger.info(f"Deleted {deleted_count} memories for user={user_id}, agent={agent_id}") return deleted_count except Exception as e: logger.error(f"Failed to delete all memories: {e}") return 0 def clear_cache(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> None: """Clear cached Mem0 instances. Args: user_id: User ID; clear all if None agent_id: Agent ID; clear all if None """ if user_id is None and agent_id is None: self._instances.clear() logger.info("Cleared all Mem0 instances from cache") else: keys_to_remove = [] for key in self._instances: # New format: "user_id:agent_id:llm_model_name" or "user_id:agent_id" parts = key.split(":") if len(parts) >= 2: u_id = parts[0] a_id = parts[1] if user_id and u_id != user_id: continue if agent_id and a_id != agent_id: continue keys_to_remove.append(key) for key in keys_to_remove: del self._instances[key] logger.info(f"Cleared {len(keys_to_remove)} Mem0 instances from cache") async def close(self) -> None: """Close the manager and clean up resources.""" logger.info("Closing Mem0Manager...") # Clean up cached instances and release connections for key, instance in self._instances.items(): self._cleanup_mem0_instance(instance) self._instances.clear() # Note: do not close the shared synchronous connection pool, which is managed by DBPoolManager self._initialized = False logger.info("Mem0Manager closed") # Global singleton _global_manager: Optional[Mem0Manager] = None def get_mem0_manager() -> Mem0Manager: """Get the global Mem0Manager singleton. Returns: Mem0Manager instance """ global _global_manager if _global_manager is None: _global_manager = Mem0Manager() return _global_manager async def init_global_mem0( sync_pool: pool.SimpleConnectionPool, ) -> Mem0Manager: """Initialize the global Mem0Manager. Args: sync_pool: PostgreSQL synchronous connection pool from DBPoolManager.sync_pool Returns: Mem0Manager instance """ manager = get_mem0_manager() manager._sync_pool = sync_pool await manager.initialize() return manager async def close_global_mem0() -> None: """Close the global Mem0Manager.""" global _global_manager if _global_manager is not None: await _global_manager.close()