qwen_agent/agent/memori_manager.py
朱潮 4d6ee6ae0c fix: pass db_url to init_global_memori
Add db_url property to MemoriManager that falls back to
CHECKPOINT_DB_URL setting, and pass it explicitly from
fastapi_app.py to ensure Memori can create sync connections.

This fixes the error "Either db_pool or db_url must be provided"
when recalling memories.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-20 08:27:34 +08:00

390 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Memori 连接和实例管理器
负责管理 Memori 客户端实例的创建、缓存和生命周期
"""
import asyncio
import logging
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from psycopg_pool import AsyncConnectionPool
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from .memori_config import MemoriConfig
logger = logging.getLogger("app")
class MemoriManager:
"""
Memori 连接和实例管理器
主要功能:
1. 管理 Memori 实例的创建和缓存
2. 支持多租户隔离entity_id + process_id
3. 处理数据库连接和会话管理
4. 提供记忆召回和存储接口
"""
def __init__(
self,
db_pool: Optional[AsyncConnectionPool] = None,
db_url: Optional[str] = None,
api_key: Optional[str] = None,
):
"""初始化 MemoriManager
Args:
db_pool: PostgreSQL 异步连接池(与 Checkpointer 共享)
db_url: 数据库连接 URL如果不使用连接池
api_key: Memori API 密钥(用于高级增强功能)
"""
self._db_pool = db_pool
self._db_url = db_url
self._api_key = api_key
# 缓存 Memori 实例: key = f"{entity_id}:{process_id}"
self._instances: Dict[str, Any] = {}
self._sync_engines: Dict[str, Any] = {}
self._initialized = False
@property
def db_url(self) -> Optional[str]:
"""获取数据库 URL"""
if self._db_url:
return self._db_url
# Fallback 到 settings
from utils.settings import CHECKPOINT_DB_URL
return CHECKPOINT_DB_URL
async def initialize(self) -> None:
"""初始化 MemoriManager
创建数据库表结构(如果不存在)
"""
if self._initialized:
return
logger.info("Initializing MemoriManager...")
try:
# 创建第一个 Memori 实例来初始化表结构
if self._db_pool or self._db_url:
db_url = self._db_url or getattr(self._db_pool, "_url", None)
if db_url:
await self._build_schema(db_url)
self._initialized = True
logger.info("MemoriManager initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize MemoriManager: {e}")
# 不抛出异常,允许系统在没有 Memori 的情况下运行
async def _build_schema(self, db_url: str) -> None:
"""构建 Memori 数据库表结构
Args:
db_url: 数据库连接 URL
"""
try:
from memori import Memori
# 创建同步引擎用于初始化
engine = create_engine(db_url)
SessionLocal = sessionmaker(bind=engine)
# 创建 Memori 实例并构建表结构
mem = Memori(conn=SessionLocal)
mem.config.storage.build()
logger.info("Memori schema built successfully")
except ImportError:
logger.warning("memori package not available, skipping schema build")
except Exception as e:
logger.error(f"Failed to build Memori schema: {e}")
def _get_sync_session(self, db_url: str) -> Session:
"""获取同步数据库会话Memori 需要)
Args:
db_url: 数据库连接 URL
Returns:
SQLAlchemy Session
"""
if db_url not in self._sync_engines:
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
engine = create_engine(db_url, pool_pre_ping=True)
self._sync_engines[db_url] = sessionmaker(bind=engine)
return self._sync_engines[db_url]()
async def get_memori(
self,
entity_id: str,
process_id: str,
session_id: str,
config: Optional[MemoriConfig] = None,
) -> Any:
"""获取或创建 Memori 实例
Args:
entity_id: 实体 ID通常是 user_identifier
process_id: 进程 ID通常是 bot_id
session_id: 会话 ID
config: Memori 配置
Returns:
Memori 实例
"""
cache_key = f"{entity_id}:{process_id}"
# 检查缓存
if cache_key in self._instances:
memori_instance = self._instances[cache_key]
# 更新会话
memori_instance.config.session_id = session_id
return memori_instance
# 创建新实例
memori_instance = await self._create_memori_instance(
entity_id=entity_id,
process_id=process_id,
session_id=session_id,
config=config,
)
# 缓存实例
self._instances[cache_key] = memori_instance
return memori_instance
async def _create_memori_instance(
self,
entity_id: str,
process_id: str,
session_id: str,
config: Optional[MemoriConfig] = None,
) -> Any:
"""创建新的 Memori 实例
Args:
entity_id: 实体 ID
process_id: 进程 ID
session_id: 会话 ID
config: Memori 配置
Returns:
Memori 实例
"""
try:
from memori import Memori
except ImportError:
logger.error("memori package not installed")
raise RuntimeError("memori package is required but not installed")
# 获取数据库连接 URL
db_url = self.db_url
if not db_url:
raise ValueError("Database URL not available")
# 创建同步会话Memori 目前需要同步连接)
session_factory = self._get_sync_session(db_url)
# 创建 Memori 实例
mem = Memori(conn=session_factory)
# 设置 API 密钥(如果提供)
if self._api_key or (config and config.api_key):
api_key = config.api_key if config else self._api_key
mem.config.api_key = api_key
# 设置 attribution
mem.attribution(entity_id=entity_id, process_id=process_id)
# 设置会话
mem.config.session_id = session_id
# 配置召回参数
if config:
mem.config.recall_facts_limit = config.semantic_search_top_k
mem.config.recall_relevance_threshold = config.semantic_search_threshold
mem.config.recall_embeddings_limit = config.semantic_search_embeddings_limit
logger.info(
f"Created Memori instance: entity={entity_id}, process={process_id}, session={session_id}"
)
return mem
async def recall_memories(
self,
query: str,
entity_id: str,
process_id: str,
session_id: str,
config: Optional[MemoriConfig] = None,
) -> List[Dict[str, Any]]:
"""召回相关记忆
Args:
query: 查询文本
entity_id: 实体 ID
process_id: 进程 ID
session_id: 会话 ID
config: Memori 配置
Returns:
记忆列表,每个记忆包含 content, similarity 等字段
"""
try:
mem = await self.get_memori(entity_id, process_id, session_id, config)
# 调用 recall 进行语义搜索
results = mem.recall(query=query, limit=config.semantic_search_top_k if config else 5)
# 转换为统一格式
memories = []
for result in results:
memory = {
"content": result.get("content", ""),
"similarity": result.get("similarity", 0.0),
"fact_type": result.get("fact_type", "unknown"),
"created_at": result.get("created_at"),
}
# 过滤低相关度记忆
threshold = config.semantic_search_threshold if config else 0.7
if memory["similarity"] >= threshold:
memories.append(memory)
logger.info(f"Recalled {len(memories)} memories for query: {query[:50]}...")
return memories
except Exception as e:
logger.error(f"Failed to recall memories: {e}")
return []
async def wait_for_augmentation(
self,
entity_id: str,
process_id: str,
session_id: str,
timeout: Optional[float] = None,
) -> None:
"""等待后台增强任务完成
Args:
entity_id: 实体 ID
process_id: 进程 ID
session_id: 会话 ID
timeout: 超时时间(秒)
"""
try:
mem = await self.get_memori(entity_id, process_id, session_id)
if timeout:
# 在线程池中运行同步的 wait()
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: mem.augmentation.wait(timeout=timeout))
else:
# 无限等待
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, mem.augmentation.wait)
except Exception as e:
logger.error(f"Failed to wait for augmentation: {e}")
def clear_cache(self, entity_id: Optional[str] = None, process_id: Optional[str] = None) -> None:
"""清除缓存的 Memori 实例
Args:
entity_id: 实体 ID如果为 None清除所有
process_id: 进程 ID如果为 None清除所有
"""
if entity_id is None and process_id is None:
self._instances.clear()
logger.info("Cleared all Memori instances from cache")
else:
keys_to_remove = []
for key in self._instances:
e_id, p_id = key.split(":")
if entity_id and e_id != entity_id:
continue
if process_id and p_id != process_id:
continue
keys_to_remove.append(key)
for key in keys_to_remove:
del self._instances[key]
logger.info(f"Cleared {len(keys_to_remove)} Memori instances from cache")
async def close(self) -> None:
"""关闭管理器并清理资源"""
logger.info("Closing MemoriManager...")
# 清理缓存的实例
self._instances.clear()
# 关闭同步引擎
for engine in self._sync_engines.values():
try:
engine.dispose()
except Exception as e:
logger.error(f"Error closing engine: {e}")
self._sync_engines.clear()
self._initialized = False
logger.info("MemoriManager closed")
# 全局单例
_global_manager: Optional[MemoriManager] = None
def get_memori_manager() -> MemoriManager:
"""获取全局 MemoriManager 单例
Returns:
MemoriManager 实例
"""
global _global_manager
if _global_manager is None:
_global_manager = MemoriManager()
return _global_manager
async def init_global_memori(
db_pool: Optional[AsyncConnectionPool] = None,
db_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> MemoriManager:
"""初始化全局 MemoriManager
Args:
db_pool: PostgreSQL 连接池
db_url: 数据库连接 URL
api_key: Memori API 密钥
Returns:
MemoriManager 实例
"""
manager = get_memori_manager()
manager._db_pool = db_pool
manager._db_url = db_url
manager._api_key = api_key
await manager.initialize()
return manager
async def close_global_memori() -> None:
"""关闭全局 MemoriManager"""
global _global_manager
if _global_manager is not None:
await _global_manager.close()