删除model_client
This commit is contained in:
parent
37784ebefe
commit
ef7cb7560f
@ -3,14 +3,11 @@ Embedding Package
|
||||
提供文本编码和语义搜索功能
|
||||
"""
|
||||
|
||||
from .manager import get_cache_manager, get_model_manager
|
||||
from .search_service import get_search_service
|
||||
from .manager import get_model_manager
|
||||
from .embedding import embed_document, split_document_by_pages
|
||||
|
||||
__all__ = [
|
||||
'get_cache_manager',
|
||||
'get_model_manager',
|
||||
'get_search_service',
|
||||
'embed_document',
|
||||
'split_document_by_pages'
|
||||
]
|
||||
@ -22,203 +22,6 @@ from sentence_transformers import SentenceTransformer
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""缓存条目"""
|
||||
embeddings: np.ndarray
|
||||
chunks: List[str]
|
||||
chunking_strategy: str
|
||||
chunking_params: Dict[str, Any]
|
||||
model_path: str
|
||||
file_path: str
|
||||
file_mtime: float # 文件修改时间
|
||||
access_count: int # 访问次数
|
||||
last_access_time: float # 最后访问时间
|
||||
load_time: float # 加载时间
|
||||
memory_size: int # 占用内存大小(字节)
|
||||
|
||||
|
||||
class EmbeddingCacheManager:
|
||||
"""Embedding 数据缓存管理器"""
|
||||
|
||||
def __init__(self, max_cache_size: int = 5, max_memory_mb: int = 1024):
|
||||
self.max_cache_size = max_cache_size # 最大缓存条目数
|
||||
self.max_memory_bytes = max_memory_mb * 1024 * 1024 # 最大内存使用量
|
||||
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||
self._lock = threading.RLock()
|
||||
self._current_memory_usage = 0
|
||||
|
||||
logger.info(f"EmbeddingCacheManager 初始化: max_size={max_cache_size}, max_memory={max_memory_mb}MB")
|
||||
|
||||
def _get_file_key(self, file_path: str) -> str:
|
||||
"""生成文件缓存键"""
|
||||
# 使用绝对路径和文件修改时间生成唯一键
|
||||
try:
|
||||
abs_path = os.path.abspath(file_path)
|
||||
mtime = os.path.getmtime(abs_path)
|
||||
key_data = f"{abs_path}:{mtime}"
|
||||
return hashlib.md5(key_data.encode()).hexdigest()
|
||||
except Exception as e:
|
||||
logger.warning(f"生成文件键失败: {file_path}, {e}")
|
||||
return hashlib.md5(file_path.encode()).hexdigest()
|
||||
|
||||
def _estimate_memory_size(self, embeddings: np.ndarray, chunks: List[str]) -> int:
|
||||
"""估算数据内存占用"""
|
||||
try:
|
||||
embeddings_size = embeddings.nbytes
|
||||
chunks_size = sum(len(chunk.encode('utf-8')) for chunk in chunks)
|
||||
overhead = 1024 * 1024 # 1MB 开销
|
||||
return embeddings_size + chunks_size + overhead
|
||||
except Exception:
|
||||
return 100 * 1024 * 1024 # 默认100MB
|
||||
|
||||
def _cleanup_cache(self):
|
||||
"""清理缓存以释放内存"""
|
||||
with self._lock:
|
||||
# 按访问时间和次数排序,清理最少使用的条目
|
||||
entries = list(self._cache.items())
|
||||
|
||||
# 计算需要清理的条目
|
||||
to_remove = []
|
||||
current_size = len(self._cache)
|
||||
|
||||
# 如果缓存条目超限,清理最老的条目
|
||||
if current_size > self.max_cache_size:
|
||||
to_remove.extend(entries[:current_size - self.max_cache_size])
|
||||
|
||||
# 如果内存超限,按LRU策略清理
|
||||
if self._current_memory_usage > self.max_memory_bytes:
|
||||
# 按最后访问时间排序
|
||||
entries.sort(key=lambda x: x[1].last_access_time)
|
||||
|
||||
accumulated_size = 0
|
||||
for key, entry in entries:
|
||||
if accumulated_size >= self._current_memory_usage - self.max_memory_bytes:
|
||||
break
|
||||
to_remove.append((key, entry))
|
||||
accumulated_size += entry.memory_size
|
||||
|
||||
# 执行清理
|
||||
for key, entry in to_remove:
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
self._current_memory_usage -= entry.memory_size
|
||||
logger.info(f"清理缓存条目: {key} ({entry.memory_size / 1024 / 1024:.1f}MB)")
|
||||
|
||||
async def load_embedding_data(self, file_path: str) -> Optional[CacheEntry]:
|
||||
"""加载 embedding 数据"""
|
||||
cache_key = self._get_file_key(file_path)
|
||||
|
||||
# 检查缓存
|
||||
with self._lock:
|
||||
if cache_key in self._cache:
|
||||
entry = self._cache[cache_key]
|
||||
entry.access_count += 1
|
||||
entry.last_access_time = time.time()
|
||||
# 移动到末尾(最近使用)
|
||||
self._cache.move_to_end(cache_key)
|
||||
logger.info(f"缓存命中: {file_path}")
|
||||
return entry
|
||||
|
||||
# 缓存未命中,异步加载数据
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"文件不存在: {file_path}")
|
||||
return None
|
||||
|
||||
# 加载 embedding 数据
|
||||
with open(file_path, 'rb') as f:
|
||||
embedding_data = pickle.load(f)
|
||||
|
||||
# 兼容新旧数据结构
|
||||
if 'chunks' in embedding_data:
|
||||
chunks = embedding_data['chunks']
|
||||
embeddings = embedding_data['embeddings']
|
||||
chunking_strategy = embedding_data.get('chunking_strategy', 'unknown')
|
||||
chunking_params = embedding_data.get('chunking_params', {})
|
||||
model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
|
||||
else:
|
||||
chunks = embedding_data['sentences']
|
||||
embeddings = embedding_data['embeddings']
|
||||
chunking_strategy = 'line'
|
||||
chunking_params = {}
|
||||
model_path = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
|
||||
|
||||
# 确保 embeddings 是 numpy 数组
|
||||
if hasattr(embeddings, 'cpu'):
|
||||
embeddings = embeddings.cpu().numpy()
|
||||
elif hasattr(embeddings, 'numpy'):
|
||||
embeddings = embeddings.numpy()
|
||||
elif not isinstance(embeddings, np.ndarray):
|
||||
embeddings = np.array(embeddings)
|
||||
|
||||
# 创建缓存条目
|
||||
load_time = time.time() - start_time
|
||||
file_mtime = os.path.getmtime(file_path)
|
||||
memory_size = self._estimate_memory_size(embeddings, chunks)
|
||||
|
||||
entry = CacheEntry(
|
||||
embeddings=embeddings,
|
||||
chunks=chunks,
|
||||
chunking_strategy=chunking_strategy,
|
||||
chunking_params=chunking_params,
|
||||
model_path=model_path,
|
||||
file_path=file_path,
|
||||
file_mtime=file_mtime,
|
||||
access_count=1,
|
||||
last_access_time=time.time(),
|
||||
load_time=load_time,
|
||||
memory_size=memory_size
|
||||
)
|
||||
|
||||
# 添加到缓存
|
||||
with self._lock:
|
||||
self._cache[cache_key] = entry
|
||||
self._current_memory_usage += memory_size
|
||||
|
||||
# 清理缓存
|
||||
self._cleanup_cache()
|
||||
|
||||
logger.info(f"加载完成: {file_path} ({memory_size / 1024 / 1024:.1f}MB, {load_time:.2f}s)")
|
||||
return entry
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载 embedding 数据失败: {file_path}, {e}")
|
||||
return None
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
with self._lock:
|
||||
return {
|
||||
"cache_size": len(self._cache),
|
||||
"max_cache_size": self.max_cache_size,
|
||||
"memory_usage_mb": self._current_memory_usage / 1024 / 1024,
|
||||
"max_memory_mb": self.max_memory_bytes / 1024 / 1024,
|
||||
"memory_usage_percent": (self._current_memory_usage / self.max_memory_bytes) * 100,
|
||||
"entries": [
|
||||
{
|
||||
"file_path": entry.file_path,
|
||||
"access_count": entry.access_count,
|
||||
"last_access_time": entry.last_access_time,
|
||||
"memory_size_mb": entry.memory_size / 1024 / 1024
|
||||
}
|
||||
for entry in self._cache.values()
|
||||
]
|
||||
}
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
with self._lock:
|
||||
cleared_count = len(self._cache)
|
||||
cleared_memory = self._current_memory_usage
|
||||
self._cache.clear()
|
||||
self._current_memory_usage = 0
|
||||
logger.info(f"清空缓存: {cleared_count} 个条目, {cleared_memory / 1024 / 1024:.1f}MB")
|
||||
|
||||
|
||||
class GlobalModelManager:
|
||||
"""全局模型管理器"""
|
||||
|
||||
@ -312,17 +115,8 @@ class GlobalModelManager:
|
||||
|
||||
|
||||
# 全局实例
|
||||
_cache_manager = None
|
||||
_model_manager = None
|
||||
|
||||
def get_cache_manager() -> EmbeddingCacheManager:
|
||||
"""获取缓存管理器实例"""
|
||||
global _cache_manager
|
||||
if _cache_manager is None:
|
||||
max_cache_size = int(os.getenv("EMBEDDING_MAX_CACHE_SIZE", "5"))
|
||||
max_memory_mb = int(os.getenv("EMBEDDING_MAX_MEMORY_MB", "1024"))
|
||||
_cache_manager = EmbeddingCacheManager(max_cache_size, max_memory_mb)
|
||||
return _cache_manager
|
||||
|
||||
def get_model_manager() -> GlobalModelManager:
|
||||
"""获取模型管理器实例"""
|
||||
|
||||
@ -1,181 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
模型服务客户端
|
||||
提供与 model_server.py 交互的客户端接口
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Union, Dict, Any
|
||||
import numpy as np
|
||||
import aiohttp
|
||||
import json
|
||||
from sentence_transformers import util
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelClient:
|
||||
"""模型服务客户端"""
|
||||
|
||||
def __init__(self, base_url: str = "http://127.0.0.1:8000"):
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.session = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=300) # 5分钟超时
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""检查服务健康状态"""
|
||||
try:
|
||||
async with self.session.get(f"{self.base_url}/health") as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查失败: {e}")
|
||||
return {"status": "unavailable", "error": str(e)}
|
||||
|
||||
async def load_model(self) -> bool:
|
||||
"""预加载模型"""
|
||||
try:
|
||||
async with self.session.post(f"{self.base_url}/load_model") as response:
|
||||
result = await response.json()
|
||||
return response.status == 200
|
||||
except Exception as e:
|
||||
logger.error(f"模型加载失败: {e}")
|
||||
return False
|
||||
|
||||
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
||||
"""编码文本为向量"""
|
||||
if not texts:
|
||||
return np.array([])
|
||||
|
||||
try:
|
||||
data = {
|
||||
"texts": texts,
|
||||
"batch_size": batch_size
|
||||
}
|
||||
|
||||
async with self.session.post(
|
||||
f"{self.base_url}/encode",
|
||||
json=data
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"编码失败: {error_text}")
|
||||
|
||||
result = await response.json()
|
||||
return np.array(result["embeddings"])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文本编码失败: {e}")
|
||||
raise
|
||||
|
||||
async def compute_similarity(self, text1: str, text2: str) -> float:
|
||||
"""计算两个文本的相似度"""
|
||||
try:
|
||||
data = {
|
||||
"text1": text1,
|
||||
"text2": text2
|
||||
}
|
||||
|
||||
async with self.session.post(
|
||||
f"{self.base_url}/similarity",
|
||||
json=data
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"相似度计算失败: {error_text}")
|
||||
|
||||
result = await response.json()
|
||||
return result["similarity_score"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"相似度计算失败: {e}")
|
||||
raise
|
||||
|
||||
# 同步版本的客户端包装器
|
||||
class SyncModelClient:
|
||||
"""同步模型客户端"""
|
||||
|
||||
def __init__(self, base_url: str = "http://127.0.0.1:8000"):
|
||||
self.base_url = base_url
|
||||
self._async_client = ModelClient(base_url)
|
||||
|
||||
def _run_async(self, coro):
|
||||
"""运行异步协程"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop.run_until_complete(coro)
|
||||
|
||||
def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
||||
"""同步版本的文本编码"""
|
||||
async def _encode():
|
||||
async with self._async_client as client:
|
||||
return await client.encode_texts(texts, batch_size)
|
||||
|
||||
return self._run_async(_encode())
|
||||
|
||||
def compute_similarity(self, text1: str, text2: str) -> float:
|
||||
"""同步版本的相似度计算"""
|
||||
async def _similarity():
|
||||
async with self._async_client as client:
|
||||
return await client.compute_similarity(text1, text2)
|
||||
|
||||
return self._run_async(_similarity())
|
||||
|
||||
def health_check(self) -> Dict[str, Any]:
|
||||
"""同步版本的健康检查"""
|
||||
async def _health():
|
||||
async with self._async_client as client:
|
||||
return await client.health_check()
|
||||
|
||||
return self._run_async(_health())
|
||||
|
||||
def load_model(self) -> bool:
|
||||
"""同步版本的模型加载"""
|
||||
async def _load():
|
||||
async with self._async_client as client:
|
||||
return await client.load_model()
|
||||
|
||||
return self._run_async(_load())
|
||||
|
||||
# 全局客户端实例(单例模式)
|
||||
_client_instance = None
|
||||
|
||||
def get_model_client(base_url: str = None) -> SyncModelClient:
|
||||
"""获取模型客户端实例(单例)"""
|
||||
global _client_instance
|
||||
|
||||
if _client_instance is None:
|
||||
url = base_url or os.environ.get('MODEL_SERVER_URL', 'http://127.0.0.1:8000')
|
||||
_client_instance = SyncModelClient(url)
|
||||
|
||||
return _client_instance
|
||||
|
||||
# 便捷函数
|
||||
def encode_texts(texts: List[str], batch_size: int = 32) -> np.ndarray:
|
||||
"""便捷的文本编码函数"""
|
||||
client = get_model_client()
|
||||
return client.encode_texts(texts, batch_size)
|
||||
|
||||
def compute_similarity(text1: str, text2: str) -> float:
|
||||
"""便捷的相似度计算函数"""
|
||||
client = get_model_client()
|
||||
return client.compute_similarity(text1, text2)
|
||||
|
||||
def is_service_available() -> bool:
|
||||
"""检查模型服务是否可用"""
|
||||
client = get_model_client()
|
||||
health = client.health_check()
|
||||
return health.get("status") == "running" or health.get("status") == "ready"
|
||||
@ -1,240 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
独立的模型服务器
|
||||
提供统一的 embedding 服务,减少内存占用
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Dict, Any
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局模型实例
|
||||
model = None
|
||||
model_config = {
|
||||
"model_name": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
||||
"local_model_path": "./models/paraphrase-multilingual-MiniLM-L12-v2",
|
||||
"device": "cpu"
|
||||
}
|
||||
|
||||
# Pydantic 模型
|
||||
class TextRequest(BaseModel):
|
||||
texts: List[str]
|
||||
batch_size: int = 32
|
||||
|
||||
class SimilarityRequest(BaseModel):
|
||||
text1: str
|
||||
text2: str
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
model_loaded: bool
|
||||
model_path: str
|
||||
device: str
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
embeddings: List[List[float]]
|
||||
shape: List[int]
|
||||
|
||||
class SimilarityResponse(BaseModel):
|
||||
similarity_score: float
|
||||
|
||||
class ModelServer:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.model_loaded = False
|
||||
|
||||
async def load_model(self):
|
||||
"""延迟加载模型"""
|
||||
if self.model_loaded:
|
||||
return
|
||||
|
||||
logger.info("正在加载模型...")
|
||||
|
||||
# 检查本地模型是否存在
|
||||
if os.path.exists(model_config["local_model_path"]):
|
||||
model_path = model_config["local_model_path"]
|
||||
logger.info(f"使用本地模型: {model_path}")
|
||||
else:
|
||||
model_path = model_config["model_name"]
|
||||
logger.info(f"使用 HuggingFace 模型: {model_path}")
|
||||
|
||||
# 从环境变量获取设备配置
|
||||
device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', model_config["device"])
|
||||
if device not in ['cpu', 'cuda', 'mps']:
|
||||
logger.warning(f"不支持的设备类型 '{device}',使用默认 CPU")
|
||||
device = 'cpu'
|
||||
|
||||
logger.info(f"使用设备: {device}")
|
||||
|
||||
try:
|
||||
self.model = SentenceTransformer(model_path, device=device)
|
||||
self.model_loaded = True
|
||||
model_config["device"] = device
|
||||
model_config["current_model_path"] = model_path
|
||||
logger.info("模型加载完成")
|
||||
except Exception as e:
|
||||
logger.error(f"模型加载失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"模型加载失败: {str(e)}")
|
||||
|
||||
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
||||
"""编码文本为向量"""
|
||||
if not texts:
|
||||
return np.array([])
|
||||
|
||||
await self.load_model()
|
||||
|
||||
try:
|
||||
embeddings = self.model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
convert_to_tensor=True,
|
||||
show_progress_bar=False
|
||||
)
|
||||
|
||||
# 转换为 CPU numpy 数组
|
||||
if embeddings.is_cuda:
|
||||
embeddings = embeddings.cpu().numpy()
|
||||
else:
|
||||
embeddings = embeddings.numpy()
|
||||
|
||||
return embeddings
|
||||
except Exception as e:
|
||||
logger.error(f"编码失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"编码失败: {str(e)}")
|
||||
|
||||
async def compute_similarity(self, text1: str, text2: str) -> float:
|
||||
"""计算两个文本的相似度"""
|
||||
if not text1 or not text2:
|
||||
raise HTTPException(status_code=400, detail="文本不能为空")
|
||||
|
||||
await self.load_model()
|
||||
|
||||
try:
|
||||
embeddings = self.model.encode([text1, text2], convert_to_tensor=True)
|
||||
from sentence_transformers import util
|
||||
similarity = util.cos_sim(embeddings[0:1], embeddings[1:2])[0][0]
|
||||
return similarity.item()
|
||||
except Exception as e:
|
||||
logger.error(f"相似度计算失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"相似度计算失败: {str(e)}")
|
||||
|
||||
# 创建服务器实例
|
||||
server = ModelServer()
|
||||
|
||||
# 创建 FastAPI 应用
|
||||
app = FastAPI(
|
||||
title="Embedding Model Server",
|
||||
description="统一的句子嵌入和相似度计算服务",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# 添加 CORS 中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 健康检查
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
try:
|
||||
if not server.model_loaded:
|
||||
return HealthResponse(
|
||||
status="ready",
|
||||
model_loaded=False,
|
||||
model_path="",
|
||||
device=model_config["device"]
|
||||
)
|
||||
|
||||
return HealthResponse(
|
||||
status="running",
|
||||
model_loaded=True,
|
||||
model_path=model_config.get("current_model_path", ""),
|
||||
device=model_config["device"]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# 预加载模型
|
||||
@app.post("/load_model")
|
||||
async def preload_model():
|
||||
"""预加载模型"""
|
||||
await server.load_model()
|
||||
return {"message": "模型加载完成"}
|
||||
|
||||
# 文本嵌入
|
||||
@app.post("/encode", response_model=EmbeddingResponse)
|
||||
async def encode_texts(request: TextRequest):
|
||||
"""编码文本为向量"""
|
||||
try:
|
||||
embeddings = await server.encode_texts(request.texts, request.batch_size)
|
||||
return EmbeddingResponse(
|
||||
embeddings=embeddings.tolist(),
|
||||
shape=list(embeddings.shape)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# 相似度计算
|
||||
@app.post("/similarity", response_model=SimilarityResponse)
|
||||
async def compute_similarity(request: SimilarityRequest):
|
||||
"""计算两个文本的相似度"""
|
||||
try:
|
||||
similarity = await server.compute_similarity(request.text1, request.text2)
|
||||
return SimilarityResponse(similarity_score=similarity)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# 批量编码(用于文档处理)
|
||||
@app.post("/batch_encode")
|
||||
async def batch_encode(texts: List[str], batch_size: int = 64):
|
||||
"""批量编码大量文本"""
|
||||
if not texts:
|
||||
return {"embeddings": [], "shape": [0, 0]}
|
||||
|
||||
try:
|
||||
embeddings = await server.encode_texts(texts, batch_size)
|
||||
return {
|
||||
"embeddings": embeddings.tolist(),
|
||||
"shape": list(embeddings.shape)
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="启动嵌入模型服务器")
|
||||
parser.add_argument("--host", default="127.0.0.1", help="服务器地址")
|
||||
parser.add_argument("--port", type=int, default=8000, help="服务器端口")
|
||||
parser.add_argument("--workers", type=int, default=1, help="工作进程数")
|
||||
parser.add_argument("--reload", action="store_true", help="开发模式自动重载")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"启动模型服务器: http://{args.host}:{args.port}")
|
||||
|
||||
uvicorn.run(
|
||||
"model_server:app",
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
workers=args.workers,
|
||||
reload=args.reload,
|
||||
log_level="info"
|
||||
)
|
||||
@ -1,259 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
语义检索服务
|
||||
支持高并发的语义搜索功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
from .manager import get_cache_manager, get_model_manager, CacheEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SemanticSearchService:
|
||||
"""语义检索服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache_manager = get_cache_manager()
|
||||
self.model_manager = get_model_manager()
|
||||
|
||||
async def semantic_search(
|
||||
self,
|
||||
embedding_file: str,
|
||||
query: str,
|
||||
top_k: int = 20,
|
||||
min_score: float = 0.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行语义搜索
|
||||
|
||||
Args:
|
||||
embedding_file: embedding.pkl 文件路径
|
||||
query: 查询关键词
|
||||
top_k: 返回结果数量
|
||||
min_score: 最小相似度阈值
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 验证输入参数
|
||||
if not embedding_file:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "embedding_file 参数不能为空"
|
||||
}
|
||||
|
||||
if not query or not query.strip():
|
||||
return {
|
||||
"success": False,
|
||||
"error": "query 参数不能为空"
|
||||
}
|
||||
|
||||
query = query.strip()
|
||||
|
||||
# 加载 embedding 数据
|
||||
cache_entry = await self.cache_manager.load_embedding_data(embedding_file)
|
||||
if cache_entry is None:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"无法加载 embedding 文件: {embedding_file}"
|
||||
}
|
||||
|
||||
# 编码查询
|
||||
query_embeddings = await self.model_manager.encode_texts([query], batch_size=1)
|
||||
if len(query_embeddings) == 0:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "查询编码失败"
|
||||
}
|
||||
|
||||
query_embedding = query_embeddings[0]
|
||||
|
||||
# 计算相似度
|
||||
similarities = self._compute_similarities(query_embedding, cache_entry.embeddings)
|
||||
|
||||
# 获取 top_k 结果
|
||||
top_results = self._get_top_results(
|
||||
similarities,
|
||||
cache_entry.chunks,
|
||||
top_k,
|
||||
min_score
|
||||
)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"query": query,
|
||||
"embedding_file": embedding_file,
|
||||
"top_k": top_k,
|
||||
"processing_time": processing_time,
|
||||
"chunking_strategy": cache_entry.chunking_strategy,
|
||||
"total_chunks": len(cache_entry.chunks),
|
||||
"results": top_results,
|
||||
"cache_stats": {
|
||||
"access_count": cache_entry.access_count,
|
||||
"load_time": cache_entry.load_time
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"语义搜索失败: {embedding_file}, {query}, {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"搜索失败: {str(e)}"
|
||||
}
|
||||
|
||||
def _compute_similarities(self, query_embedding: np.ndarray, embeddings: np.ndarray) -> np.ndarray:
|
||||
"""计算相似度分数"""
|
||||
try:
|
||||
# 确保数据类型和形状正确
|
||||
if embeddings.ndim == 1:
|
||||
embeddings = embeddings.reshape(1, -1)
|
||||
|
||||
# 使用余弦相似度
|
||||
query_norm = np.linalg.norm(query_embedding)
|
||||
embeddings_norm = np.linalg.norm(embeddings, axis=1)
|
||||
|
||||
# 避免除零错误
|
||||
if query_norm == 0:
|
||||
query_norm = 1e-8
|
||||
embeddings_norm = np.where(embeddings_norm == 0, 1e-8, embeddings_norm)
|
||||
|
||||
# 计算余弦相似度
|
||||
similarities = np.dot(embeddings, query_embedding) / (embeddings_norm * query_norm)
|
||||
|
||||
# 确保结果在合理范围内
|
||||
similarities = np.clip(similarities, -1.0, 1.0)
|
||||
|
||||
return similarities
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"相似度计算失败: {e}")
|
||||
return np.array([])
|
||||
|
||||
def _get_top_results(
|
||||
self,
|
||||
similarities: np.ndarray,
|
||||
chunks: List[str],
|
||||
top_k: int,
|
||||
min_score: float
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取 top_k 结果"""
|
||||
try:
|
||||
if len(similarities) == 0:
|
||||
return []
|
||||
|
||||
# 获取排序后的索引
|
||||
top_indices = np.argsort(-similarities)[:top_k]
|
||||
|
||||
results = []
|
||||
for rank, idx in enumerate(top_indices):
|
||||
score = similarities[idx]
|
||||
|
||||
# 过滤低分结果
|
||||
if score < min_score:
|
||||
continue
|
||||
|
||||
chunk = chunks[idx]
|
||||
|
||||
# 创建结果条目
|
||||
result = {
|
||||
"rank": rank + 1,
|
||||
"score": float(score),
|
||||
"content": chunk,
|
||||
"content_preview": self._get_preview(chunk)
|
||||
}
|
||||
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取 top_k 结果失败: {e}")
|
||||
return []
|
||||
|
||||
def _get_preview(self, content: str, max_length: int = 200) -> str:
|
||||
"""获取内容预览"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# 清理内容
|
||||
preview = content.replace('\n', ' ').strip()
|
||||
|
||||
# 截断
|
||||
if len(preview) > max_length:
|
||||
preview = preview[:max_length] + "..."
|
||||
|
||||
return preview
|
||||
|
||||
async def batch_search(
|
||||
self,
|
||||
requests: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量搜索
|
||||
|
||||
Args:
|
||||
requests: 搜索请求列表,每个请求包含 embedding_file, query, top_k
|
||||
|
||||
Returns:
|
||||
搜索结果列表
|
||||
"""
|
||||
if not requests:
|
||||
return []
|
||||
|
||||
# 并发执行搜索
|
||||
tasks = []
|
||||
for req in requests:
|
||||
embedding_file = req.get("embedding_file", "")
|
||||
query = req.get("query", "")
|
||||
top_k = req.get("top_k", 20)
|
||||
min_score = req.get("min_score", 0.0)
|
||||
|
||||
task = self.semantic_search(embedding_file, query, top_k, min_score)
|
||||
tasks.append(task)
|
||||
|
||||
# 等待所有任务完成
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理异常
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
processed_results.append({
|
||||
"success": False,
|
||||
"error": f"搜索异常: {str(result)}",
|
||||
"request_index": i
|
||||
})
|
||||
else:
|
||||
result["request_index"] = i
|
||||
processed_results.append(result)
|
||||
|
||||
return processed_results
|
||||
|
||||
def get_service_stats(self) -> Dict[str, Any]:
|
||||
"""获取服务统计信息"""
|
||||
return {
|
||||
"cache_manager_stats": self.cache_manager.get_cache_stats(),
|
||||
"model_manager_info": self.model_manager.get_model_info()
|
||||
}
|
||||
|
||||
|
||||
# 全局实例
|
||||
_search_service = None
|
||||
|
||||
def get_search_service() -> SemanticSearchService:
|
||||
"""获取搜索服务实例"""
|
||||
global _search_service
|
||||
if _search_service is None:
|
||||
_search_service = SemanticSearchService()
|
||||
return _search_service
|
||||
@ -24,7 +24,7 @@ from qwen_agent.llm.schema import ASSISTANT, FUNCTION
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# 导入语义检索服务
|
||||
from embedding import get_search_service
|
||||
from embedding import get_model_manager
|
||||
|
||||
# Import utility modules
|
||||
from utils import (
|
||||
@ -1572,7 +1572,7 @@ async def encode_texts(request: EncodeRequest):
|
||||
编码结果
|
||||
"""
|
||||
try:
|
||||
search_service = get_search_service()
|
||||
model_manager = get_model_manager()
|
||||
|
||||
if not request.texts:
|
||||
return EncodeResponse(
|
||||
@ -1587,7 +1587,7 @@ async def encode_texts(request: EncodeRequest):
|
||||
start_time = time.time()
|
||||
|
||||
# 使用模型管理器编码文本
|
||||
embeddings = await search_service.model_manager.encode_texts(
|
||||
embeddings = await model_manager.encode_texts(
|
||||
request.texts,
|
||||
batch_size=request.batch_size
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user