Merge pull request #12 from sparticleinc/bugfix/autobee-20260302-memory-results-fix
fix(memory): handle Mem0 get_all response format with results key
This commit is contained in:
commit
a6f071119f
@ -456,6 +456,55 @@ class Mem0Manager:
|
||||
logger.error(f"Failed to add memory: {e}")
|
||||
return {}
|
||||
|
||||
def _extract_memories_from_response(self, response: Any) -> List[Dict[str, Any]]:
|
||||
"""从 Mem0 get_all() 响应中提取记忆列表
|
||||
|
||||
Mem0 的 get_all() 返回格式可能有两种:
|
||||
1. 新版本: {"results": [...]}
|
||||
2. 旧版本: 直接返回列表
|
||||
|
||||
Args:
|
||||
response: Mem0 get_all() 的响应
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
if isinstance(response, dict) and "results" in response:
|
||||
return response["results"]
|
||||
elif isinstance(response, list):
|
||||
return response
|
||||
else:
|
||||
logger.warning(f"Unexpected response format from mem.get_all(): {type(response)}")
|
||||
return []
|
||||
|
||||
def _check_agent_id_match(self, memory: Dict[str, Any], agent_id: str) -> bool:
|
||||
"""检查记忆是否属于指定的 agent
|
||||
|
||||
Mem0 的记忆结构中,agent_id 可能在两个位置:
|
||||
1. 顶层: memory["agent_id"]
|
||||
2. metadata 中: memory["metadata"]["agent_id"]
|
||||
|
||||
Args:
|
||||
memory: 记忆字典
|
||||
agent_id: 要匹配的 agent ID
|
||||
|
||||
Returns:
|
||||
是否匹配
|
||||
"""
|
||||
if not isinstance(memory, dict):
|
||||
return False
|
||||
|
||||
# 首先检查顶层 agent_id(新版本格式)
|
||||
if memory.get("agent_id") == agent_id:
|
||||
return True
|
||||
|
||||
# 然后检查 metadata 中的 agent_id(旧版本格式)
|
||||
metadata = memory.get("metadata", {})
|
||||
if isinstance(metadata, dict) and metadata.get("agent_id") == agent_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def get_all_memories(
|
||||
self,
|
||||
user_id: str,
|
||||
@ -474,21 +523,16 @@ class Mem0Manager:
|
||||
mem = await self.get_mem0(user_id, agent_id, "default")
|
||||
|
||||
# 获取所有记忆
|
||||
memories = mem.get_all(user_id=user_id)
|
||||
response = mem.get_all(user_id=user_id)
|
||||
|
||||
# Mem0 的 get_all 可能返回字符串列表或字典列表,需要统一处理
|
||||
filtered_memories = []
|
||||
for m in memories:
|
||||
# 如果是字符串,跳过(旧版本格式,不包含 metadata)
|
||||
if isinstance(m, str):
|
||||
logger.warning(f"Memory item is string, skipping: {m[:50]}...")
|
||||
continue
|
||||
# 从响应中提取记忆列表
|
||||
memories = self._extract_memories_from_response(response)
|
||||
|
||||
# 如果是字典,检查 agent_id
|
||||
if isinstance(m, dict):
|
||||
metadata = m.get("metadata", {})
|
||||
if isinstance(metadata, dict) and metadata.get("agent_id") == agent_id:
|
||||
filtered_memories.append(m)
|
||||
# 过滤 agent_id(agent_id 在顶层,不在 metadata 中)
|
||||
filtered_memories = [
|
||||
m for m in memories
|
||||
if self._check_agent_id_match(m, agent_id)
|
||||
]
|
||||
|
||||
return filtered_memories
|
||||
|
||||
@ -516,21 +560,16 @@ class Mem0Manager:
|
||||
mem = await self.get_mem0(user_id, agent_id, "default")
|
||||
|
||||
# 先获取记忆以验证所有权
|
||||
memories = mem.get_all(user_id=user_id)
|
||||
response = mem.get_all(user_id=user_id)
|
||||
memories = self._extract_memories_from_response(response)
|
||||
|
||||
target_memory = None
|
||||
for m in memories:
|
||||
# 跳过字符串类型的记忆(旧版本格式)
|
||||
if isinstance(m, str):
|
||||
continue
|
||||
|
||||
# 只处理字典类型
|
||||
if isinstance(m, dict):
|
||||
if m.get("id") == memory_id:
|
||||
# 验证 agent_id 匹配
|
||||
metadata = m.get("metadata", {})
|
||||
if isinstance(metadata, dict) and metadata.get("agent_id") == agent_id:
|
||||
target_memory = m
|
||||
break
|
||||
if isinstance(m, dict) and m.get("id") == memory_id:
|
||||
# 验证 agent_id 匹配
|
||||
if self._check_agent_id_match(m, agent_id):
|
||||
target_memory = m
|
||||
break
|
||||
|
||||
if not target_memory:
|
||||
logger.warning(f"Memory {memory_id} not found or access denied for user={user_id}, agent={agent_id}")
|
||||
@ -564,27 +603,20 @@ class Mem0Manager:
|
||||
mem = await self.get_mem0(user_id, agent_id, "default")
|
||||
|
||||
# 获取所有记忆
|
||||
memories = mem.get_all(user_id=user_id)
|
||||
response = mem.get_all(user_id=user_id)
|
||||
memories = self._extract_memories_from_response(response)
|
||||
|
||||
# 过滤 agent_id 并删除
|
||||
# Mem0 的 get_all 可能返回字符串列表或字典列表,需要统一处理
|
||||
deleted_count = 0
|
||||
for m in memories:
|
||||
# 跳过字符串类型的记忆(旧版本格式)
|
||||
if isinstance(m, str):
|
||||
continue
|
||||
|
||||
# 只处理字典类型
|
||||
if isinstance(m, dict):
|
||||
metadata = m.get("metadata", {})
|
||||
if isinstance(metadata, dict) and metadata.get("agent_id") == agent_id:
|
||||
memory_id = m.get("id")
|
||||
if memory_id:
|
||||
try:
|
||||
mem.delete(memory_id=memory_id)
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete memory {memory_id}: {e}")
|
||||
if isinstance(m, dict) and self._check_agent_id_match(m, agent_id):
|
||||
memory_id = m.get("id")
|
||||
if memory_id:
|
||||
try:
|
||||
mem.delete(memory_id=memory_id)
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete memory {memory_id}: {e}")
|
||||
|
||||
logger.info(f"Deleted {deleted_count} memories for user={user_id}, agent={agent_id}")
|
||||
return deleted_count
|
||||
|
||||
Loading…
Reference in New Issue
Block a user