🐛 fix(mem0): fix connection pool exhausted error
修复 Mem0 连接池耗尽错误,问题根因是 CustomMem0Embedding.embed() 方法中使用 asyncio.run() 导致事件循环泄漏。 主要修改: - 使用线程池替代 asyncio.run() 避免事件循环泄漏 - 添加线程安全的模型缓存机制 - 为 Mem0 实例缓存添加 LRU 机制,最多保留 50 个实例 - 在 GlobalModelManager 中添加 get_model_sync() 同步方法 Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
parent
f29fd1fb54
commit
45f3a61a16
@ -5,7 +5,10 @@ Mem0 连接和实例管理器
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import threading
|
||||
import concurrent.futures
|
||||
from typing import Any, Dict, List, Optional, Literal
|
||||
from collections import OrderedDict
|
||||
from embedding.manager import GlobalModelManager, get_model_manager
|
||||
import json_repair
|
||||
from psycopg2 import pool
|
||||
@ -27,7 +30,9 @@ class CustomMem0Embedding:
|
||||
这样 Mem0 就不需要再次加载同一个模型,节省内存
|
||||
"""
|
||||
|
||||
_model_manager = None # 缓存 GlobalModelManager 实例
|
||||
_model = None # 类变量,缓存模型实例
|
||||
_lock = threading.Lock() # 线程安全锁
|
||||
_executor = None # 线程池执行器
|
||||
|
||||
def __init__(self, config: Optional[Any] = None):
|
||||
"""初始化自定义 Embedding"""
|
||||
@ -41,6 +46,41 @@ class CustomMem0Embedding:
|
||||
"""获取 embedding 维度"""
|
||||
return 384 # gte-tiny 的维度
|
||||
|
||||
def _get_model_sync(self):
|
||||
"""同步获取模型,避免 asyncio.run()"""
|
||||
# 首先尝试从 manager 获取已加载的模型
|
||||
manager = get_model_manager()
|
||||
model = manager.get_model_sync()
|
||||
|
||||
if model is not None:
|
||||
# 缓存模型
|
||||
CustomMem0Embedding._model = model
|
||||
return model
|
||||
|
||||
# 如果模型未加载,使用线程池运行异步初始化
|
||||
if CustomMem0Embedding._executor is None:
|
||||
CustomMem0Embedding._executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=1,
|
||||
thread_name_prefix="mem0_embed"
|
||||
)
|
||||
|
||||
# 在独立线程中运行异步代码
|
||||
def run_async_in_thread():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result = loop.run_until_complete(manager.get_model())
|
||||
return result
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
future = CustomMem0Embedding._executor.submit(run_async_in_thread)
|
||||
model = future.result(timeout=30) # 30秒超时
|
||||
|
||||
# 缓存模型
|
||||
CustomMem0Embedding._model = model
|
||||
return model
|
||||
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||
"""
|
||||
获取文本的 embedding 向量(同步方法,供 Mem0 调用)
|
||||
@ -52,8 +92,13 @@ class CustomMem0Embedding:
|
||||
Returns:
|
||||
list: embedding 向量
|
||||
"""
|
||||
manager = get_model_manager()
|
||||
model = asyncio.run(manager.get_model())
|
||||
# 线程安全地获取模型
|
||||
if CustomMem0Embedding._model is None:
|
||||
with CustomMem0Embedding._lock:
|
||||
if CustomMem0Embedding._model is None:
|
||||
self._get_model_sync()
|
||||
|
||||
model = CustomMem0Embedding._model
|
||||
embeddings = model.encode(text, convert_to_numpy=True)
|
||||
return embeddings.tolist()
|
||||
|
||||
@ -136,8 +181,9 @@ class Mem0Manager:
|
||||
"""
|
||||
self._sync_pool = sync_pool
|
||||
|
||||
# 缓存 Mem0 实例: key = f"{user_id}:{agent_id}"
|
||||
self._instances: Dict[str, Any] = {}
|
||||
# 使用 OrderedDict 实现 LRU 缓存,最多保留 50 个实例
|
||||
self._instances: OrderedDict[str, Any] = OrderedDict()
|
||||
self._max_instances = 50 # 最大缓存实例数
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
@ -194,10 +240,16 @@ class Mem0Manager:
|
||||
llm_suffix = f":{id(config.llm_instance)}"
|
||||
cache_key = f"{user_id}:{agent_id}{llm_suffix}"
|
||||
|
||||
# 检查缓存
|
||||
# 检查缓存(同时移动到末尾表示最近使用)
|
||||
if cache_key in self._instances:
|
||||
self._instances.move_to_end(cache_key)
|
||||
return self._instances[cache_key]
|
||||
|
||||
# 检查缓存大小,超过则移除最旧的
|
||||
if len(self._instances) >= self._max_instances:
|
||||
removed_key, _ = self._instances.popitem(last=False)
|
||||
logger.debug(f"Mem0 instance cache full, removed oldest entry: {removed_key}")
|
||||
|
||||
# 创建新实例
|
||||
mem0_instance = await self._create_mem0_instance(
|
||||
user_id=user_id,
|
||||
@ -206,7 +258,7 @@ class Mem0Manager:
|
||||
config=config,
|
||||
)
|
||||
|
||||
# 缓存实例
|
||||
# 缓存实例(新实例自动在末尾)
|
||||
self._instances[cache_key] = mem0_instance
|
||||
return mem0_instance
|
||||
|
||||
|
||||
@ -108,6 +108,16 @@ class GlobalModelManager:
|
||||
logger.error(f"文本编码失败: {e}")
|
||||
raise
|
||||
|
||||
def get_model_sync(self) -> Optional[SentenceTransformer]:
|
||||
"""同步获取模型实例(供同步上下文使用)
|
||||
|
||||
如果模型未加载,返回 None。调用者应确保先通过异步方法初始化模型。
|
||||
|
||||
Returns:
|
||||
已加载的 SentenceTransformer 模型,或 None
|
||||
"""
|
||||
return self._model
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""获取模型信息"""
|
||||
return {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user