#!/usr/bin/env python3 """ Model pool manager and cache system Support high-concurrency embedding retrieval services """ import os import asyncio import time import pickle import hashlib import logging from typing import Dict, List, Optional, Any, Tuple from dataclasses import dataclass from collections import OrderedDict from utils.settings import SENTENCE_TRANSFORMER_MODEL import threading import psutil import numpy as np from sentence_transformers import SentenceTransformer import logging logger = logging.getLogger('app') class GlobalModelManager: """Global model manager""" def __init__(self, model_name: str = 'TaylorAI/gte-tiny'): self.model_name = model_name self.local_model_path = "./models/gte-tiny" self._model: Optional[SentenceTransformer] = None self._lock = asyncio.Lock() self._load_time = 0 self._device = 'cpu' logger.info(f"GlobalModelManager initialized: {model_name}") async def get_model(self) -> SentenceTransformer: """Get the model instance with lazy loading""" if self._model is not None: return self._model async with self._lock: # Double-check if self._model is not None: return self._model try: start_time = time.time() # Check the local model model_path = self.local_model_path if os.path.exists(self.local_model_path) else self.model_name # Get device configuration self._device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu') if self._device not in ['cpu', 'cuda', 'mps']: self._device = 'cpu' logger.info(f"Loading model: {model_path} (device: {self._device})") # Run blocking operations in the event loop executor loop = asyncio.get_event_loop() self._model = await loop.run_in_executor( None, lambda: SentenceTransformer( model_path, device=self._device ) ) self._load_time = time.time() - start_time logger.info(f"Model loading completed: {self._load_time:.2f}s") return self._model except Exception as e: logger.error(f"Model loading failed: {e}") raise async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray: """Encode texts into vectors""" if not texts: return np.array([]) model = await self.get_model() try: # Run blocking operations in the event loop executor loop = asyncio.get_event_loop() embeddings = await loop.run_in_executor( None, lambda: model.encode(texts, batch_size=batch_size, show_progress_bar=False) ) # Ensure a NumPy array is returned if hasattr(embeddings, 'cpu'): embeddings = embeddings.cpu().numpy() elif hasattr(embeddings, 'numpy'): embeddings = embeddings.numpy() elif not isinstance(embeddings, np.ndarray): embeddings = np.array(embeddings) return embeddings except Exception as e: logger.error(f"Text encoding failed: {e}") raise def get_model_sync(self) -> Optional[SentenceTransformer]: """Synchronously get the model instance for synchronous contexts If the model is not loaded, return None. The caller should ensure the model is initialized via the async API first. Returns: The loaded SentenceTransformer model, or None """ return self._model def get_model_info(self) -> Dict[str, Any]: """Get model information""" return { "model_name": self.model_name, "local_model_path": self.local_model_path, "device": self._device, "is_loaded": self._model is not None, "load_time": self._load_time } # Global instance _model_manager = None def get_model_manager() -> GlobalModelManager: """Get the model manager instance""" global _model_manager if _model_manager is None: _model_manager = GlobalModelManager(SENTENCE_TRANSFORMER_MODEL) return _model_manager