181 lines
5.9 KiB
Python
181 lines
5.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
模型服务客户端
|
|
提供与 model_server.py 交互的客户端接口
|
|
"""
|
|
|
|
import os
|
|
import asyncio
|
|
import logging
|
|
from typing import List, Union, Dict, Any
|
|
import numpy as np
|
|
import aiohttp
|
|
import json
|
|
from sentence_transformers import util
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ModelClient:
|
|
"""模型服务客户端"""
|
|
|
|
def __init__(self, base_url: str = "http://127.0.0.1:8000"):
|
|
self.base_url = base_url.rstrip('/')
|
|
self.session = None
|
|
|
|
async def __aenter__(self):
|
|
self.session = aiohttp.ClientSession(
|
|
timeout=aiohttp.ClientTimeout(total=300) # 5分钟超时
|
|
)
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
if self.session:
|
|
await self.session.close()
|
|
|
|
async def health_check(self) -> Dict[str, Any]:
|
|
"""检查服务健康状态"""
|
|
try:
|
|
async with self.session.get(f"{self.base_url}/health") as response:
|
|
return await response.json()
|
|
except Exception as e:
|
|
logger.error(f"健康检查失败: {e}")
|
|
return {"status": "unavailable", "error": str(e)}
|
|
|
|
async def load_model(self) -> bool:
|
|
"""预加载模型"""
|
|
try:
|
|
async with self.session.post(f"{self.base_url}/load_model") as response:
|
|
result = await response.json()
|
|
return response.status == 200
|
|
except Exception as e:
|
|
logger.error(f"模型加载失败: {e}")
|
|
return False
|
|
|
|
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
|
"""编码文本为向量"""
|
|
if not texts:
|
|
return np.array([])
|
|
|
|
try:
|
|
data = {
|
|
"texts": texts,
|
|
"batch_size": batch_size
|
|
}
|
|
|
|
async with self.session.post(
|
|
f"{self.base_url}/encode",
|
|
json=data
|
|
) as response:
|
|
if response.status != 200:
|
|
error_text = await response.text()
|
|
raise Exception(f"编码失败: {error_text}")
|
|
|
|
result = await response.json()
|
|
return np.array(result["embeddings"])
|
|
|
|
except Exception as e:
|
|
logger.error(f"文本编码失败: {e}")
|
|
raise
|
|
|
|
async def compute_similarity(self, text1: str, text2: str) -> float:
|
|
"""计算两个文本的相似度"""
|
|
try:
|
|
data = {
|
|
"text1": text1,
|
|
"text2": text2
|
|
}
|
|
|
|
async with self.session.post(
|
|
f"{self.base_url}/similarity",
|
|
json=data
|
|
) as response:
|
|
if response.status != 200:
|
|
error_text = await response.text()
|
|
raise Exception(f"相似度计算失败: {error_text}")
|
|
|
|
result = await response.json()
|
|
return result["similarity_score"]
|
|
|
|
except Exception as e:
|
|
logger.error(f"相似度计算失败: {e}")
|
|
raise
|
|
|
|
# 同步版本的客户端包装器
|
|
class SyncModelClient:
|
|
"""同步模型客户端"""
|
|
|
|
def __init__(self, base_url: str = "http://127.0.0.1:8000"):
|
|
self.base_url = base_url
|
|
self._async_client = ModelClient(base_url)
|
|
|
|
def _run_async(self, coro):
|
|
"""运行异步协程"""
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
except RuntimeError:
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
return loop.run_until_complete(coro)
|
|
|
|
def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
|
"""同步版本的文本编码"""
|
|
async def _encode():
|
|
async with self._async_client as client:
|
|
return await client.encode_texts(texts, batch_size)
|
|
|
|
return self._run_async(_encode())
|
|
|
|
def compute_similarity(self, text1: str, text2: str) -> float:
|
|
"""同步版本的相似度计算"""
|
|
async def _similarity():
|
|
async with self._async_client as client:
|
|
return await client.compute_similarity(text1, text2)
|
|
|
|
return self._run_async(_similarity())
|
|
|
|
def health_check(self) -> Dict[str, Any]:
|
|
"""同步版本的健康检查"""
|
|
async def _health():
|
|
async with self._async_client as client:
|
|
return await client.health_check()
|
|
|
|
return self._run_async(_health())
|
|
|
|
def load_model(self) -> bool:
|
|
"""同步版本的模型加载"""
|
|
async def _load():
|
|
async with self._async_client as client:
|
|
return await client.load_model()
|
|
|
|
return self._run_async(_load())
|
|
|
|
# 全局客户端实例(单例模式)
|
|
_client_instance = None
|
|
|
|
def get_model_client(base_url: str = None) -> SyncModelClient:
|
|
"""获取模型客户端实例(单例)"""
|
|
global _client_instance
|
|
|
|
if _client_instance is None:
|
|
url = base_url or os.environ.get('MODEL_SERVER_URL', 'http://127.0.0.1:8000')
|
|
_client_instance = SyncModelClient(url)
|
|
|
|
return _client_instance
|
|
|
|
# 便捷函数
|
|
def encode_texts(texts: List[str], batch_size: int = 32) -> np.ndarray:
|
|
"""便捷的文本编码函数"""
|
|
client = get_model_client()
|
|
return client.encode_texts(texts, batch_size)
|
|
|
|
def compute_similarity(text1: str, text2: str) -> float:
|
|
"""便捷的相似度计算函数"""
|
|
client = get_model_client()
|
|
return client.compute_similarity(text1, text2)
|
|
|
|
def is_service_available() -> bool:
|
|
"""检查模型服务是否可用"""
|
|
client = get_model_client()
|
|
health = client.health_check()
|
|
return health.get("status") == "running" or health.get("status") == "ready" |