import os import time import multiprocessing from typing import Optional from fastapi import APIRouter, HTTPException from pydantic import BaseModel from utils import ( get_global_connection_pool, init_global_connection_pool, get_global_file_cache, init_global_file_cache, setup_system_optimizations ) from agent.sharded_agent_manager import init_global_sharded_agent_manager try: from utils.system_optimizer import apply_optimization_profile except ImportError: def apply_optimization_profile(profile): return {"profile": profile, "status": "system_optimizer not available"} from embedding import get_model_manager from pydantic import BaseModel import logging from utils.settings import ( MAX_CACHED_AGENTS, SHARD_COUNT, MAX_CONNECTIONS_PER_HOST, MAX_CONNECTIONS_TOTAL, KEEPALIVE_TIMEOUT, CONNECT_TIMEOUT, TOTAL_TIMEOUT, FILE_CACHE_SIZE, FILE_CACHE_TTL, TOKENIZERS_PARALLELISM ) logger = logging.getLogger('app') router = APIRouter() class EncodeRequest(BaseModel): texts: list[str] batch_size: int = 32 class EncodeResponse(BaseModel): success: bool embeddings: list[list[float]] shape: list[int] processing_time: float total_texts: int error: Optional[str] = None # 系统优化设置初始化 logger.info("正在初始化系统优化...") system_optimizer = setup_system_optimizations() # 全局助手管理器配置(使用优化后的配置) max_cached_agents = MAX_CACHED_AGENTS # 增加缓存大小 shard_count = SHARD_COUNT # 分片数量 # 初始化优化的全局助手管理器 agent_manager = init_global_sharded_agent_manager( max_cached_agents=max_cached_agents, shard_count=shard_count ) # 初始化连接池 connection_pool = init_global_connection_pool( max_connections_per_host=MAX_CONNECTIONS_PER_HOST, max_connections_total=MAX_CONNECTIONS_TOTAL, keepalive_timeout=KEEPALIVE_TIMEOUT, connect_timeout=CONNECT_TIMEOUT, total_timeout=TOTAL_TIMEOUT ) # 初始化文件缓存 file_cache = init_global_file_cache( cache_size=FILE_CACHE_SIZE, ttl=FILE_CACHE_TTL ) logger.info("系统优化初始化完成") logger.info(f"- 分片Agent管理器: {shard_count} 个分片,最多缓存 {max_cached_agents} 个agent") logger.info(f"- 连接池: 每主机100连接,总计500连接") logger.info(f"- 文件缓存: 1000个文件,TTL 300秒") @router.get("/api/health") async def health_check(): """Health check endpoint""" return {"message": "Database Assistant API is running"} @router.get("/api/v1/system/performance") async def get_performance_stats(): """获取系统性能统计信息""" try: # 获取agent管理器统计 agent_stats = agent_manager.get_cache_stats() # 获取连接池统计(简化版) pool_stats = { "connection_pool": "active", "max_connections_per_host": 100, "max_connections_total": 500, "keepalive_timeout": 30 } # 获取文件缓存统计 file_cache_stats = { "cache_size": len(file_cache._cache) if hasattr(file_cache, '_cache') else 0, "max_cache_size": file_cache.cache_size if hasattr(file_cache, 'cache_size') else 1000, "ttl": file_cache.ttl if hasattr(file_cache, 'ttl') else 300 } # 系统资源信息 try: import psutil system_stats = { "cpu_count": multiprocessing.cpu_count(), "memory_total_gb": round(psutil.virtual_memory().total / (1024**3), 2), "memory_available_gb": round(psutil.virtual_memory().available / (1024**3), 2), "memory_percent": psutil.virtual_memory().percent, "disk_usage_percent": psutil.disk_usage('/').percent } except ImportError: system_stats = { "cpu_count": multiprocessing.cpu_count(), "memory_info": "psutil not available" } return { "success": True, "timestamp": int(time.time()), "performance": { "agent_manager": agent_stats, "connection_pool": pool_stats, "file_cache": file_cache_stats, "system": system_stats } } except Exception as e: logger.error(f"Error getting performance stats: {str(e)}") raise HTTPException(status_code=500, detail=f"获取性能统计失败: {str(e)}") @router.post("/api/v1/system/optimize") async def optimize_system(profile: str = "balanced"): """应用系统优化配置""" try: # 应用优化配置 config = apply_optimization_profile(profile) return { "success": True, "message": f"已应用 {profile} 优化配置", "config": config } except Exception as e: logger.error(f"Error applying optimization profile: {str(e)}") raise HTTPException(status_code=500, detail=f"应用优化配置失败: {str(e)}") @router.post("/api/v1/system/clear-cache") async def clear_system_cache(cache_type: Optional[str] = None): """清理系统缓存""" try: cleared_counts = {} if cache_type is None or cache_type == "agent": # 清理agent缓存 agent_count = agent_manager.clear_cache() cleared_counts["agent_cache"] = agent_count if cache_type is None or cache_type == "file": # 清理文件缓存 if hasattr(file_cache, '_cache'): file_count = len(file_cache._cache) file_cache._cache.clear() cleared_counts["file_cache"] = file_count return { "success": True, "message": f"已清理指定类型的缓存", "cleared_counts": cleared_counts } except Exception as e: logger.error(f"Error clearing cache: {str(e)}") raise HTTPException(status_code=500, detail=f"清理缓存失败: {str(e)}") @router.get("/api/v1/system/config") async def get_system_config(): """获取当前系统配置""" try: return { "success": True, "config": { "max_cached_agents": max_cached_agents, "shard_count": shard_count, "tokenizer_parallelism": TOKENIZERS_PARALLELISM, "max_connections_per_host": str(MAX_CONNECTIONS_PER_HOST), "max_connections_total": str(MAX_CONNECTIONS_TOTAL), "file_cache_size": str(FILE_CACHE_SIZE), "file_cache_ttl": str(FILE_CACHE_TTL) } } except Exception as e: logger.error(f"Error getting system config: {str(e)}") raise HTTPException(status_code=500, detail=f"获取系统配置失败: {str(e)}") @router.post("/system/remove-project-cache") async def remove_project_cache(dataset_id: str): """移除特定项目的缓存""" try: removed_count = agent_manager.remove_cache_by_unique_id(dataset_id) if removed_count > 0: return {"message": f"项目缓存移除成功: {dataset_id}", "removed_count": removed_count} else: return {"message": f"未找到项目缓存: {dataset_id}", "removed_count": 0} except Exception as e: raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}") @router.post("/api/v1/embedding/encode", response_model=EncodeResponse) async def encode_texts(request: EncodeRequest): """ 文本编码 API Args: request: 包含 texts 和 batch_size 的编码请求 Returns: 编码结果 """ try: model_manager = get_model_manager() if not request.texts: return EncodeResponse( success=False, embeddings=[], shape=[0, 0], processing_time=0.0, total_texts=0, error="texts 不能为空" ) start_time = time.time() # 使用模型管理器编码文本 embeddings = await model_manager.encode_texts( request.texts, batch_size=request.batch_size ) processing_time = time.time() - start_time # 转换为列表格式 embeddings_list = embeddings.tolist() return EncodeResponse( success=True, embeddings=embeddings_list, shape=list(embeddings.shape), processing_time=processing_time, total_texts=len(request.texts) ) except Exception as e: logger.error(f"文本编码 API 错误: {e}") return EncodeResponse( success=False, embeddings=[], shape=[0, 0], processing_time=0.0, total_texts=len(request.texts) if request else 0, error=f"编码失败: {str(e)}" )