""" Mem0 连接和实例管理器 负责管理 Mem0 客户端实例的创建、缓存和生命周期 """ import logging import asyncio import threading import concurrent.futures from typing import Any, Dict, List, Optional, Literal from collections import OrderedDict from embedding.manager import GlobalModelManager, get_model_manager import json_repair from psycopg2 import pool from utils.settings import ( MEM0_POOL_SIZE ) from .mem0_config import Mem0Config logger = logging.getLogger("app") # ============================================================================ # 自定义 Embedding 类,使用项目中已有的 GlobalModelManager # 避免重复加载模型 # ============================================================================ class CustomMem0Embedding: """ 自定义 Mem0 Embedding 类,直接使用项目中已有的 GlobalModelManager 这样 Mem0 就不需要再次加载同一个模型,节省内存 """ _model = None # 类变量,缓存模型实例 _lock = threading.Lock() # 线程安全锁 _executor = None # 线程池执行器 def __init__(self, config: Optional[Any] = None): """初始化自定义 Embedding""" # 创建一个简单的 config 对象来兼容 Mem0 的 telemetry 代码 if config is None: config = type('Config', (), {'embedding_dims': 384})() self.config = config @property def embedding_dims(self): """获取 embedding 维度""" return 384 # gte-tiny 的维度 def _get_model_sync(self): """同步获取模型,避免 asyncio.run()""" # 首先尝试从 manager 获取已加载的模型 manager = get_model_manager() model = manager.get_model_sync() if model is not None: # 缓存模型 CustomMem0Embedding._model = model return model # 如果模型未加载,使用线程池运行异步初始化 if CustomMem0Embedding._executor is None: CustomMem0Embedding._executor = concurrent.futures.ThreadPoolExecutor( max_workers=1, thread_name_prefix="mem0_embed" ) # 在独立线程中运行异步代码 def run_async_in_thread(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: result = loop.run_until_complete(manager.get_model()) return result finally: loop.close() future = CustomMem0Embedding._executor.submit(run_async_in_thread) model = future.result(timeout=30) # 30秒超时 # 缓存模型 CustomMem0Embedding._model = model return model def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None): """ 获取文本的 embedding 向量(同步方法,供 Mem0 调用) Args: text: 要嵌入的文本(字符串或列表) memory_action: 记忆操作类型 (add/search/update),当前未使用 Returns: list: embedding 向量 """ # 线程安全地获取模型 if CustomMem0Embedding._model is None: with CustomMem0Embedding._lock: if CustomMem0Embedding._model is None: self._get_model_sync() model = CustomMem0Embedding._model embeddings = model.encode(text, convert_to_numpy=True) return embeddings.tolist() # Monkey patch: 使用 json_repair 替换 mem0 的 remove_code_blocks def _remove_code_blocks_with_repair(content: str) -> str: """ 使用 json_repair 替换 mem0 的 remove_code_blocks 函数 json_repair.loads 会自动处理: - 移除代码块标记(```json, ``` 等) - 修复损坏的 JSON(如尾随逗号、注释、单引号等) """ import re content_stripped = content.strip() try: # json_repair.loads 会自动去除代码块并修复 JSON result = json_repair.loads(content_stripped) if isinstance(result, (dict, list)): import json return json.dumps(result, ensure_ascii=False) # 如果返回空字符串(非 JSON 输入),回退到原内容 if result == "" and content_stripped != "": # 尝试简单的代码块去除(降级处理) pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$" match = re.match(pattern, content_stripped) if match: return match.group(1).strip() return content_stripped return str(result) except Exception: # 如果解析失败,尝试简单的代码块去除(降级处理) pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$" match = re.match(pattern, content_stripped) if match: return match.group(1).strip() return content_stripped # 执行 monkey patch(在 mem0 导入之前或之后) try: import sys import mem0.memory.utils as mem0_utils mem0_utils.remove_code_blocks = _remove_code_blocks_with_repair # 如果 mem0.memory.main 已经导入,也要 patch 它的本地引用 if 'mem0.memory.main' in sys.modules: import mem0.memory.main mem0.memory.main.remove_code_blocks = _remove_code_blocks_with_repair logger.info("Successfully patched mem0.memory.main.remove_code_blocks with json_repair") else: logger.info("Successfully patched mem0.memory.utils.remove_code_blocks with json_repair") except ImportError: # mem0 还未导入,patch 将在首次导入时生效 pass except Exception as e: logger.warning(f"Failed to patch mem0 remove_code_blocks: {e}") class Mem0Manager: """ Mem0 连接和实例管理器 主要功能: 1. 管理 Mem0 实例的创建和缓存 2. 支持多租户隔离(user_id + agent_id) 3. 使用共享的同步连接池(由 DBPoolManager 提供) 4. 提供记忆召回和存储接口 """ def __init__( self, sync_pool: Optional[pool.SimpleConnectionPool] = None, ): """初始化 Mem0Manager Args: sync_pool: PostgreSQL 同步连接池(与 DBPoolManager 共享) """ self._sync_pool = sync_pool # 使用 OrderedDict 实现 LRU 缓存,最多保留 50 个实例 self._instances: OrderedDict[str, Any] = OrderedDict() self._max_instances = MEM0_POOL_SIZE/2 # 最大缓存实例数 self._initialized = False async def initialize(self) -> None: """初始化 Mem0Manager 创建数据库表结构(如果不存在) """ if self._initialized: return logger.info("Initializing Mem0Manager...") try: # Mem0 会自动创建表结构,这里只需验证连接 if self._sync_pool: logger.info("Mem0Manager initialized successfully") else: logger.warning("No database configuration provided for Mem0") self._initialized = True except Exception as e: logger.error(f"Failed to initialize Mem0Manager: {e}") # 不抛出异常,允许系统在没有 Mem0 的情况下运行 def _get_connection_pool(self) -> Optional[pool.SimpleConnectionPool]: """获取同步数据库连接池(Mem0 需要) Returns: psycopg2.pool 连接池 """ return self._sync_pool async def get_mem0( self, user_id: str, agent_id: str, session_id: str, config: Optional[Mem0Config] = None, ) -> Any: """获取或创建 Mem0 实例 Args: user_id: 用户 ID(对应 entity_id) agent_id: Agent/Bot ID(对应 process_id) session_id: 会话 ID config: Mem0 配置 Returns: Mem0 实例 """ # 缓存键包含 LLM 实例 ID,以确保不同 LLM 使用不同实例 llm_suffix = "" if config and config.llm_instance is not None: llm_suffix = f":{id(config.llm_instance)}" cache_key = f"{user_id}:{agent_id}{llm_suffix}" # 检查缓存(同时移动到末尾表示最近使用) if cache_key in self._instances: self._instances.move_to_end(cache_key) return self._instances[cache_key] # 检查缓存大小,超过则移除最旧的 if len(self._instances) >= self._max_instances: removed_key, _ = self._instances.popitem(last=False) logger.debug(f"Mem0 instance cache full, removed oldest entry: {removed_key}") # 创建新实例 mem0_instance = await self._create_mem0_instance( user_id=user_id, agent_id=agent_id, session_id=session_id, config=config, ) # 缓存实例(新实例自动在末尾) self._instances[cache_key] = mem0_instance return mem0_instance async def _create_mem0_instance( self, user_id: str, agent_id: str, session_id: str, config: Optional[Mem0Config] = None, ) -> Any: """创建新的 Mem0 实例 Args: user_id: 用户 ID agent_id: Agent/Bot ID session_id: 会话 ID config: Mem0 配置(包含 LLM 实例) Returns: Mem0 Memory 实例 """ try: from mem0 import Memory except ImportError: logger.error("mem0 package not installed") raise RuntimeError("mem0 package is required but not installed") # 获取同步连接池 connection_pool = self._get_connection_pool() if not connection_pool: raise ValueError("Database connection pool not available") # 创建自定义 embedder(使用共享模型,避免重复加载) custom_embedder = CustomMem0Embedding() # 配置 Mem0 使用 Pgvector # 注意:这里使用 huggingface_base_url 来绕过本地模型加载 # 设置一个假的 base_url,这样 HuggingFaceEmbedding 就不会加载 SentenceTransformer config_dict = { "custom_fact_extraction_prompt": config.get_custom_fact_extraction_prompt(), "vector_store": { "provider": "pgvector", "config": { "connection_pool": connection_pool, "collection_name": f"mem0_{agent_id}".replace("-", "_")[:50], # 按 agent_id 隔离 "embedding_model_dims": 384, # paraphrase-multilingual-MiniLM-L12-v2 的维度 } }, # 使用 huggingface_base_url 绕过模型加载(稍后会被替换为自定义 embedder) "embedder": { "provider": "huggingface", "config": { "huggingface_base_url": "http://dummy-url-that-will-be-replaced", "api_key": "dummy-key" # 占位符,防止 OpenAI client 验证失败 } } } # 添加 LangChain LLM 配置(如果提供了) if config and config.llm_instance is not None: config_dict["llm"] = { "provider": "langchain", "config": {"model": config.llm_instance} } logger.info( f"Configured LangChain LLM for Mem0: {type(config.llm_instance).__name__}" ) else: # 如果没有提供 LLM,使用默认的 openai 配置 # Mem0 的 LLM 用于提取记忆事实 from utils.settings import MASTERKEY, BACKEND_HOST import os llm_api_key = os.environ.get("OPENAI_API_KEY", "") or MASTERKEY config_dict["llm"] = { "provider": "openai", "config": { "model": "gpt-4o-mini", "api_key": llm_api_key, "openai_base_url": BACKEND_HOST # 使用自定义 backend } } # 创建 Mem0 实例 mem = Memory.from_config(config_dict) logger.debug(f"Original embedder type: {type(mem.embedding_model).__name__}") logger.debug(f"Original embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}") # 替换为自定义 embedder,复用项目中已加载的模型 # 这样 Mem0 就不会重复加载模型 mem.embedding_model = custom_embedder logger.debug(f"Replaced embedder type: {type(mem.embedding_model).__name__}") logger.debug(f"Replaced embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}") logger.info("Replaced Mem0 embedder with CustomMem0Embedding (reusing existing model)") logger.info( f"Created Mem0 instance: user={user_id}, agent={agent_id}" ) return mem async def recall_memories( self, query: str, user_id: str, agent_id: str, config: Optional[Mem0Config] = None, ) -> List[Dict[str, Any]]: """召回相关记忆(用户级别,跨会话共享) Args: query: 查询文本 user_id: 用户 ID agent_id: Agent/Bot ID config: Mem0 配置 Returns: 记忆列表,每个记忆包含 content, similarity 等字段 """ try: mem = await self.get_mem0(user_id, agent_id, "default", config) # 调用 search 进行语义搜索(使用 agent_id 参数过滤) limit = config.semantic_search_top_k if config else 20 results = mem.search( query=query, limit=limit, user_id=user_id, agent_id=agent_id, ) # 转换为统一格式 memories = [] for result in results["results"]: # Mem0 返回结果可能是字符串或字典 content = result.get("memory", "") score = result.get("score", 0.0) result_metadata = result.get("metadata", {}) memory = { "content": content, "similarity": score, "metadata": result_metadata, "fact_type": result_metadata.get("category", "fact"), } memories.append(memory) logger.info(f"Recalled {len(memories)} memories for user={user_id}, query: {query[:50]}...") return memories except Exception as e: logger.error(f"Failed to recall memories: {e}") return [] async def add_memory( self, text: str, user_id: str, agent_id: str, metadata: Optional[Dict[str, Any]] = None, config: Optional[Mem0Config] = None, ) -> Dict[str, Any]: """添加新记忆(用户级别,跨会话共享) Args: text: 记忆文本 user_id: 用户 ID agent_id: Agent/Bot ID metadata: 额外的元数据 config: Mem0 配置(包含 LLM 实例用于记忆提取) Returns: 添加的记忆结果 """ try: mem = await self.get_mem0(user_id, agent_id, "default", config) # 添加记忆(使用 agent_id 参数) result = mem.add( text, user_id=user_id, agent_id=agent_id, metadata=metadata or {} ) logger.info(f"Added memory for user={user_id}, agent={agent_id}: {result}") return result except Exception as e: logger.error(f"Failed to add memory: {e}") return {} async def get_all_memories( self, user_id: str, agent_id: str, ) -> List[Dict[str, Any]]: """获取用户的所有记忆(用户级别) Args: user_id: 用户 ID agent_id: Agent/Bot ID Returns: 记忆列表 """ try: mem = await self.get_mem0(user_id, agent_id, "default") # 获取所有记忆 memories = mem.get_all(user_id=user_id) # 过滤 agent_id filtered_memories = [ m for m in memories if m.get("metadata", {}).get("agent_id") == agent_id ] return filtered_memories except Exception as e: logger.error(f"Failed to get all memories: {e}") return [] def clear_cache(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> None: """清除缓存的 Mem0 实例 Args: user_id: 用户 ID(如果为 None,清除所有) agent_id: Agent ID(如果为 None,清除所有) """ if user_id is None and agent_id is None: self._instances.clear() logger.info("Cleared all Mem0 instances from cache") else: keys_to_remove = [] for key in self._instances: # 新格式: "user_id:agent_id:llm_model_name" 或 "user_id:agent_id" parts = key.split(":") if len(parts) >= 2: u_id = parts[0] a_id = parts[1] if user_id and u_id != user_id: continue if agent_id and a_id != agent_id: continue keys_to_remove.append(key) for key in keys_to_remove: del self._instances[key] logger.info(f"Cleared {len(keys_to_remove)} Mem0 instances from cache") async def close(self) -> None: """关闭管理器并清理资源""" logger.info("Closing Mem0Manager...") # 清理缓存的实例 self._instances.clear() # 注意:不关闭共享的同步连接池(由 DBPoolManager 管理) self._initialized = False logger.info("Mem0Manager closed") # 全局单例 _global_manager: Optional[Mem0Manager] = None def get_mem0_manager() -> Mem0Manager: """获取全局 Mem0Manager 单例 Returns: Mem0Manager 实例 """ global _global_manager if _global_manager is None: _global_manager = Mem0Manager() return _global_manager async def init_global_mem0( sync_pool: pool.SimpleConnectionPool, ) -> Mem0Manager: """初始化全局 Mem0Manager Args: sync_pool: PostgreSQL 同步连接池(从 DBPoolManager.sync_pool 获取) Returns: Mem0Manager 实例 """ manager = get_mem0_manager() manager._sync_pool = sync_pool await manager.initialize() return manager async def close_global_mem0() -> None: """关闭全局 Mem0Manager""" global _global_manager if _global_manager is not None: await _global_manager.close()