qwen_agent/embedding/manager.py
朱潮 45f3a61a16 🐛 fix(mem0): fix connection pool exhausted error
修复 Mem0 连接池耗尽错误,问题根因是 CustomMem0Embedding.embed()
方法中使用 asyncio.run() 导致事件循环泄漏。

主要修改:
- 使用线程池替代 asyncio.run() 避免事件循环泄漏
- 添加线程安全的模型缓存机制
- 为 Mem0 实例缓存添加 LRU 机制,最多保留 50 个实例
- 在 GlobalModelManager 中添加 get_model_sync() 同步方法

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-21 18:17:45 +08:00

142 lines
4.5 KiB
Python

#!/usr/bin/env python3
"""
模型池管理器和缓存系统
支持高并发的 embedding 检索服务
"""
import os
import asyncio
import time
import pickle
import hashlib
import logging
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from collections import OrderedDict
from utils.settings import SENTENCE_TRANSFORMER_MODEL
import threading
import psutil
import numpy as np
from sentence_transformers import SentenceTransformer
import logging
logger = logging.getLogger('app')
class GlobalModelManager:
"""全局模型管理器"""
def __init__(self, model_name: str = 'TaylorAI/gte-tiny'):
self.model_name = model_name
self.local_model_path = "./models/gte-tiny"
self._model: Optional[SentenceTransformer] = None
self._lock = asyncio.Lock()
self._load_time = 0
self._device = 'cpu'
logger.info(f"GlobalModelManager 初始化: {model_name}")
async def get_model(self) -> SentenceTransformer:
"""获取模型实例(延迟加载)"""
if self._model is not None:
return self._model
async with self._lock:
# 双重检查
if self._model is not None:
return self._model
try:
start_time = time.time()
# 检查本地模型
model_path = self.local_model_path if os.path.exists(self.local_model_path) else self.model_name
# 获取设备配置
self._device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu')
if self._device not in ['cpu', 'cuda', 'mps']:
self._device = 'cpu'
logger.info(f"加载模型: {model_path} (device: {self._device})")
# 在事件循环中运行阻塞操作
loop = asyncio.get_event_loop()
self._model = await loop.run_in_executor(
None,
lambda: SentenceTransformer(
model_path,
device=self._device
)
)
self._load_time = time.time() - start_time
logger.info(f"模型加载完成: {self._load_time:.2f}s")
return self._model
except Exception as e:
logger.error(f"模型加载失败: {e}")
raise
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""编码文本为向量"""
if not texts:
return np.array([])
model = await self.get_model()
try:
# 在事件循环中运行阻塞操作
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(
None,
lambda: model.encode(texts, batch_size=batch_size, show_progress_bar=False)
)
# 确保返回 numpy 数组
if hasattr(embeddings, 'cpu'):
embeddings = embeddings.cpu().numpy()
elif hasattr(embeddings, 'numpy'):
embeddings = embeddings.numpy()
elif not isinstance(embeddings, np.ndarray):
embeddings = np.array(embeddings)
return embeddings
except Exception as e:
logger.error(f"文本编码失败: {e}")
raise
def get_model_sync(self) -> Optional[SentenceTransformer]:
"""同步获取模型实例(供同步上下文使用)
如果模型未加载,返回 None。调用者应确保先通过异步方法初始化模型。
Returns:
已加载的 SentenceTransformer 模型,或 None
"""
return self._model
def get_model_info(self) -> Dict[str, Any]:
"""获取模型信息"""
return {
"model_name": self.model_name,
"local_model_path": self.local_model_path,
"device": self._device,
"is_loaded": self._model is not None,
"load_time": self._load_time
}
# 全局实例
_model_manager = None
def get_model_manager() -> GlobalModelManager:
"""获取模型管理器实例"""
global _model_manager
if _model_manager is None:
_model_manager = GlobalModelManager(SENTENCE_TRANSFORMER_MODEL)
return _model_manager