- Fix mem0 connection pool exhausted error with proper pooling - Convert memory operations to async tasks - Optimize docker-compose configuration - Add skill upload functionality - Reduce cache size for better performance - Update dependencies Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
142 lines
4.5 KiB
Python
142 lines
4.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
模型池管理器和缓存系统
|
|
支持高并发的 embedding 检索服务
|
|
"""
|
|
|
|
import os
|
|
import asyncio
|
|
import time
|
|
import pickle
|
|
import hashlib
|
|
import logging
|
|
from typing import Dict, List, Optional, Any, Tuple
|
|
from dataclasses import dataclass
|
|
from collections import OrderedDict
|
|
from utils.settings import SENTENCE_TRANSFORMER_MODEL
|
|
import threading
|
|
import psutil
|
|
import numpy as np
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
import logging
|
|
|
|
logger = logging.getLogger('app')
|
|
|
|
|
|
class GlobalModelManager:
|
|
"""全局模型管理器"""
|
|
|
|
def __init__(self, model_name: str = 'TaylorAI/gte-tiny'):
|
|
self.model_name = model_name
|
|
self.local_model_path = "./models/gte-tiny"
|
|
self._model: Optional[SentenceTransformer] = None
|
|
self._lock = asyncio.Lock()
|
|
self._load_time = 0
|
|
self._device = 'cpu'
|
|
|
|
logger.info(f"GlobalModelManager 初始化: {model_name}")
|
|
|
|
async def get_model(self) -> SentenceTransformer:
|
|
"""获取模型实例(延迟加载)"""
|
|
if self._model is not None:
|
|
return self._model
|
|
|
|
async with self._lock:
|
|
# 双重检查
|
|
if self._model is not None:
|
|
return self._model
|
|
|
|
try:
|
|
start_time = time.time()
|
|
|
|
# 检查本地模型
|
|
model_path = self.local_model_path if os.path.exists(self.local_model_path) else self.model_name
|
|
|
|
# 获取设备配置
|
|
self._device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu')
|
|
if self._device not in ['cpu', 'cuda', 'mps']:
|
|
self._device = 'cpu'
|
|
|
|
logger.info(f"加载模型: {model_path} (device: {self._device})")
|
|
|
|
# 在事件循环中运行阻塞操作
|
|
loop = asyncio.get_event_loop()
|
|
self._model = await loop.run_in_executor(
|
|
None,
|
|
lambda: SentenceTransformer(
|
|
model_path,
|
|
device=self._device
|
|
)
|
|
)
|
|
|
|
self._load_time = time.time() - start_time
|
|
logger.info(f"模型加载完成: {self._load_time:.2f}s")
|
|
|
|
return self._model
|
|
|
|
except Exception as e:
|
|
logger.error(f"模型加载失败: {e}")
|
|
raise
|
|
|
|
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
|
"""编码文本为向量"""
|
|
if not texts:
|
|
return np.array([])
|
|
|
|
model = await self.get_model()
|
|
|
|
try:
|
|
# 在事件循环中运行阻塞操作
|
|
loop = asyncio.get_event_loop()
|
|
embeddings = await loop.run_in_executor(
|
|
None,
|
|
lambda: model.encode(texts, batch_size=batch_size, show_progress_bar=False)
|
|
)
|
|
|
|
# 确保返回 numpy 数组
|
|
if hasattr(embeddings, 'cpu'):
|
|
embeddings = embeddings.cpu().numpy()
|
|
elif hasattr(embeddings, 'numpy'):
|
|
embeddings = embeddings.numpy()
|
|
elif not isinstance(embeddings, np.ndarray):
|
|
embeddings = np.array(embeddings)
|
|
|
|
return embeddings
|
|
|
|
except Exception as e:
|
|
logger.error(f"文本编码失败: {e}")
|
|
raise
|
|
|
|
def get_model_sync(self) -> Optional[SentenceTransformer]:
|
|
"""同步获取模型实例(供同步上下文使用)
|
|
|
|
如果模型未加载,返回 None。调用者应确保先通过异步方法初始化模型。
|
|
|
|
Returns:
|
|
已加载的 SentenceTransformer 模型,或 None
|
|
"""
|
|
return self._model
|
|
|
|
def get_model_info(self) -> Dict[str, Any]:
|
|
"""获取模型信息"""
|
|
return {
|
|
"model_name": self.model_name,
|
|
"local_model_path": self.local_model_path,
|
|
"device": self._device,
|
|
"is_loaded": self._model is not None,
|
|
"load_time": self._load_time
|
|
}
|
|
|
|
|
|
# 全局实例
|
|
_model_manager = None
|
|
|
|
|
|
def get_model_manager() -> GlobalModelManager:
|
|
"""获取模型管理器实例"""
|
|
global _model_manager
|
|
if _model_manager is None:
|
|
_model_manager = GlobalModelManager(SENTENCE_TRANSFORMER_MODEL)
|
|
return _model_manager
|