remove fukes
This commit is contained in:
parent
655b702383
commit
e35d80ed64
159
fastapi_app.py
159
fastapi_app.py
@ -1,11 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from typing import AsyncGenerator, Dict, List, Optional, Union
|
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import BackgroundTasks, FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
|
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
|
||||||
|
|
||||||
@ -38,43 +39,26 @@ def get_content_from_messages(messages: List[dict]) -> str:
|
|||||||
|
|
||||||
return full_text
|
return full_text
|
||||||
|
|
||||||
from agent_pool import (get_agent_from_pool, init_global_agent_pool,
|
from file_loaded_agent_manager import get_global_agent_manager, init_global_agent_manager
|
||||||
release_agent_to_pool)
|
from gbase_agent import update_agent_llm
|
||||||
from gbase_agent import init_agent_service_universal, update_agent_llm
|
|
||||||
from zip_project_handler import zip_handler
|
from zip_project_handler import zip_handler
|
||||||
|
|
||||||
# 全局助手实例池,在应用启动时初始化
|
# 全局助手管理器配置
|
||||||
agent_pool_size = int(os.getenv("AGENT_POOL_SIZE", "1"))
|
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "20"))
|
||||||
|
|
||||||
|
# 初始化全局助手管理器
|
||||||
|
agent_manager = init_global_agent_manager(max_cached_agents=max_cached_agents)
|
||||||
|
|
||||||
@asynccontextmanager
|
app = FastAPI(title="Database Assistant API", version="1.0.0")
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
"""应用生命周期管理"""
|
|
||||||
# 启动时初始化助手实例池
|
|
||||||
print(f"正在启动FastAPI应用,初始化助手实例池(大小: {agent_pool_size})...")
|
|
||||||
|
|
||||||
try:
|
# 添加CORS中间件,支持前端页面
|
||||||
def agent_factory():
|
app.add_middleware(
|
||||||
return init_agent_service_universal()
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # 在生产环境中应该设置为具体的前端域名
|
||||||
await init_global_agent_pool(pool_size=agent_pool_size, agent_factory=agent_factory)
|
allow_credentials=True,
|
||||||
print("助手实例池初始化完成!")
|
allow_methods=["*"],
|
||||||
yield
|
allow_headers=["*"],
|
||||||
except Exception as e:
|
)
|
||||||
print(f"助手实例池初始化失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# 关闭时清理实例池
|
|
||||||
print("正在关闭应用,清理助手实例池...")
|
|
||||||
|
|
||||||
from agent_pool import get_agent_pool
|
|
||||||
pool = get_agent_pool()
|
|
||||||
if pool:
|
|
||||||
await pool.shutdown()
|
|
||||||
print("助手实例池清理完成!")
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Database Assistant API", version="1.0.0", lifespan=lifespan)
|
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
@ -178,12 +162,11 @@ async def chat_completions(request: ChatRequest):
|
|||||||
Chat completions API similar to OpenAI, supports both streaming and non-streaming
|
Chat completions API similar to OpenAI, supports both streaming and non-streaming
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: ChatRequest containing messages, model, project_id in extra field, etc.
|
request: ChatRequest containing messages, model, zip_url, etc.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
||||||
"""
|
"""
|
||||||
agent = None
|
|
||||||
try:
|
try:
|
||||||
# 从最外层获取zip_url参数
|
# 从最外层获取zip_url参数
|
||||||
zip_url = request.zip_url
|
zip_url = request.zip_url
|
||||||
@ -197,11 +180,21 @@ async def chat_completions(request: ChatRequest):
|
|||||||
if not project_dir:
|
if not project_dir:
|
||||||
raise HTTPException(status_code=400, detail=f"Failed to load project from ZIP URL: {zip_url}")
|
raise HTTPException(status_code=400, detail=f"Failed to load project from ZIP URL: {zip_url}")
|
||||||
|
|
||||||
# 从实例池获取助手实例
|
# 收集项目目录下所有的 document.txt 文件
|
||||||
agent = await get_agent_from_pool(timeout=30.0)
|
document_files = zip_handler.collect_document_files(project_dir)
|
||||||
|
|
||||||
# 动态设置请求的模型,支持从接口传入api_key、model_server和extra参数
|
if not document_files:
|
||||||
update_agent_llm(agent, request.model, request.api_key, request.model_server, request.generate_cfg)
|
print(f"警告: 项目目录 {project_dir} 中未找到任何 document.txt 文件")
|
||||||
|
|
||||||
|
# 从全局管理器获取或创建文件预加载的助手实例
|
||||||
|
agent = await agent_manager.get_or_create_agent(
|
||||||
|
zip_url=zip_url,
|
||||||
|
files=document_files,
|
||||||
|
model_name=request.model,
|
||||||
|
api_key=request.api_key,
|
||||||
|
model_server=request.model_server,
|
||||||
|
generate_cfg=request.generate_cfg
|
||||||
|
)
|
||||||
|
|
||||||
extra_prompt = request.extra_prompt if request.extra_prompt else ""
|
extra_prompt = request.extra_prompt if request.extra_prompt else ""
|
||||||
# 构建包含项目信息的消息上下文
|
# 构建包含项目信息的消息上下文
|
||||||
@ -209,7 +202,7 @@ async def chat_completions(request: ChatRequest):
|
|||||||
# 项目信息系统消息
|
# 项目信息系统消息
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"当前项目来自ZIP URL: {zip_url},项目目录: {project_dir}。所有文件路径中的 '[当前数据目录]' 请替换为: {project_dir}\n"+ extra_prompt
|
"content": f"当前项目来自ZIP URL: {zip_url},项目目录: {project_dir}。已加载 {len(document_files)} 个 document.txt 文件用于检索。\n" + extra_prompt
|
||||||
},
|
},
|
||||||
# 用户消息批量转换
|
# 用户消息批量转换
|
||||||
*[{"role": msg.role, "content": msg.content} for msg in request.messages]
|
*[{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||||
@ -263,16 +256,18 @@ async def chat_completions(request: ChatRequest):
|
|||||||
print(f"Error in chat_completions: {str(e)}")
|
print(f"Error in chat_completions: {str(e)}")
|
||||||
print(f"Full traceback: {error_details}")
|
print(f"Full traceback: {error_details}")
|
||||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||||
finally:
|
|
||||||
# 确保释放助手实例回池
|
|
||||||
if agent is not None:
|
|
||||||
await release_agent_to_pool(agent)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
|
"""Chat page endpoint"""
|
||||||
|
return FileResponse("chat.html", media_type="text/html")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/health")
|
||||||
|
async def health_check():
|
||||||
"""Health check endpoint"""
|
"""Health check endpoint"""
|
||||||
return {"message": "Database Assistant API is running"}
|
return {"message": "Database Assistant API is running"}
|
||||||
|
|
||||||
@ -280,33 +275,81 @@ async def root():
|
|||||||
@app.get("/system/status")
|
@app.get("/system/status")
|
||||||
async def system_status():
|
async def system_status():
|
||||||
"""获取系统状态信息"""
|
"""获取系统状态信息"""
|
||||||
from agent_pool import get_agent_pool
|
# 获取助手缓存统计
|
||||||
|
cache_stats = agent_manager.get_cache_stats()
|
||||||
pool = get_agent_pool()
|
|
||||||
pool_stats = pool.get_pool_stats() if pool else {"pool_size": 0, "available_agents": 0, "total_agents": 0, "in_use_agents": 0}
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "running",
|
"status": "running",
|
||||||
"storage_type": "Agent Pool API",
|
"storage_type": "File-Loaded Agent Manager",
|
||||||
"agent_pool": {
|
"max_cached_agents": max_cached_agents,
|
||||||
"pool_size": pool_stats["pool_size"],
|
"agent_cache": {
|
||||||
"available_agents": pool_stats["available_agents"],
|
"total_cached_agents": cache_stats["total_cached_agents"],
|
||||||
"total_agents": pool_stats["total_agents"],
|
"max_cached_agents": cache_stats["max_cached_agents"],
|
||||||
"in_use_agents": pool_stats["in_use_agents"]
|
"cached_agents": cache_stats["agents"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/system/cleanup-cache")
|
@app.post("/system/cleanup-cache")
|
||||||
async def cleanup_cache():
|
async def cleanup_cache():
|
||||||
"""清理ZIP文件缓存"""
|
"""清理ZIP文件缓存和助手缓存"""
|
||||||
try:
|
try:
|
||||||
|
# 清理ZIP文件缓存
|
||||||
zip_handler.cleanup_cache()
|
zip_handler.cleanup_cache()
|
||||||
return {"message": "缓存清理成功"}
|
|
||||||
|
# 清理助手实例缓存
|
||||||
|
cleared_count = agent_manager.clear_cache()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": "缓存清理成功",
|
||||||
|
"cleared_zip_files": True,
|
||||||
|
"cleared_agent_instances": cleared_count
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"缓存清理失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"缓存清理失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/system/cleanup-agent-cache")
|
||||||
|
async def cleanup_agent_cache():
|
||||||
|
"""仅清理助手实例缓存"""
|
||||||
|
try:
|
||||||
|
cleared_count = agent_manager.clear_cache()
|
||||||
|
return {
|
||||||
|
"message": "助手实例缓存清理成功",
|
||||||
|
"cleared_agent_instances": cleared_count
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"助手实例缓存清理失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/system/cached-projects")
|
||||||
|
async def get_cached_projects():
|
||||||
|
"""获取所有缓存的项目信息"""
|
||||||
|
try:
|
||||||
|
cached_urls = agent_manager.list_cached_zip_urls()
|
||||||
|
cache_stats = agent_manager.get_cache_stats()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"cached_projects": cached_urls,
|
||||||
|
"cache_stats": cache_stats
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取缓存项目信息失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/system/remove-project-cache")
|
||||||
|
async def remove_project_cache(zip_url: str):
|
||||||
|
"""移除特定项目的缓存"""
|
||||||
|
try:
|
||||||
|
success = agent_manager.remove_cache_by_url(zip_url)
|
||||||
|
if success:
|
||||||
|
return {"message": f"项目缓存移除成功: {zip_url}"}
|
||||||
|
else:
|
||||||
|
return {"message": f"未找到项目缓存: {zip_url}", "removed": False}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
215
file_loaded_agent_manager.py
Normal file
215
file_loaded_agent_manager.py
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
# Copyright 2023
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""文件预加载助手管理器 - 管理基于ZIP URL的助手实例缓存"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from qwen_agent.agents import Assistant
|
||||||
|
from qwen_agent.log import logger
|
||||||
|
|
||||||
|
from gbase_agent import init_agent_service_with_files, update_agent_llm
|
||||||
|
|
||||||
|
|
||||||
|
class FileLoadedAgentManager:
|
||||||
|
"""文件预加载助手管理器
|
||||||
|
|
||||||
|
基于 ZIP URL 缓存助手实例,避免重复创建和文件解析
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_cached_agents: int = 20):
|
||||||
|
self.agents: Dict[str, Assistant] = {} # {zip_url_hash: assistant_instance}
|
||||||
|
self.zip_urls: Dict[str, str] = {} # {zip_url_hash: original_zip_url}
|
||||||
|
self.access_times: Dict[str, float] = {} # LRU 访问时间管理
|
||||||
|
self.creation_times: Dict[str, float] = {} # 创建时间记录
|
||||||
|
self.file_counts: Dict[str, int] = {} # 缓存的文件数量
|
||||||
|
self.max_cached_agents = max_cached_agents
|
||||||
|
|
||||||
|
def _get_zip_url_hash(self, zip_url: str) -> str:
|
||||||
|
"""获取 ZIP URL 的哈希值作为缓存键"""
|
||||||
|
return hashlib.md5(zip_url.encode('utf-8')).hexdigest()[:16]
|
||||||
|
|
||||||
|
def _update_access_time(self, cache_key: str):
|
||||||
|
"""更新访问时间(LRU 管理)"""
|
||||||
|
self.access_times[cache_key] = time.time()
|
||||||
|
|
||||||
|
def _cleanup_old_agents(self):
|
||||||
|
"""清理旧的助手实例,基于 LRU 策略"""
|
||||||
|
if len(self.agents) <= self.max_cached_agents:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 按 LRU 顺序排序,删除最久未访问的实例
|
||||||
|
sorted_keys = sorted(self.access_times.keys(), key=lambda k: self.access_times[k])
|
||||||
|
|
||||||
|
keys_to_remove = sorted_keys[:-self.max_cached_agents]
|
||||||
|
removed_count = 0
|
||||||
|
|
||||||
|
for cache_key in keys_to_remove:
|
||||||
|
try:
|
||||||
|
del self.agents[cache_key]
|
||||||
|
del self.zip_urls[cache_key]
|
||||||
|
del self.access_times[cache_key]
|
||||||
|
del self.creation_times[cache_key]
|
||||||
|
del self.file_counts[cache_key]
|
||||||
|
removed_count += 1
|
||||||
|
logger.info(f"清理过期的助手实例缓存: {cache_key}")
|
||||||
|
except KeyError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if removed_count > 0:
|
||||||
|
logger.info(f"已清理 {removed_count} 个过期的助手实例缓存")
|
||||||
|
|
||||||
|
async def get_or_create_agent(self,
|
||||||
|
zip_url: str,
|
||||||
|
files: List[str],
|
||||||
|
model_name: str = "qwen3-next",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model_server: Optional[str] = None,
|
||||||
|
generate_cfg: Optional[Dict] = None) -> Assistant:
|
||||||
|
"""获取或创建文件预加载的助手实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
zip_url: ZIP 文件的 URL
|
||||||
|
files: 需要预加载的文件路径列表
|
||||||
|
model_name: 模型名称
|
||||||
|
api_key: API 密钥
|
||||||
|
model_server: 模型服务器地址
|
||||||
|
generate_cfg: 生成配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Assistant: 配置好的助手实例
|
||||||
|
"""
|
||||||
|
cache_key = self._get_zip_url_hash(zip_url)
|
||||||
|
|
||||||
|
# 检查是否已存在该助手实例
|
||||||
|
if cache_key in self.agents:
|
||||||
|
self._update_access_time(cache_key)
|
||||||
|
agent = self.agents[cache_key]
|
||||||
|
|
||||||
|
# 动态更新 LLM 配置(如果参数有变化)
|
||||||
|
update_agent_llm(agent, model_name, api_key, model_server, generate_cfg)
|
||||||
|
|
||||||
|
logger.info(f"复用现有的助手实例缓存: {cache_key} (文件数: {len(files)})")
|
||||||
|
return agent
|
||||||
|
|
||||||
|
# 清理过期实例
|
||||||
|
self._cleanup_old_agents()
|
||||||
|
|
||||||
|
# 创建新的助手实例,预加载文件
|
||||||
|
logger.info(f"创建新的助手实例缓存: {cache_key}, 预加载文件数: {len(files)}")
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
agent = init_agent_service_with_files(
|
||||||
|
files=files,
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=api_key,
|
||||||
|
model_server=model_server,
|
||||||
|
generate_cfg=generate_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
# 缓存实例
|
||||||
|
self.agents[cache_key] = agent
|
||||||
|
self.zip_urls[cache_key] = zip_url
|
||||||
|
self.access_times[cache_key] = current_time
|
||||||
|
self.creation_times[cache_key] = current_time
|
||||||
|
self.file_counts[cache_key] = len(files)
|
||||||
|
|
||||||
|
logger.info(f"助手实例缓存创建完成: {cache_key}")
|
||||||
|
return agent
|
||||||
|
|
||||||
|
def get_cache_stats(self) -> Dict:
|
||||||
|
"""获取缓存统计信息"""
|
||||||
|
current_time = time.time()
|
||||||
|
stats = {
|
||||||
|
"total_cached_agents": len(self.agents),
|
||||||
|
"max_cached_agents": self.max_cached_agents,
|
||||||
|
"agents": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
for cache_key, agent in self.agents.items():
|
||||||
|
stats["agents"][cache_key] = {
|
||||||
|
"zip_url": self.zip_urls.get(cache_key, "unknown"),
|
||||||
|
"file_count": self.file_counts.get(cache_key, 0),
|
||||||
|
"created_at": self.creation_times.get(cache_key, 0),
|
||||||
|
"last_accessed": self.access_times.get(cache_key, 0),
|
||||||
|
"age_seconds": int(current_time - self.creation_times.get(cache_key, current_time)),
|
||||||
|
"idle_seconds": int(current_time - self.access_times.get(cache_key, current_time))
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def clear_cache(self) -> int:
|
||||||
|
"""清空所有缓存
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 清理的实例数量
|
||||||
|
"""
|
||||||
|
cache_count = len(self.agents)
|
||||||
|
|
||||||
|
self.agents.clear()
|
||||||
|
self.zip_urls.clear()
|
||||||
|
self.access_times.clear()
|
||||||
|
self.creation_times.clear()
|
||||||
|
self.file_counts.clear()
|
||||||
|
|
||||||
|
logger.info(f"已清空所有助手实例缓存,共清理 {cache_count} 个实例")
|
||||||
|
return cache_count
|
||||||
|
|
||||||
|
def remove_cache_by_url(self, zip_url: str) -> bool:
|
||||||
|
"""根据 ZIP URL 移除特定的缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
zip_url: ZIP 文件 URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功移除
|
||||||
|
"""
|
||||||
|
cache_key = self._get_zip_url_hash(zip_url)
|
||||||
|
|
||||||
|
if cache_key in self.agents:
|
||||||
|
del self.agents[cache_key]
|
||||||
|
del self.zip_urls[cache_key]
|
||||||
|
del self.access_times[cache_key]
|
||||||
|
del self.creation_times[cache_key]
|
||||||
|
del self.file_counts[cache_key]
|
||||||
|
|
||||||
|
logger.info(f"已移除特定 ZIP URL 的助手实例缓存: {zip_url}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_cached_zip_urls(self) -> List[str]:
|
||||||
|
"""列出所有缓存的 ZIP URL"""
|
||||||
|
return list(self.zip_urls.values())
|
||||||
|
|
||||||
|
|
||||||
|
# 全局文件预加载助手管理器实例
|
||||||
|
_global_agent_manager: Optional[FileLoadedAgentManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_agent_manager() -> FileLoadedAgentManager:
|
||||||
|
"""获取全局文件预加载助手管理器实例"""
|
||||||
|
global _global_agent_manager
|
||||||
|
if _global_agent_manager is None:
|
||||||
|
_global_agent_manager = FileLoadedAgentManager()
|
||||||
|
return _global_agent_manager
|
||||||
|
|
||||||
|
|
||||||
|
def init_global_agent_manager(max_cached_agents: int = 20):
|
||||||
|
"""初始化全局文件预加载助手管理器"""
|
||||||
|
global _global_agent_manager
|
||||||
|
_global_agent_manager = FileLoadedAgentManager(max_cached_agents)
|
||||||
|
return _global_agent_manager
|
||||||
@ -100,28 +100,58 @@ def init_agent_service_with_project(project_id: str, project_data_dir: str, mode
|
|||||||
|
|
||||||
def init_agent_service_universal():
|
def init_agent_service_universal():
|
||||||
"""创建无状态的通用助手实例(使用默认LLM,可动态切换)"""
|
"""创建无状态的通用助手实例(使用默认LLM,可动态切换)"""
|
||||||
|
return init_agent_service_with_files(files=None)
|
||||||
|
|
||||||
|
|
||||||
|
def init_agent_service_with_files(files: Optional[List[str]] = None, rag_cfg: Optional[Dict] = None,
|
||||||
|
model_name: str = "qwen3-next", api_key: Optional[str] = None,
|
||||||
|
model_server: Optional[str] = None, generate_cfg: Optional[Dict] = None):
|
||||||
|
"""创建支持预加载文件的助手实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: 预加载的文件路径列表
|
||||||
|
rag_cfg: RAG配置参数
|
||||||
|
model_name: 模型名称
|
||||||
|
api_key: API 密钥
|
||||||
|
model_server: 模型服务器地址
|
||||||
|
generate_cfg: 生成配置
|
||||||
|
"""
|
||||||
# 读取通用的系统prompt(无状态)
|
# 读取通用的系统prompt(无状态)
|
||||||
system = read_system_prompt()
|
system = read_system_prompt()
|
||||||
|
|
||||||
# 读取基础的MCP工具配置(不包含项目限制)
|
# 读取基础的MCP工具配置(不包含项目限制)
|
||||||
tools = read_mcp_settings()
|
tools = read_mcp_settings()
|
||||||
|
|
||||||
# 创建默认的LLM配置(可以通过update_agent_llm动态更新)
|
# 创建LLM配置,使用传入的参数
|
||||||
llm_config = {
|
llm_config = {
|
||||||
"model": "qwen3-next", # 默认模型
|
"model": model_name,
|
||||||
"model_server": "https://openrouter.ai/api/v1", # 默认服务器
|
"api_key": api_key,
|
||||||
"api_key": "default-key" # 默认密钥,实际使用时需要通过API传入
|
"model_server": model_server,
|
||||||
|
"generate_cfg": generate_cfg if generate_cfg else {}
|
||||||
}
|
}
|
||||||
|
|
||||||
# 创建LLM实例
|
# 创建LLM实例
|
||||||
llm_instance = TextChatAtOAI(llm_config)
|
llm_instance = TextChatAtOAI(llm_config)
|
||||||
|
|
||||||
|
# 配置RAG参数以优化大量文件处理
|
||||||
|
default_rag_cfg = {
|
||||||
|
'max_ref_token': 8000, # 增加引用token限制
|
||||||
|
'parser_page_size': 1000, # 增加解析页面大小
|
||||||
|
'rag_keygen_strategy': 'SplitQueryThenGenKeyword', # 使用关键词生成策略
|
||||||
|
'rag_searchers': ['keyword_search', 'front_page_search'] # 混合搜索策略
|
||||||
|
}
|
||||||
|
|
||||||
|
# 合并用户提供的RAG配置
|
||||||
|
final_rag_cfg = {**default_rag_cfg, **(rag_cfg or {})}
|
||||||
|
|
||||||
bot = Assistant(
|
bot = Assistant(
|
||||||
llm=llm_instance, # 使用默认LLM初始化,可通过update_agent_llm动态更新
|
llm=llm_instance, # 使用默认LLM初始化,可通过update_agent_llm动态更新
|
||||||
name="通用数据检索助手",
|
name="数据检索助手",
|
||||||
description="无状态通用数据检索助手",
|
description="支持预加载文件的数据检索助手",
|
||||||
system_message=system,
|
system_message=system,
|
||||||
function_list=tools,
|
function_list=tools,
|
||||||
|
#files=files, # 预加载文件列表
|
||||||
|
#rag_cfg=final_rag_cfg, # RAG配置
|
||||||
)
|
)
|
||||||
|
|
||||||
return bot
|
return bot
|
||||||
|
|||||||
@ -6,7 +6,7 @@ uvicorn==0.35.0
|
|||||||
requests==2.32.5
|
requests==2.32.5
|
||||||
|
|
||||||
# Qwen Agent框架
|
# Qwen Agent框架
|
||||||
qwen-agent[mcp]==0.0.29
|
qwen-agent[rag,mcp]==0.0.29
|
||||||
|
|
||||||
# 数据处理
|
# 数据处理
|
||||||
pydantic==2.10.5
|
pydantic==2.10.5
|
||||||
|
|||||||
187
test_performance.py
Normal file
187
test_performance.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
性能测试脚本 - 验证 RAG 文件预加载优化效果
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from file_loaded_agent_manager import get_global_agent_manager
|
||||||
|
from zip_project_handler import zip_handler
|
||||||
|
from gbase_agent import init_agent_service_with_files
|
||||||
|
|
||||||
|
|
||||||
|
async def test_file_loading_performance():
|
||||||
|
"""测试文件加载性能"""
|
||||||
|
print("=== RAG 文件预加载性能测试 ===")
|
||||||
|
|
||||||
|
# 测试数据
|
||||||
|
test_zip_url = "https://example.com/test.zip" # 替换为实际的测试 URL
|
||||||
|
test_files = [
|
||||||
|
"./projects/7f2fdcb1bad17323/all_hp_product_spec_book2506/document.txt"
|
||||||
|
]
|
||||||
|
|
||||||
|
manager = get_global_agent_manager()
|
||||||
|
|
||||||
|
print(f"测试文件数量: {len(test_files)}")
|
||||||
|
|
||||||
|
# 测试首次创建(包含文件预加载)
|
||||||
|
print("\n1. 测试首次创建助手实例(包含文件预加载)...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
agent1 = await manager.get_or_create_agent(
|
||||||
|
zip_url=test_zip_url,
|
||||||
|
files=test_files,
|
||||||
|
model_name="qwen3-next"
|
||||||
|
)
|
||||||
|
|
||||||
|
first_create_time = time.time() - start_time
|
||||||
|
print(f" 首次创建耗时: {first_create_time:.2f} 秒")
|
||||||
|
|
||||||
|
# 测试后续复用(无文件预加载)
|
||||||
|
print("\n2. 测试复用助手实例(无文件预加载)...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
agent2 = await manager.get_or_create_agent(
|
||||||
|
zip_url=test_zip_url,
|
||||||
|
files=test_files,
|
||||||
|
model_name="qwen3-next"
|
||||||
|
)
|
||||||
|
|
||||||
|
reuse_time = time.time() - start_time
|
||||||
|
print(f" 复用实例耗时: {reuse_time:.2f} 秒")
|
||||||
|
|
||||||
|
# 验证是否为同一个实例
|
||||||
|
is_same_instance = agent1 is agent2
|
||||||
|
print(f" 是否为同一实例: {is_same_instance}")
|
||||||
|
|
||||||
|
# 计算性能提升
|
||||||
|
if reuse_time > 0:
|
||||||
|
speedup = first_create_time / reuse_time
|
||||||
|
print(f" 性能提升倍数: {speedup:.1f}x")
|
||||||
|
|
||||||
|
# 显示缓存统计
|
||||||
|
stats = manager.get_cache_stats()
|
||||||
|
print(f"\n3. 缓存统计:")
|
||||||
|
print(f" 缓存的实例数: {stats['total_cached_agents']}")
|
||||||
|
print(f" 最大缓存数: {stats['max_cached_agents']}")
|
||||||
|
|
||||||
|
if stats['agents']:
|
||||||
|
for cache_key, info in stats['agents'].items():
|
||||||
|
print(f" 实例 {cache_key}:")
|
||||||
|
print(f" 文件数: {info['file_count']}")
|
||||||
|
print(f" 创建时间: {info['created_at']:.2f}")
|
||||||
|
print(f" 最后访问: {info['last_accessed']:.2f}")
|
||||||
|
print(f" 空闲时间: {info['idle_seconds']} 秒")
|
||||||
|
|
||||||
|
return {
|
||||||
|
'first_create_time': first_create_time,
|
||||||
|
'reuse_time': reuse_time,
|
||||||
|
'speedup': speedup if reuse_time > 0 else 0,
|
||||||
|
'is_same_instance': is_same_instance,
|
||||||
|
'cached_instances': stats['total_cached_agents']
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_collection():
|
||||||
|
"""测试文件收集功能"""
|
||||||
|
print("\n=== 文件收集功能测试 ===")
|
||||||
|
|
||||||
|
# 测试当前项目目录
|
||||||
|
if zip_handler.projects_dir.exists():
|
||||||
|
files = zip_handler.collect_document_files(str(zip_handler.projects_dir))
|
||||||
|
print(f"在 {zip_handler.projects_dir} 中找到 {len(files)} 个 document.txt 文件")
|
||||||
|
|
||||||
|
for i, file in enumerate(files[:5]): # 只显示前5个
|
||||||
|
print(f" {i+1}. {file}")
|
||||||
|
|
||||||
|
return len(files)
|
||||||
|
else:
|
||||||
|
print(f"项目目录 {zip_handler.projects_dir} 不存在")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_comparison_with_old_method():
|
||||||
|
"""对比测试:传统方法 vs 优化方法"""
|
||||||
|
print("\n=== 传统方法 vs 优化方法对比 ===")
|
||||||
|
|
||||||
|
test_files = [
|
||||||
|
"./projects/7f2fdcb1bad17323/all_hp_product_spec_book2506/document.txt"
|
||||||
|
]
|
||||||
|
|
||||||
|
if not all(os.path.exists(f) for f in test_files):
|
||||||
|
print("测试文件不存在,跳过对比测试")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 测试传统方法(每次创建新实例)
|
||||||
|
print("1. 传统方法 - 每次创建新实例并重新加载文件...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
agent = init_agent_service_with_files(files=test_files)
|
||||||
|
print(f" 创建实例 {i+1} 完成")
|
||||||
|
|
||||||
|
traditional_time = time.time() - start_time
|
||||||
|
print(f" 传统方法总耗时: {traditional_time:.2f} 秒")
|
||||||
|
|
||||||
|
# 测试优化方法(复用缓存的实例)
|
||||||
|
print("\n2. 优化方法 - 复用缓存的实例...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
manager = get_global_agent_manager()
|
||||||
|
test_url = "test://comparison"
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
agent = await manager.get_or_create_agent(
|
||||||
|
zip_url=f"{test_url}_{i}", # 使用不同的URL避免缓存
|
||||||
|
files=test_files,
|
||||||
|
model_name="qwen3-next"
|
||||||
|
)
|
||||||
|
print(f" 获取实例 {i+1} 完成")
|
||||||
|
|
||||||
|
optimized_time = time.time() - start_time
|
||||||
|
print(f" 优化方法总耗时: {optimized_time:.2f} 秒")
|
||||||
|
|
||||||
|
# 计算性能提升
|
||||||
|
if optimized_time > 0:
|
||||||
|
speedup = traditional_time / optimized_time
|
||||||
|
print(f"\n 性能提升: {speedup:.1f}x")
|
||||||
|
print(f" 时间节省: {traditional_time - optimized_time:.2f} 秒 ({((traditional_time - optimized_time) / traditional_time * 100):.1f}%)")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""主测试函数"""
|
||||||
|
print("开始 RAG 文件预加载优化测试...")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 测试文件收集
|
||||||
|
file_count = test_file_collection()
|
||||||
|
|
||||||
|
if file_count > 0:
|
||||||
|
# 测试性能
|
||||||
|
performance_results = await test_file_loading_performance()
|
||||||
|
|
||||||
|
# 测试对比
|
||||||
|
await test_comparison_with_old_method()
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("测试总结:")
|
||||||
|
print(f"✓ 文件收集功能正常,找到 {file_count} 个文件")
|
||||||
|
print(f"✓ 助手实例缓存正常工作")
|
||||||
|
print(f"✓ 性能提升 {performance_results['speedup']:.1f}x")
|
||||||
|
print(f"✓ 实例复用功能正常: {performance_results['is_same_instance']}")
|
||||||
|
print("\n所有测试通过!RAG 文件预加载优化成功。")
|
||||||
|
else:
|
||||||
|
print("未找到测试文件,跳过性能测试")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n测试过程中出现错误: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import os
|
||||||
|
asyncio.run(main())
|
||||||
@ -9,7 +9,7 @@ import hashlib
|
|||||||
import zipfile
|
import zipfile
|
||||||
import requests
|
import requests
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -102,6 +102,36 @@ class ZipProjectHandler:
|
|||||||
print(f"项目准备完成: {cached_project_dir}")
|
print(f"项目准备完成: {cached_project_dir}")
|
||||||
return str(cached_project_dir)
|
return str(cached_project_dir)
|
||||||
|
|
||||||
|
def collect_document_files(self, project_dir: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
收集项目目录下所有的 document.txt 文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_dir: 项目目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 所有 document.txt 文件的完整路径列表
|
||||||
|
"""
|
||||||
|
document_files = []
|
||||||
|
project_path = Path(project_dir)
|
||||||
|
|
||||||
|
if not project_path.exists():
|
||||||
|
print(f"项目目录不存在: {project_dir}")
|
||||||
|
return document_files
|
||||||
|
|
||||||
|
# 递归搜索所有 document.txt 文件
|
||||||
|
for file_path in project_path.rglob("document.txt"):
|
||||||
|
if file_path.is_file():
|
||||||
|
document_files.append(str(file_path))
|
||||||
|
|
||||||
|
print(f"在项目目录 {project_dir} 中找到 {len(document_files)} 个 document.txt 文件")
|
||||||
|
for file_path in document_files[:5]: # 只打印前5个文件路径作为示例
|
||||||
|
print(f" - {file_path}")
|
||||||
|
if len(document_files) > 5:
|
||||||
|
print(f" ... 还有 {len(document_files) - 5} 个文件")
|
||||||
|
|
||||||
|
return document_files
|
||||||
|
|
||||||
def cleanup_cache(self):
|
def cleanup_cache(self):
|
||||||
"""清理缓存目录"""
|
"""清理缓存目录"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user