Merge branch 'feature/pre-memory-prompt' into dev
This commit is contained in:
commit
ab6b68268e
@ -188,6 +188,9 @@ class Mem0Manager:
|
||||
self._max_instances = MEM0_POOL_SIZE/2 # 最大缓存实例数
|
||||
self._initialized = False
|
||||
|
||||
# 限制并发 Mem0 操作数,防止连接池耗尽
|
||||
self._semaphore = asyncio.Semaphore(max(MEM0_POOL_SIZE - 2, 1))
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化 Mem0Manager
|
||||
|
||||
@ -234,22 +237,67 @@ class Mem0Manager:
|
||||
vector_store = mem0_instance.vector_store
|
||||
# PGVector 有 conn 和 connection_pool 属性
|
||||
if hasattr(vector_store, 'conn') and hasattr(vector_store, 'connection_pool'):
|
||||
if vector_store.connection_pool is not None:
|
||||
if vector_store.conn is not None and vector_store.connection_pool is not None:
|
||||
try:
|
||||
# 先关闭游标
|
||||
if hasattr(vector_store, 'cur') and vector_store.cur:
|
||||
vector_store.cur.close()
|
||||
vector_store.cur = None
|
||||
# 归还连接到池
|
||||
vector_store.connection_pool.putconn(vector_store.conn)
|
||||
# 标记为已清理,防止 __del__ 重复释放
|
||||
vector_store.conn = None
|
||||
vector_store.connection_pool = None
|
||||
logger.debug("Successfully released Mem0 database connection back to pool")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error releasing Mem0 connection: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning up Mem0 instance: {e}")
|
||||
|
||||
def _ensure_connection(self, mem0_instance: Any) -> None:
|
||||
"""操作前确保 Mem0 实例持有数据库连接
|
||||
|
||||
如果连接已被 _release_connection 释放,则重新从池中获取。
|
||||
|
||||
Args:
|
||||
mem0_instance: Mem0 Memory 实例
|
||||
"""
|
||||
try:
|
||||
if hasattr(mem0_instance, 'vector_store'):
|
||||
vs = mem0_instance.vector_store
|
||||
if hasattr(vs, 'conn') and vs.conn is None and self._sync_pool:
|
||||
vs.conn = self._sync_pool.getconn()
|
||||
vs.cur = vs.conn.cursor()
|
||||
# 确保 connection_pool 引用存在(用于后续归还)
|
||||
if hasattr(vs, 'connection_pool') and vs.connection_pool is None:
|
||||
vs.connection_pool = self._sync_pool
|
||||
logger.debug("Re-acquired Mem0 database connection from pool")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error ensuring Mem0 connection: {e}")
|
||||
raise
|
||||
|
||||
def _release_connection(self, mem0_instance: Any) -> None:
|
||||
"""操作后释放连接回池
|
||||
|
||||
与 _cleanup_mem0_instance 不同,这里保留 connection_pool 引用,
|
||||
以便下次 _ensure_connection 可以重新获取连接。
|
||||
|
||||
Args:
|
||||
mem0_instance: Mem0 Memory 实例
|
||||
"""
|
||||
try:
|
||||
if hasattr(mem0_instance, 'vector_store'):
|
||||
vs = mem0_instance.vector_store
|
||||
if hasattr(vs, 'conn') and vs.conn is not None:
|
||||
if hasattr(vs, 'cur') and vs.cur:
|
||||
vs.cur.close()
|
||||
vs.cur = None
|
||||
if hasattr(vs, 'connection_pool') and vs.connection_pool is not None:
|
||||
vs.connection_pool.putconn(vs.conn)
|
||||
vs.conn = None
|
||||
logger.debug("Released Mem0 database connection back to pool")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error releasing Mem0 connection: {e}")
|
||||
|
||||
async def get_mem0(
|
||||
self,
|
||||
user_id: str,
|
||||
@ -397,6 +445,9 @@ class Mem0Manager:
|
||||
f"Created Mem0 instance: user={user_id}, agent={agent_id}"
|
||||
)
|
||||
|
||||
# 创建时 PGVector 会 getconn,立即释放以避免长期占用连接
|
||||
self._release_connection(mem)
|
||||
|
||||
return mem
|
||||
|
||||
async def recall_memories(
|
||||
@ -418,16 +469,20 @@ class Mem0Manager:
|
||||
记忆列表,每个记忆包含 content, similarity 等字段
|
||||
"""
|
||||
try:
|
||||
mem = await self.get_mem0(user_id, agent_id, "default", config)
|
||||
|
||||
# 调用 search 进行语义搜索(使用 agent_id 参数过滤)
|
||||
limit = config.semantic_search_top_k if config else 20
|
||||
results = mem.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
async with self._semaphore:
|
||||
mem = await self.get_mem0(user_id, agent_id, "default", config)
|
||||
self._ensure_connection(mem)
|
||||
try:
|
||||
# 调用 search 进行语义搜索(使用 agent_id 参数过滤)
|
||||
limit = config.semantic_search_top_k if config else 20
|
||||
results = mem.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
finally:
|
||||
self._release_connection(mem)
|
||||
|
||||
# 转换为统一格式
|
||||
memories = []
|
||||
@ -473,15 +528,19 @@ class Mem0Manager:
|
||||
添加的记忆结果
|
||||
"""
|
||||
try:
|
||||
mem = await self.get_mem0(user_id, agent_id, "default", config)
|
||||
|
||||
# 添加记忆(使用 agent_id 参数)
|
||||
result = mem.add(
|
||||
text,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
async with self._semaphore:
|
||||
mem = await self.get_mem0(user_id, agent_id, "default", config)
|
||||
self._ensure_connection(mem)
|
||||
try:
|
||||
# 添加记忆(使用 agent_id 参数)
|
||||
result = mem.add(
|
||||
text,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
finally:
|
||||
self._release_connection(mem)
|
||||
|
||||
logger.info(f"Added memory for user={user_id}, agent={agent_id}: {result}")
|
||||
return result
|
||||
@ -554,10 +613,14 @@ class Mem0Manager:
|
||||
记忆列表
|
||||
"""
|
||||
try:
|
||||
mem = await self.get_mem0(user_id, agent_id, "default")
|
||||
|
||||
# 获取所有记忆
|
||||
response = mem.get_all(user_id=user_id)
|
||||
async with self._semaphore:
|
||||
mem = await self.get_mem0(user_id, agent_id, "default")
|
||||
self._ensure_connection(mem)
|
||||
try:
|
||||
# 获取所有记忆
|
||||
response = mem.get_all(user_id=user_id)
|
||||
finally:
|
||||
self._release_connection(mem)
|
||||
|
||||
# 从响应中提取记忆列表
|
||||
memories = self._extract_memories_from_response(response)
|
||||
@ -591,26 +654,30 @@ class Mem0Manager:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
mem = await self.get_mem0(user_id, agent_id, "default")
|
||||
async with self._semaphore:
|
||||
mem = await self.get_mem0(user_id, agent_id, "default")
|
||||
self._ensure_connection(mem)
|
||||
try:
|
||||
# 先获取记忆以验证所有权
|
||||
response = mem.get_all(user_id=user_id)
|
||||
memories = self._extract_memories_from_response(response)
|
||||
|
||||
# 先获取记忆以验证所有权
|
||||
response = mem.get_all(user_id=user_id)
|
||||
memories = self._extract_memories_from_response(response)
|
||||
target_memory = None
|
||||
for m in memories:
|
||||
if isinstance(m, dict) and m.get("id") == memory_id:
|
||||
# 验证 agent_id 匹配
|
||||
if self._check_agent_id_match(m, agent_id):
|
||||
target_memory = m
|
||||
break
|
||||
|
||||
target_memory = None
|
||||
for m in memories:
|
||||
if isinstance(m, dict) and m.get("id") == memory_id:
|
||||
# 验证 agent_id 匹配
|
||||
if self._check_agent_id_match(m, agent_id):
|
||||
target_memory = m
|
||||
break
|
||||
if not target_memory:
|
||||
logger.warning(f"Memory {memory_id} not found or access denied for user={user_id}, agent={agent_id}")
|
||||
return False
|
||||
|
||||
if not target_memory:
|
||||
logger.warning(f"Memory {memory_id} not found or access denied for user={user_id}, agent={agent_id}")
|
||||
return False
|
||||
|
||||
# 删除记忆
|
||||
mem.delete(memory_id=memory_id)
|
||||
# 删除记忆
|
||||
mem.delete(memory_id=memory_id)
|
||||
finally:
|
||||
self._release_connection(mem)
|
||||
|
||||
logger.info(f"Deleted memory {memory_id} for user={user_id}, agent={agent_id}")
|
||||
return True
|
||||
@ -634,23 +701,27 @@ class Mem0Manager:
|
||||
删除的记忆数量
|
||||
"""
|
||||
try:
|
||||
mem = await self.get_mem0(user_id, agent_id, "default")
|
||||
async with self._semaphore:
|
||||
mem = await self.get_mem0(user_id, agent_id, "default")
|
||||
self._ensure_connection(mem)
|
||||
try:
|
||||
# 获取所有记忆
|
||||
response = mem.get_all(user_id=user_id)
|
||||
memories = self._extract_memories_from_response(response)
|
||||
|
||||
# 获取所有记忆
|
||||
response = mem.get_all(user_id=user_id)
|
||||
memories = self._extract_memories_from_response(response)
|
||||
|
||||
# 过滤 agent_id 并删除
|
||||
deleted_count = 0
|
||||
for m in memories:
|
||||
if isinstance(m, dict) and self._check_agent_id_match(m, agent_id):
|
||||
memory_id = m.get("id")
|
||||
if memory_id:
|
||||
try:
|
||||
mem.delete(memory_id=memory_id)
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete memory {memory_id}: {e}")
|
||||
# 过滤 agent_id 并删除
|
||||
deleted_count = 0
|
||||
for m in memories:
|
||||
if isinstance(m, dict) and self._check_agent_id_match(m, agent_id):
|
||||
memory_id = m.get("id")
|
||||
if memory_id:
|
||||
try:
|
||||
mem.delete(memory_id=memory_id)
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete memory {memory_id}: {e}")
|
||||
finally:
|
||||
self._release_connection(mem)
|
||||
|
||||
logger.info(f"Deleted {deleted_count} memories for user={user_id}, agent={agent_id}")
|
||||
return deleted_count
|
||||
@ -692,7 +763,9 @@ class Mem0Manager:
|
||||
"""关闭管理器并清理资源"""
|
||||
logger.info("Closing Mem0Manager...")
|
||||
|
||||
# 清理缓存的实例
|
||||
# 清理缓存的实例,释放连接
|
||||
for key, instance in self._instances.items():
|
||||
self._cleanup_mem0_instance(instance)
|
||||
self._instances.clear()
|
||||
|
||||
# 注意:不关闭共享的同步连接池(由 DBPoolManager 管理)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user