276 lines
9.0 KiB
Python
276 lines
9.0 KiB
Python
import os
|
||
import time
|
||
import multiprocessing
|
||
from typing import Optional
|
||
from fastapi import APIRouter, HTTPException
|
||
from pydantic import BaseModel
|
||
|
||
from utils import (
|
||
get_global_connection_pool, init_global_connection_pool,
|
||
get_global_file_cache, init_global_file_cache,
|
||
setup_system_optimizations
|
||
)
|
||
from agent.sharded_agent_manager import init_global_sharded_agent_manager
|
||
try:
|
||
from utils.system_optimizer import apply_optimization_profile
|
||
except ImportError:
|
||
def apply_optimization_profile(profile):
|
||
return {"profile": profile, "status": "system_optimizer not available"}
|
||
from utils.fastapi_utils import get_content_from_messages
|
||
from embedding import get_model_manager
|
||
from pydantic import BaseModel
|
||
import logging
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
class EncodeRequest(BaseModel):
|
||
texts: list[str]
|
||
batch_size: int = 32
|
||
|
||
|
||
class EncodeResponse(BaseModel):
|
||
success: bool
|
||
embeddings: list[list[float]]
|
||
shape: list[int]
|
||
processing_time: float
|
||
total_texts: int
|
||
error: Optional[str] = None
|
||
|
||
|
||
# 系统优化设置初始化
|
||
logger.info("正在初始化系统优化...")
|
||
system_optimizer = setup_system_optimizations()
|
||
|
||
# 全局助手管理器配置(使用优化后的配置)
|
||
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "50")) # 增加缓存大小
|
||
shard_count = int(os.getenv("SHARD_COUNT", "16")) # 分片数量
|
||
|
||
# 初始化优化的全局助手管理器
|
||
agent_manager = init_global_sharded_agent_manager(
|
||
max_cached_agents=max_cached_agents,
|
||
shard_count=shard_count
|
||
)
|
||
|
||
# 初始化连接池
|
||
connection_pool = init_global_connection_pool(
|
||
max_connections_per_host=int(os.getenv("MAX_CONNECTIONS_PER_HOST", "100")),
|
||
max_connections_total=int(os.getenv("MAX_CONNECTIONS_TOTAL", "500")),
|
||
keepalive_timeout=int(os.getenv("KEEPALIVE_TIMEOUT", "30")),
|
||
connect_timeout=int(os.getenv("CONNECT_TIMEOUT", "10")),
|
||
total_timeout=int(os.getenv("TOTAL_TIMEOUT", "60"))
|
||
)
|
||
|
||
# 初始化文件缓存
|
||
file_cache = init_global_file_cache(
|
||
cache_size=int(os.getenv("FILE_CACHE_SIZE", "1000")),
|
||
ttl=int(os.getenv("FILE_CACHE_TTL", "300"))
|
||
)
|
||
|
||
logger.info("系统优化初始化完成")
|
||
logger.info(f"- 分片Agent管理器: {shard_count} 个分片,最多缓存 {max_cached_agents} 个agent")
|
||
logger.info(f"- 连接池: 每主机100连接,总计500连接")
|
||
logger.info(f"- 文件缓存: 1000个文件,TTL 300秒")
|
||
|
||
|
||
@router.get("/api/health")
|
||
async def health_check():
|
||
"""Health check endpoint"""
|
||
return {"message": "Database Assistant API is running"}
|
||
|
||
|
||
@router.get("/api/v1/system/performance")
|
||
async def get_performance_stats():
|
||
"""获取系统性能统计信息"""
|
||
try:
|
||
# 获取agent管理器统计
|
||
agent_stats = agent_manager.get_cache_stats()
|
||
|
||
# 获取连接池统计(简化版)
|
||
pool_stats = {
|
||
"connection_pool": "active",
|
||
"max_connections_per_host": 100,
|
||
"max_connections_total": 500,
|
||
"keepalive_timeout": 30
|
||
}
|
||
|
||
# 获取文件缓存统计
|
||
file_cache_stats = {
|
||
"cache_size": len(file_cache._cache) if hasattr(file_cache, '_cache') else 0,
|
||
"max_cache_size": file_cache.cache_size if hasattr(file_cache, 'cache_size') else 1000,
|
||
"ttl": file_cache.ttl if hasattr(file_cache, 'ttl') else 300
|
||
}
|
||
|
||
# 系统资源信息
|
||
try:
|
||
import psutil
|
||
system_stats = {
|
||
"cpu_count": multiprocessing.cpu_count(),
|
||
"memory_total_gb": round(psutil.virtual_memory().total / (1024**3), 2),
|
||
"memory_available_gb": round(psutil.virtual_memory().available / (1024**3), 2),
|
||
"memory_percent": psutil.virtual_memory().percent,
|
||
"disk_usage_percent": psutil.disk_usage('/').percent
|
||
}
|
||
except ImportError:
|
||
system_stats = {
|
||
"cpu_count": multiprocessing.cpu_count(),
|
||
"memory_info": "psutil not available"
|
||
}
|
||
|
||
return {
|
||
"success": True,
|
||
"timestamp": int(time.time()),
|
||
"performance": {
|
||
"agent_manager": agent_stats,
|
||
"connection_pool": pool_stats,
|
||
"file_cache": file_cache_stats,
|
||
"system": system_stats
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting performance stats: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取性能统计失败: {str(e)}")
|
||
|
||
|
||
@router.post("/api/v1/system/optimize")
|
||
async def optimize_system(profile: str = "balanced"):
|
||
"""应用系统优化配置"""
|
||
try:
|
||
# 应用优化配置
|
||
config = apply_optimization_profile(profile)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"已应用 {profile} 优化配置",
|
||
"config": config
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error applying optimization profile: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"应用优化配置失败: {str(e)}")
|
||
|
||
|
||
@router.post("/api/v1/system/clear-cache")
|
||
async def clear_system_cache(cache_type: Optional[str] = None):
|
||
"""清理系统缓存"""
|
||
try:
|
||
cleared_counts = {}
|
||
|
||
if cache_type is None or cache_type == "agent":
|
||
# 清理agent缓存
|
||
agent_count = agent_manager.clear_cache()
|
||
cleared_counts["agent_cache"] = agent_count
|
||
|
||
if cache_type is None or cache_type == "file":
|
||
# 清理文件缓存
|
||
if hasattr(file_cache, '_cache'):
|
||
file_count = len(file_cache._cache)
|
||
file_cache._cache.clear()
|
||
cleared_counts["file_cache"] = file_count
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"已清理指定类型的缓存",
|
||
"cleared_counts": cleared_counts
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error clearing cache: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"清理缓存失败: {str(e)}")
|
||
|
||
|
||
@router.get("/api/v1/system/config")
|
||
async def get_system_config():
|
||
"""获取当前系统配置"""
|
||
try:
|
||
return {
|
||
"success": True,
|
||
"config": {
|
||
"max_cached_agents": max_cached_agents,
|
||
"shard_count": shard_count,
|
||
"tokenizer_parallelism": os.getenv("TOKENIZERS_PARALLELISM", "true"),
|
||
"max_connections_per_host": os.getenv("MAX_CONNECTIONS_PER_HOST", "100"),
|
||
"max_connections_total": os.getenv("MAX_CONNECTIONS_TOTAL", "500"),
|
||
"file_cache_size": os.getenv("FILE_CACHE_SIZE", "1000"),
|
||
"file_cache_ttl": os.getenv("FILE_CACHE_TTL", "300")
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting system config: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取系统配置失败: {str(e)}")
|
||
|
||
|
||
@router.post("/system/remove-project-cache")
|
||
async def remove_project_cache(dataset_id: str):
|
||
"""移除特定项目的缓存"""
|
||
try:
|
||
removed_count = agent_manager.remove_cache_by_unique_id(dataset_id)
|
||
if removed_count > 0:
|
||
return {"message": f"项目缓存移除成功: {dataset_id}", "removed_count": removed_count}
|
||
else:
|
||
return {"message": f"未找到项目缓存: {dataset_id}", "removed_count": 0}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}")
|
||
|
||
|
||
@router.post("/api/v1/embedding/encode", response_model=EncodeResponse)
|
||
async def encode_texts(request: EncodeRequest):
|
||
"""
|
||
文本编码 API
|
||
|
||
Args:
|
||
request: 包含 texts 和 batch_size 的编码请求
|
||
|
||
Returns:
|
||
编码结果
|
||
"""
|
||
try:
|
||
model_manager = get_model_manager()
|
||
|
||
if not request.texts:
|
||
return EncodeResponse(
|
||
success=False,
|
||
embeddings=[],
|
||
shape=[0, 0],
|
||
processing_time=0.0,
|
||
total_texts=0,
|
||
error="texts 不能为空"
|
||
)
|
||
|
||
start_time = time.time()
|
||
|
||
# 使用模型管理器编码文本
|
||
embeddings = await model_manager.encode_texts(
|
||
request.texts,
|
||
batch_size=request.batch_size
|
||
)
|
||
|
||
processing_time = time.time() - start_time
|
||
|
||
# 转换为列表格式
|
||
embeddings_list = embeddings.tolist()
|
||
|
||
return EncodeResponse(
|
||
success=True,
|
||
embeddings=embeddings_list,
|
||
shape=list(embeddings.shape),
|
||
processing_time=processing_time,
|
||
total_texts=len(request.texts)
|
||
)
|
||
|
||
except Exception as e:
|
||
|
||
logger.error(f"文本编码 API 错误: {e}")
|
||
return EncodeResponse(
|
||
success=False,
|
||
embeddings=[],
|
||
shape=[0, 0],
|
||
processing_time=0.0,
|
||
total_texts=len(request.texts) if request else 0,
|
||
error=f"编码失败: {str(e)}"
|
||
)
|