333 lines
13 KiB
Python
333 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
模型池管理器和缓存系统
|
||
支持高并发的 embedding 检索服务
|
||
"""
|
||
|
||
import os
|
||
import asyncio
|
||
import time
|
||
import pickle
|
||
import hashlib
|
||
import logging
|
||
from typing import Dict, List, Optional, Any, Tuple
|
||
from dataclasses import dataclass
|
||
from collections import OrderedDict
|
||
import threading
|
||
import psutil
|
||
import numpy as np
|
||
|
||
from sentence_transformers import SentenceTransformer
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class CacheEntry:
|
||
"""缓存条目"""
|
||
embeddings: np.ndarray
|
||
chunks: List[str]
|
||
chunking_strategy: str
|
||
chunking_params: Dict[str, Any]
|
||
model_path: str
|
||
file_path: str
|
||
file_mtime: float # 文件修改时间
|
||
access_count: int # 访问次数
|
||
last_access_time: float # 最后访问时间
|
||
load_time: float # 加载时间
|
||
memory_size: int # 占用内存大小(字节)
|
||
|
||
|
||
class EmbeddingCacheManager:
|
||
"""Embedding 数据缓存管理器"""
|
||
|
||
def __init__(self, max_cache_size: int = 5, max_memory_mb: int = 1024):
|
||
self.max_cache_size = max_cache_size # 最大缓存条目数
|
||
self.max_memory_bytes = max_memory_mb * 1024 * 1024 # 最大内存使用量
|
||
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||
self._lock = threading.RLock()
|
||
self._current_memory_usage = 0
|
||
|
||
logger.info(f"EmbeddingCacheManager 初始化: max_size={max_cache_size}, max_memory={max_memory_mb}MB")
|
||
|
||
def _get_file_key(self, file_path: str) -> str:
|
||
"""生成文件缓存键"""
|
||
# 使用绝对路径和文件修改时间生成唯一键
|
||
try:
|
||
abs_path = os.path.abspath(file_path)
|
||
mtime = os.path.getmtime(abs_path)
|
||
key_data = f"{abs_path}:{mtime}"
|
||
return hashlib.md5(key_data.encode()).hexdigest()
|
||
except Exception as e:
|
||
logger.warning(f"生成文件键失败: {file_path}, {e}")
|
||
return hashlib.md5(file_path.encode()).hexdigest()
|
||
|
||
def _estimate_memory_size(self, embeddings: np.ndarray, chunks: List[str]) -> int:
|
||
"""估算数据内存占用"""
|
||
try:
|
||
embeddings_size = embeddings.nbytes
|
||
chunks_size = sum(len(chunk.encode('utf-8')) for chunk in chunks)
|
||
overhead = 1024 * 1024 # 1MB 开销
|
||
return embeddings_size + chunks_size + overhead
|
||
except Exception:
|
||
return 100 * 1024 * 1024 # 默认100MB
|
||
|
||
def _cleanup_cache(self):
|
||
"""清理缓存以释放内存"""
|
||
with self._lock:
|
||
# 按访问时间和次数排序,清理最少使用的条目
|
||
entries = list(self._cache.items())
|
||
|
||
# 计算需要清理的条目
|
||
to_remove = []
|
||
current_size = len(self._cache)
|
||
|
||
# 如果缓存条目超限,清理最老的条目
|
||
if current_size > self.max_cache_size:
|
||
to_remove.extend(entries[:current_size - self.max_cache_size])
|
||
|
||
# 如果内存超限,按LRU策略清理
|
||
if self._current_memory_usage > self.max_memory_bytes:
|
||
# 按最后访问时间排序
|
||
entries.sort(key=lambda x: x[1].last_access_time)
|
||
|
||
accumulated_size = 0
|
||
for key, entry in entries:
|
||
if accumulated_size >= self._current_memory_usage - self.max_memory_bytes:
|
||
break
|
||
to_remove.append((key, entry))
|
||
accumulated_size += entry.memory_size
|
||
|
||
# 执行清理
|
||
for key, entry in to_remove:
|
||
if key in self._cache:
|
||
del self._cache[key]
|
||
self._current_memory_usage -= entry.memory_size
|
||
logger.info(f"清理缓存条目: {key} ({entry.memory_size / 1024 / 1024:.1f}MB)")
|
||
|
||
async def load_embedding_data(self, file_path: str) -> Optional[CacheEntry]:
|
||
"""加载 embedding 数据"""
|
||
cache_key = self._get_file_key(file_path)
|
||
|
||
# 检查缓存
|
||
with self._lock:
|
||
if cache_key in self._cache:
|
||
entry = self._cache[cache_key]
|
||
entry.access_count += 1
|
||
entry.last_access_time = time.time()
|
||
# 移动到末尾(最近使用)
|
||
self._cache.move_to_end(cache_key)
|
||
logger.info(f"缓存命中: {file_path}")
|
||
return entry
|
||
|
||
# 缓存未命中,异步加载数据
|
||
try:
|
||
start_time = time.time()
|
||
|
||
# 检查文件是否存在
|
||
if not os.path.exists(file_path):
|
||
logger.error(f"文件不存在: {file_path}")
|
||
return None
|
||
|
||
# 加载 embedding 数据
|
||
with open(file_path, 'rb') as f:
|
||
embedding_data = pickle.load(f)
|
||
|
||
# 兼容新旧数据结构
|
||
if 'chunks' in embedding_data:
|
||
chunks = embedding_data['chunks']
|
||
embeddings = embedding_data['embeddings']
|
||
chunking_strategy = embedding_data.get('chunking_strategy', 'unknown')
|
||
chunking_params = embedding_data.get('chunking_params', {})
|
||
model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
|
||
else:
|
||
chunks = embedding_data['sentences']
|
||
embeddings = embedding_data['embeddings']
|
||
chunking_strategy = 'line'
|
||
chunking_params = {}
|
||
model_path = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
|
||
|
||
# 确保 embeddings 是 numpy 数组
|
||
if hasattr(embeddings, 'cpu'):
|
||
embeddings = embeddings.cpu().numpy()
|
||
elif hasattr(embeddings, 'numpy'):
|
||
embeddings = embeddings.numpy()
|
||
elif not isinstance(embeddings, np.ndarray):
|
||
embeddings = np.array(embeddings)
|
||
|
||
# 创建缓存条目
|
||
load_time = time.time() - start_time
|
||
file_mtime = os.path.getmtime(file_path)
|
||
memory_size = self._estimate_memory_size(embeddings, chunks)
|
||
|
||
entry = CacheEntry(
|
||
embeddings=embeddings,
|
||
chunks=chunks,
|
||
chunking_strategy=chunking_strategy,
|
||
chunking_params=chunking_params,
|
||
model_path=model_path,
|
||
file_path=file_path,
|
||
file_mtime=file_mtime,
|
||
access_count=1,
|
||
last_access_time=time.time(),
|
||
load_time=load_time,
|
||
memory_size=memory_size
|
||
)
|
||
|
||
# 添加到缓存
|
||
with self._lock:
|
||
self._cache[cache_key] = entry
|
||
self._current_memory_usage += memory_size
|
||
|
||
# 清理缓存
|
||
self._cleanup_cache()
|
||
|
||
logger.info(f"加载完成: {file_path} ({memory_size / 1024 / 1024:.1f}MB, {load_time:.2f}s)")
|
||
return entry
|
||
|
||
except Exception as e:
|
||
logger.error(f"加载 embedding 数据失败: {file_path}, {e}")
|
||
return None
|
||
|
||
def get_cache_stats(self) -> Dict[str, Any]:
|
||
"""获取缓存统计信息"""
|
||
with self._lock:
|
||
return {
|
||
"cache_size": len(self._cache),
|
||
"max_cache_size": self.max_cache_size,
|
||
"memory_usage_mb": self._current_memory_usage / 1024 / 1024,
|
||
"max_memory_mb": self.max_memory_bytes / 1024 / 1024,
|
||
"memory_usage_percent": (self._current_memory_usage / self.max_memory_bytes) * 100,
|
||
"entries": [
|
||
{
|
||
"file_path": entry.file_path,
|
||
"access_count": entry.access_count,
|
||
"last_access_time": entry.last_access_time,
|
||
"memory_size_mb": entry.memory_size / 1024 / 1024
|
||
}
|
||
for entry in self._cache.values()
|
||
]
|
||
}
|
||
|
||
def clear_cache(self):
|
||
"""清空缓存"""
|
||
with self._lock:
|
||
cleared_count = len(self._cache)
|
||
cleared_memory = self._current_memory_usage
|
||
self._cache.clear()
|
||
self._current_memory_usage = 0
|
||
logger.info(f"清空缓存: {cleared_count} 个条目, {cleared_memory / 1024 / 1024:.1f}MB")
|
||
|
||
|
||
class GlobalModelManager:
|
||
"""全局模型管理器"""
|
||
|
||
def __init__(self, model_name: str = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'):
|
||
self.model_name = model_name
|
||
self.local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
|
||
self._model: Optional[SentenceTransformer] = None
|
||
self._lock = asyncio.Lock()
|
||
self._load_time = 0
|
||
self._device = 'cpu'
|
||
|
||
logger.info(f"GlobalModelManager 初始化: {model_name}")
|
||
|
||
async def get_model(self) -> SentenceTransformer:
|
||
"""获取模型实例(延迟加载)"""
|
||
if self._model is not None:
|
||
return self._model
|
||
|
||
async with self._lock:
|
||
# 双重检查
|
||
if self._model is not None:
|
||
return self._model
|
||
|
||
try:
|
||
start_time = time.time()
|
||
|
||
# 检查本地模型
|
||
model_path = self.local_model_path if os.path.exists(self.local_model_path) else self.model_name
|
||
|
||
# 获取设备配置
|
||
self._device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu')
|
||
if self._device not in ['cpu', 'cuda', 'mps']:
|
||
self._device = 'cpu'
|
||
|
||
logger.info(f"加载模型: {model_path} (device: {self._device})")
|
||
|
||
# 在事件循环中运行阻塞操作
|
||
loop = asyncio.get_event_loop()
|
||
self._model = await loop.run_in_executor(
|
||
None,
|
||
lambda: SentenceTransformer(model_path, device=self._device)
|
||
)
|
||
|
||
self._load_time = time.time() - start_time
|
||
logger.info(f"模型加载完成: {self._load_time:.2f}s")
|
||
|
||
return self._model
|
||
|
||
except Exception as e:
|
||
logger.error(f"模型加载失败: {e}")
|
||
raise
|
||
|
||
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
||
"""编码文本为向量"""
|
||
if not texts:
|
||
return np.array([])
|
||
|
||
model = await self.get_model()
|
||
|
||
try:
|
||
# 在事件循环中运行阻塞操作
|
||
loop = asyncio.get_event_loop()
|
||
embeddings = await loop.run_in_executor(
|
||
None,
|
||
lambda: model.encode(texts, batch_size=batch_size, show_progress_bar=False)
|
||
)
|
||
|
||
# 确保返回 numpy 数组
|
||
if hasattr(embeddings, 'cpu'):
|
||
embeddings = embeddings.cpu().numpy()
|
||
elif hasattr(embeddings, 'numpy'):
|
||
embeddings = embeddings.numpy()
|
||
elif not isinstance(embeddings, np.ndarray):
|
||
embeddings = np.array(embeddings)
|
||
|
||
return embeddings
|
||
|
||
except Exception as e:
|
||
logger.error(f"文本编码失败: {e}")
|
||
raise
|
||
|
||
def get_model_info(self) -> Dict[str, Any]:
|
||
"""获取模型信息"""
|
||
return {
|
||
"model_name": self.model_name,
|
||
"local_model_path": self.local_model_path,
|
||
"device": self._device,
|
||
"is_loaded": self._model is not None,
|
||
"load_time": self._load_time
|
||
}
|
||
|
||
|
||
# 全局实例
|
||
_cache_manager = None
|
||
_model_manager = None
|
||
|
||
def get_cache_manager() -> EmbeddingCacheManager:
|
||
"""获取缓存管理器实例"""
|
||
global _cache_manager
|
||
if _cache_manager is None:
|
||
max_cache_size = int(os.getenv("EMBEDDING_MAX_CACHE_SIZE", "5"))
|
||
max_memory_mb = int(os.getenv("EMBEDDING_MAX_MEMORY_MB", "1024"))
|
||
_cache_manager = EmbeddingCacheManager(max_cache_size, max_memory_mb)
|
||
return _cache_manager
|
||
|
||
def get_model_manager() -> GlobalModelManager:
|
||
"""获取模型管理器实例"""
|
||
global _model_manager
|
||
if _model_manager is None:
|
||
model_name = os.getenv("SENTENCE_TRANSFORMER_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
|
||
_model_manager = GlobalModelManager(model_name)
|
||
return _model_manager |