Add Memori (https://github.com/MemoriLabs/Memori) integration for persistent cross-session memory capabilities in both create_agent and create_deep_agent. ## New Files - agent/memori_config.py: MemoriConfig dataclass for configuration - agent/memori_manager.py: MemoriManager for connection and instance management - agent/memori_middleware.py: MemoriMiddleware for memory recall/storage - tests/: Unit tests for Memori components ## Modified Files - agent/agent_config.py: Added enable_memori, memori_semantic_search_top_k, etc. - agent/deep_assistant.py: Integrated MemoriMiddleware into init_agent() - utils/settings.py: Added MEMORI_* environment variables - pyproject.toml: Added memori>=3.1.0 dependency ## Features - Semantic memory search with configurable top-k and threshold - Multi-tenant isolation (entity_id=user, process_id=bot, session_id) - Memory injection into system prompt - Background asynchronous memory augmentation - Graceful degradation when Memori is unavailable 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
384 lines
12 KiB
Python
384 lines
12 KiB
Python
"""
|
||
Memori 连接和实例管理器
|
||
负责管理 Memori 客户端实例的创建、缓存和生命周期
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from psycopg_pool import AsyncConnectionPool
|
||
from sqlalchemy import create_engine
|
||
from sqlalchemy.orm import sessionmaker, Session
|
||
|
||
from .memori_config import MemoriConfig
|
||
|
||
logger = logging.getLogger("app")
|
||
|
||
|
||
class MemoriManager:
|
||
"""
|
||
Memori 连接和实例管理器
|
||
|
||
主要功能:
|
||
1. 管理 Memori 实例的创建和缓存
|
||
2. 支持多租户隔离(entity_id + process_id)
|
||
3. 处理数据库连接和会话管理
|
||
4. 提供记忆召回和存储接口
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
db_pool: Optional[AsyncConnectionPool] = None,
|
||
db_url: Optional[str] = None,
|
||
api_key: Optional[str] = None,
|
||
):
|
||
"""初始化 MemoriManager
|
||
|
||
Args:
|
||
db_pool: PostgreSQL 异步连接池(与 Checkpointer 共享)
|
||
db_url: 数据库连接 URL(如果不使用连接池)
|
||
api_key: Memori API 密钥(用于高级增强功能)
|
||
"""
|
||
self._db_pool = db_pool
|
||
self._db_url = db_url
|
||
self._api_key = api_key
|
||
|
||
# 缓存 Memori 实例: key = f"{entity_id}:{process_id}"
|
||
self._instances: Dict[str, Any] = {}
|
||
self._sync_engines: Dict[str, Any] = {}
|
||
self._initialized = False
|
||
|
||
async def initialize(self) -> None:
|
||
"""初始化 MemoriManager
|
||
|
||
创建数据库表结构(如果不存在)
|
||
"""
|
||
if self._initialized:
|
||
return
|
||
|
||
logger.info("Initializing MemoriManager...")
|
||
|
||
try:
|
||
# 创建第一个 Memori 实例来初始化表结构
|
||
if self._db_pool or self._db_url:
|
||
db_url = self._db_url or getattr(self._db_pool, "_url", None)
|
||
if db_url:
|
||
await self._build_schema(db_url)
|
||
|
||
self._initialized = True
|
||
logger.info("MemoriManager initialized successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize MemoriManager: {e}")
|
||
# 不抛出异常,允许系统在没有 Memori 的情况下运行
|
||
|
||
async def _build_schema(self, db_url: str) -> None:
|
||
"""构建 Memori 数据库表结构
|
||
|
||
Args:
|
||
db_url: 数据库连接 URL
|
||
"""
|
||
try:
|
||
from memori import Memori
|
||
|
||
# 创建同步引擎用于初始化
|
||
engine = create_engine(db_url)
|
||
SessionLocal = sessionmaker(bind=engine)
|
||
|
||
# 创建 Memori 实例并构建表结构
|
||
mem = Memori(conn=SessionLocal)
|
||
mem.config.storage.build()
|
||
|
||
logger.info("Memori schema built successfully")
|
||
except ImportError:
|
||
logger.warning("memori package not available, skipping schema build")
|
||
except Exception as e:
|
||
logger.error(f"Failed to build Memori schema: {e}")
|
||
|
||
def _get_sync_session(self, db_url: str) -> Session:
|
||
"""获取同步数据库会话(Memori 需要)
|
||
|
||
Args:
|
||
db_url: 数据库连接 URL
|
||
|
||
Returns:
|
||
SQLAlchemy Session
|
||
"""
|
||
if db_url not in self._sync_engines:
|
||
from sqlalchemy import create_engine
|
||
from sqlalchemy.orm import sessionmaker
|
||
|
||
engine = create_engine(db_url, pool_pre_ping=True)
|
||
self._sync_engines[db_url] = sessionmaker(bind=engine)
|
||
|
||
return self._sync_engines[db_url]()
|
||
|
||
async def get_memori(
|
||
self,
|
||
entity_id: str,
|
||
process_id: str,
|
||
session_id: str,
|
||
config: Optional[MemoriConfig] = None,
|
||
) -> Any:
|
||
"""获取或创建 Memori 实例
|
||
|
||
Args:
|
||
entity_id: 实体 ID(通常是 user_identifier)
|
||
process_id: 进程 ID(通常是 bot_id)
|
||
session_id: 会话 ID
|
||
config: Memori 配置
|
||
|
||
Returns:
|
||
Memori 实例
|
||
"""
|
||
cache_key = f"{entity_id}:{process_id}"
|
||
|
||
# 检查缓存
|
||
if cache_key in self._instances:
|
||
memori_instance = self._instances[cache_key]
|
||
# 更新会话
|
||
memori_instance.config.session_id = session_id
|
||
return memori_instance
|
||
|
||
# 创建新实例
|
||
memori_instance = await self._create_memori_instance(
|
||
entity_id=entity_id,
|
||
process_id=process_id,
|
||
session_id=session_id,
|
||
config=config,
|
||
)
|
||
|
||
# 缓存实例
|
||
self._instances[cache_key] = memori_instance
|
||
return memori_instance
|
||
|
||
async def _create_memori_instance(
|
||
self,
|
||
entity_id: str,
|
||
process_id: str,
|
||
session_id: str,
|
||
config: Optional[MemoriConfig] = None,
|
||
) -> Any:
|
||
"""创建新的 Memori 实例
|
||
|
||
Args:
|
||
entity_id: 实体 ID
|
||
process_id: 进程 ID
|
||
session_id: 会话 ID
|
||
config: Memori 配置
|
||
|
||
Returns:
|
||
Memori 实例
|
||
"""
|
||
try:
|
||
from memori import Memori
|
||
except ImportError:
|
||
logger.error("memori package not installed")
|
||
raise RuntimeError("memori package is required but not installed")
|
||
|
||
# 获取数据库连接 URL
|
||
db_url = self._db_url
|
||
if self._db_pool and hasattr(self._db_pool, "_url"):
|
||
db_url = str(self._db_pool._url)
|
||
|
||
if not db_url:
|
||
raise ValueError("Either db_pool or db_url must be provided")
|
||
|
||
# 创建同步会话(Memori 目前需要同步连接)
|
||
session_factory = self._get_sync_session(db_url)
|
||
|
||
# 创建 Memori 实例
|
||
mem = Memori(conn=session_factory)
|
||
|
||
# 设置 API 密钥(如果提供)
|
||
if self._api_key or (config and config.api_key):
|
||
api_key = config.api_key if config else self._api_key
|
||
mem.config.api_key = api_key
|
||
|
||
# 设置 attribution
|
||
mem.attribution(entity_id=entity_id, process_id=process_id)
|
||
|
||
# 设置会话
|
||
mem.config.session_id = session_id
|
||
|
||
# 配置召回参数
|
||
if config:
|
||
mem.config.recall_facts_limit = config.semantic_search_top_k
|
||
mem.config.recall_relevance_threshold = config.semantic_search_threshold
|
||
mem.config.recall_embeddings_limit = config.semantic_search_embeddings_limit
|
||
|
||
logger.info(
|
||
f"Created Memori instance: entity={entity_id}, process={process_id}, session={session_id}"
|
||
)
|
||
|
||
return mem
|
||
|
||
async def recall_memories(
|
||
self,
|
||
query: str,
|
||
entity_id: str,
|
||
process_id: str,
|
||
session_id: str,
|
||
config: Optional[MemoriConfig] = None,
|
||
) -> List[Dict[str, Any]]:
|
||
"""召回相关记忆
|
||
|
||
Args:
|
||
query: 查询文本
|
||
entity_id: 实体 ID
|
||
process_id: 进程 ID
|
||
session_id: 会话 ID
|
||
config: Memori 配置
|
||
|
||
Returns:
|
||
记忆列表,每个记忆包含 content, similarity 等字段
|
||
"""
|
||
try:
|
||
mem = await self.get_memori(entity_id, process_id, session_id, config)
|
||
|
||
# 调用 recall 进行语义搜索
|
||
results = mem.recall(query=query, limit=config.semantic_search_top_k if config else 5)
|
||
|
||
# 转换为统一格式
|
||
memories = []
|
||
for result in results:
|
||
memory = {
|
||
"content": result.get("content", ""),
|
||
"similarity": result.get("similarity", 0.0),
|
||
"fact_type": result.get("fact_type", "unknown"),
|
||
"created_at": result.get("created_at"),
|
||
}
|
||
# 过滤低相关度记忆
|
||
threshold = config.semantic_search_threshold if config else 0.7
|
||
if memory["similarity"] >= threshold:
|
||
memories.append(memory)
|
||
|
||
logger.info(f"Recalled {len(memories)} memories for query: {query[:50]}...")
|
||
return memories
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to recall memories: {e}")
|
||
return []
|
||
|
||
async def wait_for_augmentation(
|
||
self,
|
||
entity_id: str,
|
||
process_id: str,
|
||
session_id: str,
|
||
timeout: Optional[float] = None,
|
||
) -> None:
|
||
"""等待后台增强任务完成
|
||
|
||
Args:
|
||
entity_id: 实体 ID
|
||
process_id: 进程 ID
|
||
session_id: 会话 ID
|
||
timeout: 超时时间(秒)
|
||
"""
|
||
try:
|
||
mem = await self.get_memori(entity_id, process_id, session_id)
|
||
|
||
if timeout:
|
||
# 在线程池中运行同步的 wait()
|
||
loop = asyncio.get_event_loop()
|
||
await loop.run_in_executor(None, lambda: mem.augmentation.wait(timeout=timeout))
|
||
else:
|
||
# 无限等待
|
||
loop = asyncio.get_event_loop()
|
||
await loop.run_in_executor(None, mem.augmentation.wait)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to wait for augmentation: {e}")
|
||
|
||
def clear_cache(self, entity_id: Optional[str] = None, process_id: Optional[str] = None) -> None:
|
||
"""清除缓存的 Memori 实例
|
||
|
||
Args:
|
||
entity_id: 实体 ID(如果为 None,清除所有)
|
||
process_id: 进程 ID(如果为 None,清除所有)
|
||
"""
|
||
if entity_id is None and process_id is None:
|
||
self._instances.clear()
|
||
logger.info("Cleared all Memori instances from cache")
|
||
else:
|
||
keys_to_remove = []
|
||
for key in self._instances:
|
||
e_id, p_id = key.split(":")
|
||
if entity_id and e_id != entity_id:
|
||
continue
|
||
if process_id and p_id != process_id:
|
||
continue
|
||
keys_to_remove.append(key)
|
||
|
||
for key in keys_to_remove:
|
||
del self._instances[key]
|
||
|
||
logger.info(f"Cleared {len(keys_to_remove)} Memori instances from cache")
|
||
|
||
async def close(self) -> None:
|
||
"""关闭管理器并清理资源"""
|
||
logger.info("Closing MemoriManager...")
|
||
|
||
# 清理缓存的实例
|
||
self._instances.clear()
|
||
|
||
# 关闭同步引擎
|
||
for engine in self._sync_engines.values():
|
||
try:
|
||
engine.dispose()
|
||
except Exception as e:
|
||
logger.error(f"Error closing engine: {e}")
|
||
|
||
self._sync_engines.clear()
|
||
self._initialized = False
|
||
|
||
logger.info("MemoriManager closed")
|
||
|
||
|
||
# 全局单例
|
||
_global_manager: Optional[MemoriManager] = None
|
||
|
||
|
||
def get_memori_manager() -> MemoriManager:
|
||
"""获取全局 MemoriManager 单例
|
||
|
||
Returns:
|
||
MemoriManager 实例
|
||
"""
|
||
global _global_manager
|
||
if _global_manager is None:
|
||
_global_manager = MemoriManager()
|
||
return _global_manager
|
||
|
||
|
||
async def init_global_memori(
|
||
db_pool: Optional[AsyncConnectionPool] = None,
|
||
db_url: Optional[str] = None,
|
||
api_key: Optional[str] = None,
|
||
) -> MemoriManager:
|
||
"""初始化全局 MemoriManager
|
||
|
||
Args:
|
||
db_pool: PostgreSQL 连接池
|
||
db_url: 数据库连接 URL
|
||
api_key: Memori API 密钥
|
||
|
||
Returns:
|
||
MemoriManager 实例
|
||
"""
|
||
manager = get_memori_manager()
|
||
manager._db_pool = db_pool
|
||
manager._db_url = db_url
|
||
manager._api_key = api_key
|
||
await manager.initialize()
|
||
return manager
|
||
|
||
|
||
async def close_global_memori() -> None:
|
||
"""关闭全局 MemoriManager"""
|
||
global _global_manager
|
||
if _global_manager is not None:
|
||
await _global_manager.close()
|