diff --git a/agent/mem0_manager.py b/agent/mem0_manager.py index 92f85a0..5c494ac 100644 --- a/agent/mem0_manager.py +++ b/agent/mem0_manager.py @@ -188,6 +188,9 @@ class Mem0Manager: self._max_instances = MEM0_POOL_SIZE/2 # 最大缓存实例数 self._initialized = False + # 限制并发 Mem0 操作数,防止连接池耗尽 + self._semaphore = asyncio.Semaphore(max(MEM0_POOL_SIZE - 2, 1)) + async def initialize(self) -> None: """初始化 Mem0Manager @@ -234,22 +237,67 @@ class Mem0Manager: vector_store = mem0_instance.vector_store # PGVector 有 conn 和 connection_pool 属性 if hasattr(vector_store, 'conn') and hasattr(vector_store, 'connection_pool'): - if vector_store.connection_pool is not None: + if vector_store.conn is not None and vector_store.connection_pool is not None: try: # 先关闭游标 if hasattr(vector_store, 'cur') and vector_store.cur: vector_store.cur.close() + vector_store.cur = None # 归还连接到池 vector_store.connection_pool.putconn(vector_store.conn) # 标记为已清理,防止 __del__ 重复释放 vector_store.conn = None - vector_store.connection_pool = None logger.debug("Successfully released Mem0 database connection back to pool") except Exception as e: logger.warning(f"Error releasing Mem0 connection: {e}") except Exception as e: logger.warning(f"Error cleaning up Mem0 instance: {e}") + def _ensure_connection(self, mem0_instance: Any) -> None: + """操作前确保 Mem0 实例持有数据库连接 + + 如果连接已被 _release_connection 释放,则重新从池中获取。 + + Args: + mem0_instance: Mem0 Memory 实例 + """ + try: + if hasattr(mem0_instance, 'vector_store'): + vs = mem0_instance.vector_store + if hasattr(vs, 'conn') and vs.conn is None and self._sync_pool: + vs.conn = self._sync_pool.getconn() + vs.cur = vs.conn.cursor() + # 确保 connection_pool 引用存在(用于后续归还) + if hasattr(vs, 'connection_pool') and vs.connection_pool is None: + vs.connection_pool = self._sync_pool + logger.debug("Re-acquired Mem0 database connection from pool") + except Exception as e: + logger.warning(f"Error ensuring Mem0 connection: {e}") + raise + + def _release_connection(self, mem0_instance: Any) -> None: + """操作后释放连接回池 + + 与 _cleanup_mem0_instance 不同,这里保留 connection_pool 引用, + 以便下次 _ensure_connection 可以重新获取连接。 + + Args: + mem0_instance: Mem0 Memory 实例 + """ + try: + if hasattr(mem0_instance, 'vector_store'): + vs = mem0_instance.vector_store + if hasattr(vs, 'conn') and vs.conn is not None: + if hasattr(vs, 'cur') and vs.cur: + vs.cur.close() + vs.cur = None + if hasattr(vs, 'connection_pool') and vs.connection_pool is not None: + vs.connection_pool.putconn(vs.conn) + vs.conn = None + logger.debug("Released Mem0 database connection back to pool") + except Exception as e: + logger.warning(f"Error releasing Mem0 connection: {e}") + async def get_mem0( self, user_id: str, @@ -397,6 +445,9 @@ class Mem0Manager: f"Created Mem0 instance: user={user_id}, agent={agent_id}" ) + # 创建时 PGVector 会 getconn,立即释放以避免长期占用连接 + self._release_connection(mem) + return mem async def recall_memories( @@ -418,16 +469,20 @@ class Mem0Manager: 记忆列表,每个记忆包含 content, similarity 等字段 """ try: - mem = await self.get_mem0(user_id, agent_id, "default", config) - - # 调用 search 进行语义搜索(使用 agent_id 参数过滤) - limit = config.semantic_search_top_k if config else 20 - results = mem.search( - query=query, - limit=limit, - user_id=user_id, - agent_id=agent_id, - ) + async with self._semaphore: + mem = await self.get_mem0(user_id, agent_id, "default", config) + self._ensure_connection(mem) + try: + # 调用 search 进行语义搜索(使用 agent_id 参数过滤) + limit = config.semantic_search_top_k if config else 20 + results = mem.search( + query=query, + limit=limit, + user_id=user_id, + agent_id=agent_id, + ) + finally: + self._release_connection(mem) # 转换为统一格式 memories = [] @@ -436,7 +491,7 @@ class Mem0Manager: content = result.get("memory", "") score = result.get("score", 0.0) result_metadata = result.get("metadata", {}) - + memory = { "content": content, "similarity": score, @@ -473,16 +528,20 @@ class Mem0Manager: 添加的记忆结果 """ try: - mem = await self.get_mem0(user_id, agent_id, "default", config) + async with self._semaphore: + mem = await self.get_mem0(user_id, agent_id, "default", config) + self._ensure_connection(mem) + try: + # 添加记忆(使用 agent_id 参数) + result = mem.add( + text, + user_id=user_id, + agent_id=agent_id, + metadata=metadata or {} + ) + finally: + self._release_connection(mem) - # 添加记忆(使用 agent_id 参数) - result = mem.add( - text, - user_id=user_id, - agent_id=agent_id, - metadata=metadata or {} - ) - logger.info(f"Added memory for user={user_id}, agent={agent_id}: {result}") return result @@ -554,10 +613,14 @@ class Mem0Manager: 记忆列表 """ try: - mem = await self.get_mem0(user_id, agent_id, "default") - - # 获取所有记忆 - response = mem.get_all(user_id=user_id) + async with self._semaphore: + mem = await self.get_mem0(user_id, agent_id, "default") + self._ensure_connection(mem) + try: + # 获取所有记忆 + response = mem.get_all(user_id=user_id) + finally: + self._release_connection(mem) # 从响应中提取记忆列表 memories = self._extract_memories_from_response(response) @@ -591,26 +654,30 @@ class Mem0Manager: 是否删除成功 """ try: - mem = await self.get_mem0(user_id, agent_id, "default") + async with self._semaphore: + mem = await self.get_mem0(user_id, agent_id, "default") + self._ensure_connection(mem) + try: + # 先获取记忆以验证所有权 + response = mem.get_all(user_id=user_id) + memories = self._extract_memories_from_response(response) - # 先获取记忆以验证所有权 - response = mem.get_all(user_id=user_id) - memories = self._extract_memories_from_response(response) + target_memory = None + for m in memories: + if isinstance(m, dict) and m.get("id") == memory_id: + # 验证 agent_id 匹配 + if self._check_agent_id_match(m, agent_id): + target_memory = m + break - target_memory = None - for m in memories: - if isinstance(m, dict) and m.get("id") == memory_id: - # 验证 agent_id 匹配 - if self._check_agent_id_match(m, agent_id): - target_memory = m - break + if not target_memory: + logger.warning(f"Memory {memory_id} not found or access denied for user={user_id}, agent={agent_id}") + return False - if not target_memory: - logger.warning(f"Memory {memory_id} not found or access denied for user={user_id}, agent={agent_id}") - return False - - # 删除记忆 - mem.delete(memory_id=memory_id) + # 删除记忆 + mem.delete(memory_id=memory_id) + finally: + self._release_connection(mem) logger.info(f"Deleted memory {memory_id} for user={user_id}, agent={agent_id}") return True @@ -634,23 +701,27 @@ class Mem0Manager: 删除的记忆数量 """ try: - mem = await self.get_mem0(user_id, agent_id, "default") + async with self._semaphore: + mem = await self.get_mem0(user_id, agent_id, "default") + self._ensure_connection(mem) + try: + # 获取所有记忆 + response = mem.get_all(user_id=user_id) + memories = self._extract_memories_from_response(response) - # 获取所有记忆 - response = mem.get_all(user_id=user_id) - memories = self._extract_memories_from_response(response) - - # 过滤 agent_id 并删除 - deleted_count = 0 - for m in memories: - if isinstance(m, dict) and self._check_agent_id_match(m, agent_id): - memory_id = m.get("id") - if memory_id: - try: - mem.delete(memory_id=memory_id) - deleted_count += 1 - except Exception as e: - logger.warning(f"Failed to delete memory {memory_id}: {e}") + # 过滤 agent_id 并删除 + deleted_count = 0 + for m in memories: + if isinstance(m, dict) and self._check_agent_id_match(m, agent_id): + memory_id = m.get("id") + if memory_id: + try: + mem.delete(memory_id=memory_id) + deleted_count += 1 + except Exception as e: + logger.warning(f"Failed to delete memory {memory_id}: {e}") + finally: + self._release_connection(mem) logger.info(f"Deleted {deleted_count} memories for user={user_id}, agent={agent_id}") return deleted_count @@ -692,7 +763,9 @@ class Mem0Manager: """关闭管理器并清理资源""" logger.info("Closing Mem0Manager...") - # 清理缓存的实例 + # 清理缓存的实例,释放连接 + for key, instance in self._instances.items(): + self._cleanup_mem0_instance(instance) self._instances.clear() # 注意:不关闭共享的同步连接池(由 DBPoolManager 管理)