107 lines
3.4 KiB
Python
107 lines
3.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Model pool manager and cache system
|
|
Support high-concurrency embedding retrieval services
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Dict, List, Any
|
|
from utils.settings import (
|
|
EMBEDDING_API_KEY,
|
|
EMBEDDING_BASE_URL,
|
|
EMBEDDING_DIMENSIONS,
|
|
EMBEDDING_MODEL_NAME,
|
|
EMBEDDING_TIMEOUT,
|
|
)
|
|
import numpy as np
|
|
import requests
|
|
|
|
|
|
logger = logging.getLogger('app')
|
|
|
|
|
|
class GlobalModelManager:
|
|
"""OpenAI-compatible embedding API manager."""
|
|
|
|
def __init__(self):
|
|
self.external_model_name = EMBEDDING_MODEL_NAME
|
|
self.external_base_url = EMBEDDING_BASE_URL.rstrip("/")
|
|
self.external_api_key = EMBEDDING_API_KEY
|
|
self.external_dimensions = EMBEDDING_DIMENSIONS
|
|
self.external_timeout = EMBEDDING_TIMEOUT
|
|
|
|
logger.info(f"GlobalModelManager initialized: external_model={self.external_model_name}")
|
|
|
|
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
|
"""Encode texts into vectors through the external embedding API."""
|
|
if not texts:
|
|
return np.array([])
|
|
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(
|
|
None,
|
|
lambda: self._encode_texts_external(texts)
|
|
)
|
|
|
|
def encode_texts_sync(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
|
"""Synchronously encode texts. Used by synchronous integrations such as Mem0."""
|
|
if not texts:
|
|
return np.array([])
|
|
|
|
return self._encode_texts_external(texts)
|
|
|
|
def _encode_texts_external(self, texts: List[str]) -> np.ndarray:
|
|
if not self.external_base_url:
|
|
raise RuntimeError("EMBEDDING_BASE_URL is required for embedding API calls")
|
|
|
|
endpoint = f"{self.external_base_url}/embeddings"
|
|
headers = {"Content-Type": "application/json"}
|
|
if self.external_api_key:
|
|
headers["Authorization"] = f"Bearer {self.external_api_key}"
|
|
|
|
payload: Dict[str, Any] = {
|
|
"model": self.external_model_name,
|
|
"input": texts,
|
|
}
|
|
if self.external_dimensions and self.external_model_name not in ("text-embedding-ada-002", "local-embedding"):
|
|
payload["dimensions"] = self.external_dimensions
|
|
|
|
response = requests.post(
|
|
endpoint,
|
|
json=payload,
|
|
headers=headers,
|
|
timeout=self.external_timeout,
|
|
)
|
|
if response.status_code != 200:
|
|
raise RuntimeError(f"External embedding API failed: {response.status_code} - {response.text}")
|
|
|
|
data = response.json()
|
|
embeddings = [item["embedding"] for item in data.get("data", [])]
|
|
if len(embeddings) != len(texts):
|
|
raise RuntimeError(
|
|
f"External embedding API returned {len(embeddings)} embeddings for {len(texts)} texts"
|
|
)
|
|
return np.array(embeddings)
|
|
|
|
def get_model_info(self) -> Dict[str, Any]:
|
|
"""Get model information"""
|
|
return {
|
|
"provider": "openai_compatible",
|
|
"base_url": self.external_base_url,
|
|
"model_name": self.external_model_name,
|
|
"dimensions": self.external_dimensions,
|
|
}
|
|
|
|
|
|
# Global instance
|
|
_model_manager = None
|
|
|
|
|
|
def get_model_manager() -> GlobalModelManager:
|
|
"""Get the model manager instance"""
|
|
global _model_manager
|
|
if _model_manager is None:
|
|
_model_manager = GlobalModelManager()
|
|
return _model_manager
|