qwen_agent/embedding/manager.py
2025-11-20 13:29:44 +08:00

333 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
模型池管理器和缓存系统
支持高并发的 embedding 检索服务
"""
import os
import asyncio
import time
import pickle
import hashlib
import logging
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from collections import OrderedDict
import threading
import psutil
import numpy as np
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
@dataclass
class CacheEntry:
"""缓存条目"""
embeddings: np.ndarray
chunks: List[str]
chunking_strategy: str
chunking_params: Dict[str, Any]
model_path: str
file_path: str
file_mtime: float # 文件修改时间
access_count: int # 访问次数
last_access_time: float # 最后访问时间
load_time: float # 加载时间
memory_size: int # 占用内存大小(字节)
class EmbeddingCacheManager:
"""Embedding 数据缓存管理器"""
def __init__(self, max_cache_size: int = 5, max_memory_mb: int = 1024):
self.max_cache_size = max_cache_size # 最大缓存条目数
self.max_memory_bytes = max_memory_mb * 1024 * 1024 # 最大内存使用量
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._lock = threading.RLock()
self._current_memory_usage = 0
logger.info(f"EmbeddingCacheManager 初始化: max_size={max_cache_size}, max_memory={max_memory_mb}MB")
def _get_file_key(self, file_path: str) -> str:
"""生成文件缓存键"""
# 使用绝对路径和文件修改时间生成唯一键
try:
abs_path = os.path.abspath(file_path)
mtime = os.path.getmtime(abs_path)
key_data = f"{abs_path}:{mtime}"
return hashlib.md5(key_data.encode()).hexdigest()
except Exception as e:
logger.warning(f"生成文件键失败: {file_path}, {e}")
return hashlib.md5(file_path.encode()).hexdigest()
def _estimate_memory_size(self, embeddings: np.ndarray, chunks: List[str]) -> int:
"""估算数据内存占用"""
try:
embeddings_size = embeddings.nbytes
chunks_size = sum(len(chunk.encode('utf-8')) for chunk in chunks)
overhead = 1024 * 1024 # 1MB 开销
return embeddings_size + chunks_size + overhead
except Exception:
return 100 * 1024 * 1024 # 默认100MB
def _cleanup_cache(self):
"""清理缓存以释放内存"""
with self._lock:
# 按访问时间和次数排序,清理最少使用的条目
entries = list(self._cache.items())
# 计算需要清理的条目
to_remove = []
current_size = len(self._cache)
# 如果缓存条目超限,清理最老的条目
if current_size > self.max_cache_size:
to_remove.extend(entries[:current_size - self.max_cache_size])
# 如果内存超限按LRU策略清理
if self._current_memory_usage > self.max_memory_bytes:
# 按最后访问时间排序
entries.sort(key=lambda x: x[1].last_access_time)
accumulated_size = 0
for key, entry in entries:
if accumulated_size >= self._current_memory_usage - self.max_memory_bytes:
break
to_remove.append((key, entry))
accumulated_size += entry.memory_size
# 执行清理
for key, entry in to_remove:
if key in self._cache:
del self._cache[key]
self._current_memory_usage -= entry.memory_size
logger.info(f"清理缓存条目: {key} ({entry.memory_size / 1024 / 1024:.1f}MB)")
async def load_embedding_data(self, file_path: str) -> Optional[CacheEntry]:
"""加载 embedding 数据"""
cache_key = self._get_file_key(file_path)
# 检查缓存
with self._lock:
if cache_key in self._cache:
entry = self._cache[cache_key]
entry.access_count += 1
entry.last_access_time = time.time()
# 移动到末尾(最近使用)
self._cache.move_to_end(cache_key)
logger.info(f"缓存命中: {file_path}")
return entry
# 缓存未命中,异步加载数据
try:
start_time = time.time()
# 检查文件是否存在
if not os.path.exists(file_path):
logger.error(f"文件不存在: {file_path}")
return None
# 加载 embedding 数据
with open(file_path, 'rb') as f:
embedding_data = pickle.load(f)
# 兼容新旧数据结构
if 'chunks' in embedding_data:
chunks = embedding_data['chunks']
embeddings = embedding_data['embeddings']
chunking_strategy = embedding_data.get('chunking_strategy', 'unknown')
chunking_params = embedding_data.get('chunking_params', {})
model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
else:
chunks = embedding_data['sentences']
embeddings = embedding_data['embeddings']
chunking_strategy = 'line'
chunking_params = {}
model_path = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
# 确保 embeddings 是 numpy 数组
if hasattr(embeddings, 'cpu'):
embeddings = embeddings.cpu().numpy()
elif hasattr(embeddings, 'numpy'):
embeddings = embeddings.numpy()
elif not isinstance(embeddings, np.ndarray):
embeddings = np.array(embeddings)
# 创建缓存条目
load_time = time.time() - start_time
file_mtime = os.path.getmtime(file_path)
memory_size = self._estimate_memory_size(embeddings, chunks)
entry = CacheEntry(
embeddings=embeddings,
chunks=chunks,
chunking_strategy=chunking_strategy,
chunking_params=chunking_params,
model_path=model_path,
file_path=file_path,
file_mtime=file_mtime,
access_count=1,
last_access_time=time.time(),
load_time=load_time,
memory_size=memory_size
)
# 添加到缓存
with self._lock:
self._cache[cache_key] = entry
self._current_memory_usage += memory_size
# 清理缓存
self._cleanup_cache()
logger.info(f"加载完成: {file_path} ({memory_size / 1024 / 1024:.1f}MB, {load_time:.2f}s)")
return entry
except Exception as e:
logger.error(f"加载 embedding 数据失败: {file_path}, {e}")
return None
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
with self._lock:
return {
"cache_size": len(self._cache),
"max_cache_size": self.max_cache_size,
"memory_usage_mb": self._current_memory_usage / 1024 / 1024,
"max_memory_mb": self.max_memory_bytes / 1024 / 1024,
"memory_usage_percent": (self._current_memory_usage / self.max_memory_bytes) * 100,
"entries": [
{
"file_path": entry.file_path,
"access_count": entry.access_count,
"last_access_time": entry.last_access_time,
"memory_size_mb": entry.memory_size / 1024 / 1024
}
for entry in self._cache.values()
]
}
def clear_cache(self):
"""清空缓存"""
with self._lock:
cleared_count = len(self._cache)
cleared_memory = self._current_memory_usage
self._cache.clear()
self._current_memory_usage = 0
logger.info(f"清空缓存: {cleared_count} 个条目, {cleared_memory / 1024 / 1024:.1f}MB")
class GlobalModelManager:
"""全局模型管理器"""
def __init__(self, model_name: str = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'):
self.model_name = model_name
self.local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
self._model: Optional[SentenceTransformer] = None
self._lock = asyncio.Lock()
self._load_time = 0
self._device = 'cpu'
logger.info(f"GlobalModelManager 初始化: {model_name}")
async def get_model(self) -> SentenceTransformer:
"""获取模型实例(延迟加载)"""
if self._model is not None:
return self._model
async with self._lock:
# 双重检查
if self._model is not None:
return self._model
try:
start_time = time.time()
# 检查本地模型
model_path = self.local_model_path if os.path.exists(self.local_model_path) else self.model_name
# 获取设备配置
self._device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu')
if self._device not in ['cpu', 'cuda', 'mps']:
self._device = 'cpu'
logger.info(f"加载模型: {model_path} (device: {self._device})")
# 在事件循环中运行阻塞操作
loop = asyncio.get_event_loop()
self._model = await loop.run_in_executor(
None,
lambda: SentenceTransformer(model_path, device=self._device)
)
self._load_time = time.time() - start_time
logger.info(f"模型加载完成: {self._load_time:.2f}s")
return self._model
except Exception as e:
logger.error(f"模型加载失败: {e}")
raise
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""编码文本为向量"""
if not texts:
return np.array([])
model = await self.get_model()
try:
# 在事件循环中运行阻塞操作
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(
None,
lambda: model.encode(texts, batch_size=batch_size, show_progress_bar=False)
)
# 确保返回 numpy 数组
if hasattr(embeddings, 'cpu'):
embeddings = embeddings.cpu().numpy()
elif hasattr(embeddings, 'numpy'):
embeddings = embeddings.numpy()
elif not isinstance(embeddings, np.ndarray):
embeddings = np.array(embeddings)
return embeddings
except Exception as e:
logger.error(f"文本编码失败: {e}")
raise
def get_model_info(self) -> Dict[str, Any]:
"""获取模型信息"""
return {
"model_name": self.model_name,
"local_model_path": self.local_model_path,
"device": self._device,
"is_loaded": self._model is not None,
"load_time": self._load_time
}
# 全局实例
_cache_manager = None
_model_manager = None
def get_cache_manager() -> EmbeddingCacheManager:
"""获取缓存管理器实例"""
global _cache_manager
if _cache_manager is None:
max_cache_size = int(os.getenv("EMBEDDING_MAX_CACHE_SIZE", "5"))
max_memory_mb = int(os.getenv("EMBEDDING_MAX_MEMORY_MB", "1024"))
_cache_manager = EmbeddingCacheManager(max_cache_size, max_memory_mb)
return _cache_manager
def get_model_manager() -> GlobalModelManager:
"""获取模型管理器实例"""
global _model_manager
if _model_manager is None:
model_name = os.getenv("SENTENCE_TRANSFORMER_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
_model_manager = GlobalModelManager(model_name)
return _model_manager