qwen_agent/embedding/manager.py

107 lines
3.4 KiB
Python

#!/usr/bin/env python3
"""
Model pool manager and cache system
Support high-concurrency embedding retrieval services
"""
import asyncio
import logging
from typing import Dict, List, Any
from utils.settings import (
EMBEDDING_API_KEY,
EMBEDDING_BASE_URL,
EMBEDDING_DIMENSIONS,
EMBEDDING_MODEL_NAME,
EMBEDDING_TIMEOUT,
)
import numpy as np
import requests
logger = logging.getLogger('app')
class GlobalModelManager:
"""OpenAI-compatible embedding API manager."""
def __init__(self):
self.external_model_name = EMBEDDING_MODEL_NAME
self.external_base_url = EMBEDDING_BASE_URL.rstrip("/")
self.external_api_key = EMBEDDING_API_KEY
self.external_dimensions = EMBEDDING_DIMENSIONS
self.external_timeout = EMBEDDING_TIMEOUT
logger.info(f"GlobalModelManager initialized: external_model={self.external_model_name}")
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""Encode texts into vectors through the external embedding API."""
if not texts:
return np.array([])
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: self._encode_texts_external(texts)
)
def encode_texts_sync(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""Synchronously encode texts. Used by synchronous integrations such as Mem0."""
if not texts:
return np.array([])
return self._encode_texts_external(texts)
def _encode_texts_external(self, texts: List[str]) -> np.ndarray:
if not self.external_base_url:
raise RuntimeError("EMBEDDING_BASE_URL is required for embedding API calls")
endpoint = f"{self.external_base_url}/embeddings"
headers = {"Content-Type": "application/json"}
if self.external_api_key:
headers["Authorization"] = f"Bearer {self.external_api_key}"
payload: Dict[str, Any] = {
"model": self.external_model_name,
"input": texts,
}
if self.external_dimensions and self.external_model_name != "text-embedding-ada-002":
payload["dimensions"] = self.external_dimensions
response = requests.post(
endpoint,
json=payload,
headers=headers,
timeout=self.external_timeout,
)
if response.status_code != 200:
raise RuntimeError(f"External embedding API failed: {response.status_code} - {response.text}")
data = response.json()
embeddings = [item["embedding"] for item in data.get("data", [])]
if len(embeddings) != len(texts):
raise RuntimeError(
f"External embedding API returned {len(embeddings)} embeddings for {len(texts)} texts"
)
return np.array(embeddings)
def get_model_info(self) -> Dict[str, Any]:
"""Get model information"""
return {
"provider": "openai_compatible",
"base_url": self.external_base_url,
"model_name": self.external_model_name,
"dimensions": self.external_dimensions,
}
# Global instance
_model_manager = None
def get_model_manager() -> GlobalModelManager:
"""Get the model manager instance"""
global _model_manager
if _model_manager is None:
_model_manager = GlobalModelManager()
return _model_manager