🐛 fix: 修复 Mem0 连接池耗尽问题,改为操作级连接获取/释放

每个缓存的 Mem0 实例长期持有数据库连接导致并发时连接池耗尽。
改为每次操作前从池中获取连接、操作后立即释放,并添加 Semaphore 限制并发数。

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
朱潮 2026-04-02 17:46:00 +08:00
parent 7da8466b3d
commit 5eb0b7759d

View File

@ -188,6 +188,9 @@ class Mem0Manager:
self._max_instances = MEM0_POOL_SIZE/2 # 最大缓存实例数 self._max_instances = MEM0_POOL_SIZE/2 # 最大缓存实例数
self._initialized = False self._initialized = False
# 限制并发 Mem0 操作数,防止连接池耗尽
self._semaphore = asyncio.Semaphore(max(MEM0_POOL_SIZE - 2, 1))
async def initialize(self) -> None: async def initialize(self) -> None:
"""初始化 Mem0Manager """初始化 Mem0Manager
@ -234,22 +237,67 @@ class Mem0Manager:
vector_store = mem0_instance.vector_store vector_store = mem0_instance.vector_store
# PGVector 有 conn 和 connection_pool 属性 # PGVector 有 conn 和 connection_pool 属性
if hasattr(vector_store, 'conn') and hasattr(vector_store, 'connection_pool'): if hasattr(vector_store, 'conn') and hasattr(vector_store, 'connection_pool'):
if vector_store.connection_pool is not None: if vector_store.conn is not None and vector_store.connection_pool is not None:
try: try:
# 先关闭游标 # 先关闭游标
if hasattr(vector_store, 'cur') and vector_store.cur: if hasattr(vector_store, 'cur') and vector_store.cur:
vector_store.cur.close() vector_store.cur.close()
vector_store.cur = None
# 归还连接到池 # 归还连接到池
vector_store.connection_pool.putconn(vector_store.conn) vector_store.connection_pool.putconn(vector_store.conn)
# 标记为已清理,防止 __del__ 重复释放 # 标记为已清理,防止 __del__ 重复释放
vector_store.conn = None vector_store.conn = None
vector_store.connection_pool = None
logger.debug("Successfully released Mem0 database connection back to pool") logger.debug("Successfully released Mem0 database connection back to pool")
except Exception as e: except Exception as e:
logger.warning(f"Error releasing Mem0 connection: {e}") logger.warning(f"Error releasing Mem0 connection: {e}")
except Exception as e: except Exception as e:
logger.warning(f"Error cleaning up Mem0 instance: {e}") logger.warning(f"Error cleaning up Mem0 instance: {e}")
def _ensure_connection(self, mem0_instance: Any) -> None:
"""操作前确保 Mem0 实例持有数据库连接
如果连接已被 _release_connection 释放则重新从池中获取
Args:
mem0_instance: Mem0 Memory 实例
"""
try:
if hasattr(mem0_instance, 'vector_store'):
vs = mem0_instance.vector_store
if hasattr(vs, 'conn') and vs.conn is None and self._sync_pool:
vs.conn = self._sync_pool.getconn()
vs.cur = vs.conn.cursor()
# 确保 connection_pool 引用存在(用于后续归还)
if hasattr(vs, 'connection_pool') and vs.connection_pool is None:
vs.connection_pool = self._sync_pool
logger.debug("Re-acquired Mem0 database connection from pool")
except Exception as e:
logger.warning(f"Error ensuring Mem0 connection: {e}")
raise
def _release_connection(self, mem0_instance: Any) -> None:
"""操作后释放连接回池
_cleanup_mem0_instance 不同这里保留 connection_pool 引用
以便下次 _ensure_connection 可以重新获取连接
Args:
mem0_instance: Mem0 Memory 实例
"""
try:
if hasattr(mem0_instance, 'vector_store'):
vs = mem0_instance.vector_store
if hasattr(vs, 'conn') and vs.conn is not None:
if hasattr(vs, 'cur') and vs.cur:
vs.cur.close()
vs.cur = None
if hasattr(vs, 'connection_pool') and vs.connection_pool is not None:
vs.connection_pool.putconn(vs.conn)
vs.conn = None
logger.debug("Released Mem0 database connection back to pool")
except Exception as e:
logger.warning(f"Error releasing Mem0 connection: {e}")
async def get_mem0( async def get_mem0(
self, self,
user_id: str, user_id: str,
@ -397,6 +445,9 @@ class Mem0Manager:
f"Created Mem0 instance: user={user_id}, agent={agent_id}" f"Created Mem0 instance: user={user_id}, agent={agent_id}"
) )
# 创建时 PGVector 会 getconn立即释放以避免长期占用连接
self._release_connection(mem)
return mem return mem
async def recall_memories( async def recall_memories(
@ -418,16 +469,20 @@ class Mem0Manager:
记忆列表每个记忆包含 content, similarity 等字段 记忆列表每个记忆包含 content, similarity 等字段
""" """
try: try:
mem = await self.get_mem0(user_id, agent_id, "default", config) async with self._semaphore:
mem = await self.get_mem0(user_id, agent_id, "default", config)
# 调用 search 进行语义搜索(使用 agent_id 参数过滤) self._ensure_connection(mem)
limit = config.semantic_search_top_k if config else 20 try:
results = mem.search( # 调用 search 进行语义搜索(使用 agent_id 参数过滤)
query=query, limit = config.semantic_search_top_k if config else 20
limit=limit, results = mem.search(
user_id=user_id, query=query,
agent_id=agent_id, limit=limit,
) user_id=user_id,
agent_id=agent_id,
)
finally:
self._release_connection(mem)
# 转换为统一格式 # 转换为统一格式
memories = [] memories = []
@ -436,7 +491,7 @@ class Mem0Manager:
content = result.get("memory", "") content = result.get("memory", "")
score = result.get("score", 0.0) score = result.get("score", 0.0)
result_metadata = result.get("metadata", {}) result_metadata = result.get("metadata", {})
memory = { memory = {
"content": content, "content": content,
"similarity": score, "similarity": score,
@ -473,16 +528,20 @@ class Mem0Manager:
添加的记忆结果 添加的记忆结果
""" """
try: try:
mem = await self.get_mem0(user_id, agent_id, "default", config) async with self._semaphore:
mem = await self.get_mem0(user_id, agent_id, "default", config)
self._ensure_connection(mem)
try:
# 添加记忆(使用 agent_id 参数)
result = mem.add(
text,
user_id=user_id,
agent_id=agent_id,
metadata=metadata or {}
)
finally:
self._release_connection(mem)
# 添加记忆(使用 agent_id 参数)
result = mem.add(
text,
user_id=user_id,
agent_id=agent_id,
metadata=metadata or {}
)
logger.info(f"Added memory for user={user_id}, agent={agent_id}: {result}") logger.info(f"Added memory for user={user_id}, agent={agent_id}: {result}")
return result return result
@ -554,10 +613,14 @@ class Mem0Manager:
记忆列表 记忆列表
""" """
try: try:
mem = await self.get_mem0(user_id, agent_id, "default") async with self._semaphore:
mem = await self.get_mem0(user_id, agent_id, "default")
# 获取所有记忆 self._ensure_connection(mem)
response = mem.get_all(user_id=user_id) try:
# 获取所有记忆
response = mem.get_all(user_id=user_id)
finally:
self._release_connection(mem)
# 从响应中提取记忆列表 # 从响应中提取记忆列表
memories = self._extract_memories_from_response(response) memories = self._extract_memories_from_response(response)
@ -591,26 +654,30 @@ class Mem0Manager:
是否删除成功 是否删除成功
""" """
try: try:
mem = await self.get_mem0(user_id, agent_id, "default") async with self._semaphore:
mem = await self.get_mem0(user_id, agent_id, "default")
self._ensure_connection(mem)
try:
# 先获取记忆以验证所有权
response = mem.get_all(user_id=user_id)
memories = self._extract_memories_from_response(response)
# 先获取记忆以验证所有权 target_memory = None
response = mem.get_all(user_id=user_id) for m in memories:
memories = self._extract_memories_from_response(response) if isinstance(m, dict) and m.get("id") == memory_id:
# 验证 agent_id 匹配
if self._check_agent_id_match(m, agent_id):
target_memory = m
break
target_memory = None if not target_memory:
for m in memories: logger.warning(f"Memory {memory_id} not found or access denied for user={user_id}, agent={agent_id}")
if isinstance(m, dict) and m.get("id") == memory_id: return False
# 验证 agent_id 匹配
if self._check_agent_id_match(m, agent_id):
target_memory = m
break
if not target_memory: # 删除记忆
logger.warning(f"Memory {memory_id} not found or access denied for user={user_id}, agent={agent_id}") mem.delete(memory_id=memory_id)
return False finally:
self._release_connection(mem)
# 删除记忆
mem.delete(memory_id=memory_id)
logger.info(f"Deleted memory {memory_id} for user={user_id}, agent={agent_id}") logger.info(f"Deleted memory {memory_id} for user={user_id}, agent={agent_id}")
return True return True
@ -634,23 +701,27 @@ class Mem0Manager:
删除的记忆数量 删除的记忆数量
""" """
try: try:
mem = await self.get_mem0(user_id, agent_id, "default") async with self._semaphore:
mem = await self.get_mem0(user_id, agent_id, "default")
self._ensure_connection(mem)
try:
# 获取所有记忆
response = mem.get_all(user_id=user_id)
memories = self._extract_memories_from_response(response)
# 获取所有记忆 # 过滤 agent_id 并删除
response = mem.get_all(user_id=user_id) deleted_count = 0
memories = self._extract_memories_from_response(response) for m in memories:
if isinstance(m, dict) and self._check_agent_id_match(m, agent_id):
# 过滤 agent_id 并删除 memory_id = m.get("id")
deleted_count = 0 if memory_id:
for m in memories: try:
if isinstance(m, dict) and self._check_agent_id_match(m, agent_id): mem.delete(memory_id=memory_id)
memory_id = m.get("id") deleted_count += 1
if memory_id: except Exception as e:
try: logger.warning(f"Failed to delete memory {memory_id}: {e}")
mem.delete(memory_id=memory_id) finally:
deleted_count += 1 self._release_connection(mem)
except Exception as e:
logger.warning(f"Failed to delete memory {memory_id}: {e}")
logger.info(f"Deleted {deleted_count} memories for user={user_id}, agent={agent_id}") logger.info(f"Deleted {deleted_count} memories for user={user_id}, agent={agent_id}")
return deleted_count return deleted_count
@ -692,7 +763,9 @@ class Mem0Manager:
"""关闭管理器并清理资源""" """关闭管理器并清理资源"""
logger.info("Closing Mem0Manager...") logger.info("Closing Mem0Manager...")
# 清理缓存的实例 # 清理缓存的实例,释放连接
for key, instance in self._instances.items():
self._cleanup_mem0_instance(instance)
self._instances.clear() self._instances.clear()
# 注意:不关闭共享的同步连接池(由 DBPoolManager 管理) # 注意:不关闭共享的同步连接池(由 DBPoolManager 管理)