qwen_agent/agent/mem0_manager.py
朱潮 425f3c5bb4 chore: replace Chinese comments and log messages with English
Convert all Chinese comments, docstrings, logger/print output,
HTTPException detail messages, and API response messages to English
across the entire codebase. Functional zh/ja localized strings
(e.g. prompt templates, timezone display names, date formats) are
preserved as-is.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-30 19:45:35 +08:00

817 lines
30 KiB
Python

"""
Mem0 connection and instance manager.
Responsible for creating, caching, and managing the lifecycle of Mem0 client instances.
"""
import logging
import asyncio
import threading
import concurrent.futures
from typing import Any, Dict, List, Optional, Literal
from collections import OrderedDict
from embedding.manager import GlobalModelManager, get_model_manager
import json_repair
from psycopg2 import pool
from utils.settings import (
MEM0_POOL_SIZE
)
from .mem0_config import Mem0Config
logger = logging.getLogger("app")
# ============================================================================
# Custom embedding class that uses the project's existing GlobalModelManager
# Avoid loading the same model more than once
# ============================================================================
class CustomMem0Embedding:
"""
Custom Mem0 embedding class that directly uses the project's existing GlobalModelManager
This prevents Mem0 from loading the same model again and saves memory
"""
_model = None # Class variable caching the model instance
_lock = threading.Lock() # Thread-safe lock
_executor = None # Thread pool executor
def __init__(self, config: Optional[Any] = None):
"""Initialize the custom embedding."""
# Create a simple config object compatible with Mem0 telemetry code
if config is None:
config = type('Config', (), {'embedding_dims': 384})()
self.config = config
@property
def embedding_dims(self):
"""Get the embedding dimension."""
return 384 # Dimension of gte-tiny
def _get_model_sync(self):
"""Synchronously get the model without using asyncio.run()."""
# First try to get an already-loaded model from the manager
manager = get_model_manager()
model = manager.get_model_sync()
if model is not None:
# Cache the model
CustomMem0Embedding._model = model
return model
# If the model is not loaded, run async initialization in a thread pool
if CustomMem0Embedding._executor is None:
CustomMem0Embedding._executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1,
thread_name_prefix="mem0_embed"
)
# Run async code in a dedicated thread
def run_async_in_thread():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(manager.get_model())
return result
finally:
loop.close()
future = CustomMem0Embedding._executor.submit(run_async_in_thread)
model = future.result(timeout=30) # 30-second timeout
# Cache the model
CustomMem0Embedding._model = model
return model
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
"""
Get the embedding vector for text (synchronous method for Mem0)
Args:
text: Text to embed (string or list)
memory_action: Memory operation type (add/search/update), currently unused
Returns:
list: Embedding vector
"""
# Retrieve the model in a thread-safe manner
if CustomMem0Embedding._model is None:
with CustomMem0Embedding._lock:
if CustomMem0Embedding._model is None:
self._get_model_sync()
model = CustomMem0Embedding._model
embeddings = model.encode(text, convert_to_numpy=True)
return embeddings.tolist()
# Monkey patch: replace mem0's remove_code_blocks with json_repair
def _remove_code_blocks_with_repair(content: str) -> str:
"""
Replace mem0's remove_code_blocks function with json_repair
json_repair.loads automatically handles:
- Removing code block markers (```json, ``` and similar)
- Repairing malformed JSON (such as trailing commas, comments, and single quotes)
"""
import re
content_stripped = content.strip()
try:
# json_repair.loads automatically strips code blocks and repairs JSON
result = json_repair.loads(content_stripped)
if isinstance(result, (dict, list)):
import json
return json.dumps(result, ensure_ascii=False)
# If an empty string is returned for non-JSON input, fall back to the original content
if result == "" and content_stripped != "":
# Try simple code block removal as a fallback
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
match = re.match(pattern, content_stripped)
if match:
return match.group(1).strip()
return content_stripped
return str(result)
except Exception:
# If parsing fails, try simple code block removal as a fallback
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
match = re.match(pattern, content_stripped)
if match:
return match.group(1).strip()
return content_stripped
# Apply the monkey patch before or after importing mem0
try:
import sys
import mem0.memory.utils as mem0_utils
mem0_utils.remove_code_blocks = _remove_code_blocks_with_repair
# If mem0.memory.main has already been imported, patch its local reference as well
if 'mem0.memory.main' in sys.modules:
import mem0.memory.main
mem0.memory.main.remove_code_blocks = _remove_code_blocks_with_repair
logger.info("Successfully patched mem0.memory.main.remove_code_blocks with json_repair")
else:
logger.info("Successfully patched mem0.memory.utils.remove_code_blocks with json_repair")
except ImportError:
# mem0 has not been imported yet; the patch will take effect on first import
pass
except Exception as e:
logger.warning(f"Failed to patch mem0 remove_code_blocks: {e}")
class Mem0Manager:
"""
Mem0 connection and instance manager
Main responsibilities:
1. Manage creation and caching of Mem0 instances
2. Support multi-tenant isolation (user_id + agent_id)
3. Use a shared synchronous connection pool (provided by DBPoolManager)
4. Provide memory recall and storage interfaces
"""
def __init__(
self,
sync_pool: Optional[pool.SimpleConnectionPool] = None,
):
"""Initialize Mem0Manager.
Args:
sync_pool: Shared PostgreSQL synchronous connection pool from DBPoolManager
"""
self._sync_pool = sync_pool
# Use OrderedDict to implement an LRU cache, keeping at most 50 instances
self._instances: OrderedDict[str, Any] = OrderedDict()
self._max_instances = MEM0_POOL_SIZE/2 # Maximum number of cached instances
self._initialized = False
# Limit concurrent Mem0 operations to avoid exhausting the connection pool
self._semaphore = asyncio.Semaphore(max(MEM0_POOL_SIZE - 2, 1))
async def initialize(self) -> None:
"""Initialize Mem0Manager.
Create the database schema if needed.
"""
if self._initialized:
return
logger.info("Initializing Mem0Manager...")
try:
# Mem0 creates the schema automatically; only connection validation is needed here
if self._sync_pool:
logger.info("Mem0Manager initialized successfully")
else:
logger.warning("No database configuration provided for Mem0")
self._initialized = True
except Exception as e:
logger.error(f"Failed to initialize Mem0Manager: {e}")
# Do not raise; allow the system to run without Mem0
def _get_connection_pool(self) -> Optional[pool.SimpleConnectionPool]:
"""Get the synchronous database connection pool required by Mem0.
Returns:
psycopg2.pool connection pool
"""
return self._sync_pool
def _cleanup_mem0_instance(self, mem0_instance: Any) -> None:
"""Clean up a Mem0 instance and release its database connection.
Mem0's PGVector implementation acquires and holds a connection during
initialization, only returning it in __del__. Python GC does not guarantee
__del__ will run immediately, which can exhaust the connection pool. This
method releases the connection explicitly.
Args:
mem0_instance: Mem0 Memory instance
"""
try:
# Mem0 Memory instances have a vector_store attribute of type PGVector
if hasattr(mem0_instance, 'vector_store'):
vector_store = mem0_instance.vector_store
# PGVector has conn and connection_pool attributes
if hasattr(vector_store, 'conn') and hasattr(vector_store, 'connection_pool'):
if vector_store.conn is not None and vector_store.connection_pool is not None:
try:
# Close the cursor first
if hasattr(vector_store, 'cur') and vector_store.cur:
vector_store.cur.close()
vector_store.cur = None
# Return the connection to the pool
vector_store.connection_pool.putconn(vector_store.conn)
# Mark as cleaned up to prevent __del__ from releasing it again
vector_store.conn = None
logger.debug("Successfully released Mem0 database connection back to pool")
except Exception as e:
logger.warning(f"Error releasing Mem0 connection: {e}")
except Exception as e:
logger.warning(f"Error cleaning up Mem0 instance: {e}")
def _ensure_connection(self, mem0_instance: Any) -> None:
"""Ensure a Mem0 instance has a database connection before use.
If the connection was released by _release_connection, reacquire it from the pool.
Args:
mem0_instance: Mem0 Memory instance
"""
try:
if hasattr(mem0_instance, 'vector_store'):
vs = mem0_instance.vector_store
if hasattr(vs, 'conn') and vs.conn is None and self._sync_pool:
vs.conn = self._sync_pool.getconn()
vs.cur = vs.conn.cursor()
# Ensure the connection_pool reference exists for later return
if hasattr(vs, 'connection_pool') and vs.connection_pool is None:
vs.connection_pool = self._sync_pool
logger.debug("Re-acquired Mem0 database connection from pool")
except Exception as e:
logger.warning(f"Error ensuring Mem0 connection: {e}")
raise
def _release_connection(self, mem0_instance: Any) -> None:
"""Release the connection back to the pool after use.
Unlike _cleanup_mem0_instance, this keeps the connection_pool reference so
_ensure_connection can reacquire a connection later.
Args:
mem0_instance: Mem0 Memory instance
"""
try:
if hasattr(mem0_instance, 'vector_store'):
vs = mem0_instance.vector_store
if hasattr(vs, 'conn') and vs.conn is not None:
if hasattr(vs, 'cur') and vs.cur:
vs.cur.close()
vs.cur = None
if hasattr(vs, 'connection_pool') and vs.connection_pool is not None:
vs.connection_pool.putconn(vs.conn)
vs.conn = None
logger.debug("Released Mem0 database connection back to pool")
except Exception as e:
logger.warning(f"Error releasing Mem0 connection: {e}")
async def get_mem0(
self,
user_id: str,
agent_id: str,
session_id: str,
config: Optional[Mem0Config] = None,
) -> Any:
"""Get or create a Mem0 instance.
Args:
user_id: User ID (mapped to entity_id)
agent_id: Agent/Bot ID (mapped to process_id)
session_id: Session ID
config: Mem0 configuration
Returns:
Mem0 instance
"""
# The cache key includes the LLM instance ID so different LLMs use different instances
llm_suffix = ""
if config and config.llm_instance is not None:
llm_suffix = f":{id(config.llm_instance)}"
cache_key = f"{user_id}:{agent_id}{llm_suffix}"
# Check the cache and move hits to the end to mark recent use
if cache_key in self._instances:
self._instances.move_to_end(cache_key)
return self._instances[cache_key]
# Check cache size and evict the oldest entry if needed
if len(self._instances) >= self._max_instances:
removed_key, removed_instance = self._instances.popitem(last=False)
# Explicitly release the connection instead of waiting for GC and risking pool exhaustion
self._cleanup_mem0_instance(removed_instance)
logger.debug(f"Mem0 instance cache full, removed and cleaned: {removed_key}")
# Create a new instance
mem0_instance = await self._create_mem0_instance(
user_id=user_id,
agent_id=agent_id,
session_id=session_id,
config=config,
)
# Cache the instance; new entries are automatically appended to the end
self._instances[cache_key] = mem0_instance
return mem0_instance
async def _create_mem0_instance(
self,
user_id: str,
agent_id: str,
session_id: str,
config: Optional[Mem0Config] = None,
) -> Any:
"""Create a new Mem0 instance.
Args:
user_id: User ID
agent_id: Agent/Bot ID
session_id: Session ID
config: Mem0 configuration, including the LLM instance
Returns:
Mem0 Memory instance
"""
try:
from mem0 import Memory
except ImportError:
logger.error("mem0 package not installed")
raise RuntimeError("mem0 package is required but not installed")
# Get the synchronous connection pool
connection_pool = self._get_connection_pool()
if not connection_pool:
raise ValueError("Database connection pool not available")
# Create a custom embedder that reuses the shared model to avoid duplicate loading
custom_embedder = CustomMem0Embedding()
# Configure Mem0 to use Pgvector
# Note: use huggingface_base_url here to bypass local model loading
# Set a dummy base_url so HuggingFaceEmbedding does not load SentenceTransformer
config_dict = {
"vector_store": {
"provider": "pgvector",
"config": {
"connection_pool": connection_pool,
"collection_name": f"mem0_{agent_id}".replace("-", "_")[:50], # Isolate by agent_id
"embedding_model_dims": 384, # Dimension of paraphrase-multilingual-MiniLM-L12-v2
}
},
# Use huggingface_base_url to bypass model loading; it will later be replaced with the custom embedder
"embedder": {
"provider": "huggingface",
"config": {
"huggingface_base_url": "http://dummy-url-that-will-be-replaced",
"api_key": "dummy-key" # Placeholder to prevent OpenAI client validation failure
}
}
}
# Add a custom memory extraction prompt if config is provided
if config is not None:
config_dict["custom_fact_extraction_prompt"] = await config.get_custom_fact_extraction_prompt_async()
# Add LangChain LLM configuration if provided
if config and config.llm_instance is not None:
config_dict["llm"] = {
"provider": "langchain",
"config": {"model": config.llm_instance}
}
logger.info(
f"Configured LangChain LLM for Mem0: {type(config.llm_instance).__name__}"
)
else:
# If no LLM is provided, use the default OpenAI configuration
# Mem0's LLM is used to extract memory facts
from utils.settings import MASTERKEY, BACKEND_HOST
import os
llm_api_key = os.environ.get("OPENAI_API_KEY", "") or MASTERKEY
config_dict["llm"] = {
"provider": "openai",
"config": {
"model": "gpt-4o-mini",
"api_key": llm_api_key,
"openai_base_url": BACKEND_HOST # Use the custom backend
}
}
# Create the Mem0 instance
mem = Memory.from_config(config_dict)
logger.debug(f"Original embedder type: {type(mem.embedding_model).__name__}")
logger.debug(f"Original embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}")
# Replace with the custom embedder to reuse the model already loaded by the project
# This prevents Mem0 from loading the model again
mem.embedding_model = custom_embedder
logger.debug(f"Replaced embedder type: {type(mem.embedding_model).__name__}")
logger.debug(f"Replaced embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}")
logger.info("Replaced Mem0 embedder with CustomMem0Embedding (reusing existing model)")
logger.info(
f"Created Mem0 instance: user={user_id}, agent={agent_id}"
)
# PGVector calls getconn during creation; release immediately to avoid holding the connection long-term
self._release_connection(mem)
return mem
async def recall_memories(
self,
query: str,
user_id: str,
agent_id: str,
config: Optional[Mem0Config] = None,
) -> List[Dict[str, Any]]:
"""Recall relevant memories at the user level, shared across sessions.
Args:
query: Query text
user_id: User ID
agent_id: Agent/Bot ID
config: Mem0 configuration
Returns:
List of memories, each containing fields such as content and similarity
"""
try:
async with self._semaphore:
mem = await self.get_mem0(user_id, agent_id, "default", config)
self._ensure_connection(mem)
try:
# Call search for semantic retrieval, filtering with the agent_id parameter
limit = config.semantic_search_top_k if config else 20
results = mem.search(
query=query,
limit=limit,
user_id=user_id,
agent_id=agent_id,
)
finally:
self._release_connection(mem)
# Convert to a normalized format
memories = []
for result in results["results"]:
# Mem0 may return results as strings or dictionaries
content = result.get("memory", "")
score = result.get("score", 0.0)
result_metadata = result.get("metadata", {})
memory = {
"content": content,
"similarity": score,
"metadata": result_metadata,
"fact_type": result_metadata.get("category", "fact"),
}
memories.append(memory)
logger.info(f"Recalled {len(memories)} memories for user={user_id}, query: {query[:50]}...")
return memories
except Exception as e:
logger.error(f"Failed to recall memories: {e}")
return []
async def add_memory(
self,
text: str,
user_id: str,
agent_id: str,
metadata: Optional[Dict[str, Any]] = None,
config: Optional[Mem0Config] = None,
) -> Dict[str, Any]:
"""Add a new memory at the user level, shared across sessions.
Args:
text: Memory text
user_id: User ID
agent_id: Agent/Bot ID
metadata: Additional metadata
config: Mem0 configuration, including the LLM instance for memory extraction
Returns:
Result of the added memory
"""
try:
async with self._semaphore:
mem = await self.get_mem0(user_id, agent_id, "default", config)
self._ensure_connection(mem)
try:
# Add the memory using the agent_id parameter
result = mem.add(
text,
user_id=user_id,
agent_id=agent_id,
metadata=metadata or {}
)
finally:
self._release_connection(mem)
logger.info(f"Added memory for user={user_id}, agent={agent_id}: {result}")
return result
except Exception as e:
logger.error(f"Failed to add memory: {e}")
return {}
def _extract_memories_from_response(self, response: Any) -> List[Dict[str, Any]]:
"""Extract the memory list from a Mem0 get_all() response.
Mem0 get_all() may return one of two formats:
1. New version: {"results": [...]}
2. Old version: a list directly
Args:
response: Response returned by Mem0 get_all()
Returns:
List of memories
"""
if isinstance(response, dict) and "results" in response:
return response["results"]
elif isinstance(response, list):
return response
else:
logger.warning(f"Unexpected response format from mem.get_all(): {type(response)}")
return []
def _check_agent_id_match(self, memory: Dict[str, Any], agent_id: str) -> bool:
"""Check whether a memory belongs to the specified agent.
In Mem0's memory structure, agent_id may appear in two places:
1. Top level: memory["agent_id"]
2. Inside metadata: memory["metadata"]["agent_id"]
Args:
memory: Memory dictionary
agent_id: Agent ID to match
Returns:
Whether it matches
"""
if not isinstance(memory, dict):
return False
# First check the top-level agent_id field (new format)
if memory.get("agent_id") == agent_id:
return True
# Then check agent_id inside metadata (old format)
metadata = memory.get("metadata", {})
if isinstance(metadata, dict) and metadata.get("agent_id") == agent_id:
return True
return False
async def get_all_memories(
self,
user_id: str,
agent_id: str,
) -> List[Dict[str, Any]]:
"""Get all memories for a user at the user level.
Args:
user_id: User ID
agent_id: Agent/Bot ID
Returns:
List of memories
"""
try:
async with self._semaphore:
mem = await self.get_mem0(user_id, agent_id, "default")
self._ensure_connection(mem)
try:
# Get all memories
response = mem.get_all(user_id=user_id)
finally:
self._release_connection(mem)
# Extract the memory list from the response
memories = self._extract_memories_from_response(response)
# Filter by agent_id, which is stored at the top level rather than in metadata
filtered_memories = [
m for m in memories
if self._check_agent_id_match(m, agent_id)
]
return filtered_memories
except Exception as e:
logger.error(f"Failed to get all memories: {e}")
return []
async def delete_memory(
self,
memory_id: str,
user_id: str,
agent_id: str,
) -> bool:
"""Delete a single memory.
Args:
memory_id: Memory ID
user_id: User ID
agent_id: Agent/Bot ID
Returns:
Whether deletion succeeded
"""
try:
async with self._semaphore:
mem = await self.get_mem0(user_id, agent_id, "default")
self._ensure_connection(mem)
try:
# First fetch memories to verify ownership
response = mem.get_all(user_id=user_id)
memories = self._extract_memories_from_response(response)
target_memory = None
for m in memories:
if isinstance(m, dict) and m.get("id") == memory_id:
# Verify the agent_id match
if self._check_agent_id_match(m, agent_id):
target_memory = m
break
if not target_memory:
logger.warning(f"Memory {memory_id} not found or access denied for user={user_id}, agent={agent_id}")
return False
# Delete the memory
mem.delete(memory_id=memory_id)
finally:
self._release_connection(mem)
logger.info(f"Deleted memory {memory_id} for user={user_id}, agent={agent_id}")
return True
except Exception as e:
logger.error(f"Failed to delete memory {memory_id}: {e}")
return False
async def delete_all_memories(
self,
user_id: str,
agent_id: str,
) -> int:
"""Delete all memories for a user under the specified agent.
Args:
user_id: User ID
agent_id: Agent/Bot ID
Returns:
Number of deleted memories
"""
try:
async with self._semaphore:
mem = await self.get_mem0(user_id, agent_id, "default")
self._ensure_connection(mem)
try:
# Get all memories
response = mem.get_all(user_id=user_id)
memories = self._extract_memories_from_response(response)
# Filter by agent_id and delete matching memories
deleted_count = 0
for m in memories:
if isinstance(m, dict) and self._check_agent_id_match(m, agent_id):
memory_id = m.get("id")
if memory_id:
try:
mem.delete(memory_id=memory_id)
deleted_count += 1
except Exception as e:
logger.warning(f"Failed to delete memory {memory_id}: {e}")
finally:
self._release_connection(mem)
logger.info(f"Deleted {deleted_count} memories for user={user_id}, agent={agent_id}")
return deleted_count
except Exception as e:
logger.error(f"Failed to delete all memories: {e}")
return 0
def clear_cache(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> None:
"""Clear cached Mem0 instances.
Args:
user_id: User ID; clear all if None
agent_id: Agent ID; clear all if None
"""
if user_id is None and agent_id is None:
self._instances.clear()
logger.info("Cleared all Mem0 instances from cache")
else:
keys_to_remove = []
for key in self._instances:
# New format: "user_id:agent_id:llm_model_name" or "user_id:agent_id"
parts = key.split(":")
if len(parts) >= 2:
u_id = parts[0]
a_id = parts[1]
if user_id and u_id != user_id:
continue
if agent_id and a_id != agent_id:
continue
keys_to_remove.append(key)
for key in keys_to_remove:
del self._instances[key]
logger.info(f"Cleared {len(keys_to_remove)} Mem0 instances from cache")
async def close(self) -> None:
"""Close the manager and clean up resources."""
logger.info("Closing Mem0Manager...")
# Clean up cached instances and release connections
for key, instance in self._instances.items():
self._cleanup_mem0_instance(instance)
self._instances.clear()
# Note: do not close the shared synchronous connection pool, which is managed by DBPoolManager
self._initialized = False
logger.info("Mem0Manager closed")
# Global singleton
_global_manager: Optional[Mem0Manager] = None
def get_mem0_manager() -> Mem0Manager:
"""Get the global Mem0Manager singleton.
Returns:
Mem0Manager instance
"""
global _global_manager
if _global_manager is None:
_global_manager = Mem0Manager()
return _global_manager
async def init_global_mem0(
sync_pool: pool.SimpleConnectionPool,
) -> Mem0Manager:
"""Initialize the global Mem0Manager.
Args:
sync_pool: PostgreSQL synchronous connection pool from DBPoolManager.sync_pool
Returns:
Mem0Manager instance
"""
manager = get_mem0_manager()
manager._sync_pool = sync_pool
await manager.initialize()
return manager
async def close_global_mem0() -> None:
"""Close the global Mem0Manager."""
global _global_manager
if _global_manager is not None:
await _global_manager.close()