qwen_agent/embedding/manager.py
朱潮 3dc119bca8 refactor(mem0): optimize connection pool and async memory handling
- Fix mem0 connection pool exhausted error with proper pooling
- Convert memory operations to async tasks
- Optimize docker-compose configuration
- Add skill upload functionality
- Reduce cache size for better performance
- Update dependencies

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 19:39:12 +08:00

142 lines
4.5 KiB
Python

#!/usr/bin/env python3
"""
模型池管理器和缓存系统
支持高并发的 embedding 检索服务
"""
import os
import asyncio
import time
import pickle
import hashlib
import logging
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from collections import OrderedDict
from utils.settings import SENTENCE_TRANSFORMER_MODEL
import threading
import psutil
import numpy as np
from sentence_transformers import SentenceTransformer
import logging
logger = logging.getLogger('app')
class GlobalModelManager:
"""全局模型管理器"""
def __init__(self, model_name: str = 'TaylorAI/gte-tiny'):
self.model_name = model_name
self.local_model_path = "./models/gte-tiny"
self._model: Optional[SentenceTransformer] = None
self._lock = asyncio.Lock()
self._load_time = 0
self._device = 'cpu'
logger.info(f"GlobalModelManager 初始化: {model_name}")
async def get_model(self) -> SentenceTransformer:
"""获取模型实例(延迟加载)"""
if self._model is not None:
return self._model
async with self._lock:
# 双重检查
if self._model is not None:
return self._model
try:
start_time = time.time()
# 检查本地模型
model_path = self.local_model_path if os.path.exists(self.local_model_path) else self.model_name
# 获取设备配置
self._device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu')
if self._device not in ['cpu', 'cuda', 'mps']:
self._device = 'cpu'
logger.info(f"加载模型: {model_path} (device: {self._device})")
# 在事件循环中运行阻塞操作
loop = asyncio.get_event_loop()
self._model = await loop.run_in_executor(
None,
lambda: SentenceTransformer(
model_path,
device=self._device
)
)
self._load_time = time.time() - start_time
logger.info(f"模型加载完成: {self._load_time:.2f}s")
return self._model
except Exception as e:
logger.error(f"模型加载失败: {e}")
raise
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""编码文本为向量"""
if not texts:
return np.array([])
model = await self.get_model()
try:
# 在事件循环中运行阻塞操作
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(
None,
lambda: model.encode(texts, batch_size=batch_size, show_progress_bar=False)
)
# 确保返回 numpy 数组
if hasattr(embeddings, 'cpu'):
embeddings = embeddings.cpu().numpy()
elif hasattr(embeddings, 'numpy'):
embeddings = embeddings.numpy()
elif not isinstance(embeddings, np.ndarray):
embeddings = np.array(embeddings)
return embeddings
except Exception as e:
logger.error(f"文本编码失败: {e}")
raise
def get_model_sync(self) -> Optional[SentenceTransformer]:
"""同步获取模型实例(供同步上下文使用)
如果模型未加载,返回 None。调用者应确保先通过异步方法初始化模型。
Returns:
已加载的 SentenceTransformer 模型,或 None
"""
return self._model
def get_model_info(self) -> Dict[str, Any]:
"""获取模型信息"""
return {
"model_name": self.model_name,
"local_model_path": self.local_model_path,
"device": self._device,
"is_loaded": self._model is not None,
"load_time": self._load_time
}
# 全局实例
_model_manager = None
def get_model_manager() -> GlobalModelManager:
"""获取模型管理器实例"""
global _model_manager
if _model_manager is None:
_model_manager = GlobalModelManager(SENTENCE_TRANSFORMER_MODEL)
return _model_manager