qwen_agent/agent/mem0_manager.py
朱潮 5eb0b7759d 🐛 fix: 修复 Mem0 连接池耗尽问题,改为操作级连接获取/释放
每个缓存的 Mem0 实例长期持有数据库连接导致并发时连接池耗尽。
改为每次操作前从池中获取连接、操作后立即释放,并添加 Semaphore 限制并发数。

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-02 17:46:00 +08:00

816 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Mem0 连接和实例管理器
负责管理 Mem0 客户端实例的创建、缓存和生命周期
"""
import logging
import asyncio
import threading
import concurrent.futures
from typing import Any, Dict, List, Optional, Literal
from collections import OrderedDict
from embedding.manager import GlobalModelManager, get_model_manager
import json_repair
from psycopg2 import pool
from utils.settings import (
MEM0_POOL_SIZE
)
from .mem0_config import Mem0Config
logger = logging.getLogger("app")
# ============================================================================
# 自定义 Embedding 类,使用项目中已有的 GlobalModelManager
# 避免重复加载模型
# ============================================================================
class CustomMem0Embedding:
"""
自定义 Mem0 Embedding 类,直接使用项目中已有的 GlobalModelManager
这样 Mem0 就不需要再次加载同一个模型,节省内存
"""
_model = None # 类变量,缓存模型实例
_lock = threading.Lock() # 线程安全锁
_executor = None # 线程池执行器
def __init__(self, config: Optional[Any] = None):
"""初始化自定义 Embedding"""
# 创建一个简单的 config 对象来兼容 Mem0 的 telemetry 代码
if config is None:
config = type('Config', (), {'embedding_dims': 384})()
self.config = config
@property
def embedding_dims(self):
"""获取 embedding 维度"""
return 384 # gte-tiny 的维度
def _get_model_sync(self):
"""同步获取模型,避免 asyncio.run()"""
# 首先尝试从 manager 获取已加载的模型
manager = get_model_manager()
model = manager.get_model_sync()
if model is not None:
# 缓存模型
CustomMem0Embedding._model = model
return model
# 如果模型未加载,使用线程池运行异步初始化
if CustomMem0Embedding._executor is None:
CustomMem0Embedding._executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1,
thread_name_prefix="mem0_embed"
)
# 在独立线程中运行异步代码
def run_async_in_thread():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(manager.get_model())
return result
finally:
loop.close()
future = CustomMem0Embedding._executor.submit(run_async_in_thread)
model = future.result(timeout=30) # 30秒超时
# 缓存模型
CustomMem0Embedding._model = model
return model
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
"""
获取文本的 embedding 向量(同步方法,供 Mem0 调用)
Args:
text: 要嵌入的文本(字符串或列表)
memory_action: 记忆操作类型 (add/search/update),当前未使用
Returns:
list: embedding 向量
"""
# 线程安全地获取模型
if CustomMem0Embedding._model is None:
with CustomMem0Embedding._lock:
if CustomMem0Embedding._model is None:
self._get_model_sync()
model = CustomMem0Embedding._model
embeddings = model.encode(text, convert_to_numpy=True)
return embeddings.tolist()
# Monkey patch: 使用 json_repair 替换 mem0 的 remove_code_blocks
def _remove_code_blocks_with_repair(content: str) -> str:
"""
使用 json_repair 替换 mem0 的 remove_code_blocks 函数
json_repair.loads 会自动处理:
- 移除代码块标记(```json, ``` 等)
- 修复损坏的 JSON如尾随逗号、注释、单引号等
"""
import re
content_stripped = content.strip()
try:
# json_repair.loads 会自动去除代码块并修复 JSON
result = json_repair.loads(content_stripped)
if isinstance(result, (dict, list)):
import json
return json.dumps(result, ensure_ascii=False)
# 如果返回空字符串(非 JSON 输入),回退到原内容
if result == "" and content_stripped != "":
# 尝试简单的代码块去除(降级处理)
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
match = re.match(pattern, content_stripped)
if match:
return match.group(1).strip()
return content_stripped
return str(result)
except Exception:
# 如果解析失败,尝试简单的代码块去除(降级处理)
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
match = re.match(pattern, content_stripped)
if match:
return match.group(1).strip()
return content_stripped
# 执行 monkey patch在 mem0 导入之前或之后)
try:
import sys
import mem0.memory.utils as mem0_utils
mem0_utils.remove_code_blocks = _remove_code_blocks_with_repair
# 如果 mem0.memory.main 已经导入,也要 patch 它的本地引用
if 'mem0.memory.main' in sys.modules:
import mem0.memory.main
mem0.memory.main.remove_code_blocks = _remove_code_blocks_with_repair
logger.info("Successfully patched mem0.memory.main.remove_code_blocks with json_repair")
else:
logger.info("Successfully patched mem0.memory.utils.remove_code_blocks with json_repair")
except ImportError:
# mem0 还未导入patch 将在首次导入时生效
pass
except Exception as e:
logger.warning(f"Failed to patch mem0 remove_code_blocks: {e}")
class Mem0Manager:
"""
Mem0 连接和实例管理器
主要功能:
1. 管理 Mem0 实例的创建和缓存
2. 支持多租户隔离user_id + agent_id
3. 使用共享的同步连接池(由 DBPoolManager 提供)
4. 提供记忆召回和存储接口
"""
def __init__(
self,
sync_pool: Optional[pool.SimpleConnectionPool] = None,
):
"""初始化 Mem0Manager
Args:
sync_pool: PostgreSQL 同步连接池(与 DBPoolManager 共享)
"""
self._sync_pool = sync_pool
# 使用 OrderedDict 实现 LRU 缓存,最多保留 50 个实例
self._instances: OrderedDict[str, Any] = OrderedDict()
self._max_instances = MEM0_POOL_SIZE/2 # 最大缓存实例数
self._initialized = False
# 限制并发 Mem0 操作数,防止连接池耗尽
self._semaphore = asyncio.Semaphore(max(MEM0_POOL_SIZE - 2, 1))
async def initialize(self) -> None:
"""初始化 Mem0Manager
创建数据库表结构(如果不存在)
"""
if self._initialized:
return
logger.info("Initializing Mem0Manager...")
try:
# Mem0 会自动创建表结构,这里只需验证连接
if self._sync_pool:
logger.info("Mem0Manager initialized successfully")
else:
logger.warning("No database configuration provided for Mem0")
self._initialized = True
except Exception as e:
logger.error(f"Failed to initialize Mem0Manager: {e}")
# 不抛出异常,允许系统在没有 Mem0 的情况下运行
def _get_connection_pool(self) -> Optional[pool.SimpleConnectionPool]:
"""获取同步数据库连接池Mem0 需要)
Returns:
psycopg2.pool 连接池
"""
return self._sync_pool
def _cleanup_mem0_instance(self, mem0_instance: Any) -> None:
"""清理 Mem0 实例,释放数据库连接
Mem0 的 PGVector 实现在初始化时获取连接并持有,
只有在 __del__ 时才归还。Python 的 GC 不保证 __del__ 立即被调用,
可能导致连接池耗尽。此方法显式释放连接。
Args:
mem0_instance: Mem0 Memory 实例
"""
try:
# Mem0 Memory 实例有一个 vector_store 属性,类型是 PGVector
if hasattr(mem0_instance, 'vector_store'):
vector_store = mem0_instance.vector_store
# PGVector 有 conn 和 connection_pool 属性
if hasattr(vector_store, 'conn') and hasattr(vector_store, 'connection_pool'):
if vector_store.conn is not None and vector_store.connection_pool is not None:
try:
# 先关闭游标
if hasattr(vector_store, 'cur') and vector_store.cur:
vector_store.cur.close()
vector_store.cur = None
# 归还连接到池
vector_store.connection_pool.putconn(vector_store.conn)
# 标记为已清理,防止 __del__ 重复释放
vector_store.conn = None
logger.debug("Successfully released Mem0 database connection back to pool")
except Exception as e:
logger.warning(f"Error releasing Mem0 connection: {e}")
except Exception as e:
logger.warning(f"Error cleaning up Mem0 instance: {e}")
def _ensure_connection(self, mem0_instance: Any) -> None:
"""操作前确保 Mem0 实例持有数据库连接
如果连接已被 _release_connection 释放,则重新从池中获取。
Args:
mem0_instance: Mem0 Memory 实例
"""
try:
if hasattr(mem0_instance, 'vector_store'):
vs = mem0_instance.vector_store
if hasattr(vs, 'conn') and vs.conn is None and self._sync_pool:
vs.conn = self._sync_pool.getconn()
vs.cur = vs.conn.cursor()
# 确保 connection_pool 引用存在(用于后续归还)
if hasattr(vs, 'connection_pool') and vs.connection_pool is None:
vs.connection_pool = self._sync_pool
logger.debug("Re-acquired Mem0 database connection from pool")
except Exception as e:
logger.warning(f"Error ensuring Mem0 connection: {e}")
raise
def _release_connection(self, mem0_instance: Any) -> None:
"""操作后释放连接回池
与 _cleanup_mem0_instance 不同,这里保留 connection_pool 引用,
以便下次 _ensure_connection 可以重新获取连接。
Args:
mem0_instance: Mem0 Memory 实例
"""
try:
if hasattr(mem0_instance, 'vector_store'):
vs = mem0_instance.vector_store
if hasattr(vs, 'conn') and vs.conn is not None:
if hasattr(vs, 'cur') and vs.cur:
vs.cur.close()
vs.cur = None
if hasattr(vs, 'connection_pool') and vs.connection_pool is not None:
vs.connection_pool.putconn(vs.conn)
vs.conn = None
logger.debug("Released Mem0 database connection back to pool")
except Exception as e:
logger.warning(f"Error releasing Mem0 connection: {e}")
async def get_mem0(
self,
user_id: str,
agent_id: str,
session_id: str,
config: Optional[Mem0Config] = None,
) -> Any:
"""获取或创建 Mem0 实例
Args:
user_id: 用户 ID对应 entity_id
agent_id: Agent/Bot ID对应 process_id
session_id: 会话 ID
config: Mem0 配置
Returns:
Mem0 实例
"""
# 缓存键包含 LLM 实例 ID以确保不同 LLM 使用不同实例
llm_suffix = ""
if config and config.llm_instance is not None:
llm_suffix = f":{id(config.llm_instance)}"
cache_key = f"{user_id}:{agent_id}{llm_suffix}"
# 检查缓存(同时移动到末尾表示最近使用)
if cache_key in self._instances:
self._instances.move_to_end(cache_key)
return self._instances[cache_key]
# 检查缓存大小,超过则移除最旧的
if len(self._instances) >= self._max_instances:
removed_key, removed_instance = self._instances.popitem(last=False)
# 显式释放连接,避免等待 GC 导致连接池耗尽
self._cleanup_mem0_instance(removed_instance)
logger.debug(f"Mem0 instance cache full, removed and cleaned: {removed_key}")
# 创建新实例
mem0_instance = await self._create_mem0_instance(
user_id=user_id,
agent_id=agent_id,
session_id=session_id,
config=config,
)
# 缓存实例(新实例自动在末尾)
self._instances[cache_key] = mem0_instance
return mem0_instance
async def _create_mem0_instance(
self,
user_id: str,
agent_id: str,
session_id: str,
config: Optional[Mem0Config] = None,
) -> Any:
"""创建新的 Mem0 实例
Args:
user_id: 用户 ID
agent_id: Agent/Bot ID
session_id: 会话 ID
config: Mem0 配置(包含 LLM 实例)
Returns:
Mem0 Memory 实例
"""
try:
from mem0 import Memory
except ImportError:
logger.error("mem0 package not installed")
raise RuntimeError("mem0 package is required but not installed")
# 获取同步连接池
connection_pool = self._get_connection_pool()
if not connection_pool:
raise ValueError("Database connection pool not available")
# 创建自定义 embedder使用共享模型避免重复加载
custom_embedder = CustomMem0Embedding()
# 配置 Mem0 使用 Pgvector
# 注意:这里使用 huggingface_base_url 来绕过本地模型加载
# 设置一个假的 base_url这样 HuggingFaceEmbedding 就不会加载 SentenceTransformer
config_dict = {
"vector_store": {
"provider": "pgvector",
"config": {
"connection_pool": connection_pool,
"collection_name": f"mem0_{agent_id}".replace("-", "_")[:50], # 按 agent_id 隔离
"embedding_model_dims": 384, # paraphrase-multilingual-MiniLM-L12-v2 的维度
}
},
# 使用 huggingface_base_url 绕过模型加载(稍后会被替换为自定义 embedder
"embedder": {
"provider": "huggingface",
"config": {
"huggingface_base_url": "http://dummy-url-that-will-be-replaced",
"api_key": "dummy-key" # 占位符,防止 OpenAI client 验证失败
}
}
}
# 添加自定义记忆提取提示词(如果提供了 config
if config is not None:
config_dict["custom_fact_extraction_prompt"] = await config.get_custom_fact_extraction_prompt_async()
# 添加 LangChain LLM 配置(如果提供了)
if config and config.llm_instance is not None:
config_dict["llm"] = {
"provider": "langchain",
"config": {"model": config.llm_instance}
}
logger.info(
f"Configured LangChain LLM for Mem0: {type(config.llm_instance).__name__}"
)
else:
# 如果没有提供 LLM使用默认的 openai 配置
# Mem0 的 LLM 用于提取记忆事实
from utils.settings import MASTERKEY, BACKEND_HOST
import os
llm_api_key = os.environ.get("OPENAI_API_KEY", "") or MASTERKEY
config_dict["llm"] = {
"provider": "openai",
"config": {
"model": "gpt-4o-mini",
"api_key": llm_api_key,
"openai_base_url": BACKEND_HOST # 使用自定义 backend
}
}
# 创建 Mem0 实例
mem = Memory.from_config(config_dict)
logger.debug(f"Original embedder type: {type(mem.embedding_model).__name__}")
logger.debug(f"Original embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}")
# 替换为自定义 embedder复用项目中已加载的模型
# 这样 Mem0 就不会重复加载模型
mem.embedding_model = custom_embedder
logger.debug(f"Replaced embedder type: {type(mem.embedding_model).__name__}")
logger.debug(f"Replaced embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}")
logger.info("Replaced Mem0 embedder with CustomMem0Embedding (reusing existing model)")
logger.info(
f"Created Mem0 instance: user={user_id}, agent={agent_id}"
)
# 创建时 PGVector 会 getconn立即释放以避免长期占用连接
self._release_connection(mem)
return mem
async def recall_memories(
self,
query: str,
user_id: str,
agent_id: str,
config: Optional[Mem0Config] = None,
) -> List[Dict[str, Any]]:
"""召回相关记忆(用户级别,跨会话共享)
Args:
query: 查询文本
user_id: 用户 ID
agent_id: Agent/Bot ID
config: Mem0 配置
Returns:
记忆列表,每个记忆包含 content, similarity 等字段
"""
try:
async with self._semaphore:
mem = await self.get_mem0(user_id, agent_id, "default", config)
self._ensure_connection(mem)
try:
# 调用 search 进行语义搜索(使用 agent_id 参数过滤)
limit = config.semantic_search_top_k if config else 20
results = mem.search(
query=query,
limit=limit,
user_id=user_id,
agent_id=agent_id,
)
finally:
self._release_connection(mem)
# 转换为统一格式
memories = []
for result in results["results"]:
# Mem0 返回结果可能是字符串或字典
content = result.get("memory", "")
score = result.get("score", 0.0)
result_metadata = result.get("metadata", {})
memory = {
"content": content,
"similarity": score,
"metadata": result_metadata,
"fact_type": result_metadata.get("category", "fact"),
}
memories.append(memory)
logger.info(f"Recalled {len(memories)} memories for user={user_id}, query: {query[:50]}...")
return memories
except Exception as e:
logger.error(f"Failed to recall memories: {e}")
return []
async def add_memory(
self,
text: str,
user_id: str,
agent_id: str,
metadata: Optional[Dict[str, Any]] = None,
config: Optional[Mem0Config] = None,
) -> Dict[str, Any]:
"""添加新记忆(用户级别,跨会话共享)
Args:
text: 记忆文本
user_id: 用户 ID
agent_id: Agent/Bot ID
metadata: 额外的元数据
config: Mem0 配置(包含 LLM 实例用于记忆提取)
Returns:
添加的记忆结果
"""
try:
async with self._semaphore:
mem = await self.get_mem0(user_id, agent_id, "default", config)
self._ensure_connection(mem)
try:
# 添加记忆(使用 agent_id 参数)
result = mem.add(
text,
user_id=user_id,
agent_id=agent_id,
metadata=metadata or {}
)
finally:
self._release_connection(mem)
logger.info(f"Added memory for user={user_id}, agent={agent_id}: {result}")
return result
except Exception as e:
logger.error(f"Failed to add memory: {e}")
return {}
def _extract_memories_from_response(self, response: Any) -> List[Dict[str, Any]]:
"""从 Mem0 get_all() 响应中提取记忆列表
Mem0 的 get_all() 返回格式可能有两种:
1. 新版本: {"results": [...]}
2. 旧版本: 直接返回列表
Args:
response: Mem0 get_all() 的响应
Returns:
记忆列表
"""
if isinstance(response, dict) and "results" in response:
return response["results"]
elif isinstance(response, list):
return response
else:
logger.warning(f"Unexpected response format from mem.get_all(): {type(response)}")
return []
def _check_agent_id_match(self, memory: Dict[str, Any], agent_id: str) -> bool:
"""检查记忆是否属于指定的 agent
Mem0 的记忆结构中agent_id 可能在两个位置:
1. 顶层: memory["agent_id"]
2. metadata 中: memory["metadata"]["agent_id"]
Args:
memory: 记忆字典
agent_id: 要匹配的 agent ID
Returns:
是否匹配
"""
if not isinstance(memory, dict):
return False
# 首先检查顶层 agent_id新版本格式
if memory.get("agent_id") == agent_id:
return True
# 然后检查 metadata 中的 agent_id旧版本格式
metadata = memory.get("metadata", {})
if isinstance(metadata, dict) and metadata.get("agent_id") == agent_id:
return True
return False
async def get_all_memories(
self,
user_id: str,
agent_id: str,
) -> List[Dict[str, Any]]:
"""获取用户的所有记忆(用户级别)
Args:
user_id: 用户 ID
agent_id: Agent/Bot ID
Returns:
记忆列表
"""
try:
async with self._semaphore:
mem = await self.get_mem0(user_id, agent_id, "default")
self._ensure_connection(mem)
try:
# 获取所有记忆
response = mem.get_all(user_id=user_id)
finally:
self._release_connection(mem)
# 从响应中提取记忆列表
memories = self._extract_memories_from_response(response)
# 过滤 agent_idagent_id 在顶层,不在 metadata 中)
filtered_memories = [
m for m in memories
if self._check_agent_id_match(m, agent_id)
]
return filtered_memories
except Exception as e:
logger.error(f"Failed to get all memories: {e}")
return []
async def delete_memory(
self,
memory_id: str,
user_id: str,
agent_id: str,
) -> bool:
"""删除单条记忆
Args:
memory_id: 记忆 ID
user_id: 用户 ID
agent_id: Agent/Bot ID
Returns:
是否删除成功
"""
try:
async with self._semaphore:
mem = await self.get_mem0(user_id, agent_id, "default")
self._ensure_connection(mem)
try:
# 先获取记忆以验证所有权
response = mem.get_all(user_id=user_id)
memories = self._extract_memories_from_response(response)
target_memory = None
for m in memories:
if isinstance(m, dict) and m.get("id") == memory_id:
# 验证 agent_id 匹配
if self._check_agent_id_match(m, agent_id):
target_memory = m
break
if not target_memory:
logger.warning(f"Memory {memory_id} not found or access denied for user={user_id}, agent={agent_id}")
return False
# 删除记忆
mem.delete(memory_id=memory_id)
finally:
self._release_connection(mem)
logger.info(f"Deleted memory {memory_id} for user={user_id}, agent={agent_id}")
return True
except Exception as e:
logger.error(f"Failed to delete memory {memory_id}: {e}")
return False
async def delete_all_memories(
self,
user_id: str,
agent_id: str,
) -> int:
"""删除用户在指定 Agent 下的所有记忆
Args:
user_id: 用户 ID
agent_id: Agent/Bot ID
Returns:
删除的记忆数量
"""
try:
async with self._semaphore:
mem = await self.get_mem0(user_id, agent_id, "default")
self._ensure_connection(mem)
try:
# 获取所有记忆
response = mem.get_all(user_id=user_id)
memories = self._extract_memories_from_response(response)
# 过滤 agent_id 并删除
deleted_count = 0
for m in memories:
if isinstance(m, dict) and self._check_agent_id_match(m, agent_id):
memory_id = m.get("id")
if memory_id:
try:
mem.delete(memory_id=memory_id)
deleted_count += 1
except Exception as e:
logger.warning(f"Failed to delete memory {memory_id}: {e}")
finally:
self._release_connection(mem)
logger.info(f"Deleted {deleted_count} memories for user={user_id}, agent={agent_id}")
return deleted_count
except Exception as e:
logger.error(f"Failed to delete all memories: {e}")
return 0
def clear_cache(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> None:
"""清除缓存的 Mem0 实例
Args:
user_id: 用户 ID如果为 None清除所有
agent_id: Agent ID如果为 None清除所有
"""
if user_id is None and agent_id is None:
self._instances.clear()
logger.info("Cleared all Mem0 instances from cache")
else:
keys_to_remove = []
for key in self._instances:
# 新格式: "user_id:agent_id:llm_model_name" 或 "user_id:agent_id"
parts = key.split(":")
if len(parts) >= 2:
u_id = parts[0]
a_id = parts[1]
if user_id and u_id != user_id:
continue
if agent_id and a_id != agent_id:
continue
keys_to_remove.append(key)
for key in keys_to_remove:
del self._instances[key]
logger.info(f"Cleared {len(keys_to_remove)} Mem0 instances from cache")
async def close(self) -> None:
"""关闭管理器并清理资源"""
logger.info("Closing Mem0Manager...")
# 清理缓存的实例,释放连接
for key, instance in self._instances.items():
self._cleanup_mem0_instance(instance)
self._instances.clear()
# 注意:不关闭共享的同步连接池(由 DBPoolManager 管理)
self._initialized = False
logger.info("Mem0Manager closed")
# 全局单例
_global_manager: Optional[Mem0Manager] = None
def get_mem0_manager() -> Mem0Manager:
"""获取全局 Mem0Manager 单例
Returns:
Mem0Manager 实例
"""
global _global_manager
if _global_manager is None:
_global_manager = Mem0Manager()
return _global_manager
async def init_global_mem0(
sync_pool: pool.SimpleConnectionPool,
) -> Mem0Manager:
"""初始化全局 Mem0Manager
Args:
sync_pool: PostgreSQL 同步连接池(从 DBPoolManager.sync_pool 获取)
Returns:
Mem0Manager 实例
"""
manager = get_mem0_manager()
manager._sync_pool = sync_pool
await manager.initialize()
return manager
async def close_global_mem0() -> None:
"""关闭全局 Mem0Manager"""
global _global_manager
if _global_manager is not None:
await _global_manager.close()