#!/usr/bin/env python3 """ 模型服务客户端 提供与 model_server.py 交互的客户端接口 """ import os import asyncio import logging from typing import List, Union, Dict, Any import numpy as np import aiohttp import json from sentence_transformers import util logger = logging.getLogger(__name__) class ModelClient: """模型服务客户端""" def __init__(self, base_url: str = "http://127.0.0.1:8000"): self.base_url = base_url.rstrip('/') self.session = None async def __aenter__(self): self.session = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=300) # 5分钟超时 ) return self async def __aexit__(self, exc_type, exc_val, exc_tb): if self.session: await self.session.close() async def health_check(self) -> Dict[str, Any]: """检查服务健康状态""" try: async with self.session.get(f"{self.base_url}/health") as response: return await response.json() except Exception as e: logger.error(f"健康检查失败: {e}") return {"status": "unavailable", "error": str(e)} async def load_model(self) -> bool: """预加载模型""" try: async with self.session.post(f"{self.base_url}/load_model") as response: result = await response.json() return response.status == 200 except Exception as e: logger.error(f"模型加载失败: {e}") return False async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray: """编码文本为向量""" if not texts: return np.array([]) try: data = { "texts": texts, "batch_size": batch_size } async with self.session.post( f"{self.base_url}/encode", json=data ) as response: if response.status != 200: error_text = await response.text() raise Exception(f"编码失败: {error_text}") result = await response.json() return np.array(result["embeddings"]) except Exception as e: logger.error(f"文本编码失败: {e}") raise async def compute_similarity(self, text1: str, text2: str) -> float: """计算两个文本的相似度""" try: data = { "text1": text1, "text2": text2 } async with self.session.post( f"{self.base_url}/similarity", json=data ) as response: if response.status != 200: error_text = await response.text() raise Exception(f"相似度计算失败: {error_text}") result = await response.json() return result["similarity_score"] except Exception as e: logger.error(f"相似度计算失败: {e}") raise # 同步版本的客户端包装器 class SyncModelClient: """同步模型客户端""" def __init__(self, base_url: str = "http://127.0.0.1:8000"): self.base_url = base_url self._async_client = ModelClient(base_url) def _run_async(self, coro): """运行异步协程""" try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop.run_until_complete(coro) def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray: """同步版本的文本编码""" async def _encode(): async with self._async_client as client: return await client.encode_texts(texts, batch_size) return self._run_async(_encode()) def compute_similarity(self, text1: str, text2: str) -> float: """同步版本的相似度计算""" async def _similarity(): async with self._async_client as client: return await client.compute_similarity(text1, text2) return self._run_async(_similarity()) def health_check(self) -> Dict[str, Any]: """同步版本的健康检查""" async def _health(): async with self._async_client as client: return await client.health_check() return self._run_async(_health()) def load_model(self) -> bool: """同步版本的模型加载""" async def _load(): async with self._async_client as client: return await client.load_model() return self._run_async(_load()) # 全局客户端实例(单例模式) _client_instance = None def get_model_client(base_url: str = None) -> SyncModelClient: """获取模型客户端实例(单例)""" global _client_instance if _client_instance is None: url = base_url or os.environ.get('MODEL_SERVER_URL', 'http://127.0.0.1:8000') _client_instance = SyncModelClient(url) return _client_instance # 便捷函数 def encode_texts(texts: List[str], batch_size: int = 32) -> np.ndarray: """便捷的文本编码函数""" client = get_model_client() return client.encode_texts(texts, batch_size) def compute_similarity(text1: str, text2: str) -> float: """便捷的相似度计算函数""" client = get_model_client() return client.compute_similarity(text1, text2) def is_service_available() -> bool: """检查模型服务是否可用""" client = get_model_client() health = client.health_check() return health.get("status") == "running" or health.get("status") == "ready"