qwen_agent/embedding/manager.py
2025-11-20 21:37:46 +08:00

132 lines
4.3 KiB
Python

#!/usr/bin/env python3
"""
模型池管理器和缓存系统
支持高并发的 embedding 检索服务
"""
import os
import asyncio
import time
import pickle
import hashlib
import logging
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from collections import OrderedDict
import threading
import psutil
import numpy as np
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
class GlobalModelManager:
"""全局模型管理器"""
def __init__(self, model_name: str = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'):
self.model_name = model_name
self.local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
self._model: Optional[SentenceTransformer] = None
self._lock = asyncio.Lock()
self._load_time = 0
self._device = 'cpu'
logger.info(f"GlobalModelManager 初始化: {model_name}")
async def get_model(self) -> SentenceTransformer:
"""获取模型实例(延迟加载)"""
if self._model is not None:
return self._model
async with self._lock:
# 双重检查
if self._model is not None:
return self._model
try:
start_time = time.time()
# 检查本地模型
model_path = self.local_model_path if os.path.exists(self.local_model_path) else self.model_name
# 获取设备配置
self._device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu')
if self._device not in ['cpu', 'cuda', 'mps']:
self._device = 'cpu'
logger.info(f"加载模型: {model_path} (device: {self._device})")
# 在事件循环中运行阻塞操作
loop = asyncio.get_event_loop()
self._model = await loop.run_in_executor(
None,
lambda: SentenceTransformer(
model_path,
device=self._device,
truncate_dim=128
)
)
self._load_time = time.time() - start_time
logger.info(f"模型加载完成: {self._load_time:.2f}s")
return self._model
except Exception as e:
logger.error(f"模型加载失败: {e}")
raise
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""编码文本为向量"""
if not texts:
return np.array([])
model = await self.get_model()
try:
# 在事件循环中运行阻塞操作
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(
None,
lambda: model.encode(texts, batch_size=batch_size, show_progress_bar=False)
)
# 确保返回 numpy 数组
if hasattr(embeddings, 'cpu'):
embeddings = embeddings.cpu().numpy()
elif hasattr(embeddings, 'numpy'):
embeddings = embeddings.numpy()
elif not isinstance(embeddings, np.ndarray):
embeddings = np.array(embeddings)
return embeddings
except Exception as e:
logger.error(f"文本编码失败: {e}")
raise
def get_model_info(self) -> Dict[str, Any]:
"""获取模型信息"""
return {
"model_name": self.model_name,
"local_model_path": self.local_model_path,
"device": self._device,
"is_loaded": self._model is not None,
"load_time": self._load_time
}
# 全局实例
_model_manager = None
def get_model_manager() -> GlobalModelManager:
"""获取模型管理器实例"""
global _model_manager
if _model_manager is None:
model_name = os.getenv("SENTENCE_TRANSFORMER_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
_model_manager = GlobalModelManager(model_name)
return _model_manager