""" Memori 连接和实例管理器 负责管理 Memori 客户端实例的创建、缓存和生命周期 """ import asyncio import logging from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional from psycopg_pool import AsyncConnectionPool from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, Session from .memori_config import MemoriConfig logger = logging.getLogger("app") class MemoriManager: """ Memori 连接和实例管理器 主要功能: 1. 管理 Memori 实例的创建和缓存 2. 支持多租户隔离(entity_id + process_id) 3. 处理数据库连接和会话管理 4. 提供记忆召回和存储接口 """ def __init__( self, db_pool: Optional[AsyncConnectionPool] = None, db_url: Optional[str] = None, api_key: Optional[str] = None, ): """初始化 MemoriManager Args: db_pool: PostgreSQL 异步连接池(与 Checkpointer 共享) db_url: 数据库连接 URL(如果不使用连接池) api_key: Memori API 密钥(用于高级增强功能) """ self._db_pool = db_pool self._db_url = db_url self._api_key = api_key # 缓存 Memori 实例: key = f"{entity_id}:{process_id}" self._instances: Dict[str, Any] = {} self._sync_engines: Dict[str, Any] = {} self._initialized = False async def initialize(self) -> None: """初始化 MemoriManager 创建数据库表结构(如果不存在) """ if self._initialized: return logger.info("Initializing MemoriManager...") try: # 创建第一个 Memori 实例来初始化表结构 if self._db_pool or self._db_url: db_url = self._db_url or getattr(self._db_pool, "_url", None) if db_url: await self._build_schema(db_url) self._initialized = True logger.info("MemoriManager initialized successfully") except Exception as e: logger.error(f"Failed to initialize MemoriManager: {e}") # 不抛出异常,允许系统在没有 Memori 的情况下运行 async def _build_schema(self, db_url: str) -> None: """构建 Memori 数据库表结构 Args: db_url: 数据库连接 URL """ try: from memori import Memori # 创建同步引擎用于初始化 engine = create_engine(db_url) SessionLocal = sessionmaker(bind=engine) # 创建 Memori 实例并构建表结构 mem = Memori(conn=SessionLocal) mem.config.storage.build() logger.info("Memori schema built successfully") except ImportError: logger.warning("memori package not available, skipping schema build") except Exception as e: logger.error(f"Failed to build Memori schema: {e}") def _get_sync_session(self, db_url: str) -> Session: """获取同步数据库会话(Memori 需要) Args: db_url: 数据库连接 URL Returns: SQLAlchemy Session """ if db_url not in self._sync_engines: from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker engine = create_engine(db_url, pool_pre_ping=True) self._sync_engines[db_url] = sessionmaker(bind=engine) return self._sync_engines[db_url]() async def get_memori( self, entity_id: str, process_id: str, session_id: str, config: Optional[MemoriConfig] = None, ) -> Any: """获取或创建 Memori 实例 Args: entity_id: 实体 ID(通常是 user_identifier) process_id: 进程 ID(通常是 bot_id) session_id: 会话 ID config: Memori 配置 Returns: Memori 实例 """ cache_key = f"{entity_id}:{process_id}" # 检查缓存 if cache_key in self._instances: memori_instance = self._instances[cache_key] # 更新会话 memori_instance.config.session_id = session_id return memori_instance # 创建新实例 memori_instance = await self._create_memori_instance( entity_id=entity_id, process_id=process_id, session_id=session_id, config=config, ) # 缓存实例 self._instances[cache_key] = memori_instance return memori_instance async def _create_memori_instance( self, entity_id: str, process_id: str, session_id: str, config: Optional[MemoriConfig] = None, ) -> Any: """创建新的 Memori 实例 Args: entity_id: 实体 ID process_id: 进程 ID session_id: 会话 ID config: Memori 配置 Returns: Memori 实例 """ try: from memori import Memori except ImportError: logger.error("memori package not installed") raise RuntimeError("memori package is required but not installed") # 获取数据库连接 URL db_url = self._db_url if self._db_pool and hasattr(self._db_pool, "_url"): db_url = str(self._db_pool._url) if not db_url: raise ValueError("Either db_pool or db_url must be provided") # 创建同步会话(Memori 目前需要同步连接) session_factory = self._get_sync_session(db_url) # 创建 Memori 实例 mem = Memori(conn=session_factory) # 设置 API 密钥(如果提供) if self._api_key or (config and config.api_key): api_key = config.api_key if config else self._api_key mem.config.api_key = api_key # 设置 attribution mem.attribution(entity_id=entity_id, process_id=process_id) # 设置会话 mem.config.session_id = session_id # 配置召回参数 if config: mem.config.recall_facts_limit = config.semantic_search_top_k mem.config.recall_relevance_threshold = config.semantic_search_threshold mem.config.recall_embeddings_limit = config.semantic_search_embeddings_limit logger.info( f"Created Memori instance: entity={entity_id}, process={process_id}, session={session_id}" ) return mem async def recall_memories( self, query: str, entity_id: str, process_id: str, session_id: str, config: Optional[MemoriConfig] = None, ) -> List[Dict[str, Any]]: """召回相关记忆 Args: query: 查询文本 entity_id: 实体 ID process_id: 进程 ID session_id: 会话 ID config: Memori 配置 Returns: 记忆列表,每个记忆包含 content, similarity 等字段 """ try: mem = await self.get_memori(entity_id, process_id, session_id, config) # 调用 recall 进行语义搜索 results = mem.recall(query=query, limit=config.semantic_search_top_k if config else 5) # 转换为统一格式 memories = [] for result in results: memory = { "content": result.get("content", ""), "similarity": result.get("similarity", 0.0), "fact_type": result.get("fact_type", "unknown"), "created_at": result.get("created_at"), } # 过滤低相关度记忆 threshold = config.semantic_search_threshold if config else 0.7 if memory["similarity"] >= threshold: memories.append(memory) logger.info(f"Recalled {len(memories)} memories for query: {query[:50]}...") return memories except Exception as e: logger.error(f"Failed to recall memories: {e}") return [] async def wait_for_augmentation( self, entity_id: str, process_id: str, session_id: str, timeout: Optional[float] = None, ) -> None: """等待后台增强任务完成 Args: entity_id: 实体 ID process_id: 进程 ID session_id: 会话 ID timeout: 超时时间(秒) """ try: mem = await self.get_memori(entity_id, process_id, session_id) if timeout: # 在线程池中运行同步的 wait() loop = asyncio.get_event_loop() await loop.run_in_executor(None, lambda: mem.augmentation.wait(timeout=timeout)) else: # 无限等待 loop = asyncio.get_event_loop() await loop.run_in_executor(None, mem.augmentation.wait) except Exception as e: logger.error(f"Failed to wait for augmentation: {e}") def clear_cache(self, entity_id: Optional[str] = None, process_id: Optional[str] = None) -> None: """清除缓存的 Memori 实例 Args: entity_id: 实体 ID(如果为 None,清除所有) process_id: 进程 ID(如果为 None,清除所有) """ if entity_id is None and process_id is None: self._instances.clear() logger.info("Cleared all Memori instances from cache") else: keys_to_remove = [] for key in self._instances: e_id, p_id = key.split(":") if entity_id and e_id != entity_id: continue if process_id and p_id != process_id: continue keys_to_remove.append(key) for key in keys_to_remove: del self._instances[key] logger.info(f"Cleared {len(keys_to_remove)} Memori instances from cache") async def close(self) -> None: """关闭管理器并清理资源""" logger.info("Closing MemoriManager...") # 清理缓存的实例 self._instances.clear() # 关闭同步引擎 for engine in self._sync_engines.values(): try: engine.dispose() except Exception as e: logger.error(f"Error closing engine: {e}") self._sync_engines.clear() self._initialized = False logger.info("MemoriManager closed") # 全局单例 _global_manager: Optional[MemoriManager] = None def get_memori_manager() -> MemoriManager: """获取全局 MemoriManager 单例 Returns: MemoriManager 实例 """ global _global_manager if _global_manager is None: _global_manager = MemoriManager() return _global_manager async def init_global_memori( db_pool: Optional[AsyncConnectionPool] = None, db_url: Optional[str] = None, api_key: Optional[str] = None, ) -> MemoriManager: """初始化全局 MemoriManager Args: db_pool: PostgreSQL 连接池 db_url: 数据库连接 URL api_key: Memori API 密钥 Returns: MemoriManager 实例 """ manager = get_memori_manager() manager._db_pool = db_pool manager._db_url = db_url manager._api_key = api_key await manager.initialize() return manager async def close_global_memori() -> None: """关闭全局 MemoriManager""" global _global_manager if _global_manager is not None: await _global_manager.close()