Merge branch 'staging' into prod
This commit is contained in:
commit
667fdb8a3b
13
Dockerfile
13
Dockerfile
@ -9,7 +9,7 @@ ENV PYTHONPATH=/app
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# 安装系统依赖(含 LibreOffice 和 sharp 所需的 libvips)
|
||||
RUN apt-get update && apt-get install -y \
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
wget \
|
||||
gnupg2 \
|
||||
@ -25,7 +25,8 @@ RUN apt-get update && apt-get install -y \
|
||||
|
||||
# 安装Node.js (支持npx命令)
|
||||
RUN curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
|
||||
apt-get install -y nodejs
|
||||
apt-get install -y --no-install-recommends nodejs && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 安装uv (Python包管理器)
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
@ -35,7 +36,10 @@ ENV PATH="/root/.cargo/bin:$PATH"
|
||||
|
||||
# 复制requirements文件并安装Python依赖
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
RUN grep -Ev '^(torch|triton|nvidia-[^=]+|sentence-transformers|transformers|tokenizers|safetensors|scikit-learn|scipy|huggingface-hub|hf-xet)==' requirements.txt > /tmp/requirements.runtime.txt && \
|
||||
! grep -E '^(torch|triton|nvidia-[^=]+|sentence-transformers|transformers|tokenizers|safetensors|scikit-learn|scipy|huggingface-hub|hf-xet)==' /tmp/requirements.runtime.txt && \
|
||||
pip install --no-cache-dir -r /tmp/requirements.runtime.txt && \
|
||||
rm -f /tmp/requirements.runtime.txt
|
||||
|
||||
# 安装 Playwright 并下载 Chromium
|
||||
RUN pip install --no-cache-dir playwright && \
|
||||
@ -48,9 +52,6 @@ RUN mkdir -p /app/projects
|
||||
RUN mkdir -p /app/public
|
||||
RUN mkdir -p /app/models
|
||||
|
||||
# 下载sentence-transformers模型到models目录
|
||||
RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('TaylorAI/gte-tiny'); model.save('/app/models/gte-tiny')"
|
||||
|
||||
FROM base AS bytecode-builder
|
||||
|
||||
# 复制应用代码,仅在构建阶段编译为字节码
|
||||
|
||||
@ -10,7 +10,8 @@ ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# 安装系统依赖(含 LibreOffice 和 sharp 所需的 libvips)
|
||||
RUN sed -i 's|http://deb.debian.org|http://mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources && \
|
||||
apt-get update && apt-get install -y \
|
||||
apt-get -o Acquire::Retries=3 update && \
|
||||
apt-get -o Acquire::Retries=3 install -y --no-install-recommends \
|
||||
curl \
|
||||
wget \
|
||||
gnupg2 \
|
||||
@ -26,7 +27,8 @@ RUN sed -i 's|http://deb.debian.org|http://mirrors.aliyun.com|g' /etc/apt/source
|
||||
|
||||
# 安装Node.js (支持npx命令)
|
||||
RUN curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
|
||||
apt-get install -y nodejs
|
||||
apt-get -o Acquire::Retries=3 install -y --no-install-recommends nodejs && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 安装uv (Python包管理器)
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
@ -36,7 +38,10 @@ ENV PATH="/root/.cargo/bin:$PATH"
|
||||
|
||||
# 复制requirements文件并安装Python依赖
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -i https://mirrors.aliyun.com/pypi/simple/ -r requirements.txt
|
||||
RUN grep -Ev '^(torch|triton|nvidia-[^=]+|sentence-transformers|transformers|tokenizers|safetensors|scikit-learn|scipy|huggingface-hub|hf-xet)==' requirements.txt > /tmp/requirements.runtime.txt && \
|
||||
! grep -E '^(torch|triton|nvidia-[^=]+|sentence-transformers|transformers|tokenizers|safetensors|scikit-learn|scipy|huggingface-hub|hf-xet)==' /tmp/requirements.runtime.txt && \
|
||||
pip install --no-cache-dir -i https://mirrors.aliyun.com/pypi/simple/ -r /tmp/requirements.runtime.txt && \
|
||||
rm -f /tmp/requirements.runtime.txt
|
||||
|
||||
# 安装 Playwright 并下载 Chromium
|
||||
RUN pip install --no-cache-dir -i https://mirrors.aliyun.com/pypi/simple/ playwright && \
|
||||
|
||||
@ -5,14 +5,18 @@ Responsible for creating, caching, and managing the lifecycle of Mem0 client ins
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import threading
|
||||
import concurrent.futures
|
||||
from typing import Any, Dict, List, Optional, Literal
|
||||
from collections import OrderedDict
|
||||
from embedding.manager import GlobalModelManager, get_model_manager
|
||||
from urllib.parse import unquote, urlparse
|
||||
from embedding.manager import get_model_manager
|
||||
import json_repair
|
||||
from psycopg2 import pool
|
||||
from utils.settings import (
|
||||
CHECKPOINT_DB_URL,
|
||||
EMBEDDING_API_KEY,
|
||||
EMBEDDING_BASE_URL,
|
||||
EMBEDDING_DIMENSIONS,
|
||||
EMBEDDING_MODEL_NAME,
|
||||
MEM0_POOL_SIZE
|
||||
)
|
||||
from .mem0_config import Mem0Config
|
||||
@ -27,15 +31,9 @@ logger = logging.getLogger("app")
|
||||
|
||||
class CustomMem0Embedding:
|
||||
"""
|
||||
Custom Mem0 embedding class that directly uses the project's existing GlobalModelManager
|
||||
|
||||
This prevents Mem0 from loading the same model again and saves memory
|
||||
Custom Mem0 embedding class backed by the external embedding API.
|
||||
"""
|
||||
|
||||
_model = None # Class variable caching the model instance
|
||||
_lock = threading.Lock() # Thread-safe lock
|
||||
_executor = None # Thread pool executor
|
||||
|
||||
def __init__(self, config: Optional[Any] = None):
|
||||
"""Initialize the custom embedding."""
|
||||
# Create a simple config object compatible with Mem0 telemetry code
|
||||
@ -46,42 +44,7 @@ class CustomMem0Embedding:
|
||||
@property
|
||||
def embedding_dims(self):
|
||||
"""Get the embedding dimension."""
|
||||
return 384 # Dimension of gte-tiny
|
||||
|
||||
def _get_model_sync(self):
|
||||
"""Synchronously get the model without using asyncio.run()."""
|
||||
# First try to get an already-loaded model from the manager
|
||||
manager = get_model_manager()
|
||||
model = manager.get_model_sync()
|
||||
|
||||
if model is not None:
|
||||
# Cache the model
|
||||
CustomMem0Embedding._model = model
|
||||
return model
|
||||
|
||||
# If the model is not loaded, run async initialization in a thread pool
|
||||
if CustomMem0Embedding._executor is None:
|
||||
CustomMem0Embedding._executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=1,
|
||||
thread_name_prefix="mem0_embed"
|
||||
)
|
||||
|
||||
# Run async code in a dedicated thread
|
||||
def run_async_in_thread():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result = loop.run_until_complete(manager.get_model())
|
||||
return result
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
future = CustomMem0Embedding._executor.submit(run_async_in_thread)
|
||||
model = future.result(timeout=30) # 30-second timeout
|
||||
|
||||
# Cache the model
|
||||
CustomMem0Embedding._model = model
|
||||
return model
|
||||
return EMBEDDING_DIMENSIONS
|
||||
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||
"""
|
||||
@ -94,15 +57,11 @@ class CustomMem0Embedding:
|
||||
Returns:
|
||||
list: Embedding vector
|
||||
"""
|
||||
# Retrieve the model in a thread-safe manner
|
||||
if CustomMem0Embedding._model is None:
|
||||
with CustomMem0Embedding._lock:
|
||||
if CustomMem0Embedding._model is None:
|
||||
self._get_model_sync()
|
||||
|
||||
model = CustomMem0Embedding._model
|
||||
embeddings = model.encode(text, convert_to_numpy=True)
|
||||
return embeddings.tolist()
|
||||
manager = get_model_manager()
|
||||
input_texts = text if isinstance(text, list) else [text]
|
||||
embeddings = manager.encode_texts_sync(input_texts, batch_size=1)
|
||||
result = embeddings.tolist()
|
||||
return result if isinstance(text, list) else result[0]
|
||||
|
||||
# Monkey patch: replace mem0's remove_code_blocks with json_repair
|
||||
def _remove_code_blocks_with_repair(content: str) -> str:
|
||||
@ -233,27 +192,68 @@ class Mem0Manager:
|
||||
mem0_instance: Mem0 Memory instance
|
||||
"""
|
||||
try:
|
||||
# Mem0 Memory instances have a vector_store attribute of type PGVector
|
||||
if hasattr(mem0_instance, 'vector_store'):
|
||||
vector_store = mem0_instance.vector_store
|
||||
# PGVector has conn and connection_pool attributes
|
||||
if hasattr(vector_store, 'conn') and hasattr(vector_store, 'connection_pool'):
|
||||
if vector_store.conn is not None and vector_store.connection_pool is not None:
|
||||
try:
|
||||
# Close the cursor first
|
||||
if hasattr(vector_store, 'cur') and vector_store.cur:
|
||||
vector_store.cur.close()
|
||||
vector_store.cur = None
|
||||
# Return the connection to the pool
|
||||
vector_store.connection_pool.putconn(vector_store.conn)
|
||||
# Mark as cleaned up to prevent __del__ from releasing it again
|
||||
vector_store.conn = None
|
||||
logger.debug("Successfully released Mem0 database connection back to pool")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error releasing Mem0 connection: {e}")
|
||||
vector_store = getattr(mem0_instance, 'vector_store', None)
|
||||
if vector_store is not None and getattr(vector_store, 'conn', None) is not None:
|
||||
try:
|
||||
if getattr(vector_store, 'cur', None):
|
||||
vector_store.cur.close()
|
||||
vector_store.cur = None
|
||||
connection_pool = getattr(vector_store, 'connection_pool', None)
|
||||
if connection_pool is not None:
|
||||
connection_pool.putconn(vector_store.conn)
|
||||
logger.debug("Successfully released Mem0 database connection back to pool")
|
||||
else:
|
||||
vector_store.conn.close()
|
||||
logger.debug("Successfully closed Mem0 database connection")
|
||||
vector_store.conn = None
|
||||
except Exception as e:
|
||||
logger.warning(f"Error releasing Mem0 connection: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning up Mem0 instance: {e}")
|
||||
|
||||
def _build_pgvector_config(self, agent_id: str) -> Dict[str, Any]:
|
||||
"""Build Mem0 PGVector config using only fields accepted by mem0."""
|
||||
parsed_url = urlparse(CHECKPOINT_DB_URL)
|
||||
if parsed_url.scheme not in ("postgresql", "postgres"):
|
||||
raise ValueError(f"Unsupported CHECKPOINT_DB_URL scheme: {parsed_url.scheme}")
|
||||
|
||||
return {
|
||||
"dbname": unquote(parsed_url.path.lstrip("/") or "postgres"),
|
||||
"user": unquote(parsed_url.username or ""),
|
||||
"password": unquote(parsed_url.password or ""),
|
||||
"host": parsed_url.hostname or "localhost",
|
||||
"port": parsed_url.port or 5432,
|
||||
"collection_name": f"mem0_{agent_id}".replace("-", "_")[:50],
|
||||
"embedding_model_dims": EMBEDDING_DIMENSIONS,
|
||||
}
|
||||
|
||||
def _attach_pool_to_vector_store(self, mem0_instance: Any) -> None:
|
||||
"""Move Mem0's runtime vector store onto the shared psycopg2 pool."""
|
||||
vector_store = getattr(mem0_instance, 'vector_store', None)
|
||||
if vector_store is None:
|
||||
return
|
||||
|
||||
if getattr(vector_store, 'cur', None):
|
||||
vector_store.cur.close()
|
||||
vector_store.cur = None
|
||||
if getattr(vector_store, 'conn', None) is not None:
|
||||
vector_store.conn.close()
|
||||
vector_store.conn = None
|
||||
vector_store.connection_pool = self._sync_pool
|
||||
|
||||
def _close_telemetry_vector_store(self, mem0_instance: Any) -> None:
|
||||
"""Close Mem0's migration telemetry vector-store connection after init."""
|
||||
vector_store = getattr(mem0_instance, '_telemetry_vector_store', None)
|
||||
if vector_store is None:
|
||||
return
|
||||
|
||||
if getattr(vector_store, 'cur', None):
|
||||
vector_store.cur.close()
|
||||
vector_store.cur = None
|
||||
if getattr(vector_store, 'conn', None) is not None:
|
||||
vector_store.conn.close()
|
||||
vector_store.conn = None
|
||||
|
||||
def _ensure_connection(self, mem0_instance: Any) -> None:
|
||||
"""Ensure a Mem0 instance has a database connection before use.
|
||||
|
||||
@ -268,8 +268,7 @@ class Mem0Manager:
|
||||
if hasattr(vs, 'conn') and vs.conn is None and self._sync_pool:
|
||||
vs.conn = self._sync_pool.getconn()
|
||||
vs.cur = vs.conn.cursor()
|
||||
# Ensure the connection_pool reference exists for later return
|
||||
if hasattr(vs, 'connection_pool') and vs.connection_pool is None:
|
||||
if not hasattr(vs, 'connection_pool') or vs.connection_pool is None:
|
||||
vs.connection_pool = self._sync_pool
|
||||
logger.debug("Re-acquired Mem0 database connection from pool")
|
||||
except Exception as e:
|
||||
@ -292,8 +291,11 @@ class Mem0Manager:
|
||||
if hasattr(vs, 'cur') and vs.cur:
|
||||
vs.cur.close()
|
||||
vs.cur = None
|
||||
if hasattr(vs, 'connection_pool') and vs.connection_pool is not None:
|
||||
vs.connection_pool.putconn(vs.conn)
|
||||
connection_pool = getattr(vs, 'connection_pool', None)
|
||||
if connection_pool is not None:
|
||||
connection_pool.putconn(vs.conn)
|
||||
else:
|
||||
vs.conn.close()
|
||||
vs.conn = None
|
||||
logger.debug("Released Mem0 database connection back to pool")
|
||||
except Exception as e:
|
||||
@ -376,28 +378,25 @@ class Mem0Manager:
|
||||
if not connection_pool:
|
||||
raise ValueError("Database connection pool not available")
|
||||
|
||||
# Create a custom embedder that reuses the shared model to avoid duplicate loading
|
||||
# Create a custom embedder backed by the external embedding API.
|
||||
custom_embedder = CustomMem0Embedding()
|
||||
|
||||
# Configure Mem0 to use Pgvector
|
||||
# Note: use huggingface_base_url here to bypass local model loading
|
||||
# Set a dummy base_url so HuggingFaceEmbedding does not load SentenceTransformer
|
||||
|
||||
# Configure Mem0 to use Pgvector.
|
||||
# Mem0 validates this config strictly, so connection_pool is attached after creation.
|
||||
pgvector_config = self._build_pgvector_config(agent_id)
|
||||
config_dict = {
|
||||
"vector_store": {
|
||||
"provider": "pgvector",
|
||||
"config": {
|
||||
"connection_pool": connection_pool,
|
||||
"collection_name": f"mem0_{agent_id}".replace("-", "_")[:50], # Isolate by agent_id
|
||||
"embedding_model_dims": 384, # Dimension of paraphrase-multilingual-MiniLM-L12-v2
|
||||
}
|
||||
"config": pgvector_config,
|
||||
},
|
||||
# Use huggingface_base_url to bypass model loading; it will later be replaced with the custom embedder
|
||||
# The embedder is replaced immediately after Memory is created.
|
||||
"embedder": {
|
||||
"provider": "huggingface",
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"huggingface_base_url": "http://dummy-url-that-will-be-replaced",
|
||||
"api_key": "dummy-key" # Placeholder to prevent OpenAI client validation failure
|
||||
"api_key": EMBEDDING_API_KEY,
|
||||
"openai_base_url": EMBEDDING_BASE_URL,
|
||||
"model": EMBEDDING_MODEL_NAME,
|
||||
"embedding_dims": EMBEDDING_DIMENSIONS,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -432,6 +431,8 @@ class Mem0Manager:
|
||||
|
||||
# Create the Mem0 instance
|
||||
mem = Memory.from_config(config_dict)
|
||||
self._attach_pool_to_vector_store(mem)
|
||||
self._close_telemetry_vector_store(mem)
|
||||
logger.debug(f"Original embedder type: {type(mem.embedding_model).__name__}")
|
||||
logger.debug(f"Original embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}")
|
||||
|
||||
|
||||
@ -4,128 +4,93 @@ Model pool manager and cache system
|
||||
Support high-concurrency embedding retrieval services
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import time
|
||||
import pickle
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from collections import OrderedDict
|
||||
from utils.settings import SENTENCE_TRANSFORMER_MODEL
|
||||
import threading
|
||||
import psutil
|
||||
from typing import Dict, List, Any
|
||||
from utils.settings import (
|
||||
EMBEDDING_API_KEY,
|
||||
EMBEDDING_BASE_URL,
|
||||
EMBEDDING_DIMENSIONS,
|
||||
EMBEDDING_MODEL_NAME,
|
||||
EMBEDDING_TIMEOUT,
|
||||
)
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
|
||||
class GlobalModelManager:
|
||||
"""Global model manager"""
|
||||
"""OpenAI-compatible embedding API manager."""
|
||||
|
||||
def __init__(self, model_name: str = 'TaylorAI/gte-tiny'):
|
||||
self.model_name = model_name
|
||||
self.local_model_path = "./models/gte-tiny"
|
||||
self._model: Optional[SentenceTransformer] = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._load_time = 0
|
||||
self._device = 'cpu'
|
||||
def __init__(self):
|
||||
self.external_model_name = EMBEDDING_MODEL_NAME
|
||||
self.external_base_url = EMBEDDING_BASE_URL.rstrip("/")
|
||||
self.external_api_key = EMBEDDING_API_KEY
|
||||
self.external_dimensions = EMBEDDING_DIMENSIONS
|
||||
self.external_timeout = EMBEDDING_TIMEOUT
|
||||
|
||||
logger.info(f"GlobalModelManager initialized: {model_name}")
|
||||
|
||||
async def get_model(self) -> SentenceTransformer:
|
||||
"""Get the model instance with lazy loading"""
|
||||
if self._model is not None:
|
||||
return self._model
|
||||
|
||||
async with self._lock:
|
||||
# Double-check
|
||||
if self._model is not None:
|
||||
return self._model
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Check the local model
|
||||
model_path = self.local_model_path if os.path.exists(self.local_model_path) else self.model_name
|
||||
|
||||
# Get device configuration
|
||||
self._device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu')
|
||||
if self._device not in ['cpu', 'cuda', 'mps']:
|
||||
self._device = 'cpu'
|
||||
|
||||
logger.info(f"Loading model: {model_path} (device: {self._device})")
|
||||
|
||||
# Run blocking operations in the event loop executor
|
||||
loop = asyncio.get_event_loop()
|
||||
self._model = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: SentenceTransformer(
|
||||
model_path,
|
||||
device=self._device
|
||||
)
|
||||
)
|
||||
|
||||
self._load_time = time.time() - start_time
|
||||
logger.info(f"Model loading completed: {self._load_time:.2f}s")
|
||||
|
||||
return self._model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Model loading failed: {e}")
|
||||
raise
|
||||
logger.info(f"GlobalModelManager initialized: external_model={self.external_model_name}")
|
||||
|
||||
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
||||
"""Encode texts into vectors"""
|
||||
"""Encode texts into vectors through the external embedding API."""
|
||||
if not texts:
|
||||
return np.array([])
|
||||
|
||||
model = await self.get_model()
|
||||
|
||||
try:
|
||||
# Run blocking operations in the event loop executor
|
||||
loop = asyncio.get_event_loop()
|
||||
embeddings = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: model.encode(texts, batch_size=batch_size, show_progress_bar=False)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._encode_texts_external(texts)
|
||||
)
|
||||
|
||||
def encode_texts_sync(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
||||
"""Synchronously encode texts. Used by synchronous integrations such as Mem0."""
|
||||
if not texts:
|
||||
return np.array([])
|
||||
|
||||
return self._encode_texts_external(texts)
|
||||
|
||||
def _encode_texts_external(self, texts: List[str]) -> np.ndarray:
|
||||
if not self.external_base_url:
|
||||
raise RuntimeError("EMBEDDING_BASE_URL is required for embedding API calls")
|
||||
|
||||
endpoint = f"{self.external_base_url}/embeddings"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.external_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.external_api_key}"
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": self.external_model_name,
|
||||
"input": texts,
|
||||
}
|
||||
if self.external_dimensions and self.external_model_name != "text-embedding-ada-002":
|
||||
payload["dimensions"] = self.external_dimensions
|
||||
|
||||
response = requests.post(
|
||||
endpoint,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=self.external_timeout,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"External embedding API failed: {response.status_code} - {response.text}")
|
||||
|
||||
data = response.json()
|
||||
embeddings = [item["embedding"] for item in data.get("data", [])]
|
||||
if len(embeddings) != len(texts):
|
||||
raise RuntimeError(
|
||||
f"External embedding API returned {len(embeddings)} embeddings for {len(texts)} texts"
|
||||
)
|
||||
|
||||
# Ensure a NumPy array is returned
|
||||
if hasattr(embeddings, 'cpu'):
|
||||
embeddings = embeddings.cpu().numpy()
|
||||
elif hasattr(embeddings, 'numpy'):
|
||||
embeddings = embeddings.numpy()
|
||||
elif not isinstance(embeddings, np.ndarray):
|
||||
embeddings = np.array(embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Text encoding failed: {e}")
|
||||
raise
|
||||
|
||||
def get_model_sync(self) -> Optional[SentenceTransformer]:
|
||||
"""Synchronously get the model instance for synchronous contexts
|
||||
|
||||
If the model is not loaded, return None. The caller should ensure the model is initialized via the async API first.
|
||||
|
||||
Returns:
|
||||
The loaded SentenceTransformer model, or None
|
||||
"""
|
||||
return self._model
|
||||
return np.array(embeddings)
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""Get model information"""
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"local_model_path": self.local_model_path,
|
||||
"device": self._device,
|
||||
"is_loaded": self._model is not None,
|
||||
"load_time": self._load_time
|
||||
"provider": "openai_compatible",
|
||||
"base_url": self.external_base_url,
|
||||
"model_name": self.external_model_name,
|
||||
"dimensions": self.external_dimensions,
|
||||
}
|
||||
|
||||
|
||||
@ -137,5 +102,5 @@ def get_model_manager() -> GlobalModelManager:
|
||||
"""Get the model manager instance"""
|
||||
global _model_manager
|
||||
if _model_manager is None:
|
||||
_model_manager = GlobalModelManager(SENTENCE_TRANSFORMER_MODEL)
|
||||
_model_manager = GlobalModelManager()
|
||||
return _model_manager
|
||||
|
||||
@ -6,8 +6,8 @@ Data merging functions for combining processed file results.
|
||||
import os
|
||||
import pickle
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import json
|
||||
from typing import Dict
|
||||
from utils.settings import EMBEDDING_MODEL_NAME
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger('app')
|
||||
@ -202,7 +202,7 @@ def merge_embeddings_by_group(unique_id: str, group_name: str) -> Dict:
|
||||
dimensions = 0
|
||||
chunking_strategy = 'unknown'
|
||||
chunking_params = {}
|
||||
model_path = 'TaylorAI/gte-tiny'
|
||||
model_path = EMBEDDING_MODEL_NAME
|
||||
|
||||
for filename_stem, embedding_path in sorted(embedding_files):
|
||||
try:
|
||||
|
||||
@ -30,7 +30,12 @@ PROJECT_NAME = os.getenv("PROJECT_NAME", "support")
|
||||
TOKENIZERS_PARALLELISM = os.getenv("TOKENIZERS_PARALLELISM", "true")
|
||||
|
||||
# Embedding Model Settings
|
||||
SENTENCE_TRANSFORMER_MODEL = os.getenv("SENTENCE_TRANSFORMER_MODEL", "TaylorAI/gte-tiny")
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4")
|
||||
EMBEDDING_BASE_URL = os.getenv("EMBEDDING_BASE_URL", "https://one-dev.felo.me/v1")
|
||||
EMBEDDING_API_KEY = os.getenv("EMBEDDING_API_KEY", OPENAI_API_KEY)
|
||||
EMBEDDING_MODEL_NAME = os.getenv("EMBEDDING_MODEL_NAME", "text-embedding-3-small")
|
||||
EMBEDDING_DIMENSIONS = int(os.getenv("EMBEDDING_DIMENSIONS", "384"))
|
||||
EMBEDDING_TIMEOUT = int(os.getenv("EMBEDDING_TIMEOUT", "30"))
|
||||
|
||||
# Tool Output Length Control Settings
|
||||
TOOL_OUTPUT_MAX_LENGTH = SUMMARIZATION_MAX_TOKENS
|
||||
|
||||
Loading…
Reference in New Issue
Block a user