qwen_agent/embedding/model_server.py
2025-11-20 13:29:44 +08:00

240 lines
7.5 KiB
Python

#!/usr/bin/env python3
"""
独立的模型服务器
提供统一的 embedding 服务,减少内存占用
"""
import os
import asyncio
import logging
from typing import List, Dict, Any
import numpy as np
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from pydantic import BaseModel
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 全局模型实例
model = None
model_config = {
"model_name": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
"local_model_path": "./models/paraphrase-multilingual-MiniLM-L12-v2",
"device": "cpu"
}
# Pydantic 模型
class TextRequest(BaseModel):
texts: List[str]
batch_size: int = 32
class SimilarityRequest(BaseModel):
text1: str
text2: str
class HealthResponse(BaseModel):
status: str
model_loaded: bool
model_path: str
device: str
class EmbeddingResponse(BaseModel):
embeddings: List[List[float]]
shape: List[int]
class SimilarityResponse(BaseModel):
similarity_score: float
class ModelServer:
def __init__(self):
self.model = None
self.model_loaded = False
async def load_model(self):
"""延迟加载模型"""
if self.model_loaded:
return
logger.info("正在加载模型...")
# 检查本地模型是否存在
if os.path.exists(model_config["local_model_path"]):
model_path = model_config["local_model_path"]
logger.info(f"使用本地模型: {model_path}")
else:
model_path = model_config["model_name"]
logger.info(f"使用 HuggingFace 模型: {model_path}")
# 从环境变量获取设备配置
device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', model_config["device"])
if device not in ['cpu', 'cuda', 'mps']:
logger.warning(f"不支持的设备类型 '{device}',使用默认 CPU")
device = 'cpu'
logger.info(f"使用设备: {device}")
try:
self.model = SentenceTransformer(model_path, device=device)
self.model_loaded = True
model_config["device"] = device
model_config["current_model_path"] = model_path
logger.info("模型加载完成")
except Exception as e:
logger.error(f"模型加载失败: {e}")
raise HTTPException(status_code=500, detail=f"模型加载失败: {str(e)}")
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""编码文本为向量"""
if not texts:
return np.array([])
await self.load_model()
try:
embeddings = self.model.encode(
texts,
batch_size=batch_size,
convert_to_tensor=True,
show_progress_bar=False
)
# 转换为 CPU numpy 数组
if embeddings.is_cuda:
embeddings = embeddings.cpu().numpy()
else:
embeddings = embeddings.numpy()
return embeddings
except Exception as e:
logger.error(f"编码失败: {e}")
raise HTTPException(status_code=500, detail=f"编码失败: {str(e)}")
async def compute_similarity(self, text1: str, text2: str) -> float:
"""计算两个文本的相似度"""
if not text1 or not text2:
raise HTTPException(status_code=400, detail="文本不能为空")
await self.load_model()
try:
embeddings = self.model.encode([text1, text2], convert_to_tensor=True)
from sentence_transformers import util
similarity = util.cos_sim(embeddings[0:1], embeddings[1:2])[0][0]
return similarity.item()
except Exception as e:
logger.error(f"相似度计算失败: {e}")
raise HTTPException(status_code=500, detail=f"相似度计算失败: {str(e)}")
# 创建服务器实例
server = ModelServer()
# 创建 FastAPI 应用
app = FastAPI(
title="Embedding Model Server",
description="统一的句子嵌入和相似度计算服务",
version="1.0.0"
)
# 添加 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 健康检查
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""健康检查"""
try:
if not server.model_loaded:
return HealthResponse(
status="ready",
model_loaded=False,
model_path="",
device=model_config["device"]
)
return HealthResponse(
status="running",
model_loaded=True,
model_path=model_config.get("current_model_path", ""),
device=model_config["device"]
)
except Exception as e:
logger.error(f"健康检查失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
# 预加载模型
@app.post("/load_model")
async def preload_model():
"""预加载模型"""
await server.load_model()
return {"message": "模型加载完成"}
# 文本嵌入
@app.post("/encode", response_model=EmbeddingResponse)
async def encode_texts(request: TextRequest):
"""编码文本为向量"""
try:
embeddings = await server.encode_texts(request.texts, request.batch_size)
return EmbeddingResponse(
embeddings=embeddings.tolist(),
shape=list(embeddings.shape)
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 相似度计算
@app.post("/similarity", response_model=SimilarityResponse)
async def compute_similarity(request: SimilarityRequest):
"""计算两个文本的相似度"""
try:
similarity = await server.compute_similarity(request.text1, request.text2)
return SimilarityResponse(similarity_score=similarity)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 批量编码(用于文档处理)
@app.post("/batch_encode")
async def batch_encode(texts: List[str], batch_size: int = 64):
"""批量编码大量文本"""
if not texts:
return {"embeddings": [], "shape": [0, 0]}
try:
embeddings = await server.encode_texts(texts, batch_size)
return {
"embeddings": embeddings.tolist(),
"shape": list(embeddings.shape)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="启动嵌入模型服务器")
parser.add_argument("--host", default="127.0.0.1", help="服务器地址")
parser.add_argument("--port", type=int, default=8000, help="服务器端口")
parser.add_argument("--workers", type=int, default=1, help="工作进程数")
parser.add_argument("--reload", action="store_true", help="开发模式自动重载")
args = parser.parse_args()
logger.info(f"启动模型服务器: http://{args.host}:{args.port}")
uvicorn.run(
"model_server:app",
host=args.host,
port=args.port,
workers=args.workers,
reload=args.reload,
log_level="info"
)