diff --git a/agent/mem0_manager.py b/agent/mem0_manager.py index 464bc8f..1b73833 100644 --- a/agent/mem0_manager.py +++ b/agent/mem0_manager.py @@ -456,6 +456,55 @@ class Mem0Manager: logger.error(f"Failed to add memory: {e}") return {} + def _extract_memories_from_response(self, response: Any) -> List[Dict[str, Any]]: + """从 Mem0 get_all() 响应中提取记忆列表 + + Mem0 的 get_all() 返回格式可能有两种: + 1. 新版本: {"results": [...]} + 2. 旧版本: 直接返回列表 + + Args: + response: Mem0 get_all() 的响应 + + Returns: + 记忆列表 + """ + if isinstance(response, dict) and "results" in response: + return response["results"] + elif isinstance(response, list): + return response + else: + logger.warning(f"Unexpected response format from mem.get_all(): {type(response)}") + return [] + + def _check_agent_id_match(self, memory: Dict[str, Any], agent_id: str) -> bool: + """检查记忆是否属于指定的 agent + + Mem0 的记忆结构中,agent_id 可能在两个位置: + 1. 顶层: memory["agent_id"] + 2. metadata 中: memory["metadata"]["agent_id"] + + Args: + memory: 记忆字典 + agent_id: 要匹配的 agent ID + + Returns: + 是否匹配 + """ + if not isinstance(memory, dict): + return False + + # 首先检查顶层 agent_id(新版本格式) + if memory.get("agent_id") == agent_id: + return True + + # 然后检查 metadata 中的 agent_id(旧版本格式) + metadata = memory.get("metadata", {}) + if isinstance(metadata, dict) and metadata.get("agent_id") == agent_id: + return True + + return False + async def get_all_memories( self, user_id: str, @@ -474,21 +523,16 @@ class Mem0Manager: mem = await self.get_mem0(user_id, agent_id, "default") # 获取所有记忆 - memories = mem.get_all(user_id=user_id) + response = mem.get_all(user_id=user_id) - # Mem0 的 get_all 可能返回字符串列表或字典列表,需要统一处理 - filtered_memories = [] - for m in memories: - # 如果是字符串,跳过(旧版本格式,不包含 metadata) - if isinstance(m, str): - logger.warning(f"Memory item is string, skipping: {m[:50]}...") - continue + # 从响应中提取记忆列表 + memories = self._extract_memories_from_response(response) - # 如果是字典,检查 agent_id - if isinstance(m, dict): - metadata = m.get("metadata", {}) - if isinstance(metadata, dict) and metadata.get("agent_id") == agent_id: - filtered_memories.append(m) + # 过滤 agent_id(agent_id 在顶层,不在 metadata 中) + filtered_memories = [ + m for m in memories + if self._check_agent_id_match(m, agent_id) + ] return filtered_memories @@ -516,21 +560,16 @@ class Mem0Manager: mem = await self.get_mem0(user_id, agent_id, "default") # 先获取记忆以验证所有权 - memories = mem.get_all(user_id=user_id) + response = mem.get_all(user_id=user_id) + memories = self._extract_memories_from_response(response) + target_memory = None for m in memories: - # 跳过字符串类型的记忆(旧版本格式) - if isinstance(m, str): - continue - - # 只处理字典类型 - if isinstance(m, dict): - if m.get("id") == memory_id: - # 验证 agent_id 匹配 - metadata = m.get("metadata", {}) - if isinstance(metadata, dict) and metadata.get("agent_id") == agent_id: - target_memory = m - break + if isinstance(m, dict) and m.get("id") == memory_id: + # 验证 agent_id 匹配 + if self._check_agent_id_match(m, agent_id): + target_memory = m + break if not target_memory: logger.warning(f"Memory {memory_id} not found or access denied for user={user_id}, agent={agent_id}") @@ -564,27 +603,20 @@ class Mem0Manager: mem = await self.get_mem0(user_id, agent_id, "default") # 获取所有记忆 - memories = mem.get_all(user_id=user_id) + response = mem.get_all(user_id=user_id) + memories = self._extract_memories_from_response(response) # 过滤 agent_id 并删除 - # Mem0 的 get_all 可能返回字符串列表或字典列表,需要统一处理 deleted_count = 0 for m in memories: - # 跳过字符串类型的记忆(旧版本格式) - if isinstance(m, str): - continue - - # 只处理字典类型 - if isinstance(m, dict): - metadata = m.get("metadata", {}) - if isinstance(metadata, dict) and metadata.get("agent_id") == agent_id: - memory_id = m.get("id") - if memory_id: - try: - mem.delete(memory_id=memory_id) - deleted_count += 1 - except Exception as e: - logger.warning(f"Failed to delete memory {memory_id}: {e}") + if isinstance(m, dict) and self._check_agent_id_match(m, agent_id): + memory_id = m.get("id") + if memory_id: + try: + mem.delete(memory_id=memory_id) + deleted_count += 1 + except Exception as e: + logger.warning(f"Failed to delete memory {memory_id}: {e}") logger.info(f"Deleted {deleted_count} memories for user={user_id}, agent={agent_id}") return deleted_count