156 lines
4.2 KiB
Python
156 lines
4.2 KiB
Python
import os
|
|
import time
|
|
import multiprocessing
|
|
from typing import Optional
|
|
from fastapi import APIRouter, HTTPException
|
|
from pydantic import BaseModel
|
|
|
|
from utils import (
|
|
setup_system_optimizations
|
|
)
|
|
|
|
from embedding import get_model_manager
|
|
from pydantic import BaseModel
|
|
import logging
|
|
|
|
logger = logging.getLogger('app')
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class EncodeRequest(BaseModel):
|
|
texts: list[str]
|
|
batch_size: int = 32
|
|
|
|
|
|
class EncodeResponse(BaseModel):
|
|
success: bool
|
|
embeddings: list[list[float]]
|
|
shape: list[int]
|
|
processing_time: float
|
|
total_texts: int
|
|
error: Optional[str] = None
|
|
|
|
|
|
# 系统优化设置初始化
|
|
logger.info("正在初始化系统优化...")
|
|
system_optimizer = setup_system_optimizations()
|
|
|
|
|
|
logger.info("系统优化初始化完成")
|
|
|
|
|
|
@router.get("/api/health")
|
|
async def health_check():
|
|
"""Health check endpoint"""
|
|
return {"message": "Database Assistant API is running"}
|
|
|
|
|
|
@router.get("/api/v1/system/performance")
|
|
async def get_performance_stats():
|
|
"""获取系统性能统计信息"""
|
|
try:
|
|
# 获取连接池统计(简化版)
|
|
pool_stats = {
|
|
"connection_pool": "active",
|
|
"max_connections_per_host": 100,
|
|
"max_connections_total": 500,
|
|
"keepalive_timeout": 30
|
|
}
|
|
|
|
# 获取文件缓存统计
|
|
file_cache_stats = {
|
|
"cache_size": len(file_cache._cache) if hasattr(file_cache, '_cache') else 0,
|
|
"max_cache_size": file_cache.cache_size if hasattr(file_cache, 'cache_size') else 1000,
|
|
"ttl": file_cache.ttl if hasattr(file_cache, 'ttl') else 300
|
|
}
|
|
|
|
# 系统资源信息
|
|
try:
|
|
import psutil
|
|
system_stats = {
|
|
"cpu_count": multiprocessing.cpu_count(),
|
|
"memory_total_gb": round(psutil.virtual_memory().total / (1024**3), 2),
|
|
"memory_available_gb": round(psutil.virtual_memory().available / (1024**3), 2),
|
|
"memory_percent": psutil.virtual_memory().percent,
|
|
"disk_usage_percent": psutil.disk_usage('/').percent
|
|
}
|
|
except ImportError:
|
|
system_stats = {
|
|
"cpu_count": multiprocessing.cpu_count(),
|
|
"memory_info": "psutil not available"
|
|
}
|
|
|
|
return {
|
|
"success": True,
|
|
"timestamp": int(time.time()),
|
|
"performance": {
|
|
"connection_pool": pool_stats,
|
|
"file_cache": file_cache_stats,
|
|
"system": system_stats
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting performance stats: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"获取性能统计失败: {str(e)}")
|
|
|
|
|
|
|
|
@router.post("/api/v1/embedding/encode", response_model=EncodeResponse)
|
|
async def encode_texts(request: EncodeRequest):
|
|
"""
|
|
文本编码 API
|
|
|
|
Args:
|
|
request: 包含 texts 和 batch_size 的编码请求
|
|
|
|
Returns:
|
|
编码结果
|
|
"""
|
|
try:
|
|
model_manager = get_model_manager()
|
|
|
|
if not request.texts:
|
|
return EncodeResponse(
|
|
success=False,
|
|
embeddings=[],
|
|
shape=[0, 0],
|
|
processing_time=0.0,
|
|
total_texts=0,
|
|
error="texts 不能为空"
|
|
)
|
|
|
|
start_time = time.time()
|
|
|
|
# 使用模型管理器编码文本
|
|
embeddings = await model_manager.encode_texts(
|
|
request.texts,
|
|
batch_size=request.batch_size
|
|
)
|
|
|
|
processing_time = time.time() - start_time
|
|
|
|
# 转换为列表格式
|
|
embeddings_list = embeddings.tolist()
|
|
|
|
return EncodeResponse(
|
|
success=True,
|
|
embeddings=embeddings_list,
|
|
shape=list(embeddings.shape),
|
|
processing_time=processing_time,
|
|
total_texts=len(request.texts)
|
|
)
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"文本编码 API 错误: {e}")
|
|
return EncodeResponse(
|
|
success=False,
|
|
embeddings=[],
|
|
shape=[0, 0],
|
|
processing_time=0.0,
|
|
total_texts=len(request.texts) if request else 0,
|
|
error=f"编码失败: {str(e)}"
|
|
)
|