- Fix mem0 connection pool exhausted error with proper pooling - Convert memory operations to async tasks - Optimize docker-compose configuration - Add skill upload functionality - Reduce cache size for better performance - Update dependencies Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
569 lines
19 KiB
Python
569 lines
19 KiB
Python
"""
|
||
Mem0 连接和实例管理器
|
||
负责管理 Mem0 客户端实例的创建、缓存和生命周期
|
||
"""
|
||
|
||
import logging
|
||
import asyncio
|
||
import threading
|
||
import concurrent.futures
|
||
from typing import Any, Dict, List, Optional, Literal
|
||
from collections import OrderedDict
|
||
from embedding.manager import GlobalModelManager, get_model_manager
|
||
import json_repair
|
||
from psycopg2 import pool
|
||
from utils.settings import (
|
||
MEM0_POOL_SIZE
|
||
)
|
||
from .mem0_config import Mem0Config
|
||
|
||
logger = logging.getLogger("app")
|
||
|
||
|
||
# ============================================================================
|
||
# 自定义 Embedding 类,使用项目中已有的 GlobalModelManager
|
||
# 避免重复加载模型
|
||
# ============================================================================
|
||
|
||
class CustomMem0Embedding:
|
||
"""
|
||
自定义 Mem0 Embedding 类,直接使用项目中已有的 GlobalModelManager
|
||
|
||
这样 Mem0 就不需要再次加载同一个模型,节省内存
|
||
"""
|
||
|
||
_model = None # 类变量,缓存模型实例
|
||
_lock = threading.Lock() # 线程安全锁
|
||
_executor = None # 线程池执行器
|
||
|
||
def __init__(self, config: Optional[Any] = None):
|
||
"""初始化自定义 Embedding"""
|
||
# 创建一个简单的 config 对象来兼容 Mem0 的 telemetry 代码
|
||
if config is None:
|
||
config = type('Config', (), {'embedding_dims': 384})()
|
||
self.config = config
|
||
|
||
@property
|
||
def embedding_dims(self):
|
||
"""获取 embedding 维度"""
|
||
return 384 # gte-tiny 的维度
|
||
|
||
def _get_model_sync(self):
|
||
"""同步获取模型,避免 asyncio.run()"""
|
||
# 首先尝试从 manager 获取已加载的模型
|
||
manager = get_model_manager()
|
||
model = manager.get_model_sync()
|
||
|
||
if model is not None:
|
||
# 缓存模型
|
||
CustomMem0Embedding._model = model
|
||
return model
|
||
|
||
# 如果模型未加载,使用线程池运行异步初始化
|
||
if CustomMem0Embedding._executor is None:
|
||
CustomMem0Embedding._executor = concurrent.futures.ThreadPoolExecutor(
|
||
max_workers=1,
|
||
thread_name_prefix="mem0_embed"
|
||
)
|
||
|
||
# 在独立线程中运行异步代码
|
||
def run_async_in_thread():
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
try:
|
||
result = loop.run_until_complete(manager.get_model())
|
||
return result
|
||
finally:
|
||
loop.close()
|
||
|
||
future = CustomMem0Embedding._executor.submit(run_async_in_thread)
|
||
model = future.result(timeout=30) # 30秒超时
|
||
|
||
# 缓存模型
|
||
CustomMem0Embedding._model = model
|
||
return model
|
||
|
||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||
"""
|
||
获取文本的 embedding 向量(同步方法,供 Mem0 调用)
|
||
|
||
Args:
|
||
text: 要嵌入的文本(字符串或列表)
|
||
memory_action: 记忆操作类型 (add/search/update),当前未使用
|
||
|
||
Returns:
|
||
list: embedding 向量
|
||
"""
|
||
# 线程安全地获取模型
|
||
if CustomMem0Embedding._model is None:
|
||
with CustomMem0Embedding._lock:
|
||
if CustomMem0Embedding._model is None:
|
||
self._get_model_sync()
|
||
|
||
model = CustomMem0Embedding._model
|
||
embeddings = model.encode(text, convert_to_numpy=True)
|
||
return embeddings.tolist()
|
||
|
||
# Monkey patch: 使用 json_repair 替换 mem0 的 remove_code_blocks
|
||
def _remove_code_blocks_with_repair(content: str) -> str:
|
||
"""
|
||
使用 json_repair 替换 mem0 的 remove_code_blocks 函数
|
||
|
||
json_repair.loads 会自动处理:
|
||
- 移除代码块标记(```json, ``` 等)
|
||
- 修复损坏的 JSON(如尾随逗号、注释、单引号等)
|
||
"""
|
||
import re
|
||
|
||
content_stripped = content.strip()
|
||
|
||
try:
|
||
# json_repair.loads 会自动去除代码块并修复 JSON
|
||
result = json_repair.loads(content_stripped)
|
||
if isinstance(result, (dict, list)):
|
||
import json
|
||
return json.dumps(result, ensure_ascii=False)
|
||
# 如果返回空字符串(非 JSON 输入),回退到原内容
|
||
if result == "" and content_stripped != "":
|
||
# 尝试简单的代码块去除(降级处理)
|
||
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
|
||
match = re.match(pattern, content_stripped)
|
||
if match:
|
||
return match.group(1).strip()
|
||
return content_stripped
|
||
return str(result)
|
||
except Exception:
|
||
# 如果解析失败,尝试简单的代码块去除(降级处理)
|
||
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
|
||
match = re.match(pattern, content_stripped)
|
||
if match:
|
||
return match.group(1).strip()
|
||
return content_stripped
|
||
|
||
|
||
# 执行 monkey patch(在 mem0 导入之前或之后)
|
||
try:
|
||
import sys
|
||
import mem0.memory.utils as mem0_utils
|
||
mem0_utils.remove_code_blocks = _remove_code_blocks_with_repair
|
||
|
||
# 如果 mem0.memory.main 已经导入,也要 patch 它的本地引用
|
||
if 'mem0.memory.main' in sys.modules:
|
||
import mem0.memory.main
|
||
mem0.memory.main.remove_code_blocks = _remove_code_blocks_with_repair
|
||
logger.info("Successfully patched mem0.memory.main.remove_code_blocks with json_repair")
|
||
else:
|
||
logger.info("Successfully patched mem0.memory.utils.remove_code_blocks with json_repair")
|
||
except ImportError:
|
||
# mem0 还未导入,patch 将在首次导入时生效
|
||
pass
|
||
except Exception as e:
|
||
logger.warning(f"Failed to patch mem0 remove_code_blocks: {e}")
|
||
|
||
|
||
class Mem0Manager:
|
||
"""
|
||
Mem0 连接和实例管理器
|
||
|
||
主要功能:
|
||
1. 管理 Mem0 实例的创建和缓存
|
||
2. 支持多租户隔离(user_id + agent_id)
|
||
3. 使用共享的同步连接池(由 DBPoolManager 提供)
|
||
4. 提供记忆召回和存储接口
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
sync_pool: Optional[pool.SimpleConnectionPool] = None,
|
||
):
|
||
"""初始化 Mem0Manager
|
||
|
||
Args:
|
||
sync_pool: PostgreSQL 同步连接池(与 DBPoolManager 共享)
|
||
"""
|
||
self._sync_pool = sync_pool
|
||
|
||
# 使用 OrderedDict 实现 LRU 缓存,最多保留 50 个实例
|
||
self._instances: OrderedDict[str, Any] = OrderedDict()
|
||
self._max_instances = MEM0_POOL_SIZE/2 # 最大缓存实例数
|
||
self._initialized = False
|
||
|
||
async def initialize(self) -> None:
|
||
"""初始化 Mem0Manager
|
||
|
||
创建数据库表结构(如果不存在)
|
||
"""
|
||
if self._initialized:
|
||
return
|
||
|
||
logger.info("Initializing Mem0Manager...")
|
||
|
||
try:
|
||
# Mem0 会自动创建表结构,这里只需验证连接
|
||
if self._sync_pool:
|
||
logger.info("Mem0Manager initialized successfully")
|
||
else:
|
||
logger.warning("No database configuration provided for Mem0")
|
||
|
||
self._initialized = True
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize Mem0Manager: {e}")
|
||
# 不抛出异常,允许系统在没有 Mem0 的情况下运行
|
||
|
||
def _get_connection_pool(self) -> Optional[pool.SimpleConnectionPool]:
|
||
"""获取同步数据库连接池(Mem0 需要)
|
||
|
||
Returns:
|
||
psycopg2.pool 连接池
|
||
"""
|
||
return self._sync_pool
|
||
|
||
async def get_mem0(
|
||
self,
|
||
user_id: str,
|
||
agent_id: str,
|
||
session_id: str,
|
||
config: Optional[Mem0Config] = None,
|
||
) -> Any:
|
||
"""获取或创建 Mem0 实例
|
||
|
||
Args:
|
||
user_id: 用户 ID(对应 entity_id)
|
||
agent_id: Agent/Bot ID(对应 process_id)
|
||
session_id: 会话 ID
|
||
config: Mem0 配置
|
||
|
||
Returns:
|
||
Mem0 实例
|
||
"""
|
||
# 缓存键包含 LLM 实例 ID,以确保不同 LLM 使用不同实例
|
||
llm_suffix = ""
|
||
if config and config.llm_instance is not None:
|
||
llm_suffix = f":{id(config.llm_instance)}"
|
||
cache_key = f"{user_id}:{agent_id}{llm_suffix}"
|
||
|
||
# 检查缓存(同时移动到末尾表示最近使用)
|
||
if cache_key in self._instances:
|
||
self._instances.move_to_end(cache_key)
|
||
return self._instances[cache_key]
|
||
|
||
# 检查缓存大小,超过则移除最旧的
|
||
if len(self._instances) >= self._max_instances:
|
||
removed_key, _ = self._instances.popitem(last=False)
|
||
logger.debug(f"Mem0 instance cache full, removed oldest entry: {removed_key}")
|
||
|
||
# 创建新实例
|
||
mem0_instance = await self._create_mem0_instance(
|
||
user_id=user_id,
|
||
agent_id=agent_id,
|
||
session_id=session_id,
|
||
config=config,
|
||
)
|
||
|
||
# 缓存实例(新实例自动在末尾)
|
||
self._instances[cache_key] = mem0_instance
|
||
return mem0_instance
|
||
|
||
async def _create_mem0_instance(
|
||
self,
|
||
user_id: str,
|
||
agent_id: str,
|
||
session_id: str,
|
||
config: Optional[Mem0Config] = None,
|
||
) -> Any:
|
||
"""创建新的 Mem0 实例
|
||
|
||
Args:
|
||
user_id: 用户 ID
|
||
agent_id: Agent/Bot ID
|
||
session_id: 会话 ID
|
||
config: Mem0 配置(包含 LLM 实例)
|
||
|
||
Returns:
|
||
Mem0 Memory 实例
|
||
"""
|
||
try:
|
||
from mem0 import Memory
|
||
except ImportError:
|
||
logger.error("mem0 package not installed")
|
||
raise RuntimeError("mem0 package is required but not installed")
|
||
|
||
# 获取同步连接池
|
||
connection_pool = self._get_connection_pool()
|
||
if not connection_pool:
|
||
raise ValueError("Database connection pool not available")
|
||
|
||
# 创建自定义 embedder(使用共享模型,避免重复加载)
|
||
custom_embedder = CustomMem0Embedding()
|
||
|
||
# 配置 Mem0 使用 Pgvector
|
||
# 注意:这里使用 huggingface_base_url 来绕过本地模型加载
|
||
# 设置一个假的 base_url,这样 HuggingFaceEmbedding 就不会加载 SentenceTransformer
|
||
|
||
config_dict = {
|
||
"custom_fact_extraction_prompt": config.get_custom_fact_extraction_prompt(),
|
||
"vector_store": {
|
||
"provider": "pgvector",
|
||
"config": {
|
||
"connection_pool": connection_pool,
|
||
"collection_name": f"mem0_{agent_id}".replace("-", "_")[:50], # 按 agent_id 隔离
|
||
"embedding_model_dims": 384, # paraphrase-multilingual-MiniLM-L12-v2 的维度
|
||
}
|
||
},
|
||
# 使用 huggingface_base_url 绕过模型加载(稍后会被替换为自定义 embedder)
|
||
"embedder": {
|
||
"provider": "huggingface",
|
||
"config": {
|
||
"huggingface_base_url": "http://dummy-url-that-will-be-replaced",
|
||
"api_key": "dummy-key" # 占位符,防止 OpenAI client 验证失败
|
||
}
|
||
}
|
||
}
|
||
|
||
# 添加 LangChain LLM 配置(如果提供了)
|
||
if config and config.llm_instance is not None:
|
||
config_dict["llm"] = {
|
||
"provider": "langchain",
|
||
"config": {"model": config.llm_instance}
|
||
}
|
||
logger.info(
|
||
f"Configured LangChain LLM for Mem0: {type(config.llm_instance).__name__}"
|
||
)
|
||
else:
|
||
# 如果没有提供 LLM,使用默认的 openai 配置
|
||
# Mem0 的 LLM 用于提取记忆事实
|
||
from utils.settings import MASTERKEY, BACKEND_HOST
|
||
import os
|
||
llm_api_key = os.environ.get("OPENAI_API_KEY", "") or MASTERKEY
|
||
config_dict["llm"] = {
|
||
"provider": "openai",
|
||
"config": {
|
||
"model": "gpt-4o-mini",
|
||
"api_key": llm_api_key,
|
||
"openai_base_url": BACKEND_HOST # 使用自定义 backend
|
||
}
|
||
}
|
||
|
||
# 创建 Mem0 实例
|
||
mem = Memory.from_config(config_dict)
|
||
logger.debug(f"Original embedder type: {type(mem.embedding_model).__name__}")
|
||
logger.debug(f"Original embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}")
|
||
|
||
# 替换为自定义 embedder,复用项目中已加载的模型
|
||
# 这样 Mem0 就不会重复加载模型
|
||
mem.embedding_model = custom_embedder
|
||
logger.debug(f"Replaced embedder type: {type(mem.embedding_model).__name__}")
|
||
logger.debug(f"Replaced embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}")
|
||
logger.info("Replaced Mem0 embedder with CustomMem0Embedding (reusing existing model)")
|
||
|
||
logger.info(
|
||
f"Created Mem0 instance: user={user_id}, agent={agent_id}"
|
||
)
|
||
|
||
return mem
|
||
|
||
async def recall_memories(
|
||
self,
|
||
query: str,
|
||
user_id: str,
|
||
agent_id: str,
|
||
config: Optional[Mem0Config] = None,
|
||
) -> List[Dict[str, Any]]:
|
||
"""召回相关记忆(用户级别,跨会话共享)
|
||
|
||
Args:
|
||
query: 查询文本
|
||
user_id: 用户 ID
|
||
agent_id: Agent/Bot ID
|
||
config: Mem0 配置
|
||
|
||
Returns:
|
||
记忆列表,每个记忆包含 content, similarity 等字段
|
||
"""
|
||
try:
|
||
mem = await self.get_mem0(user_id, agent_id, "default", config)
|
||
|
||
# 调用 search 进行语义搜索(使用 agent_id 参数过滤)
|
||
limit = config.semantic_search_top_k if config else 20
|
||
results = mem.search(
|
||
query=query,
|
||
limit=limit,
|
||
user_id=user_id,
|
||
agent_id=agent_id,
|
||
)
|
||
|
||
# 转换为统一格式
|
||
memories = []
|
||
for result in results["results"]:
|
||
# Mem0 返回结果可能是字符串或字典
|
||
content = result.get("memory", "")
|
||
score = result.get("score", 0.0)
|
||
result_metadata = result.get("metadata", {})
|
||
|
||
memory = {
|
||
"content": content,
|
||
"similarity": score,
|
||
"metadata": result_metadata,
|
||
"fact_type": result_metadata.get("category", "fact"),
|
||
}
|
||
memories.append(memory)
|
||
|
||
logger.info(f"Recalled {len(memories)} memories for user={user_id}, query: {query[:50]}...")
|
||
return memories
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to recall memories: {e}")
|
||
return []
|
||
|
||
async def add_memory(
|
||
self,
|
||
text: str,
|
||
user_id: str,
|
||
agent_id: str,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
config: Optional[Mem0Config] = None,
|
||
) -> Dict[str, Any]:
|
||
"""添加新记忆(用户级别,跨会话共享)
|
||
|
||
Args:
|
||
text: 记忆文本
|
||
user_id: 用户 ID
|
||
agent_id: Agent/Bot ID
|
||
metadata: 额外的元数据
|
||
config: Mem0 配置(包含 LLM 实例用于记忆提取)
|
||
|
||
Returns:
|
||
添加的记忆结果
|
||
"""
|
||
try:
|
||
mem = await self.get_mem0(user_id, agent_id, "default", config)
|
||
|
||
# 添加记忆(使用 agent_id 参数)
|
||
result = mem.add(
|
||
text,
|
||
user_id=user_id,
|
||
agent_id=agent_id,
|
||
metadata=metadata or {}
|
||
)
|
||
|
||
logger.info(f"Added memory for user={user_id}, agent={agent_id}: {result}")
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to add memory: {e}")
|
||
return {}
|
||
|
||
async def get_all_memories(
|
||
self,
|
||
user_id: str,
|
||
agent_id: str,
|
||
) -> List[Dict[str, Any]]:
|
||
"""获取用户的所有记忆(用户级别)
|
||
|
||
Args:
|
||
user_id: 用户 ID
|
||
agent_id: Agent/Bot ID
|
||
|
||
Returns:
|
||
记忆列表
|
||
"""
|
||
try:
|
||
mem = await self.get_mem0(user_id, agent_id, "default")
|
||
|
||
# 获取所有记忆
|
||
memories = mem.get_all(user_id=user_id)
|
||
|
||
# 过滤 agent_id
|
||
filtered_memories = [
|
||
m for m in memories
|
||
if m.get("metadata", {}).get("agent_id") == agent_id
|
||
]
|
||
|
||
return filtered_memories
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get all memories: {e}")
|
||
return []
|
||
|
||
def clear_cache(self, user_id: Optional[str] = None, agent_id: Optional[str] = None) -> None:
|
||
"""清除缓存的 Mem0 实例
|
||
|
||
Args:
|
||
user_id: 用户 ID(如果为 None,清除所有)
|
||
agent_id: Agent ID(如果为 None,清除所有)
|
||
"""
|
||
if user_id is None and agent_id is None:
|
||
self._instances.clear()
|
||
logger.info("Cleared all Mem0 instances from cache")
|
||
else:
|
||
keys_to_remove = []
|
||
for key in self._instances:
|
||
# 新格式: "user_id:agent_id:llm_model_name" 或 "user_id:agent_id"
|
||
parts = key.split(":")
|
||
if len(parts) >= 2:
|
||
u_id = parts[0]
|
||
a_id = parts[1]
|
||
if user_id and u_id != user_id:
|
||
continue
|
||
if agent_id and a_id != agent_id:
|
||
continue
|
||
keys_to_remove.append(key)
|
||
|
||
for key in keys_to_remove:
|
||
del self._instances[key]
|
||
|
||
logger.info(f"Cleared {len(keys_to_remove)} Mem0 instances from cache")
|
||
|
||
async def close(self) -> None:
|
||
"""关闭管理器并清理资源"""
|
||
logger.info("Closing Mem0Manager...")
|
||
|
||
# 清理缓存的实例
|
||
self._instances.clear()
|
||
|
||
# 注意:不关闭共享的同步连接池(由 DBPoolManager 管理)
|
||
|
||
self._initialized = False
|
||
|
||
logger.info("Mem0Manager closed")
|
||
|
||
|
||
# 全局单例
|
||
_global_manager: Optional[Mem0Manager] = None
|
||
|
||
|
||
def get_mem0_manager() -> Mem0Manager:
|
||
"""获取全局 Mem0Manager 单例
|
||
|
||
Returns:
|
||
Mem0Manager 实例
|
||
"""
|
||
global _global_manager
|
||
if _global_manager is None:
|
||
_global_manager = Mem0Manager()
|
||
return _global_manager
|
||
|
||
|
||
async def init_global_mem0(
|
||
sync_pool: pool.SimpleConnectionPool,
|
||
) -> Mem0Manager:
|
||
"""初始化全局 Mem0Manager
|
||
|
||
Args:
|
||
sync_pool: PostgreSQL 同步连接池(从 DBPoolManager.sync_pool 获取)
|
||
|
||
Returns:
|
||
Mem0Manager 实例
|
||
"""
|
||
manager = get_mem0_manager()
|
||
manager._sync_pool = sync_pool
|
||
await manager.initialize()
|
||
return manager
|
||
|
||
|
||
async def close_global_mem0() -> None:
|
||
"""关闭全局 Mem0Manager"""
|
||
global _global_manager
|
||
if _global_manager is not None:
|
||
await _global_manager.close()
|