🐛 fix(mem0): fix connection pool exhausted error

修复 Mem0 连接池耗尽错误,问题根因是 CustomMem0Embedding.embed()
方法中使用 asyncio.run() 导致事件循环泄漏。

主要修改:
- 使用线程池替代 asyncio.run() 避免事件循环泄漏
- 添加线程安全的模型缓存机制
- 为 Mem0 实例缓存添加 LRU 机制,最多保留 50 个实例
- 在 GlobalModelManager 中添加 get_model_sync() 同步方法

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
朱潮 2026-01-21 18:17:45 +08:00
parent f29fd1fb54
commit 45f3a61a16
2 changed files with 69 additions and 7 deletions

View File

@ -5,7 +5,10 @@ Mem0 连接和实例管理器
import logging import logging
import asyncio import asyncio
import threading
import concurrent.futures
from typing import Any, Dict, List, Optional, Literal from typing import Any, Dict, List, Optional, Literal
from collections import OrderedDict
from embedding.manager import GlobalModelManager, get_model_manager from embedding.manager import GlobalModelManager, get_model_manager
import json_repair import json_repair
from psycopg2 import pool from psycopg2 import pool
@ -27,7 +30,9 @@ class CustomMem0Embedding:
这样 Mem0 就不需要再次加载同一个模型节省内存 这样 Mem0 就不需要再次加载同一个模型节省内存
""" """
_model_manager = None # 缓存 GlobalModelManager 实例 _model = None # 类变量,缓存模型实例
_lock = threading.Lock() # 线程安全锁
_executor = None # 线程池执行器
def __init__(self, config: Optional[Any] = None): def __init__(self, config: Optional[Any] = None):
"""初始化自定义 Embedding""" """初始化自定义 Embedding"""
@ -41,6 +46,41 @@ class CustomMem0Embedding:
"""获取 embedding 维度""" """获取 embedding 维度"""
return 384 # gte-tiny 的维度 return 384 # gte-tiny 的维度
def _get_model_sync(self):
"""同步获取模型,避免 asyncio.run()"""
# 首先尝试从 manager 获取已加载的模型
manager = get_model_manager()
model = manager.get_model_sync()
if model is not None:
# 缓存模型
CustomMem0Embedding._model = model
return model
# 如果模型未加载,使用线程池运行异步初始化
if CustomMem0Embedding._executor is None:
CustomMem0Embedding._executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1,
thread_name_prefix="mem0_embed"
)
# 在独立线程中运行异步代码
def run_async_in_thread():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(manager.get_model())
return result
finally:
loop.close()
future = CustomMem0Embedding._executor.submit(run_async_in_thread)
model = future.result(timeout=30) # 30秒超时
# 缓存模型
CustomMem0Embedding._model = model
return model
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None): def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
""" """
获取文本的 embedding 向量同步方法 Mem0 调用 获取文本的 embedding 向量同步方法 Mem0 调用
@ -52,8 +92,13 @@ class CustomMem0Embedding:
Returns: Returns:
list: embedding 向量 list: embedding 向量
""" """
manager = get_model_manager() # 线程安全地获取模型
model = asyncio.run(manager.get_model()) if CustomMem0Embedding._model is None:
with CustomMem0Embedding._lock:
if CustomMem0Embedding._model is None:
self._get_model_sync()
model = CustomMem0Embedding._model
embeddings = model.encode(text, convert_to_numpy=True) embeddings = model.encode(text, convert_to_numpy=True)
return embeddings.tolist() return embeddings.tolist()
@ -136,8 +181,9 @@ class Mem0Manager:
""" """
self._sync_pool = sync_pool self._sync_pool = sync_pool
# 缓存 Mem0 实例: key = f"{user_id}:{agent_id}" # 使用 OrderedDict 实现 LRU 缓存,最多保留 50 个实例
self._instances: Dict[str, Any] = {} self._instances: OrderedDict[str, Any] = OrderedDict()
self._max_instances = 50 # 最大缓存实例数
self._initialized = False self._initialized = False
async def initialize(self) -> None: async def initialize(self) -> None:
@ -194,10 +240,16 @@ class Mem0Manager:
llm_suffix = f":{id(config.llm_instance)}" llm_suffix = f":{id(config.llm_instance)}"
cache_key = f"{user_id}:{agent_id}{llm_suffix}" cache_key = f"{user_id}:{agent_id}{llm_suffix}"
# 检查缓存 # 检查缓存(同时移动到末尾表示最近使用)
if cache_key in self._instances: if cache_key in self._instances:
self._instances.move_to_end(cache_key)
return self._instances[cache_key] return self._instances[cache_key]
# 检查缓存大小,超过则移除最旧的
if len(self._instances) >= self._max_instances:
removed_key, _ = self._instances.popitem(last=False)
logger.debug(f"Mem0 instance cache full, removed oldest entry: {removed_key}")
# 创建新实例 # 创建新实例
mem0_instance = await self._create_mem0_instance( mem0_instance = await self._create_mem0_instance(
user_id=user_id, user_id=user_id,
@ -206,7 +258,7 @@ class Mem0Manager:
config=config, config=config,
) )
# 缓存实例 # 缓存实例(新实例自动在末尾)
self._instances[cache_key] = mem0_instance self._instances[cache_key] = mem0_instance
return mem0_instance return mem0_instance

View File

@ -108,6 +108,16 @@ class GlobalModelManager:
logger.error(f"文本编码失败: {e}") logger.error(f"文本编码失败: {e}")
raise raise
def get_model_sync(self) -> Optional[SentenceTransformer]:
"""同步获取模型实例(供同步上下文使用)
如果模型未加载返回 None调用者应确保先通过异步方法初始化模型
Returns:
已加载的 SentenceTransformer 模型 None
"""
return self._model
def get_model_info(self) -> Dict[str, Any]: def get_model_info(self) -> Dict[str, Any]:
"""获取模型信息""" """获取模型信息"""
return { return {