""" Mem0 连接和实例管理器 负责管理 Mem0 客户端实例的创建、缓存和生命周期 """ import logging import asyncio from typing import Any, Dict, List, Optional, Literal from embedding.manager import GlobalModelManager, get_model_manager import json_repair from psycopg2 import pool from .mem0_config import Mem0Config logger = logging.getLogger("app") # ============================================================================ # 自定义 Embedding 类,使用项目中已有的 GlobalModelManager # 避免重复加载模型 # ============================================================================ class CustomMem0Embedding: """ 自定义 Mem0 Embedding 类,直接使用项目中已有的 GlobalModelManager 这样 Mem0 就不需要再次加载同一个模型,节省内存 """ _model_manager = None # 缓存 GlobalModelManager 实例 def __init__(self, config: Optional[Any] = None): """初始化自定义 Embedding""" # 创建一个简单的 config 对象来兼容 Mem0 的 telemetry 代码 if config is None: config = type('Config', (), {'embedding_dims': 384})() self.config = config @property def embedding_dims(self): """获取 embedding 维度""" return 384 # gte-tiny 的维度 def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None): """ 获取文本的 embedding 向量(同步方法,供 Mem0 调用) Args: text: 要嵌入的文本(字符串或列表) memory_action: 记忆操作类型 (add/search/update),当前未使用 Returns: list: embedding 向量 """ manager = get_model_manager() model = asyncio.run(manager.get_model()) embeddings = model.encode(text, convert_to_numpy=True) return embeddings.tolist() # Monkey patch: 使用 json_repair 替换 mem0 的 remove_code_blocks def _remove_code_blocks_with_repair(content: str) -> str: """ 使用 json_repair 替换 mem0 的 remove_code_blocks 函数 json_repair.loads 会自动处理: - 移除代码块标记(```json, ``` 等) - 修复损坏的 JSON(如尾随逗号、注释、单引号等) """ import re content_stripped = content.strip() try: # json_repair.loads 会自动去除代码块并修复 JSON result = json_repair.loads(content_stripped) if isinstance(result, (dict, list)): import json return json.dumps(result, ensure_ascii=False) # 如果返回空字符串(非 JSON 输入),回退到原内容 if result == "" and content_stripped != "": # 尝试简单的代码块去除(降级处理) pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$" match = re.match(pattern, content_stripped) if match: return match.group(1).strip() return content_stripped return str(result) except Exception: # 如果解析失败,尝试简单的代码块去除(降级处理) pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$" match = re.match(pattern, content_stripped) if match: return match.group(1).strip() return content_stripped # 执行 monkey patch(在 mem0 导入之前或之后) try: import sys import mem0.memory.utils as mem0_utils mem0_utils.remove_code_blocks = _remove_code_blocks_with_repair # 如果 mem0.memory.main 已经导入,也要 patch 它的本地引用 if 'mem0.memory.main' in sys.modules: import mem0.memory.main mem0.memory.main.remove_code_blocks = _remove_code_blocks_with_repair logger.info("Successfully patched mem0.memory.main.remove_code_blocks with json_repair") else: logger.info("Successfully patched mem0.memory.utils.remove_code_blocks with json_repair") except ImportError: # mem0 还未导入,patch 将在首次导入时生效 pass except Exception as e: logger.warning(f"Failed to patch mem0 remove_code_blocks: {e}") class Mem0Manager: """ Mem0 连接和实例管理器 主要功能: 1. 管理 Mem0 实例的创建和缓存 2. 支持多租户隔离(user_id + agent_id) 3. 使用共享的同步连接池(由 DBPoolManager 提供) 4. 提供记忆召回和存储接口 """ def __init__( self, sync_pool: Optional[pool.SimpleConnectionPool] = None, ): """初始化 Mem0Manager Args: sync_pool: PostgreSQL 同步连接池(与 DBPoolManager 共享) """ self._sync_pool = sync_pool # 缓存 Mem0 实例: key = f"{user_id}:{agent_id}" self._instances: Dict[str, Any] = {} self._initialized = False async def initialize(self) -> None: """初始化 Mem0Manager 创建数据库表结构(如果不存在) """ if self._initialized: return logger.info("Initializing Mem0Manager...") try: # Mem0 会自动创建表结构,这里只需验证连接 if self._sync_pool: logger.info("Mem0Manager initialized successfully") else: logger.warning("No database configuration provided for Mem0") self._initialized = True except Exception as e: logger.error(f"Failed to initialize Mem0Manager: {e}") # 不抛出异常,允许系统在没有 Mem0 的情况下运行 def _get_connection_pool(self) -> Optional[pool.SimpleConnectionPool]: """获取同步数据库连接池(Mem0 需要) Returns: psycopg2.pool 连接池 """ return self._sync_pool async def get_mem0( self, user_id: str, agent_id: str, session_id: str, config: Optional[Mem0Config] = None, ) -> Any: """获取或创建 Mem0 实例 Args: user_id: 用户 ID(对应 entity_id) agent_id: Agent/Bot ID(对应 process_id) session_id: 会话 ID config: Mem0 配置 Returns: Mem0 实例 """ # 缓存键包含 LLM 实例 ID,以确保不同 LLM 使用不同实例 llm_suffix = "" if config and config.llm_instance is not None: llm_suffix = f":{id(config.llm_instance)}" cache_key = f"{user_id}:{agent_id}{llm_suffix}" # 检查缓存 if cache_key in self._instances: return self._instances[cache_key] # 创建新实例 mem0_instance = await self._create_mem0_instance( user_id=user_id, agent_id=agent_id, session_id=session_id, config=config, ) # 缓存实例 self._instances[cache_key] = mem0_instance return mem0_instance async def _create_mem0_instance( self, user_id: str, agent_id: str, session_id: str, config: Optional[Mem0Config] = None, ) -> Any: """创建新的 Mem0 实例 Args: user_id: 用户 ID agent_id: Agent/Bot ID session_id: 会话 ID config: Mem0 配置(包含 LLM 实例) Returns: Mem0 Memory 实例 """ try: from mem0 import Memory except ImportError: logger.error("mem0 package not installed") raise RuntimeError("mem0 package is required but not installed") # 获取同步连接池 connection_pool = self._get_connection_pool() if not connection_pool: raise ValueError("Database connection pool not available") # 创建自定义 embedder(使用共享模型,避免重复加载) custom_embedder = CustomMem0Embedding() # 配置 Mem0 使用 Pgvector # 注意:这里使用 huggingface_base_url 来绕过本地模型加载 # 设置一个假的 base_url,这样 HuggingFaceEmbedding 就不会加载 SentenceTransformer config_dict = { "custom_fact_extraction_prompt": config.get_custom_fact_extraction_prompt(), "vector_store": { "provider": "pgvector", "config": { "connection_pool": connection_pool, "collection_name": f"mem0_{agent_id}".replace("-", "_")[:50], # 按 agent_id 隔离 "embedding_model_dims": 384, # paraphrase-multilingual-MiniLM-L12-v2 的维度 } }, # 使用 huggingface_base_url 绕过模型加载(稍后会被替换为自定义 embedder) "embedder": { "provider": "huggingface", "config": { "huggingface_base_url": "http://dummy-url-that-will-be-replaced", "api_key": "dummy-key" # 占位符,防止 OpenAI client 验证失败 } } } # 添加 LangChain LLM 配置(如果提供了) if config and config.llm_instance is not None: config_dict["llm"] = { "provider": "langchain", "config": {"model": config.llm_instance} } logger.info( f"Configured LangChain LLM for Mem0: {type(config.llm_instance).__name__}" ) else: # 如果没有提供 LLM,使用默认的 openai 配置 # Mem0 的 LLM 用于提取记忆事实 from utils.settings import MASTERKEY, BACKEND_HOST import os llm_api_key = os.environ.get("OPENAI_API_KEY", "") or MASTERKEY config_dict["llm"] = { "provider": "openai", "config": { "model": "gpt-4o-mini", "api_key": llm_api_key, "openai_base_url": BACKEND_HOST # 使用自定义 backend } } # 创建 Mem0 实例 mem = Memory.from_config(config_dict) logger.debug(f"Original embedder type: {type(mem.embedding_model).__name__}") logger.debug(f"Original embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}") # 替换为自定义 embedder,复用项目中已加载的模型 # 这样 Mem0 就不会重复加载模型 mem.embedding_model = custom_embedder logger.debug(f"Replaced embedder type: {type(mem.embedding_model).__name__}") logger.debug(f"Replaced embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}") logger.info("Replaced Mem0 embedder with CustomMem0Embedding (reusing existing model)") logger.info( f"Created Mem0 instance: user={user_id}, agent={agent_id}" ) return mem async def recall_memories( self, query: str, user_id: str, agent_id: str, config: Optional[Mem0Config] = None, ) -> List[Dict[str, Any]]: """召回相关记忆(用户级别,跨会话共享) Args: query: 查询文本 user_id: 用户 ID agent_id: Agent/Bot ID config: Mem0 配置 Returns: 记忆列表,每个记忆包含 content, similarity 等字段 """ try: mem = await self.get_mem0(user_id, agent_id, "default", config) # 调用 search 进行语义搜索(使用 agent_id 参数过滤) limit = config.semantic_search_top_k if config else 20 results = mem.search( query=query, limit=limit, user_id=user_id, agent_id=agent_id, ) # 转换为统一格式 memories = [] for result in results["results"]: # Mem0 返回结果可能是字符串或字典 content = result.get("memory", "") score = result.get("score", 0.0) result_metadata = result.get("metadata", {}) memory = { "content": content, "similarity": score, "metadata": result_metadata, "fact_type": result_metadata.get("category", "fact"), } memories.append(memory) logger.info(f"Recalled {len(memories)} memories for user={user_id}, query: {query[:50]}...") return memories except Exception as e: logger.error(f"Failed to recall memories: {e}") return [] async def add_memory( self, text: str, user_id: str, agent_id: str, metadata: Optional[Dict[str, Any]] = None, config: Optional[Mem0Config] = None, ) -> Dict[str, Any]: """添加新记忆(用户级别,跨会话共享) Args: text: 记忆文本 user_id: 用户 ID agent_id: Agent/Bot ID metadata: 额外的元数据 config: Mem0 配置(包含 LLM 实例用于记忆提取) Returns: 添加的记忆结果 """ try: mem = await self.get_mem0(user_id, agent_id, "default", config) # 添加记忆(使用 agent_id 参数) result = mem.add( text, user_id=user_id, agent_id=agent_id, metadata=metadata or {} ) logger.info(f"Added memory for user={user_id}, agent={agent_id}: {result}") return result except Exception as e: logger.error(f"Failed to add memory: {e}") return {} async def get_all_memories( self, user_id: str, agent_id: str, ) -> List[Dict[str, Any]]: """获取用户的所有记忆(用户级别) Args: user_id: 用户 ID agent_id: Agent/Bot ID Returns: 记忆列表 """ try: mem = await self.get_mem0(user_id, agent_id, "default") # 获取所有记忆 memories = mem.get_all(user_id=user_id) # 过滤 agent_id filtered_memories = [ m for m in memories if m.get("metadata", {}).get("agent_id") == agent_id ] return filtered_memories except Exception as e: logger.error(f"Failed to get all memories: {e}") return [] def clear_cache(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> None: """清除缓存的 Mem0 实例 Args: user_id: 用户 ID(如果为 None,清除所有) agent_id: Agent ID(如果为 None,清除所有) """ if user_id is None and agent_id is None: self._instances.clear() logger.info("Cleared all Mem0 instances from cache") else: keys_to_remove = [] for key in self._instances: # 新格式: "user_id:agent_id:llm_model_name" 或 "user_id:agent_id" parts = key.split(":") if len(parts) >= 2: u_id = parts[0] a_id = parts[1] if user_id and u_id != user_id: continue if agent_id and a_id != agent_id: continue keys_to_remove.append(key) for key in keys_to_remove: del self._instances[key] logger.info(f"Cleared {len(keys_to_remove)} Mem0 instances from cache") async def close(self) -> None: """关闭管理器并清理资源""" logger.info("Closing Mem0Manager...") # 清理缓存的实例 self._instances.clear() # 注意:不关闭共享的同步连接池(由 DBPoolManager 管理) self._initialized = False logger.info("Mem0Manager closed") # 全局单例 _global_manager: Optional[Mem0Manager] = None def get_mem0_manager() -> Mem0Manager: """获取全局 Mem0Manager 单例 Returns: Mem0Manager 实例 """ global _global_manager if _global_manager is None: _global_manager = Mem0Manager() return _global_manager async def init_global_mem0( sync_pool: pool.SimpleConnectionPool, ) -> Mem0Manager: """初始化全局 Mem0Manager Args: sync_pool: PostgreSQL 同步连接池(从 DBPoolManager.sync_pool 获取) Returns: Mem0Manager 实例 """ manager = get_mem0_manager() manager._sync_pool = sync_pool await manager.initialize() return manager async def close_global_mem0() -> None: """关闭全局 Mem0Manager""" global _global_manager if _global_manager is not None: await _global_manager.close()