Convert all Chinese comments, docstrings, logger/print output, HTTPException detail messages, and API response messages to English across the entire codebase. Functional zh/ja localized strings (e.g. prompt templates, timezone display names, date formats) are preserved as-is. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
817 lines
30 KiB
Python
817 lines
30 KiB
Python
"""
|
|
Mem0 connection and instance manager.
|
|
Responsible for creating, caching, and managing the lifecycle of Mem0 client instances.
|
|
"""
|
|
|
|
import logging
|
|
import asyncio
|
|
import threading
|
|
import concurrent.futures
|
|
from typing import Any, Dict, List, Optional, Literal
|
|
from collections import OrderedDict
|
|
from embedding.manager import GlobalModelManager, get_model_manager
|
|
import json_repair
|
|
from psycopg2 import pool
|
|
from utils.settings import (
|
|
MEM0_POOL_SIZE
|
|
)
|
|
from .mem0_config import Mem0Config
|
|
|
|
logger = logging.getLogger("app")
|
|
|
|
|
|
# ============================================================================
|
|
# Custom embedding class that uses the project's existing GlobalModelManager
|
|
# Avoid loading the same model more than once
|
|
# ============================================================================
|
|
|
|
class CustomMem0Embedding:
|
|
"""
|
|
Custom Mem0 embedding class that directly uses the project's existing GlobalModelManager
|
|
|
|
This prevents Mem0 from loading the same model again and saves memory
|
|
"""
|
|
|
|
_model = None # Class variable caching the model instance
|
|
_lock = threading.Lock() # Thread-safe lock
|
|
_executor = None # Thread pool executor
|
|
|
|
def __init__(self, config: Optional[Any] = None):
|
|
"""Initialize the custom embedding."""
|
|
# Create a simple config object compatible with Mem0 telemetry code
|
|
if config is None:
|
|
config = type('Config', (), {'embedding_dims': 384})()
|
|
self.config = config
|
|
|
|
@property
|
|
def embedding_dims(self):
|
|
"""Get the embedding dimension."""
|
|
return 384 # Dimension of gte-tiny
|
|
|
|
def _get_model_sync(self):
|
|
"""Synchronously get the model without using asyncio.run()."""
|
|
# First try to get an already-loaded model from the manager
|
|
manager = get_model_manager()
|
|
model = manager.get_model_sync()
|
|
|
|
if model is not None:
|
|
# Cache the model
|
|
CustomMem0Embedding._model = model
|
|
return model
|
|
|
|
# If the model is not loaded, run async initialization in a thread pool
|
|
if CustomMem0Embedding._executor is None:
|
|
CustomMem0Embedding._executor = concurrent.futures.ThreadPoolExecutor(
|
|
max_workers=1,
|
|
thread_name_prefix="mem0_embed"
|
|
)
|
|
|
|
# Run async code in a dedicated thread
|
|
def run_async_in_thread():
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
result = loop.run_until_complete(manager.get_model())
|
|
return result
|
|
finally:
|
|
loop.close()
|
|
|
|
future = CustomMem0Embedding._executor.submit(run_async_in_thread)
|
|
model = future.result(timeout=30) # 30-second timeout
|
|
|
|
# Cache the model
|
|
CustomMem0Embedding._model = model
|
|
return model
|
|
|
|
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
|
"""
|
|
Get the embedding vector for text (synchronous method for Mem0)
|
|
|
|
Args:
|
|
text: Text to embed (string or list)
|
|
memory_action: Memory operation type (add/search/update), currently unused
|
|
|
|
Returns:
|
|
list: Embedding vector
|
|
"""
|
|
# Retrieve the model in a thread-safe manner
|
|
if CustomMem0Embedding._model is None:
|
|
with CustomMem0Embedding._lock:
|
|
if CustomMem0Embedding._model is None:
|
|
self._get_model_sync()
|
|
|
|
model = CustomMem0Embedding._model
|
|
embeddings = model.encode(text, convert_to_numpy=True)
|
|
return embeddings.tolist()
|
|
|
|
# Monkey patch: replace mem0's remove_code_blocks with json_repair
|
|
def _remove_code_blocks_with_repair(content: str) -> str:
|
|
"""
|
|
Replace mem0's remove_code_blocks function with json_repair
|
|
|
|
json_repair.loads automatically handles:
|
|
- Removing code block markers (```json, ``` and similar)
|
|
- Repairing malformed JSON (such as trailing commas, comments, and single quotes)
|
|
"""
|
|
import re
|
|
|
|
content_stripped = content.strip()
|
|
|
|
try:
|
|
# json_repair.loads automatically strips code blocks and repairs JSON
|
|
result = json_repair.loads(content_stripped)
|
|
if isinstance(result, (dict, list)):
|
|
import json
|
|
return json.dumps(result, ensure_ascii=False)
|
|
# If an empty string is returned for non-JSON input, fall back to the original content
|
|
if result == "" and content_stripped != "":
|
|
# Try simple code block removal as a fallback
|
|
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
|
|
match = re.match(pattern, content_stripped)
|
|
if match:
|
|
return match.group(1).strip()
|
|
return content_stripped
|
|
return str(result)
|
|
except Exception:
|
|
# If parsing fails, try simple code block removal as a fallback
|
|
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
|
|
match = re.match(pattern, content_stripped)
|
|
if match:
|
|
return match.group(1).strip()
|
|
return content_stripped
|
|
|
|
|
|
# Apply the monkey patch before or after importing mem0
|
|
try:
|
|
import sys
|
|
import mem0.memory.utils as mem0_utils
|
|
mem0_utils.remove_code_blocks = _remove_code_blocks_with_repair
|
|
|
|
# If mem0.memory.main has already been imported, patch its local reference as well
|
|
if 'mem0.memory.main' in sys.modules:
|
|
import mem0.memory.main
|
|
mem0.memory.main.remove_code_blocks = _remove_code_blocks_with_repair
|
|
logger.info("Successfully patched mem0.memory.main.remove_code_blocks with json_repair")
|
|
else:
|
|
logger.info("Successfully patched mem0.memory.utils.remove_code_blocks with json_repair")
|
|
except ImportError:
|
|
# mem0 has not been imported yet; the patch will take effect on first import
|
|
pass
|
|
except Exception as e:
|
|
logger.warning(f"Failed to patch mem0 remove_code_blocks: {e}")
|
|
|
|
|
|
class Mem0Manager:
|
|
"""
|
|
Mem0 connection and instance manager
|
|
|
|
Main responsibilities:
|
|
1. Manage creation and caching of Mem0 instances
|
|
2. Support multi-tenant isolation (user_id + agent_id)
|
|
3. Use a shared synchronous connection pool (provided by DBPoolManager)
|
|
4. Provide memory recall and storage interfaces
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
sync_pool: Optional[pool.SimpleConnectionPool] = None,
|
|
):
|
|
"""Initialize Mem0Manager.
|
|
|
|
Args:
|
|
sync_pool: Shared PostgreSQL synchronous connection pool from DBPoolManager
|
|
"""
|
|
self._sync_pool = sync_pool
|
|
|
|
# Use OrderedDict to implement an LRU cache, keeping at most 50 instances
|
|
self._instances: OrderedDict[str, Any] = OrderedDict()
|
|
self._max_instances = MEM0_POOL_SIZE/2 # Maximum number of cached instances
|
|
self._initialized = False
|
|
|
|
# Limit concurrent Mem0 operations to avoid exhausting the connection pool
|
|
self._semaphore = asyncio.Semaphore(max(MEM0_POOL_SIZE - 2, 1))
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize Mem0Manager.
|
|
|
|
Create the database schema if needed.
|
|
"""
|
|
if self._initialized:
|
|
return
|
|
|
|
logger.info("Initializing Mem0Manager...")
|
|
|
|
try:
|
|
# Mem0 creates the schema automatically; only connection validation is needed here
|
|
if self._sync_pool:
|
|
logger.info("Mem0Manager initialized successfully")
|
|
else:
|
|
logger.warning("No database configuration provided for Mem0")
|
|
|
|
self._initialized = True
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize Mem0Manager: {e}")
|
|
# Do not raise; allow the system to run without Mem0
|
|
|
|
def _get_connection_pool(self) -> Optional[pool.SimpleConnectionPool]:
|
|
"""Get the synchronous database connection pool required by Mem0.
|
|
|
|
Returns:
|
|
psycopg2.pool connection pool
|
|
"""
|
|
return self._sync_pool
|
|
|
|
def _cleanup_mem0_instance(self, mem0_instance: Any) -> None:
|
|
"""Clean up a Mem0 instance and release its database connection.
|
|
|
|
Mem0's PGVector implementation acquires and holds a connection during
|
|
initialization, only returning it in __del__. Python GC does not guarantee
|
|
__del__ will run immediately, which can exhaust the connection pool. This
|
|
method releases the connection explicitly.
|
|
|
|
Args:
|
|
mem0_instance: Mem0 Memory instance
|
|
"""
|
|
try:
|
|
# Mem0 Memory instances have a vector_store attribute of type PGVector
|
|
if hasattr(mem0_instance, 'vector_store'):
|
|
vector_store = mem0_instance.vector_store
|
|
# PGVector has conn and connection_pool attributes
|
|
if hasattr(vector_store, 'conn') and hasattr(vector_store, 'connection_pool'):
|
|
if vector_store.conn is not None and vector_store.connection_pool is not None:
|
|
try:
|
|
# Close the cursor first
|
|
if hasattr(vector_store, 'cur') and vector_store.cur:
|
|
vector_store.cur.close()
|
|
vector_store.cur = None
|
|
# Return the connection to the pool
|
|
vector_store.connection_pool.putconn(vector_store.conn)
|
|
# Mark as cleaned up to prevent __del__ from releasing it again
|
|
vector_store.conn = None
|
|
logger.debug("Successfully released Mem0 database connection back to pool")
|
|
except Exception as e:
|
|
logger.warning(f"Error releasing Mem0 connection: {e}")
|
|
except Exception as e:
|
|
logger.warning(f"Error cleaning up Mem0 instance: {e}")
|
|
|
|
def _ensure_connection(self, mem0_instance: Any) -> None:
|
|
"""Ensure a Mem0 instance has a database connection before use.
|
|
|
|
If the connection was released by _release_connection, reacquire it from the pool.
|
|
|
|
Args:
|
|
mem0_instance: Mem0 Memory instance
|
|
"""
|
|
try:
|
|
if hasattr(mem0_instance, 'vector_store'):
|
|
vs = mem0_instance.vector_store
|
|
if hasattr(vs, 'conn') and vs.conn is None and self._sync_pool:
|
|
vs.conn = self._sync_pool.getconn()
|
|
vs.cur = vs.conn.cursor()
|
|
# Ensure the connection_pool reference exists for later return
|
|
if hasattr(vs, 'connection_pool') and vs.connection_pool is None:
|
|
vs.connection_pool = self._sync_pool
|
|
logger.debug("Re-acquired Mem0 database connection from pool")
|
|
except Exception as e:
|
|
logger.warning(f"Error ensuring Mem0 connection: {e}")
|
|
raise
|
|
|
|
def _release_connection(self, mem0_instance: Any) -> None:
|
|
"""Release the connection back to the pool after use.
|
|
|
|
Unlike _cleanup_mem0_instance, this keeps the connection_pool reference so
|
|
_ensure_connection can reacquire a connection later.
|
|
|
|
Args:
|
|
mem0_instance: Mem0 Memory instance
|
|
"""
|
|
try:
|
|
if hasattr(mem0_instance, 'vector_store'):
|
|
vs = mem0_instance.vector_store
|
|
if hasattr(vs, 'conn') and vs.conn is not None:
|
|
if hasattr(vs, 'cur') and vs.cur:
|
|
vs.cur.close()
|
|
vs.cur = None
|
|
if hasattr(vs, 'connection_pool') and vs.connection_pool is not None:
|
|
vs.connection_pool.putconn(vs.conn)
|
|
vs.conn = None
|
|
logger.debug("Released Mem0 database connection back to pool")
|
|
except Exception as e:
|
|
logger.warning(f"Error releasing Mem0 connection: {e}")
|
|
|
|
async def get_mem0(
|
|
self,
|
|
user_id: str,
|
|
agent_id: str,
|
|
session_id: str,
|
|
config: Optional[Mem0Config] = None,
|
|
) -> Any:
|
|
"""Get or create a Mem0 instance.
|
|
|
|
Args:
|
|
user_id: User ID (mapped to entity_id)
|
|
agent_id: Agent/Bot ID (mapped to process_id)
|
|
session_id: Session ID
|
|
config: Mem0 configuration
|
|
|
|
Returns:
|
|
Mem0 instance
|
|
"""
|
|
# The cache key includes the LLM instance ID so different LLMs use different instances
|
|
llm_suffix = ""
|
|
if config and config.llm_instance is not None:
|
|
llm_suffix = f":{id(config.llm_instance)}"
|
|
cache_key = f"{user_id}:{agent_id}{llm_suffix}"
|
|
|
|
# Check the cache and move hits to the end to mark recent use
|
|
if cache_key in self._instances:
|
|
self._instances.move_to_end(cache_key)
|
|
return self._instances[cache_key]
|
|
|
|
# Check cache size and evict the oldest entry if needed
|
|
if len(self._instances) >= self._max_instances:
|
|
removed_key, removed_instance = self._instances.popitem(last=False)
|
|
# Explicitly release the connection instead of waiting for GC and risking pool exhaustion
|
|
self._cleanup_mem0_instance(removed_instance)
|
|
logger.debug(f"Mem0 instance cache full, removed and cleaned: {removed_key}")
|
|
|
|
# Create a new instance
|
|
mem0_instance = await self._create_mem0_instance(
|
|
user_id=user_id,
|
|
agent_id=agent_id,
|
|
session_id=session_id,
|
|
config=config,
|
|
)
|
|
|
|
# Cache the instance; new entries are automatically appended to the end
|
|
self._instances[cache_key] = mem0_instance
|
|
return mem0_instance
|
|
|
|
async def _create_mem0_instance(
|
|
self,
|
|
user_id: str,
|
|
agent_id: str,
|
|
session_id: str,
|
|
config: Optional[Mem0Config] = None,
|
|
) -> Any:
|
|
"""Create a new Mem0 instance.
|
|
|
|
Args:
|
|
user_id: User ID
|
|
agent_id: Agent/Bot ID
|
|
session_id: Session ID
|
|
config: Mem0 configuration, including the LLM instance
|
|
|
|
Returns:
|
|
Mem0 Memory instance
|
|
"""
|
|
try:
|
|
from mem0 import Memory
|
|
except ImportError:
|
|
logger.error("mem0 package not installed")
|
|
raise RuntimeError("mem0 package is required but not installed")
|
|
|
|
# Get the synchronous connection pool
|
|
connection_pool = self._get_connection_pool()
|
|
if not connection_pool:
|
|
raise ValueError("Database connection pool not available")
|
|
|
|
# Create a custom embedder that reuses the shared model to avoid duplicate loading
|
|
custom_embedder = CustomMem0Embedding()
|
|
|
|
# Configure Mem0 to use Pgvector
|
|
# Note: use huggingface_base_url here to bypass local model loading
|
|
# Set a dummy base_url so HuggingFaceEmbedding does not load SentenceTransformer
|
|
|
|
config_dict = {
|
|
"vector_store": {
|
|
"provider": "pgvector",
|
|
"config": {
|
|
"connection_pool": connection_pool,
|
|
"collection_name": f"mem0_{agent_id}".replace("-", "_")[:50], # Isolate by agent_id
|
|
"embedding_model_dims": 384, # Dimension of paraphrase-multilingual-MiniLM-L12-v2
|
|
}
|
|
},
|
|
# Use huggingface_base_url to bypass model loading; it will later be replaced with the custom embedder
|
|
"embedder": {
|
|
"provider": "huggingface",
|
|
"config": {
|
|
"huggingface_base_url": "http://dummy-url-that-will-be-replaced",
|
|
"api_key": "dummy-key" # Placeholder to prevent OpenAI client validation failure
|
|
}
|
|
}
|
|
}
|
|
|
|
# Add a custom memory extraction prompt if config is provided
|
|
if config is not None:
|
|
config_dict["custom_fact_extraction_prompt"] = await config.get_custom_fact_extraction_prompt_async()
|
|
|
|
# Add LangChain LLM configuration if provided
|
|
if config and config.llm_instance is not None:
|
|
config_dict["llm"] = {
|
|
"provider": "langchain",
|
|
"config": {"model": config.llm_instance}
|
|
}
|
|
logger.info(
|
|
f"Configured LangChain LLM for Mem0: {type(config.llm_instance).__name__}"
|
|
)
|
|
else:
|
|
# If no LLM is provided, use the default OpenAI configuration
|
|
# Mem0's LLM is used to extract memory facts
|
|
from utils.settings import MASTERKEY, BACKEND_HOST
|
|
import os
|
|
llm_api_key = os.environ.get("OPENAI_API_KEY", "") or MASTERKEY
|
|
config_dict["llm"] = {
|
|
"provider": "openai",
|
|
"config": {
|
|
"model": "gpt-4o-mini",
|
|
"api_key": llm_api_key,
|
|
"openai_base_url": BACKEND_HOST # Use the custom backend
|
|
}
|
|
}
|
|
|
|
# Create the Mem0 instance
|
|
mem = Memory.from_config(config_dict)
|
|
logger.debug(f"Original embedder type: {type(mem.embedding_model).__name__}")
|
|
logger.debug(f"Original embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}")
|
|
|
|
# Replace with the custom embedder to reuse the model already loaded by the project
|
|
# This prevents Mem0 from loading the model again
|
|
mem.embedding_model = custom_embedder
|
|
logger.debug(f"Replaced embedder type: {type(mem.embedding_model).__name__}")
|
|
logger.debug(f"Replaced embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}")
|
|
logger.info("Replaced Mem0 embedder with CustomMem0Embedding (reusing existing model)")
|
|
|
|
logger.info(
|
|
f"Created Mem0 instance: user={user_id}, agent={agent_id}"
|
|
)
|
|
|
|
# PGVector calls getconn during creation; release immediately to avoid holding the connection long-term
|
|
self._release_connection(mem)
|
|
|
|
return mem
|
|
|
|
async def recall_memories(
|
|
self,
|
|
query: str,
|
|
user_id: str,
|
|
agent_id: str,
|
|
config: Optional[Mem0Config] = None,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Recall relevant memories at the user level, shared across sessions.
|
|
|
|
Args:
|
|
query: Query text
|
|
user_id: User ID
|
|
agent_id: Agent/Bot ID
|
|
config: Mem0 configuration
|
|
|
|
Returns:
|
|
List of memories, each containing fields such as content and similarity
|
|
"""
|
|
try:
|
|
async with self._semaphore:
|
|
mem = await self.get_mem0(user_id, agent_id, "default", config)
|
|
self._ensure_connection(mem)
|
|
try:
|
|
# Call search for semantic retrieval, filtering with the agent_id parameter
|
|
limit = config.semantic_search_top_k if config else 20
|
|
results = mem.search(
|
|
query=query,
|
|
limit=limit,
|
|
user_id=user_id,
|
|
agent_id=agent_id,
|
|
)
|
|
finally:
|
|
self._release_connection(mem)
|
|
|
|
# Convert to a normalized format
|
|
memories = []
|
|
for result in results["results"]:
|
|
# Mem0 may return results as strings or dictionaries
|
|
content = result.get("memory", "")
|
|
score = result.get("score", 0.0)
|
|
result_metadata = result.get("metadata", {})
|
|
|
|
memory = {
|
|
"content": content,
|
|
"similarity": score,
|
|
"metadata": result_metadata,
|
|
"fact_type": result_metadata.get("category", "fact"),
|
|
}
|
|
memories.append(memory)
|
|
|
|
logger.info(f"Recalled {len(memories)} memories for user={user_id}, query: {query[:50]}...")
|
|
return memories
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to recall memories: {e}")
|
|
return []
|
|
|
|
async def add_memory(
|
|
self,
|
|
text: str,
|
|
user_id: str,
|
|
agent_id: str,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
config: Optional[Mem0Config] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Add a new memory at the user level, shared across sessions.
|
|
|
|
Args:
|
|
text: Memory text
|
|
user_id: User ID
|
|
agent_id: Agent/Bot ID
|
|
metadata: Additional metadata
|
|
config: Mem0 configuration, including the LLM instance for memory extraction
|
|
|
|
Returns:
|
|
Result of the added memory
|
|
"""
|
|
try:
|
|
async with self._semaphore:
|
|
mem = await self.get_mem0(user_id, agent_id, "default", config)
|
|
self._ensure_connection(mem)
|
|
try:
|
|
# Add the memory using the agent_id parameter
|
|
result = mem.add(
|
|
text,
|
|
user_id=user_id,
|
|
agent_id=agent_id,
|
|
metadata=metadata or {}
|
|
)
|
|
finally:
|
|
self._release_connection(mem)
|
|
|
|
logger.info(f"Added memory for user={user_id}, agent={agent_id}: {result}")
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to add memory: {e}")
|
|
return {}
|
|
|
|
def _extract_memories_from_response(self, response: Any) -> List[Dict[str, Any]]:
|
|
"""Extract the memory list from a Mem0 get_all() response.
|
|
|
|
Mem0 get_all() may return one of two formats:
|
|
1. New version: {"results": [...]}
|
|
2. Old version: a list directly
|
|
|
|
Args:
|
|
response: Response returned by Mem0 get_all()
|
|
|
|
Returns:
|
|
List of memories
|
|
"""
|
|
if isinstance(response, dict) and "results" in response:
|
|
return response["results"]
|
|
elif isinstance(response, list):
|
|
return response
|
|
else:
|
|
logger.warning(f"Unexpected response format from mem.get_all(): {type(response)}")
|
|
return []
|
|
|
|
def _check_agent_id_match(self, memory: Dict[str, Any], agent_id: str) -> bool:
|
|
"""Check whether a memory belongs to the specified agent.
|
|
|
|
In Mem0's memory structure, agent_id may appear in two places:
|
|
1. Top level: memory["agent_id"]
|
|
2. Inside metadata: memory["metadata"]["agent_id"]
|
|
|
|
Args:
|
|
memory: Memory dictionary
|
|
agent_id: Agent ID to match
|
|
|
|
Returns:
|
|
Whether it matches
|
|
"""
|
|
if not isinstance(memory, dict):
|
|
return False
|
|
|
|
# First check the top-level agent_id field (new format)
|
|
if memory.get("agent_id") == agent_id:
|
|
return True
|
|
|
|
# Then check agent_id inside metadata (old format)
|
|
metadata = memory.get("metadata", {})
|
|
if isinstance(metadata, dict) and metadata.get("agent_id") == agent_id:
|
|
return True
|
|
|
|
return False
|
|
|
|
async def get_all_memories(
|
|
self,
|
|
user_id: str,
|
|
agent_id: str,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get all memories for a user at the user level.
|
|
|
|
Args:
|
|
user_id: User ID
|
|
agent_id: Agent/Bot ID
|
|
|
|
Returns:
|
|
List of memories
|
|
"""
|
|
try:
|
|
async with self._semaphore:
|
|
mem = await self.get_mem0(user_id, agent_id, "default")
|
|
self._ensure_connection(mem)
|
|
try:
|
|
# Get all memories
|
|
response = mem.get_all(user_id=user_id)
|
|
finally:
|
|
self._release_connection(mem)
|
|
|
|
# Extract the memory list from the response
|
|
memories = self._extract_memories_from_response(response)
|
|
|
|
# Filter by agent_id, which is stored at the top level rather than in metadata
|
|
filtered_memories = [
|
|
m for m in memories
|
|
if self._check_agent_id_match(m, agent_id)
|
|
]
|
|
|
|
return filtered_memories
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get all memories: {e}")
|
|
return []
|
|
|
|
async def delete_memory(
|
|
self,
|
|
memory_id: str,
|
|
user_id: str,
|
|
agent_id: str,
|
|
) -> bool:
|
|
"""Delete a single memory.
|
|
|
|
Args:
|
|
memory_id: Memory ID
|
|
user_id: User ID
|
|
agent_id: Agent/Bot ID
|
|
|
|
Returns:
|
|
Whether deletion succeeded
|
|
"""
|
|
try:
|
|
async with self._semaphore:
|
|
mem = await self.get_mem0(user_id, agent_id, "default")
|
|
self._ensure_connection(mem)
|
|
try:
|
|
# First fetch memories to verify ownership
|
|
response = mem.get_all(user_id=user_id)
|
|
memories = self._extract_memories_from_response(response)
|
|
|
|
target_memory = None
|
|
for m in memories:
|
|
if isinstance(m, dict) and m.get("id") == memory_id:
|
|
# Verify the agent_id match
|
|
if self._check_agent_id_match(m, agent_id):
|
|
target_memory = m
|
|
break
|
|
|
|
if not target_memory:
|
|
logger.warning(f"Memory {memory_id} not found or access denied for user={user_id}, agent={agent_id}")
|
|
return False
|
|
|
|
# Delete the memory
|
|
mem.delete(memory_id=memory_id)
|
|
finally:
|
|
self._release_connection(mem)
|
|
|
|
logger.info(f"Deleted memory {memory_id} for user={user_id}, agent={agent_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete memory {memory_id}: {e}")
|
|
return False
|
|
|
|
async def delete_all_memories(
|
|
self,
|
|
user_id: str,
|
|
agent_id: str,
|
|
) -> int:
|
|
"""Delete all memories for a user under the specified agent.
|
|
|
|
Args:
|
|
user_id: User ID
|
|
agent_id: Agent/Bot ID
|
|
|
|
Returns:
|
|
Number of deleted memories
|
|
"""
|
|
try:
|
|
async with self._semaphore:
|
|
mem = await self.get_mem0(user_id, agent_id, "default")
|
|
self._ensure_connection(mem)
|
|
try:
|
|
# Get all memories
|
|
response = mem.get_all(user_id=user_id)
|
|
memories = self._extract_memories_from_response(response)
|
|
|
|
# Filter by agent_id and delete matching memories
|
|
deleted_count = 0
|
|
for m in memories:
|
|
if isinstance(m, dict) and self._check_agent_id_match(m, agent_id):
|
|
memory_id = m.get("id")
|
|
if memory_id:
|
|
try:
|
|
mem.delete(memory_id=memory_id)
|
|
deleted_count += 1
|
|
except Exception as e:
|
|
logger.warning(f"Failed to delete memory {memory_id}: {e}")
|
|
finally:
|
|
self._release_connection(mem)
|
|
|
|
logger.info(f"Deleted {deleted_count} memories for user={user_id}, agent={agent_id}")
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete all memories: {e}")
|
|
return 0
|
|
|
|
def clear_cache(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> None:
|
|
"""Clear cached Mem0 instances.
|
|
|
|
Args:
|
|
user_id: User ID; clear all if None
|
|
agent_id: Agent ID; clear all if None
|
|
"""
|
|
if user_id is None and agent_id is None:
|
|
self._instances.clear()
|
|
logger.info("Cleared all Mem0 instances from cache")
|
|
else:
|
|
keys_to_remove = []
|
|
for key in self._instances:
|
|
# New format: "user_id:agent_id:llm_model_name" or "user_id:agent_id"
|
|
parts = key.split(":")
|
|
if len(parts) >= 2:
|
|
u_id = parts[0]
|
|
a_id = parts[1]
|
|
if user_id and u_id != user_id:
|
|
continue
|
|
if agent_id and a_id != agent_id:
|
|
continue
|
|
keys_to_remove.append(key)
|
|
|
|
for key in keys_to_remove:
|
|
del self._instances[key]
|
|
|
|
logger.info(f"Cleared {len(keys_to_remove)} Mem0 instances from cache")
|
|
|
|
async def close(self) -> None:
|
|
"""Close the manager and clean up resources."""
|
|
logger.info("Closing Mem0Manager...")
|
|
|
|
# Clean up cached instances and release connections
|
|
for key, instance in self._instances.items():
|
|
self._cleanup_mem0_instance(instance)
|
|
self._instances.clear()
|
|
|
|
# Note: do not close the shared synchronous connection pool, which is managed by DBPoolManager
|
|
|
|
self._initialized = False
|
|
|
|
logger.info("Mem0Manager closed")
|
|
|
|
|
|
# Global singleton
|
|
_global_manager: Optional[Mem0Manager] = None
|
|
|
|
|
|
def get_mem0_manager() -> Mem0Manager:
|
|
"""Get the global Mem0Manager singleton.
|
|
|
|
Returns:
|
|
Mem0Manager instance
|
|
"""
|
|
global _global_manager
|
|
if _global_manager is None:
|
|
_global_manager = Mem0Manager()
|
|
return _global_manager
|
|
|
|
|
|
async def init_global_mem0(
|
|
sync_pool: pool.SimpleConnectionPool,
|
|
) -> Mem0Manager:
|
|
"""Initialize the global Mem0Manager.
|
|
|
|
Args:
|
|
sync_pool: PostgreSQL synchronous connection pool from DBPoolManager.sync_pool
|
|
|
|
Returns:
|
|
Mem0Manager instance
|
|
"""
|
|
manager = get_mem0_manager()
|
|
manager._sync_pool = sync_pool
|
|
await manager.initialize()
|
|
return manager
|
|
|
|
|
|
async def close_global_mem0() -> None:
|
|
"""Close the global Mem0Manager."""
|
|
global _global_manager
|
|
if _global_manager is not None:
|
|
await _global_manager.close()
|