442 lines
13 KiB
Python
442 lines
13 KiB
Python
"""
|
||
Mem0 连接和实例管理器
|
||
负责管理 Mem0 客户端实例的创建、缓存和生命周期
|
||
"""
|
||
|
||
import logging
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
import json_repair
|
||
from psycopg2 import pool
|
||
|
||
from .mem0_config import Mem0Config
|
||
from utils.settings import MEM0_EMBEDDING_MODEL
|
||
|
||
logger = logging.getLogger("app")
|
||
|
||
|
||
# Monkey patch: 使用 json_repair 替换 mem0 的 remove_code_blocks
|
||
def _remove_code_blocks_with_repair(content: str) -> str:
|
||
"""
|
||
使用 json_repair 替换 mem0 的 remove_code_blocks 函数
|
||
|
||
json_repair.loads 会自动处理:
|
||
- 移除代码块标记(```json, ``` 等)
|
||
- 修复损坏的 JSON(如尾随逗号、注释、单引号等)
|
||
"""
|
||
import re
|
||
|
||
content_stripped = content.strip()
|
||
|
||
try:
|
||
# json_repair.loads 会自动去除代码块并修复 JSON
|
||
result = json_repair.loads(content_stripped)
|
||
if isinstance(result, (dict, list)):
|
||
import json
|
||
return json.dumps(result, ensure_ascii=False)
|
||
# 如果返回空字符串(非 JSON 输入),回退到原内容
|
||
if result == "" and content_stripped != "":
|
||
# 尝试简单的代码块去除(降级处理)
|
||
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
|
||
match = re.match(pattern, content_stripped)
|
||
if match:
|
||
return match.group(1).strip()
|
||
return content_stripped
|
||
return str(result)
|
||
except Exception:
|
||
# 如果解析失败,尝试简单的代码块去除(降级处理)
|
||
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
|
||
match = re.match(pattern, content_stripped)
|
||
if match:
|
||
return match.group(1).strip()
|
||
return content_stripped
|
||
|
||
|
||
# 执行 monkey patch(在 mem0 导入之前或之后)
|
||
try:
|
||
import sys
|
||
import mem0.memory.utils as mem0_utils
|
||
mem0_utils.remove_code_blocks = _remove_code_blocks_with_repair
|
||
|
||
# 如果 mem0.memory.main 已经导入,也要 patch 它的本地引用
|
||
if 'mem0.memory.main' in sys.modules:
|
||
import mem0.memory.main
|
||
mem0.memory.main.remove_code_blocks = _remove_code_blocks_with_repair
|
||
logger.info("Successfully patched mem0.memory.main.remove_code_blocks with json_repair")
|
||
else:
|
||
logger.info("Successfully patched mem0.memory.utils.remove_code_blocks with json_repair")
|
||
except ImportError:
|
||
# mem0 还未导入,patch 将在首次导入时生效
|
||
pass
|
||
except Exception as e:
|
||
logger.warning(f"Failed to patch mem0 remove_code_blocks: {e}")
|
||
|
||
|
||
class Mem0Manager:
|
||
"""
|
||
Mem0 连接和实例管理器
|
||
|
||
主要功能:
|
||
1. 管理 Mem0 实例的创建和缓存
|
||
2. 支持多租户隔离(user_id + agent_id)
|
||
3. 使用共享的同步连接池(由 DBPoolManager 提供)
|
||
4. 提供记忆召回和存储接口
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
sync_pool: Optional[pool.SimpleConnectionPool] = None,
|
||
):
|
||
"""初始化 Mem0Manager
|
||
|
||
Args:
|
||
sync_pool: PostgreSQL 同步连接池(与 DBPoolManager 共享)
|
||
"""
|
||
self._sync_pool = sync_pool
|
||
|
||
# 缓存 Mem0 实例: key = f"{user_id}:{agent_id}"
|
||
self._instances: Dict[str, Any] = {}
|
||
self._initialized = False
|
||
|
||
async def initialize(self) -> None:
|
||
"""初始化 Mem0Manager
|
||
|
||
创建数据库表结构(如果不存在)
|
||
"""
|
||
if self._initialized:
|
||
return
|
||
|
||
logger.info("Initializing Mem0Manager...")
|
||
|
||
try:
|
||
# Mem0 会自动创建表结构,这里只需验证连接
|
||
if self._sync_pool:
|
||
logger.info("Mem0Manager initialized successfully")
|
||
else:
|
||
logger.warning("No database configuration provided for Mem0")
|
||
|
||
self._initialized = True
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize Mem0Manager: {e}")
|
||
# 不抛出异常,允许系统在没有 Mem0 的情况下运行
|
||
|
||
def _get_connection_pool(self) -> Optional[pool.SimpleConnectionPool]:
|
||
"""获取同步数据库连接池(Mem0 需要)
|
||
|
||
Returns:
|
||
psycopg2.pool 连接池
|
||
"""
|
||
return self._sync_pool
|
||
|
||
async def get_mem0(
|
||
self,
|
||
user_id: str,
|
||
agent_id: str,
|
||
session_id: str,
|
||
config: Optional[Mem0Config] = None,
|
||
) -> Any:
|
||
"""获取或创建 Mem0 实例
|
||
|
||
Args:
|
||
user_id: 用户 ID(对应 entity_id)
|
||
agent_id: Agent/Bot ID(对应 process_id)
|
||
session_id: 会话 ID
|
||
config: Mem0 配置
|
||
|
||
Returns:
|
||
Mem0 实例
|
||
"""
|
||
# 缓存键包含 LLM 实例 ID,以确保不同 LLM 使用不同实例
|
||
llm_suffix = ""
|
||
if config and config.llm_instance is not None:
|
||
llm_suffix = f":{id(config.llm_instance)}"
|
||
cache_key = f"{user_id}:{agent_id}{llm_suffix}"
|
||
|
||
# 检查缓存
|
||
if cache_key in self._instances:
|
||
return self._instances[cache_key]
|
||
|
||
# 创建新实例
|
||
mem0_instance = await self._create_mem0_instance(
|
||
user_id=user_id,
|
||
agent_id=agent_id,
|
||
session_id=session_id,
|
||
config=config,
|
||
)
|
||
|
||
# 缓存实例
|
||
self._instances[cache_key] = mem0_instance
|
||
return mem0_instance
|
||
|
||
async def _create_mem0_instance(
|
||
self,
|
||
user_id: str,
|
||
agent_id: str,
|
||
session_id: str,
|
||
config: Optional[Mem0Config] = None,
|
||
) -> Any:
|
||
"""创建新的 Mem0 实例
|
||
|
||
Args:
|
||
user_id: 用户 ID
|
||
agent_id: Agent/Bot ID
|
||
session_id: 会话 ID
|
||
config: Mem0 配置(包含 LLM 实例)
|
||
|
||
Returns:
|
||
Mem0 Memory 实例
|
||
"""
|
||
try:
|
||
from mem0 import Memory
|
||
except ImportError:
|
||
logger.error("mem0 package not installed")
|
||
raise RuntimeError("mem0 package is required but not installed")
|
||
|
||
# 获取同步连接池
|
||
connection_pool = self._get_connection_pool()
|
||
if not connection_pool:
|
||
raise ValueError("Database connection pool not available")
|
||
|
||
# 配置 Mem0 使用 Pgvector
|
||
config_dict = {
|
||
"vector_store": {
|
||
"provider": "pgvector",
|
||
"config": {
|
||
"connection_pool": connection_pool,
|
||
"collection_name": f"mem0_{agent_id}".replace("-", "_")[:50], # 按 agent_id 隔离
|
||
"embedding_model_dims": 384, # paraphrase-multilingual-MiniLM-L12-v2 的维度
|
||
}
|
||
},
|
||
"embedder": {
|
||
"provider": "huggingface",
|
||
"config": {
|
||
"model": MEM0_EMBEDDING_MODEL,
|
||
"embedding_dims":384
|
||
}
|
||
}
|
||
}
|
||
|
||
# 添加 LangChain LLM 配置(如果提供了)
|
||
if config and config.llm_instance is not None:
|
||
config_dict["llm"] = {
|
||
"provider": "langchain",
|
||
"config": {"model": config.llm_instance}
|
||
}
|
||
logger.info(
|
||
f"Configured LangChain LLM for Mem0: {type(config.llm_instance).__name__}"
|
||
)
|
||
|
||
# 创建 Mem0 实例
|
||
mem = Memory.from_config(config_dict)
|
||
|
||
logger.info(
|
||
f"Created Mem0 instance: user={user_id}, agent={agent_id}"
|
||
)
|
||
|
||
return mem
|
||
|
||
async def recall_memories(
|
||
self,
|
||
query: str,
|
||
user_id: str,
|
||
agent_id: str,
|
||
config: Optional[Mem0Config] = None,
|
||
) -> List[Dict[str, Any]]:
|
||
"""召回相关记忆(用户级别,跨会话共享)
|
||
|
||
Args:
|
||
query: 查询文本
|
||
user_id: 用户 ID
|
||
agent_id: Agent/Bot ID
|
||
config: Mem0 配置
|
||
|
||
Returns:
|
||
记忆列表,每个记忆包含 content, similarity 等字段
|
||
"""
|
||
try:
|
||
mem = await self.get_mem0(user_id, agent_id, "default", config)
|
||
|
||
# 调用 search 进行语义搜索(使用 agent_id 参数过滤)
|
||
limit = config.semantic_search_top_k if config else 20
|
||
results = mem.search(
|
||
query=query,
|
||
limit=limit,
|
||
user_id=user_id,
|
||
agent_id=agent_id,
|
||
)
|
||
|
||
# 转换为统一格式
|
||
memories = []
|
||
for result in results["results"]:
|
||
# Mem0 返回结果可能是字符串或字典
|
||
content = result.get("memory", "")
|
||
score = result.get("score", 0.0)
|
||
result_metadata = result.get("metadata", {})
|
||
|
||
memory = {
|
||
"content": content,
|
||
"similarity": score,
|
||
"metadata": result_metadata,
|
||
"fact_type": result_metadata.get("category", "fact"),
|
||
}
|
||
memories.append(memory)
|
||
|
||
logger.info(f"Recalled {len(memories)} memories for user={user_id}, query: {query[:50]}...")
|
||
return memories
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to recall memories: {e}")
|
||
return []
|
||
|
||
async def add_memory(
|
||
self,
|
||
text: str,
|
||
user_id: str,
|
||
agent_id: str,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
config: Optional[Mem0Config] = None,
|
||
) -> Dict[str, Any]:
|
||
"""添加新记忆(用户级别,跨会话共享)
|
||
|
||
Args:
|
||
text: 记忆文本
|
||
user_id: 用户 ID
|
||
agent_id: Agent/Bot ID
|
||
metadata: 额外的元数据
|
||
config: Mem0 配置(包含 LLM 实例用于记忆提取)
|
||
|
||
Returns:
|
||
添加的记忆结果
|
||
"""
|
||
try:
|
||
mem = await self.get_mem0(user_id, agent_id, "default", config)
|
||
|
||
# 添加记忆(使用 agent_id 参数)
|
||
result = mem.add(
|
||
text,
|
||
user_id=user_id,
|
||
agent_id=agent_id,
|
||
metadata=metadata or {}
|
||
)
|
||
|
||
logger.info(f"Added memory for user={user_id}, agent={agent_id}: {result}")
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to add memory: {e}")
|
||
return {}
|
||
|
||
async def get_all_memories(
|
||
self,
|
||
user_id: str,
|
||
agent_id: str,
|
||
) -> List[Dict[str, Any]]:
|
||
"""获取用户的所有记忆(用户级别)
|
||
|
||
Args:
|
||
user_id: 用户 ID
|
||
agent_id: Agent/Bot ID
|
||
|
||
Returns:
|
||
记忆列表
|
||
"""
|
||
try:
|
||
mem = await self.get_mem0(user_id, agent_id, "default")
|
||
|
||
# 获取所有记忆
|
||
memories = mem.get_all(user_id=user_id)
|
||
|
||
# 过滤 agent_id
|
||
filtered_memories = [
|
||
m for m in memories
|
||
if m.get("metadata", {}).get("agent_id") == agent_id
|
||
]
|
||
|
||
return filtered_memories
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get all memories: {e}")
|
||
return []
|
||
|
||
def clear_cache(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> None:
|
||
"""清除缓存的 Mem0 实例
|
||
|
||
Args:
|
||
user_id: 用户 ID(如果为 None,清除所有)
|
||
agent_id: Agent ID(如果为 None,清除所有)
|
||
"""
|
||
if user_id is None and agent_id is None:
|
||
self._instances.clear()
|
||
logger.info("Cleared all Mem0 instances from cache")
|
||
else:
|
||
keys_to_remove = []
|
||
for key in self._instances:
|
||
# 新格式: "user_id:agent_id:llm_model_name" 或 "user_id:agent_id"
|
||
parts = key.split(":")
|
||
if len(parts) >= 2:
|
||
u_id = parts[0]
|
||
a_id = parts[1]
|
||
if user_id and u_id != user_id:
|
||
continue
|
||
if agent_id and a_id != agent_id:
|
||
continue
|
||
keys_to_remove.append(key)
|
||
|
||
for key in keys_to_remove:
|
||
del self._instances[key]
|
||
|
||
logger.info(f"Cleared {len(keys_to_remove)} Mem0 instances from cache")
|
||
|
||
async def close(self) -> None:
|
||
"""关闭管理器并清理资源"""
|
||
logger.info("Closing Mem0Manager...")
|
||
|
||
# 清理缓存的实例
|
||
self._instances.clear()
|
||
|
||
# 注意:不关闭共享的同步连接池(由 DBPoolManager 管理)
|
||
|
||
self._initialized = False
|
||
|
||
logger.info("Mem0Manager closed")
|
||
|
||
|
||
# 全局单例
|
||
_global_manager: Optional[Mem0Manager] = None
|
||
|
||
|
||
def get_mem0_manager() -> Mem0Manager:
|
||
"""获取全局 Mem0Manager 单例
|
||
|
||
Returns:
|
||
Mem0Manager 实例
|
||
"""
|
||
global _global_manager
|
||
if _global_manager is None:
|
||
_global_manager = Mem0Manager()
|
||
return _global_manager
|
||
|
||
|
||
async def init_global_mem0(
|
||
sync_pool: pool.SimpleConnectionPool,
|
||
) -> Mem0Manager:
|
||
"""初始化全局 Mem0Manager
|
||
|
||
Args:
|
||
sync_pool: PostgreSQL 同步连接池(从 DBPoolManager.sync_pool 获取)
|
||
|
||
Returns:
|
||
Mem0Manager 实例
|
||
"""
|
||
manager = get_mem0_manager()
|
||
manager._sync_pool = sync_pool
|
||
await manager.initialize()
|
||
return manager
|
||
|
||
|
||
async def close_global_mem0() -> None:
|
||
"""关闭全局 Mem0Manager"""
|
||
global _global_manager
|
||
if _global_manager is not None:
|
||
await _global_manager.close()
|