#!/usr/bin/env python3 """ 模型池管理器和缓存系统 支持高并发的 embedding 检索服务 """ import os import asyncio import time import pickle import hashlib import logging from typing import Dict, List, Optional, Any, Tuple from dataclasses import dataclass from collections import OrderedDict import threading import psutil import numpy as np from sentence_transformers import SentenceTransformer logger = logging.getLogger(__name__) @dataclass class CacheEntry: """缓存条目""" embeddings: np.ndarray chunks: List[str] chunking_strategy: str chunking_params: Dict[str, Any] model_path: str file_path: str file_mtime: float # 文件修改时间 access_count: int # 访问次数 last_access_time: float # 最后访问时间 load_time: float # 加载时间 memory_size: int # 占用内存大小(字节) class EmbeddingCacheManager: """Embedding 数据缓存管理器""" def __init__(self, max_cache_size: int = 5, max_memory_mb: int = 1024): self.max_cache_size = max_cache_size # 最大缓存条目数 self.max_memory_bytes = max_memory_mb * 1024 * 1024 # 最大内存使用量 self._cache: OrderedDict[str, CacheEntry] = OrderedDict() self._lock = threading.RLock() self._current_memory_usage = 0 logger.info(f"EmbeddingCacheManager 初始化: max_size={max_cache_size}, max_memory={max_memory_mb}MB") def _get_file_key(self, file_path: str) -> str: """生成文件缓存键""" # 使用绝对路径和文件修改时间生成唯一键 try: abs_path = os.path.abspath(file_path) mtime = os.path.getmtime(abs_path) key_data = f"{abs_path}:{mtime}" return hashlib.md5(key_data.encode()).hexdigest() except Exception as e: logger.warning(f"生成文件键失败: {file_path}, {e}") return hashlib.md5(file_path.encode()).hexdigest() def _estimate_memory_size(self, embeddings: np.ndarray, chunks: List[str]) -> int: """估算数据内存占用""" try: embeddings_size = embeddings.nbytes chunks_size = sum(len(chunk.encode('utf-8')) for chunk in chunks) overhead = 1024 * 1024 # 1MB 开销 return embeddings_size + chunks_size + overhead except Exception: return 100 * 1024 * 1024 # 默认100MB def _cleanup_cache(self): """清理缓存以释放内存""" with self._lock: # 按访问时间和次数排序,清理最少使用的条目 entries = list(self._cache.items()) # 计算需要清理的条目 to_remove = [] current_size = len(self._cache) # 如果缓存条目超限,清理最老的条目 if current_size > self.max_cache_size: to_remove.extend(entries[:current_size - self.max_cache_size]) # 如果内存超限,按LRU策略清理 if self._current_memory_usage > self.max_memory_bytes: # 按最后访问时间排序 entries.sort(key=lambda x: x[1].last_access_time) accumulated_size = 0 for key, entry in entries: if accumulated_size >= self._current_memory_usage - self.max_memory_bytes: break to_remove.append((key, entry)) accumulated_size += entry.memory_size # 执行清理 for key, entry in to_remove: if key in self._cache: del self._cache[key] self._current_memory_usage -= entry.memory_size logger.info(f"清理缓存条目: {key} ({entry.memory_size / 1024 / 1024:.1f}MB)") async def load_embedding_data(self, file_path: str) -> Optional[CacheEntry]: """加载 embedding 数据""" cache_key = self._get_file_key(file_path) # 检查缓存 with self._lock: if cache_key in self._cache: entry = self._cache[cache_key] entry.access_count += 1 entry.last_access_time = time.time() # 移动到末尾(最近使用) self._cache.move_to_end(cache_key) logger.info(f"缓存命中: {file_path}") return entry # 缓存未命中,异步加载数据 try: start_time = time.time() # 检查文件是否存在 if not os.path.exists(file_path): logger.error(f"文件不存在: {file_path}") return None # 加载 embedding 数据 with open(file_path, 'rb') as f: embedding_data = pickle.load(f) # 兼容新旧数据结构 if 'chunks' in embedding_data: chunks = embedding_data['chunks'] embeddings = embedding_data['embeddings'] chunking_strategy = embedding_data.get('chunking_strategy', 'unknown') chunking_params = embedding_data.get('chunking_params', {}) model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') else: chunks = embedding_data['sentences'] embeddings = embedding_data['embeddings'] chunking_strategy = 'line' chunking_params = {} model_path = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2' # 确保 embeddings 是 numpy 数组 if hasattr(embeddings, 'cpu'): embeddings = embeddings.cpu().numpy() elif hasattr(embeddings, 'numpy'): embeddings = embeddings.numpy() elif not isinstance(embeddings, np.ndarray): embeddings = np.array(embeddings) # 创建缓存条目 load_time = time.time() - start_time file_mtime = os.path.getmtime(file_path) memory_size = self._estimate_memory_size(embeddings, chunks) entry = CacheEntry( embeddings=embeddings, chunks=chunks, chunking_strategy=chunking_strategy, chunking_params=chunking_params, model_path=model_path, file_path=file_path, file_mtime=file_mtime, access_count=1, last_access_time=time.time(), load_time=load_time, memory_size=memory_size ) # 添加到缓存 with self._lock: self._cache[cache_key] = entry self._current_memory_usage += memory_size # 清理缓存 self._cleanup_cache() logger.info(f"加载完成: {file_path} ({memory_size / 1024 / 1024:.1f}MB, {load_time:.2f}s)") return entry except Exception as e: logger.error(f"加载 embedding 数据失败: {file_path}, {e}") return None def get_cache_stats(self) -> Dict[str, Any]: """获取缓存统计信息""" with self._lock: return { "cache_size": len(self._cache), "max_cache_size": self.max_cache_size, "memory_usage_mb": self._current_memory_usage / 1024 / 1024, "max_memory_mb": self.max_memory_bytes / 1024 / 1024, "memory_usage_percent": (self._current_memory_usage / self.max_memory_bytes) * 100, "entries": [ { "file_path": entry.file_path, "access_count": entry.access_count, "last_access_time": entry.last_access_time, "memory_size_mb": entry.memory_size / 1024 / 1024 } for entry in self._cache.values() ] } def clear_cache(self): """清空缓存""" with self._lock: cleared_count = len(self._cache) cleared_memory = self._current_memory_usage self._cache.clear() self._current_memory_usage = 0 logger.info(f"清空缓存: {cleared_count} 个条目, {cleared_memory / 1024 / 1024:.1f}MB") class GlobalModelManager: """全局模型管理器""" def __init__(self, model_name: str = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'): self.model_name = model_name self.local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2" self._model: Optional[SentenceTransformer] = None self._lock = asyncio.Lock() self._load_time = 0 self._device = 'cpu' logger.info(f"GlobalModelManager 初始化: {model_name}") async def get_model(self) -> SentenceTransformer: """获取模型实例(延迟加载)""" if self._model is not None: return self._model async with self._lock: # 双重检查 if self._model is not None: return self._model try: start_time = time.time() # 检查本地模型 model_path = self.local_model_path if os.path.exists(self.local_model_path) else self.model_name # 获取设备配置 self._device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu') if self._device not in ['cpu', 'cuda', 'mps']: self._device = 'cpu' logger.info(f"加载模型: {model_path} (device: {self._device})") # 在事件循环中运行阻塞操作 loop = asyncio.get_event_loop() self._model = await loop.run_in_executor( None, lambda: SentenceTransformer(model_path, device=self._device) ) self._load_time = time.time() - start_time logger.info(f"模型加载完成: {self._load_time:.2f}s") return self._model except Exception as e: logger.error(f"模型加载失败: {e}") raise async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray: """编码文本为向量""" if not texts: return np.array([]) model = await self.get_model() try: # 在事件循环中运行阻塞操作 loop = asyncio.get_event_loop() embeddings = await loop.run_in_executor( None, lambda: model.encode(texts, batch_size=batch_size, show_progress_bar=False) ) # 确保返回 numpy 数组 if hasattr(embeddings, 'cpu'): embeddings = embeddings.cpu().numpy() elif hasattr(embeddings, 'numpy'): embeddings = embeddings.numpy() elif not isinstance(embeddings, np.ndarray): embeddings = np.array(embeddings) return embeddings except Exception as e: logger.error(f"文本编码失败: {e}") raise def get_model_info(self) -> Dict[str, Any]: """获取模型信息""" return { "model_name": self.model_name, "local_model_path": self.local_model_path, "device": self._device, "is_loaded": self._model is not None, "load_time": self._load_time } # 全局实例 _cache_manager = None _model_manager = None def get_cache_manager() -> EmbeddingCacheManager: """获取缓存管理器实例""" global _cache_manager if _cache_manager is None: max_cache_size = int(os.getenv("EMBEDDING_MAX_CACHE_SIZE", "5")) max_memory_mb = int(os.getenv("EMBEDDING_MAX_MEMORY_MB", "1024")) _cache_manager = EmbeddingCacheManager(max_cache_size, max_memory_mb) return _cache_manager def get_model_manager() -> GlobalModelManager: """获取模型管理器实例""" global _model_manager if _model_manager is None: model_name = os.getenv("SENTENCE_TRANSFORMER_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") _model_manager = GlobalModelManager(model_name) return _model_manager