删除model_client

This commit is contained in:
朱潮 2025-11-20 20:54:38 +08:00
parent 37784ebefe
commit ef7cb7560f
6 changed files with 6 additions and 895 deletions

View File

@ -3,14 +3,11 @@ Embedding Package
提供文本编码和语义搜索功能
"""
from .manager import get_cache_manager, get_model_manager
from .search_service import get_search_service
from .manager import get_model_manager
from .embedding import embed_document, split_document_by_pages
__all__ = [
'get_cache_manager',
'get_model_manager',
'get_search_service',
'embed_document',
'split_document_by_pages'
]
]

View File

@ -22,203 +22,6 @@ from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
@dataclass
class CacheEntry:
"""缓存条目"""
embeddings: np.ndarray
chunks: List[str]
chunking_strategy: str
chunking_params: Dict[str, Any]
model_path: str
file_path: str
file_mtime: float # 文件修改时间
access_count: int # 访问次数
last_access_time: float # 最后访问时间
load_time: float # 加载时间
memory_size: int # 占用内存大小(字节)
class EmbeddingCacheManager:
"""Embedding 数据缓存管理器"""
def __init__(self, max_cache_size: int = 5, max_memory_mb: int = 1024):
self.max_cache_size = max_cache_size # 最大缓存条目数
self.max_memory_bytes = max_memory_mb * 1024 * 1024 # 最大内存使用量
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._lock = threading.RLock()
self._current_memory_usage = 0
logger.info(f"EmbeddingCacheManager 初始化: max_size={max_cache_size}, max_memory={max_memory_mb}MB")
def _get_file_key(self, file_path: str) -> str:
"""生成文件缓存键"""
# 使用绝对路径和文件修改时间生成唯一键
try:
abs_path = os.path.abspath(file_path)
mtime = os.path.getmtime(abs_path)
key_data = f"{abs_path}:{mtime}"
return hashlib.md5(key_data.encode()).hexdigest()
except Exception as e:
logger.warning(f"生成文件键失败: {file_path}, {e}")
return hashlib.md5(file_path.encode()).hexdigest()
def _estimate_memory_size(self, embeddings: np.ndarray, chunks: List[str]) -> int:
"""估算数据内存占用"""
try:
embeddings_size = embeddings.nbytes
chunks_size = sum(len(chunk.encode('utf-8')) for chunk in chunks)
overhead = 1024 * 1024 # 1MB 开销
return embeddings_size + chunks_size + overhead
except Exception:
return 100 * 1024 * 1024 # 默认100MB
def _cleanup_cache(self):
"""清理缓存以释放内存"""
with self._lock:
# 按访问时间和次数排序,清理最少使用的条目
entries = list(self._cache.items())
# 计算需要清理的条目
to_remove = []
current_size = len(self._cache)
# 如果缓存条目超限,清理最老的条目
if current_size > self.max_cache_size:
to_remove.extend(entries[:current_size - self.max_cache_size])
# 如果内存超限按LRU策略清理
if self._current_memory_usage > self.max_memory_bytes:
# 按最后访问时间排序
entries.sort(key=lambda x: x[1].last_access_time)
accumulated_size = 0
for key, entry in entries:
if accumulated_size >= self._current_memory_usage - self.max_memory_bytes:
break
to_remove.append((key, entry))
accumulated_size += entry.memory_size
# 执行清理
for key, entry in to_remove:
if key in self._cache:
del self._cache[key]
self._current_memory_usage -= entry.memory_size
logger.info(f"清理缓存条目: {key} ({entry.memory_size / 1024 / 1024:.1f}MB)")
async def load_embedding_data(self, file_path: str) -> Optional[CacheEntry]:
"""加载 embedding 数据"""
cache_key = self._get_file_key(file_path)
# 检查缓存
with self._lock:
if cache_key in self._cache:
entry = self._cache[cache_key]
entry.access_count += 1
entry.last_access_time = time.time()
# 移动到末尾(最近使用)
self._cache.move_to_end(cache_key)
logger.info(f"缓存命中: {file_path}")
return entry
# 缓存未命中,异步加载数据
try:
start_time = time.time()
# 检查文件是否存在
if not os.path.exists(file_path):
logger.error(f"文件不存在: {file_path}")
return None
# 加载 embedding 数据
with open(file_path, 'rb') as f:
embedding_data = pickle.load(f)
# 兼容新旧数据结构
if 'chunks' in embedding_data:
chunks = embedding_data['chunks']
embeddings = embedding_data['embeddings']
chunking_strategy = embedding_data.get('chunking_strategy', 'unknown')
chunking_params = embedding_data.get('chunking_params', {})
model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
else:
chunks = embedding_data['sentences']
embeddings = embedding_data['embeddings']
chunking_strategy = 'line'
chunking_params = {}
model_path = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
# 确保 embeddings 是 numpy 数组
if hasattr(embeddings, 'cpu'):
embeddings = embeddings.cpu().numpy()
elif hasattr(embeddings, 'numpy'):
embeddings = embeddings.numpy()
elif not isinstance(embeddings, np.ndarray):
embeddings = np.array(embeddings)
# 创建缓存条目
load_time = time.time() - start_time
file_mtime = os.path.getmtime(file_path)
memory_size = self._estimate_memory_size(embeddings, chunks)
entry = CacheEntry(
embeddings=embeddings,
chunks=chunks,
chunking_strategy=chunking_strategy,
chunking_params=chunking_params,
model_path=model_path,
file_path=file_path,
file_mtime=file_mtime,
access_count=1,
last_access_time=time.time(),
load_time=load_time,
memory_size=memory_size
)
# 添加到缓存
with self._lock:
self._cache[cache_key] = entry
self._current_memory_usage += memory_size
# 清理缓存
self._cleanup_cache()
logger.info(f"加载完成: {file_path} ({memory_size / 1024 / 1024:.1f}MB, {load_time:.2f}s)")
return entry
except Exception as e:
logger.error(f"加载 embedding 数据失败: {file_path}, {e}")
return None
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
with self._lock:
return {
"cache_size": len(self._cache),
"max_cache_size": self.max_cache_size,
"memory_usage_mb": self._current_memory_usage / 1024 / 1024,
"max_memory_mb": self.max_memory_bytes / 1024 / 1024,
"memory_usage_percent": (self._current_memory_usage / self.max_memory_bytes) * 100,
"entries": [
{
"file_path": entry.file_path,
"access_count": entry.access_count,
"last_access_time": entry.last_access_time,
"memory_size_mb": entry.memory_size / 1024 / 1024
}
for entry in self._cache.values()
]
}
def clear_cache(self):
"""清空缓存"""
with self._lock:
cleared_count = len(self._cache)
cleared_memory = self._current_memory_usage
self._cache.clear()
self._current_memory_usage = 0
logger.info(f"清空缓存: {cleared_count} 个条目, {cleared_memory / 1024 / 1024:.1f}MB")
class GlobalModelManager:
"""全局模型管理器"""
@ -312,17 +115,8 @@ class GlobalModelManager:
# 全局实例
_cache_manager = None
_model_manager = None
def get_cache_manager() -> EmbeddingCacheManager:
"""获取缓存管理器实例"""
global _cache_manager
if _cache_manager is None:
max_cache_size = int(os.getenv("EMBEDDING_MAX_CACHE_SIZE", "5"))
max_memory_mb = int(os.getenv("EMBEDDING_MAX_MEMORY_MB", "1024"))
_cache_manager = EmbeddingCacheManager(max_cache_size, max_memory_mb)
return _cache_manager
def get_model_manager() -> GlobalModelManager:
"""获取模型管理器实例"""
@ -330,4 +124,4 @@ def get_model_manager() -> GlobalModelManager:
if _model_manager is None:
model_name = os.getenv("SENTENCE_TRANSFORMER_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
_model_manager = GlobalModelManager(model_name)
return _model_manager
return _model_manager

View File

@ -1,181 +0,0 @@
#!/usr/bin/env python3
"""
模型服务客户端
提供与 model_server.py 交互的客户端接口
"""
import os
import asyncio
import logging
from typing import List, Union, Dict, Any
import numpy as np
import aiohttp
import json
from sentence_transformers import util
logger = logging.getLogger(__name__)
class ModelClient:
"""模型服务客户端"""
def __init__(self, base_url: str = "http://127.0.0.1:8000"):
self.base_url = base_url.rstrip('/')
self.session = None
async def __aenter__(self):
self.session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=300) # 5分钟超时
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.session:
await self.session.close()
async def health_check(self) -> Dict[str, Any]:
"""检查服务健康状态"""
try:
async with self.session.get(f"{self.base_url}/health") as response:
return await response.json()
except Exception as e:
logger.error(f"健康检查失败: {e}")
return {"status": "unavailable", "error": str(e)}
async def load_model(self) -> bool:
"""预加载模型"""
try:
async with self.session.post(f"{self.base_url}/load_model") as response:
result = await response.json()
return response.status == 200
except Exception as e:
logger.error(f"模型加载失败: {e}")
return False
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""编码文本为向量"""
if not texts:
return np.array([])
try:
data = {
"texts": texts,
"batch_size": batch_size
}
async with self.session.post(
f"{self.base_url}/encode",
json=data
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"编码失败: {error_text}")
result = await response.json()
return np.array(result["embeddings"])
except Exception as e:
logger.error(f"文本编码失败: {e}")
raise
async def compute_similarity(self, text1: str, text2: str) -> float:
"""计算两个文本的相似度"""
try:
data = {
"text1": text1,
"text2": text2
}
async with self.session.post(
f"{self.base_url}/similarity",
json=data
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"相似度计算失败: {error_text}")
result = await response.json()
return result["similarity_score"]
except Exception as e:
logger.error(f"相似度计算失败: {e}")
raise
# 同步版本的客户端包装器
class SyncModelClient:
"""同步模型客户端"""
def __init__(self, base_url: str = "http://127.0.0.1:8000"):
self.base_url = base_url
self._async_client = ModelClient(base_url)
def _run_async(self, coro):
"""运行异步协程"""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coro)
def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""同步版本的文本编码"""
async def _encode():
async with self._async_client as client:
return await client.encode_texts(texts, batch_size)
return self._run_async(_encode())
def compute_similarity(self, text1: str, text2: str) -> float:
"""同步版本的相似度计算"""
async def _similarity():
async with self._async_client as client:
return await client.compute_similarity(text1, text2)
return self._run_async(_similarity())
def health_check(self) -> Dict[str, Any]:
"""同步版本的健康检查"""
async def _health():
async with self._async_client as client:
return await client.health_check()
return self._run_async(_health())
def load_model(self) -> bool:
"""同步版本的模型加载"""
async def _load():
async with self._async_client as client:
return await client.load_model()
return self._run_async(_load())
# 全局客户端实例(单例模式)
_client_instance = None
def get_model_client(base_url: str = None) -> SyncModelClient:
"""获取模型客户端实例(单例)"""
global _client_instance
if _client_instance is None:
url = base_url or os.environ.get('MODEL_SERVER_URL', 'http://127.0.0.1:8000')
_client_instance = SyncModelClient(url)
return _client_instance
# 便捷函数
def encode_texts(texts: List[str], batch_size: int = 32) -> np.ndarray:
"""便捷的文本编码函数"""
client = get_model_client()
return client.encode_texts(texts, batch_size)
def compute_similarity(text1: str, text2: str) -> float:
"""便捷的相似度计算函数"""
client = get_model_client()
return client.compute_similarity(text1, text2)
def is_service_available() -> bool:
"""检查模型服务是否可用"""
client = get_model_client()
health = client.health_check()
return health.get("status") == "running" or health.get("status") == "ready"

View File

@ -1,240 +0,0 @@
#!/usr/bin/env python3
"""
独立的模型服务器
提供统一的 embedding 服务减少内存占用
"""
import os
import asyncio
import logging
from typing import List, Dict, Any
import numpy as np
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from pydantic import BaseModel
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 全局模型实例
model = None
model_config = {
"model_name": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
"local_model_path": "./models/paraphrase-multilingual-MiniLM-L12-v2",
"device": "cpu"
}
# Pydantic 模型
class TextRequest(BaseModel):
texts: List[str]
batch_size: int = 32
class SimilarityRequest(BaseModel):
text1: str
text2: str
class HealthResponse(BaseModel):
status: str
model_loaded: bool
model_path: str
device: str
class EmbeddingResponse(BaseModel):
embeddings: List[List[float]]
shape: List[int]
class SimilarityResponse(BaseModel):
similarity_score: float
class ModelServer:
def __init__(self):
self.model = None
self.model_loaded = False
async def load_model(self):
"""延迟加载模型"""
if self.model_loaded:
return
logger.info("正在加载模型...")
# 检查本地模型是否存在
if os.path.exists(model_config["local_model_path"]):
model_path = model_config["local_model_path"]
logger.info(f"使用本地模型: {model_path}")
else:
model_path = model_config["model_name"]
logger.info(f"使用 HuggingFace 模型: {model_path}")
# 从环境变量获取设备配置
device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', model_config["device"])
if device not in ['cpu', 'cuda', 'mps']:
logger.warning(f"不支持的设备类型 '{device}',使用默认 CPU")
device = 'cpu'
logger.info(f"使用设备: {device}")
try:
self.model = SentenceTransformer(model_path, device=device)
self.model_loaded = True
model_config["device"] = device
model_config["current_model_path"] = model_path
logger.info("模型加载完成")
except Exception as e:
logger.error(f"模型加载失败: {e}")
raise HTTPException(status_code=500, detail=f"模型加载失败: {str(e)}")
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""编码文本为向量"""
if not texts:
return np.array([])
await self.load_model()
try:
embeddings = self.model.encode(
texts,
batch_size=batch_size,
convert_to_tensor=True,
show_progress_bar=False
)
# 转换为 CPU numpy 数组
if embeddings.is_cuda:
embeddings = embeddings.cpu().numpy()
else:
embeddings = embeddings.numpy()
return embeddings
except Exception as e:
logger.error(f"编码失败: {e}")
raise HTTPException(status_code=500, detail=f"编码失败: {str(e)}")
async def compute_similarity(self, text1: str, text2: str) -> float:
"""计算两个文本的相似度"""
if not text1 or not text2:
raise HTTPException(status_code=400, detail="文本不能为空")
await self.load_model()
try:
embeddings = self.model.encode([text1, text2], convert_to_tensor=True)
from sentence_transformers import util
similarity = util.cos_sim(embeddings[0:1], embeddings[1:2])[0][0]
return similarity.item()
except Exception as e:
logger.error(f"相似度计算失败: {e}")
raise HTTPException(status_code=500, detail=f"相似度计算失败: {str(e)}")
# 创建服务器实例
server = ModelServer()
# 创建 FastAPI 应用
app = FastAPI(
title="Embedding Model Server",
description="统一的句子嵌入和相似度计算服务",
version="1.0.0"
)
# 添加 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 健康检查
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""健康检查"""
try:
if not server.model_loaded:
return HealthResponse(
status="ready",
model_loaded=False,
model_path="",
device=model_config["device"]
)
return HealthResponse(
status="running",
model_loaded=True,
model_path=model_config.get("current_model_path", ""),
device=model_config["device"]
)
except Exception as e:
logger.error(f"健康检查失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
# 预加载模型
@app.post("/load_model")
async def preload_model():
"""预加载模型"""
await server.load_model()
return {"message": "模型加载完成"}
# 文本嵌入
@app.post("/encode", response_model=EmbeddingResponse)
async def encode_texts(request: TextRequest):
"""编码文本为向量"""
try:
embeddings = await server.encode_texts(request.texts, request.batch_size)
return EmbeddingResponse(
embeddings=embeddings.tolist(),
shape=list(embeddings.shape)
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 相似度计算
@app.post("/similarity", response_model=SimilarityResponse)
async def compute_similarity(request: SimilarityRequest):
"""计算两个文本的相似度"""
try:
similarity = await server.compute_similarity(request.text1, request.text2)
return SimilarityResponse(similarity_score=similarity)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 批量编码(用于文档处理)
@app.post("/batch_encode")
async def batch_encode(texts: List[str], batch_size: int = 64):
"""批量编码大量文本"""
if not texts:
return {"embeddings": [], "shape": [0, 0]}
try:
embeddings = await server.encode_texts(texts, batch_size)
return {
"embeddings": embeddings.tolist(),
"shape": list(embeddings.shape)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="启动嵌入模型服务器")
parser.add_argument("--host", default="127.0.0.1", help="服务器地址")
parser.add_argument("--port", type=int, default=8000, help="服务器端口")
parser.add_argument("--workers", type=int, default=1, help="工作进程数")
parser.add_argument("--reload", action="store_true", help="开发模式自动重载")
args = parser.parse_args()
logger.info(f"启动模型服务器: http://{args.host}:{args.port}")
uvicorn.run(
"model_server:app",
host=args.host,
port=args.port,
workers=args.workers,
reload=args.reload,
log_level="info"
)

View File

@ -1,259 +0,0 @@
#!/usr/bin/env python3
"""
语义检索服务
支持高并发的语义搜索功能
"""
import asyncio
import time
import logging
from typing import List, Dict, Any, Optional, Tuple
import numpy as np
from .manager import get_cache_manager, get_model_manager, CacheEntry
logger = logging.getLogger(__name__)
class SemanticSearchService:
"""语义检索服务"""
def __init__(self):
self.cache_manager = get_cache_manager()
self.model_manager = get_model_manager()
async def semantic_search(
self,
embedding_file: str,
query: str,
top_k: int = 20,
min_score: float = 0.0
) -> Dict[str, Any]:
"""
执行语义搜索
Args:
embedding_file: embedding.pkl 文件路径
query: 查询关键词
top_k: 返回结果数量
min_score: 最小相似度阈值
Returns:
搜索结果
"""
start_time = time.time()
try:
# 验证输入参数
if not embedding_file:
return {
"success": False,
"error": "embedding_file 参数不能为空"
}
if not query or not query.strip():
return {
"success": False,
"error": "query 参数不能为空"
}
query = query.strip()
# 加载 embedding 数据
cache_entry = await self.cache_manager.load_embedding_data(embedding_file)
if cache_entry is None:
return {
"success": False,
"error": f"无法加载 embedding 文件: {embedding_file}"
}
# 编码查询
query_embeddings = await self.model_manager.encode_texts([query], batch_size=1)
if len(query_embeddings) == 0:
return {
"success": False,
"error": "查询编码失败"
}
query_embedding = query_embeddings[0]
# 计算相似度
similarities = self._compute_similarities(query_embedding, cache_entry.embeddings)
# 获取 top_k 结果
top_results = self._get_top_results(
similarities,
cache_entry.chunks,
top_k,
min_score
)
processing_time = time.time() - start_time
return {
"success": True,
"query": query,
"embedding_file": embedding_file,
"top_k": top_k,
"processing_time": processing_time,
"chunking_strategy": cache_entry.chunking_strategy,
"total_chunks": len(cache_entry.chunks),
"results": top_results,
"cache_stats": {
"access_count": cache_entry.access_count,
"load_time": cache_entry.load_time
}
}
except Exception as e:
logger.error(f"语义搜索失败: {embedding_file}, {query}, {e}")
return {
"success": False,
"error": f"搜索失败: {str(e)}"
}
def _compute_similarities(self, query_embedding: np.ndarray, embeddings: np.ndarray) -> np.ndarray:
"""计算相似度分数"""
try:
# 确保数据类型和形状正确
if embeddings.ndim == 1:
embeddings = embeddings.reshape(1, -1)
# 使用余弦相似度
query_norm = np.linalg.norm(query_embedding)
embeddings_norm = np.linalg.norm(embeddings, axis=1)
# 避免除零错误
if query_norm == 0:
query_norm = 1e-8
embeddings_norm = np.where(embeddings_norm == 0, 1e-8, embeddings_norm)
# 计算余弦相似度
similarities = np.dot(embeddings, query_embedding) / (embeddings_norm * query_norm)
# 确保结果在合理范围内
similarities = np.clip(similarities, -1.0, 1.0)
return similarities
except Exception as e:
logger.error(f"相似度计算失败: {e}")
return np.array([])
def _get_top_results(
self,
similarities: np.ndarray,
chunks: List[str],
top_k: int,
min_score: float
) -> List[Dict[str, Any]]:
"""获取 top_k 结果"""
try:
if len(similarities) == 0:
return []
# 获取排序后的索引
top_indices = np.argsort(-similarities)[:top_k]
results = []
for rank, idx in enumerate(top_indices):
score = similarities[idx]
# 过滤低分结果
if score < min_score:
continue
chunk = chunks[idx]
# 创建结果条目
result = {
"rank": rank + 1,
"score": float(score),
"content": chunk,
"content_preview": self._get_preview(chunk)
}
results.append(result)
return results
except Exception as e:
logger.error(f"获取 top_k 结果失败: {e}")
return []
def _get_preview(self, content: str, max_length: int = 200) -> str:
"""获取内容预览"""
if not content:
return ""
# 清理内容
preview = content.replace('\n', ' ').strip()
# 截断
if len(preview) > max_length:
preview = preview[:max_length] + "..."
return preview
async def batch_search(
self,
requests: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
批量搜索
Args:
requests: 搜索请求列表每个请求包含 embedding_file, query, top_k
Returns:
搜索结果列表
"""
if not requests:
return []
# 并发执行搜索
tasks = []
for req in requests:
embedding_file = req.get("embedding_file", "")
query = req.get("query", "")
top_k = req.get("top_k", 20)
min_score = req.get("min_score", 0.0)
task = self.semantic_search(embedding_file, query, top_k, min_score)
tasks.append(task)
# 等待所有任务完成
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理异常
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
processed_results.append({
"success": False,
"error": f"搜索异常: {str(result)}",
"request_index": i
})
else:
result["request_index"] = i
processed_results.append(result)
return processed_results
def get_service_stats(self) -> Dict[str, Any]:
"""获取服务统计信息"""
return {
"cache_manager_stats": self.cache_manager.get_cache_stats(),
"model_manager_info": self.model_manager.get_model_info()
}
# 全局实例
_search_service = None
def get_search_service() -> SemanticSearchService:
"""获取搜索服务实例"""
global _search_service
if _search_service is None:
_search_service = SemanticSearchService()
return _search_service

View File

@ -24,7 +24,7 @@ from qwen_agent.llm.schema import ASSISTANT, FUNCTION
from pydantic import BaseModel, Field
# 导入语义检索服务
from embedding import get_search_service
from embedding import get_model_manager
# Import utility modules
from utils import (
@ -1572,7 +1572,7 @@ async def encode_texts(request: EncodeRequest):
编码结果
"""
try:
search_service = get_search_service()
model_manager = get_model_manager()
if not request.texts:
return EncodeResponse(
@ -1587,7 +1587,7 @@ async def encode_texts(request: EncodeRequest):
start_time = time.time()
# 使用模型管理器编码文本
embeddings = await search_service.model_manager.encode_texts(
embeddings = await model_manager.encode_texts(
request.texts,
batch_size=request.batch_size
)