qwen_agent/agent/mem0_manager.py
2026-01-20 21:30:32 +08:00

442 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Mem0 连接和实例管理器
负责管理 Mem0 客户端实例的创建、缓存和生命周期
"""
import logging
from typing import Any, Dict, List, Optional
import json_repair
from psycopg2 import pool
from .mem0_config import Mem0Config
from utils.settings import MEM0_EMBEDDING_MODEL
logger = logging.getLogger("app")
# Monkey patch: 使用 json_repair 替换 mem0 的 remove_code_blocks
def _remove_code_blocks_with_repair(content: str) -> str:
"""
使用 json_repair 替换 mem0 的 remove_code_blocks 函数
json_repair.loads 会自动处理:
- 移除代码块标记(```json, ``` 等)
- 修复损坏的 JSON如尾随逗号、注释、单引号等
"""
import re
content_stripped = content.strip()
try:
# json_repair.loads 会自动去除代码块并修复 JSON
result = json_repair.loads(content_stripped)
if isinstance(result, (dict, list)):
import json
return json.dumps(result, ensure_ascii=False)
# 如果返回空字符串(非 JSON 输入),回退到原内容
if result == "" and content_stripped != "":
# 尝试简单的代码块去除(降级处理)
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
match = re.match(pattern, content_stripped)
if match:
return match.group(1).strip()
return content_stripped
return str(result)
except Exception:
# 如果解析失败,尝试简单的代码块去除(降级处理)
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
match = re.match(pattern, content_stripped)
if match:
return match.group(1).strip()
return content_stripped
# 执行 monkey patch在 mem0 导入之前或之后)
try:
import sys
import mem0.memory.utils as mem0_utils
mem0_utils.remove_code_blocks = _remove_code_blocks_with_repair
# 如果 mem0.memory.main 已经导入,也要 patch 它的本地引用
if 'mem0.memory.main' in sys.modules:
import mem0.memory.main
mem0.memory.main.remove_code_blocks = _remove_code_blocks_with_repair
logger.info("Successfully patched mem0.memory.main.remove_code_blocks with json_repair")
else:
logger.info("Successfully patched mem0.memory.utils.remove_code_blocks with json_repair")
except ImportError:
# mem0 还未导入patch 将在首次导入时生效
pass
except Exception as e:
logger.warning(f"Failed to patch mem0 remove_code_blocks: {e}")
class Mem0Manager:
"""
Mem0 连接和实例管理器
主要功能:
1. 管理 Mem0 实例的创建和缓存
2. 支持多租户隔离user_id + agent_id
3. 使用共享的同步连接池(由 DBPoolManager 提供)
4. 提供记忆召回和存储接口
"""
def __init__(
self,
sync_pool: Optional[pool.SimpleConnectionPool] = None,
):
"""初始化 Mem0Manager
Args:
sync_pool: PostgreSQL 同步连接池(与 DBPoolManager 共享)
"""
self._sync_pool = sync_pool
# 缓存 Mem0 实例: key = f"{user_id}:{agent_id}"
self._instances: Dict[str, Any] = {}
self._initialized = False
async def initialize(self) -> None:
"""初始化 Mem0Manager
创建数据库表结构(如果不存在)
"""
if self._initialized:
return
logger.info("Initializing Mem0Manager...")
try:
# Mem0 会自动创建表结构,这里只需验证连接
if self._sync_pool:
logger.info("Mem0Manager initialized successfully")
else:
logger.warning("No database configuration provided for Mem0")
self._initialized = True
except Exception as e:
logger.error(f"Failed to initialize Mem0Manager: {e}")
# 不抛出异常,允许系统在没有 Mem0 的情况下运行
def _get_connection_pool(self) -> Optional[pool.SimpleConnectionPool]:
"""获取同步数据库连接池Mem0 需要)
Returns:
psycopg2.pool 连接池
"""
return self._sync_pool
async def get_mem0(
self,
user_id: str,
agent_id: str,
session_id: str,
config: Optional[Mem0Config] = None,
) -> Any:
"""获取或创建 Mem0 实例
Args:
user_id: 用户 ID对应 entity_id
agent_id: Agent/Bot ID对应 process_id
session_id: 会话 ID
config: Mem0 配置
Returns:
Mem0 实例
"""
# 缓存键包含 LLM 实例 ID以确保不同 LLM 使用不同实例
llm_suffix = ""
if config and config.llm_instance is not None:
llm_suffix = f":{id(config.llm_instance)}"
cache_key = f"{user_id}:{agent_id}{llm_suffix}"
# 检查缓存
if cache_key in self._instances:
return self._instances[cache_key]
# 创建新实例
mem0_instance = await self._create_mem0_instance(
user_id=user_id,
agent_id=agent_id,
session_id=session_id,
config=config,
)
# 缓存实例
self._instances[cache_key] = mem0_instance
return mem0_instance
async def _create_mem0_instance(
self,
user_id: str,
agent_id: str,
session_id: str,
config: Optional[Mem0Config] = None,
) -> Any:
"""创建新的 Mem0 实例
Args:
user_id: 用户 ID
agent_id: Agent/Bot ID
session_id: 会话 ID
config: Mem0 配置(包含 LLM 实例)
Returns:
Mem0 Memory 实例
"""
try:
from mem0 import Memory
except ImportError:
logger.error("mem0 package not installed")
raise RuntimeError("mem0 package is required but not installed")
# 获取同步连接池
connection_pool = self._get_connection_pool()
if not connection_pool:
raise ValueError("Database connection pool not available")
# 配置 Mem0 使用 Pgvector
config_dict = {
"vector_store": {
"provider": "pgvector",
"config": {
"connection_pool": connection_pool,
"collection_name": f"mem0_{agent_id}".replace("-", "_")[:50], # 按 agent_id 隔离
"embedding_model_dims": 384, # paraphrase-multilingual-MiniLM-L12-v2 的维度
}
},
"embedder": {
"provider": "huggingface",
"config": {
"model": MEM0_EMBEDDING_MODEL,
"embedding_dims":384
}
}
}
# 添加 LangChain LLM 配置(如果提供了)
if config and config.llm_instance is not None:
config_dict["llm"] = {
"provider": "langchain",
"config": {"model": config.llm_instance}
}
logger.info(
f"Configured LangChain LLM for Mem0: {type(config.llm_instance).__name__}"
)
# 创建 Mem0 实例
mem = Memory.from_config(config_dict)
logger.info(
f"Created Mem0 instance: user={user_id}, agent={agent_id}"
)
return mem
async def recall_memories(
self,
query: str,
user_id: str,
agent_id: str,
config: Optional[Mem0Config] = None,
) -> List[Dict[str, Any]]:
"""召回相关记忆(用户级别,跨会话共享)
Args:
query: 查询文本
user_id: 用户 ID
agent_id: Agent/Bot ID
config: Mem0 配置
Returns:
记忆列表,每个记忆包含 content, similarity 等字段
"""
try:
mem = await self.get_mem0(user_id, agent_id, "default", config)
# 调用 search 进行语义搜索(使用 agent_id 参数过滤)
limit = config.semantic_search_top_k if config else 20
results = mem.search(
query=query,
limit=limit,
user_id=user_id,
agent_id=agent_id,
)
# 转换为统一格式
memories = []
for result in results["results"]:
# Mem0 返回结果可能是字符串或字典
content = result.get("memory", "")
score = result.get("score", 0.0)
result_metadata = result.get("metadata", {})
memory = {
"content": content,
"similarity": score,
"metadata": result_metadata,
"fact_type": result_metadata.get("category", "fact"),
}
memories.append(memory)
logger.info(f"Recalled {len(memories)} memories for user={user_id}, query: {query[:50]}...")
return memories
except Exception as e:
logger.error(f"Failed to recall memories: {e}")
return []
async def add_memory(
self,
text: str,
user_id: str,
agent_id: str,
metadata: Optional[Dict[str, Any]] = None,
config: Optional[Mem0Config] = None,
) -> Dict[str, Any]:
"""添加新记忆(用户级别,跨会话共享)
Args:
text: 记忆文本
user_id: 用户 ID
agent_id: Agent/Bot ID
metadata: 额外的元数据
config: Mem0 配置(包含 LLM 实例用于记忆提取)
Returns:
添加的记忆结果
"""
try:
mem = await self.get_mem0(user_id, agent_id, "default", config)
# 添加记忆(使用 agent_id 参数)
result = mem.add(
text,
user_id=user_id,
agent_id=agent_id,
metadata=metadata or {}
)
logger.info(f"Added memory for user={user_id}, agent={agent_id}: {result}")
return result
except Exception as e:
logger.error(f"Failed to add memory: {e}")
return {}
async def get_all_memories(
self,
user_id: str,
agent_id: str,
) -> List[Dict[str, Any]]:
"""获取用户的所有记忆(用户级别)
Args:
user_id: 用户 ID
agent_id: Agent/Bot ID
Returns:
记忆列表
"""
try:
mem = await self.get_mem0(user_id, agent_id, "default")
# 获取所有记忆
memories = mem.get_all(user_id=user_id)
# 过滤 agent_id
filtered_memories = [
m for m in memories
if m.get("metadata", {}).get("agent_id") == agent_id
]
return filtered_memories
except Exception as e:
logger.error(f"Failed to get all memories: {e}")
return []
def clear_cache(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> None:
"""清除缓存的 Mem0 实例
Args:
user_id: 用户 ID如果为 None清除所有
agent_id: Agent ID如果为 None清除所有
"""
if user_id is None and agent_id is None:
self._instances.clear()
logger.info("Cleared all Mem0 instances from cache")
else:
keys_to_remove = []
for key in self._instances:
# 新格式: "user_id:agent_id:llm_model_name" 或 "user_id:agent_id"
parts = key.split(":")
if len(parts) >= 2:
u_id = parts[0]
a_id = parts[1]
if user_id and u_id != user_id:
continue
if agent_id and a_id != agent_id:
continue
keys_to_remove.append(key)
for key in keys_to_remove:
del self._instances[key]
logger.info(f"Cleared {len(keys_to_remove)} Mem0 instances from cache")
async def close(self) -> None:
"""关闭管理器并清理资源"""
logger.info("Closing Mem0Manager...")
# 清理缓存的实例
self._instances.clear()
# 注意:不关闭共享的同步连接池(由 DBPoolManager 管理)
self._initialized = False
logger.info("Mem0Manager closed")
# 全局单例
_global_manager: Optional[Mem0Manager] = None
def get_mem0_manager() -> Mem0Manager:
"""获取全局 Mem0Manager 单例
Returns:
Mem0Manager 实例
"""
global _global_manager
if _global_manager is None:
_global_manager = Mem0Manager()
return _global_manager
async def init_global_mem0(
sync_pool: pool.SimpleConnectionPool,
) -> Mem0Manager:
"""初始化全局 Mem0Manager
Args:
sync_pool: PostgreSQL 同步连接池(从 DBPoolManager.sync_pool 获取)
Returns:
Mem0Manager 实例
"""
manager = get_mem0_manager()
manager._sync_pool = sync_pool
await manager.initialize()
return manager
async def close_global_mem0() -> None:
"""关闭全局 Mem0Manager"""
global _global_manager
if _global_manager is not None:
await _global_manager.close()