qwen_agent/embedding/model_client.py
2025-11-20 13:29:44 +08:00

181 lines
5.9 KiB
Python

#!/usr/bin/env python3
"""
模型服务客户端
提供与 model_server.py 交互的客户端接口
"""
import os
import asyncio
import logging
from typing import List, Union, Dict, Any
import numpy as np
import aiohttp
import json
from sentence_transformers import util
logger = logging.getLogger(__name__)
class ModelClient:
"""模型服务客户端"""
def __init__(self, base_url: str = "http://127.0.0.1:8000"):
self.base_url = base_url.rstrip('/')
self.session = None
async def __aenter__(self):
self.session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=300) # 5分钟超时
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.session:
await self.session.close()
async def health_check(self) -> Dict[str, Any]:
"""检查服务健康状态"""
try:
async with self.session.get(f"{self.base_url}/health") as response:
return await response.json()
except Exception as e:
logger.error(f"健康检查失败: {e}")
return {"status": "unavailable", "error": str(e)}
async def load_model(self) -> bool:
"""预加载模型"""
try:
async with self.session.post(f"{self.base_url}/load_model") as response:
result = await response.json()
return response.status == 200
except Exception as e:
logger.error(f"模型加载失败: {e}")
return False
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""编码文本为向量"""
if not texts:
return np.array([])
try:
data = {
"texts": texts,
"batch_size": batch_size
}
async with self.session.post(
f"{self.base_url}/encode",
json=data
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"编码失败: {error_text}")
result = await response.json()
return np.array(result["embeddings"])
except Exception as e:
logger.error(f"文本编码失败: {e}")
raise
async def compute_similarity(self, text1: str, text2: str) -> float:
"""计算两个文本的相似度"""
try:
data = {
"text1": text1,
"text2": text2
}
async with self.session.post(
f"{self.base_url}/similarity",
json=data
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"相似度计算失败: {error_text}")
result = await response.json()
return result["similarity_score"]
except Exception as e:
logger.error(f"相似度计算失败: {e}")
raise
# 同步版本的客户端包装器
class SyncModelClient:
"""同步模型客户端"""
def __init__(self, base_url: str = "http://127.0.0.1:8000"):
self.base_url = base_url
self._async_client = ModelClient(base_url)
def _run_async(self, coro):
"""运行异步协程"""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coro)
def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""同步版本的文本编码"""
async def _encode():
async with self._async_client as client:
return await client.encode_texts(texts, batch_size)
return self._run_async(_encode())
def compute_similarity(self, text1: str, text2: str) -> float:
"""同步版本的相似度计算"""
async def _similarity():
async with self._async_client as client:
return await client.compute_similarity(text1, text2)
return self._run_async(_similarity())
def health_check(self) -> Dict[str, Any]:
"""同步版本的健康检查"""
async def _health():
async with self._async_client as client:
return await client.health_check()
return self._run_async(_health())
def load_model(self) -> bool:
"""同步版本的模型加载"""
async def _load():
async with self._async_client as client:
return await client.load_model()
return self._run_async(_load())
# 全局客户端实例(单例模式)
_client_instance = None
def get_model_client(base_url: str = None) -> SyncModelClient:
"""获取模型客户端实例(单例)"""
global _client_instance
if _client_instance is None:
url = base_url or os.environ.get('MODEL_SERVER_URL', 'http://127.0.0.1:8000')
_client_instance = SyncModelClient(url)
return _client_instance
# 便捷函数
def encode_texts(texts: List[str], batch_size: int = 32) -> np.ndarray:
"""便捷的文本编码函数"""
client = get_model_client()
return client.encode_texts(texts, batch_size)
def compute_similarity(text1: str, text2: str) -> float:
"""便捷的相似度计算函数"""
client = get_model_client()
return client.compute_similarity(text1, text2)
def is_service_available() -> bool:
"""检查模型服务是否可用"""
client = get_model_client()
health = client.health_check()
return health.get("status") == "running" or health.get("status") == "ready"