240 lines
7.5 KiB
Python
240 lines
7.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
独立的模型服务器
|
|
提供统一的 embedding 服务,减少内存占用
|
|
"""
|
|
|
|
import os
|
|
import asyncio
|
|
import logging
|
|
from typing import List, Dict, Any
|
|
import numpy as np
|
|
from sentence_transformers import SentenceTransformer
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import uvicorn
|
|
from pydantic import BaseModel
|
|
|
|
# 配置日志
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 全局模型实例
|
|
model = None
|
|
model_config = {
|
|
"model_name": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
|
"local_model_path": "./models/paraphrase-multilingual-MiniLM-L12-v2",
|
|
"device": "cpu"
|
|
}
|
|
|
|
# Pydantic 模型
|
|
class TextRequest(BaseModel):
|
|
texts: List[str]
|
|
batch_size: int = 32
|
|
|
|
class SimilarityRequest(BaseModel):
|
|
text1: str
|
|
text2: str
|
|
|
|
class HealthResponse(BaseModel):
|
|
status: str
|
|
model_loaded: bool
|
|
model_path: str
|
|
device: str
|
|
|
|
class EmbeddingResponse(BaseModel):
|
|
embeddings: List[List[float]]
|
|
shape: List[int]
|
|
|
|
class SimilarityResponse(BaseModel):
|
|
similarity_score: float
|
|
|
|
class ModelServer:
|
|
def __init__(self):
|
|
self.model = None
|
|
self.model_loaded = False
|
|
|
|
async def load_model(self):
|
|
"""延迟加载模型"""
|
|
if self.model_loaded:
|
|
return
|
|
|
|
logger.info("正在加载模型...")
|
|
|
|
# 检查本地模型是否存在
|
|
if os.path.exists(model_config["local_model_path"]):
|
|
model_path = model_config["local_model_path"]
|
|
logger.info(f"使用本地模型: {model_path}")
|
|
else:
|
|
model_path = model_config["model_name"]
|
|
logger.info(f"使用 HuggingFace 模型: {model_path}")
|
|
|
|
# 从环境变量获取设备配置
|
|
device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', model_config["device"])
|
|
if device not in ['cpu', 'cuda', 'mps']:
|
|
logger.warning(f"不支持的设备类型 '{device}',使用默认 CPU")
|
|
device = 'cpu'
|
|
|
|
logger.info(f"使用设备: {device}")
|
|
|
|
try:
|
|
self.model = SentenceTransformer(model_path, device=device)
|
|
self.model_loaded = True
|
|
model_config["device"] = device
|
|
model_config["current_model_path"] = model_path
|
|
logger.info("模型加载完成")
|
|
except Exception as e:
|
|
logger.error(f"模型加载失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"模型加载失败: {str(e)}")
|
|
|
|
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
|
"""编码文本为向量"""
|
|
if not texts:
|
|
return np.array([])
|
|
|
|
await self.load_model()
|
|
|
|
try:
|
|
embeddings = self.model.encode(
|
|
texts,
|
|
batch_size=batch_size,
|
|
convert_to_tensor=True,
|
|
show_progress_bar=False
|
|
)
|
|
|
|
# 转换为 CPU numpy 数组
|
|
if embeddings.is_cuda:
|
|
embeddings = embeddings.cpu().numpy()
|
|
else:
|
|
embeddings = embeddings.numpy()
|
|
|
|
return embeddings
|
|
except Exception as e:
|
|
logger.error(f"编码失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"编码失败: {str(e)}")
|
|
|
|
async def compute_similarity(self, text1: str, text2: str) -> float:
|
|
"""计算两个文本的相似度"""
|
|
if not text1 or not text2:
|
|
raise HTTPException(status_code=400, detail="文本不能为空")
|
|
|
|
await self.load_model()
|
|
|
|
try:
|
|
embeddings = self.model.encode([text1, text2], convert_to_tensor=True)
|
|
from sentence_transformers import util
|
|
similarity = util.cos_sim(embeddings[0:1], embeddings[1:2])[0][0]
|
|
return similarity.item()
|
|
except Exception as e:
|
|
logger.error(f"相似度计算失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"相似度计算失败: {str(e)}")
|
|
|
|
# 创建服务器实例
|
|
server = ModelServer()
|
|
|
|
# 创建 FastAPI 应用
|
|
app = FastAPI(
|
|
title="Embedding Model Server",
|
|
description="统一的句子嵌入和相似度计算服务",
|
|
version="1.0.0"
|
|
)
|
|
|
|
# 添加 CORS 中间件
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# 健康检查
|
|
@app.get("/health", response_model=HealthResponse)
|
|
async def health_check():
|
|
"""健康检查"""
|
|
try:
|
|
if not server.model_loaded:
|
|
return HealthResponse(
|
|
status="ready",
|
|
model_loaded=False,
|
|
model_path="",
|
|
device=model_config["device"]
|
|
)
|
|
|
|
return HealthResponse(
|
|
status="running",
|
|
model_loaded=True,
|
|
model_path=model_config.get("current_model_path", ""),
|
|
device=model_config["device"]
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"健康检查失败: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
# 预加载模型
|
|
@app.post("/load_model")
|
|
async def preload_model():
|
|
"""预加载模型"""
|
|
await server.load_model()
|
|
return {"message": "模型加载完成"}
|
|
|
|
# 文本嵌入
|
|
@app.post("/encode", response_model=EmbeddingResponse)
|
|
async def encode_texts(request: TextRequest):
|
|
"""编码文本为向量"""
|
|
try:
|
|
embeddings = await server.encode_texts(request.texts, request.batch_size)
|
|
return EmbeddingResponse(
|
|
embeddings=embeddings.tolist(),
|
|
shape=list(embeddings.shape)
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
# 相似度计算
|
|
@app.post("/similarity", response_model=SimilarityResponse)
|
|
async def compute_similarity(request: SimilarityRequest):
|
|
"""计算两个文本的相似度"""
|
|
try:
|
|
similarity = await server.compute_similarity(request.text1, request.text2)
|
|
return SimilarityResponse(similarity_score=similarity)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
# 批量编码(用于文档处理)
|
|
@app.post("/batch_encode")
|
|
async def batch_encode(texts: List[str], batch_size: int = 64):
|
|
"""批量编码大量文本"""
|
|
if not texts:
|
|
return {"embeddings": [], "shape": [0, 0]}
|
|
|
|
try:
|
|
embeddings = await server.encode_texts(texts, batch_size)
|
|
return {
|
|
"embeddings": embeddings.tolist(),
|
|
"shape": list(embeddings.shape)
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="启动嵌入模型服务器")
|
|
parser.add_argument("--host", default="127.0.0.1", help="服务器地址")
|
|
parser.add_argument("--port", type=int, default=8000, help="服务器端口")
|
|
parser.add_argument("--workers", type=int, default=1, help="工作进程数")
|
|
parser.add_argument("--reload", action="store_true", help="开发模式自动重载")
|
|
|
|
args = parser.parse_args()
|
|
|
|
logger.info(f"启动模型服务器: http://{args.host}:{args.port}")
|
|
|
|
uvicorn.run(
|
|
"model_server:app",
|
|
host=args.host,
|
|
port=args.port,
|
|
workers=args.workers,
|
|
reload=args.reload,
|
|
log_level="info"
|
|
) |