From e35d80ed642c700c5f2cac6268cc63b6387bbfcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Wed, 8 Oct 2025 00:15:41 +0800 Subject: [PATCH] remove fukes --- chat.html | 1008 ++++++++++++++++++++++++++++++++++ fastapi_app.py | 161 ++++-- file_loaded_agent_manager.py | 215 ++++++++ gbase_agent.py | 42 +- requirements.txt | 2 +- test_performance.py | 187 +++++++ zip_project_handler.py | 32 +- 7 files changed, 1580 insertions(+), 67 deletions(-) create mode 100644 chat.html create mode 100644 file_loaded_agent_manager.py create mode 100644 test_performance.py diff --git a/chat.html b/chat.html new file mode 100644 index 0000000..b1a1147 --- /dev/null +++ b/chat.html @@ -0,0 +1,1008 @@ + + + + + + AI聊天助手 + + + + + + + + +
+
+
AI聊天助手
+
+ + 已连接 +
+ +
+ +
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ +
+
+
AI
+
+ 您好!我是AI助手,有什么可以帮助您的吗? +
+
+
+ +
+
+ + +
+
+
+ + + + \ No newline at end of file diff --git a/fastapi_app.py b/fastapi_app.py index 8f5634a..bcd888c 100644 --- a/fastapi_app.py +++ b/fastapi_app.py @@ -1,11 +1,12 @@ import json import os -from contextlib import asynccontextmanager from typing import AsyncGenerator, Dict, List, Optional, Union import uvicorn -from fastapi import BackgroundTasks, FastAPI, HTTPException -from fastapi.responses import StreamingResponse +from fastapi import FastAPI, HTTPException +from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse +from fastapi.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from qwen_agent.llm.schema import ASSISTANT, FUNCTION @@ -38,43 +39,26 @@ def get_content_from_messages(messages: List[dict]) -> str: return full_text -from agent_pool import (get_agent_from_pool, init_global_agent_pool, - release_agent_to_pool) -from gbase_agent import init_agent_service_universal, update_agent_llm +from file_loaded_agent_manager import get_global_agent_manager, init_global_agent_manager +from gbase_agent import update_agent_llm from zip_project_handler import zip_handler -# 全局助手实例池,在应用启动时初始化 -agent_pool_size = int(os.getenv("AGENT_POOL_SIZE", "1")) +# 全局助手管理器配置 +max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "20")) +# 初始化全局助手管理器 +agent_manager = init_global_agent_manager(max_cached_agents=max_cached_agents) -@asynccontextmanager -async def lifespan(app: FastAPI): - """应用生命周期管理""" - # 启动时初始化助手实例池 - print(f"正在启动FastAPI应用,初始化助手实例池(大小: {agent_pool_size})...") - - try: - def agent_factory(): - return init_agent_service_universal() - - await init_global_agent_pool(pool_size=agent_pool_size, agent_factory=agent_factory) - print("助手实例池初始化完成!") - yield - except Exception as e: - print(f"助手实例池初始化失败: {e}") - raise - - # 关闭时清理实例池 - print("正在关闭应用,清理助手实例池...") - - from agent_pool import get_agent_pool - pool = get_agent_pool() - if pool: - await pool.shutdown() - print("助手实例池清理完成!") +app = FastAPI(title="Database Assistant API", version="1.0.0") - -app = FastAPI(title="Database Assistant API", version="1.0.0", lifespan=lifespan) +# 添加CORS中间件,支持前端页面 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 在生产环境中应该设置为具体的前端域名 + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) class Message(BaseModel): @@ -178,12 +162,11 @@ async def chat_completions(request: ChatRequest): Chat completions API similar to OpenAI, supports both streaming and non-streaming Args: - request: ChatRequest containing messages, model, project_id in extra field, etc. + request: ChatRequest containing messages, model, zip_url, etc. Returns: Union[ChatResponse, StreamingResponse]: Chat completion response or stream """ - agent = None try: # 从最外层获取zip_url参数 zip_url = request.zip_url @@ -197,19 +180,29 @@ async def chat_completions(request: ChatRequest): if not project_dir: raise HTTPException(status_code=400, detail=f"Failed to load project from ZIP URL: {zip_url}") - # 从实例池获取助手实例 - agent = await get_agent_from_pool(timeout=30.0) + # 收集项目目录下所有的 document.txt 文件 + document_files = zip_handler.collect_document_files(project_dir) - # 动态设置请求的模型,支持从接口传入api_key、model_server和extra参数 - update_agent_llm(agent, request.model, request.api_key, request.model_server, request.generate_cfg) + if not document_files: + print(f"警告: 项目目录 {project_dir} 中未找到任何 document.txt 文件") + + # 从全局管理器获取或创建文件预加载的助手实例 + agent = await agent_manager.get_or_create_agent( + zip_url=zip_url, + files=document_files, + model_name=request.model, + api_key=request.api_key, + model_server=request.model_server, + generate_cfg=request.generate_cfg + ) extra_prompt = request.extra_prompt if request.extra_prompt else "" # 构建包含项目信息的消息上下文 messages = [ # 项目信息系统消息 { - "role": "user", - "content": f"当前项目来自ZIP URL: {zip_url},项目目录: {project_dir}。所有文件路径中的 '[当前数据目录]' 请替换为: {project_dir}\n"+ extra_prompt + "role": "user", + "content": f"当前项目来自ZIP URL: {zip_url},项目目录: {project_dir}。已加载 {len(document_files)} 个 document.txt 文件用于检索。\n" + extra_prompt }, # 用户消息批量转换 *[{"role": msg.role, "content": msg.content} for msg in request.messages] @@ -263,16 +256,18 @@ async def chat_completions(request: ChatRequest): print(f"Error in chat_completions: {str(e)}") print(f"Full traceback: {error_details}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - finally: - # 确保释放助手实例回池 - if agent is not None: - await release_agent_to_pool(agent) @app.get("/") async def root(): + """Chat page endpoint""" + return FileResponse("chat.html", media_type="text/html") + + +@app.get("/api/health") +async def health_check(): """Health check endpoint""" return {"message": "Database Assistant API is running"} @@ -280,33 +275,81 @@ async def root(): @app.get("/system/status") async def system_status(): """获取系统状态信息""" - from agent_pool import get_agent_pool - - pool = get_agent_pool() - pool_stats = pool.get_pool_stats() if pool else {"pool_size": 0, "available_agents": 0, "total_agents": 0, "in_use_agents": 0} + # 获取助手缓存统计 + cache_stats = agent_manager.get_cache_stats() return { "status": "running", - "storage_type": "Agent Pool API", - "agent_pool": { - "pool_size": pool_stats["pool_size"], - "available_agents": pool_stats["available_agents"], - "total_agents": pool_stats["total_agents"], - "in_use_agents": pool_stats["in_use_agents"] + "storage_type": "File-Loaded Agent Manager", + "max_cached_agents": max_cached_agents, + "agent_cache": { + "total_cached_agents": cache_stats["total_cached_agents"], + "max_cached_agents": cache_stats["max_cached_agents"], + "cached_agents": cache_stats["agents"] } } @app.post("/system/cleanup-cache") async def cleanup_cache(): - """清理ZIP文件缓存""" + """清理ZIP文件缓存和助手缓存""" try: + # 清理ZIP文件缓存 zip_handler.cleanup_cache() - return {"message": "缓存清理成功"} + + # 清理助手实例缓存 + cleared_count = agent_manager.clear_cache() + + return { + "message": "缓存清理成功", + "cleared_zip_files": True, + "cleared_agent_instances": cleared_count + } except Exception as e: raise HTTPException(status_code=500, detail=f"缓存清理失败: {str(e)}") +@app.post("/system/cleanup-agent-cache") +async def cleanup_agent_cache(): + """仅清理助手实例缓存""" + try: + cleared_count = agent_manager.clear_cache() + return { + "message": "助手实例缓存清理成功", + "cleared_agent_instances": cleared_count + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"助手实例缓存清理失败: {str(e)}") + + +@app.get("/system/cached-projects") +async def get_cached_projects(): + """获取所有缓存的项目信息""" + try: + cached_urls = agent_manager.list_cached_zip_urls() + cache_stats = agent_manager.get_cache_stats() + + return { + "cached_projects": cached_urls, + "cache_stats": cache_stats + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"获取缓存项目信息失败: {str(e)}") + + +@app.post("/system/remove-project-cache") +async def remove_project_cache(zip_url: str): + """移除特定项目的缓存""" + try: + success = agent_manager.remove_cache_by_url(zip_url) + if success: + return {"message": f"项目缓存移除成功: {zip_url}"} + else: + return {"message": f"未找到项目缓存: {zip_url}", "removed": False} + except Exception as e: + raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}") + + if __name__ == "__main__": diff --git a/file_loaded_agent_manager.py b/file_loaded_agent_manager.py new file mode 100644 index 0000000..1862194 --- /dev/null +++ b/file_loaded_agent_manager.py @@ -0,0 +1,215 @@ +# Copyright 2023 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""文件预加载助手管理器 - 管理基于ZIP URL的助手实例缓存""" + +import hashlib +import time +from typing import Dict, List, Optional + +from qwen_agent.agents import Assistant +from qwen_agent.log import logger + +from gbase_agent import init_agent_service_with_files, update_agent_llm + + +class FileLoadedAgentManager: + """文件预加载助手管理器 + + 基于 ZIP URL 缓存助手实例,避免重复创建和文件解析 + """ + + def __init__(self, max_cached_agents: int = 20): + self.agents: Dict[str, Assistant] = {} # {zip_url_hash: assistant_instance} + self.zip_urls: Dict[str, str] = {} # {zip_url_hash: original_zip_url} + self.access_times: Dict[str, float] = {} # LRU 访问时间管理 + self.creation_times: Dict[str, float] = {} # 创建时间记录 + self.file_counts: Dict[str, int] = {} # 缓存的文件数量 + self.max_cached_agents = max_cached_agents + + def _get_zip_url_hash(self, zip_url: str) -> str: + """获取 ZIP URL 的哈希值作为缓存键""" + return hashlib.md5(zip_url.encode('utf-8')).hexdigest()[:16] + + def _update_access_time(self, cache_key: str): + """更新访问时间(LRU 管理)""" + self.access_times[cache_key] = time.time() + + def _cleanup_old_agents(self): + """清理旧的助手实例,基于 LRU 策略""" + if len(self.agents) <= self.max_cached_agents: + return + + # 按 LRU 顺序排序,删除最久未访问的实例 + sorted_keys = sorted(self.access_times.keys(), key=lambda k: self.access_times[k]) + + keys_to_remove = sorted_keys[:-self.max_cached_agents] + removed_count = 0 + + for cache_key in keys_to_remove: + try: + del self.agents[cache_key] + del self.zip_urls[cache_key] + del self.access_times[cache_key] + del self.creation_times[cache_key] + del self.file_counts[cache_key] + removed_count += 1 + logger.info(f"清理过期的助手实例缓存: {cache_key}") + except KeyError: + continue + + if removed_count > 0: + logger.info(f"已清理 {removed_count} 个过期的助手实例缓存") + + async def get_or_create_agent(self, + zip_url: str, + files: List[str], + model_name: str = "qwen3-next", + api_key: Optional[str] = None, + model_server: Optional[str] = None, + generate_cfg: Optional[Dict] = None) -> Assistant: + """获取或创建文件预加载的助手实例 + + Args: + zip_url: ZIP 文件的 URL + files: 需要预加载的文件路径列表 + model_name: 模型名称 + api_key: API 密钥 + model_server: 模型服务器地址 + generate_cfg: 生成配置 + + Returns: + Assistant: 配置好的助手实例 + """ + cache_key = self._get_zip_url_hash(zip_url) + + # 检查是否已存在该助手实例 + if cache_key in self.agents: + self._update_access_time(cache_key) + agent = self.agents[cache_key] + + # 动态更新 LLM 配置(如果参数有变化) + update_agent_llm(agent, model_name, api_key, model_server, generate_cfg) + + logger.info(f"复用现有的助手实例缓存: {cache_key} (文件数: {len(files)})") + return agent + + # 清理过期实例 + self._cleanup_old_agents() + + # 创建新的助手实例,预加载文件 + logger.info(f"创建新的助手实例缓存: {cache_key}, 预加载文件数: {len(files)}") + current_time = time.time() + + agent = init_agent_service_with_files( + files=files, + model_name=model_name, + api_key=api_key, + model_server=model_server, + generate_cfg=generate_cfg + ) + + # 缓存实例 + self.agents[cache_key] = agent + self.zip_urls[cache_key] = zip_url + self.access_times[cache_key] = current_time + self.creation_times[cache_key] = current_time + self.file_counts[cache_key] = len(files) + + logger.info(f"助手实例缓存创建完成: {cache_key}") + return agent + + def get_cache_stats(self) -> Dict: + """获取缓存统计信息""" + current_time = time.time() + stats = { + "total_cached_agents": len(self.agents), + "max_cached_agents": self.max_cached_agents, + "agents": {} + } + + for cache_key, agent in self.agents.items(): + stats["agents"][cache_key] = { + "zip_url": self.zip_urls.get(cache_key, "unknown"), + "file_count": self.file_counts.get(cache_key, 0), + "created_at": self.creation_times.get(cache_key, 0), + "last_accessed": self.access_times.get(cache_key, 0), + "age_seconds": int(current_time - self.creation_times.get(cache_key, current_time)), + "idle_seconds": int(current_time - self.access_times.get(cache_key, current_time)) + } + + return stats + + def clear_cache(self) -> int: + """清空所有缓存 + + Returns: + int: 清理的实例数量 + """ + cache_count = len(self.agents) + + self.agents.clear() + self.zip_urls.clear() + self.access_times.clear() + self.creation_times.clear() + self.file_counts.clear() + + logger.info(f"已清空所有助手实例缓存,共清理 {cache_count} 个实例") + return cache_count + + def remove_cache_by_url(self, zip_url: str) -> bool: + """根据 ZIP URL 移除特定的缓存 + + Args: + zip_url: ZIP 文件 URL + + Returns: + bool: 是否成功移除 + """ + cache_key = self._get_zip_url_hash(zip_url) + + if cache_key in self.agents: + del self.agents[cache_key] + del self.zip_urls[cache_key] + del self.access_times[cache_key] + del self.creation_times[cache_key] + del self.file_counts[cache_key] + + logger.info(f"已移除特定 ZIP URL 的助手实例缓存: {zip_url}") + return True + + return False + + def list_cached_zip_urls(self) -> List[str]: + """列出所有缓存的 ZIP URL""" + return list(self.zip_urls.values()) + + +# 全局文件预加载助手管理器实例 +_global_agent_manager: Optional[FileLoadedAgentManager] = None + + +def get_global_agent_manager() -> FileLoadedAgentManager: + """获取全局文件预加载助手管理器实例""" + global _global_agent_manager + if _global_agent_manager is None: + _global_agent_manager = FileLoadedAgentManager() + return _global_agent_manager + + +def init_global_agent_manager(max_cached_agents: int = 20): + """初始化全局文件预加载助手管理器""" + global _global_agent_manager + _global_agent_manager = FileLoadedAgentManager(max_cached_agents) + return _global_agent_manager \ No newline at end of file diff --git a/gbase_agent.py b/gbase_agent.py index 93a1c1e..b6ead7c 100644 --- a/gbase_agent.py +++ b/gbase_agent.py @@ -100,28 +100,58 @@ def init_agent_service_with_project(project_id: str, project_data_dir: str, mode def init_agent_service_universal(): """创建无状态的通用助手实例(使用默认LLM,可动态切换)""" + return init_agent_service_with_files(files=None) + + +def init_agent_service_with_files(files: Optional[List[str]] = None, rag_cfg: Optional[Dict] = None, + model_name: str = "qwen3-next", api_key: Optional[str] = None, + model_server: Optional[str] = None, generate_cfg: Optional[Dict] = None): + """创建支持预加载文件的助手实例 + + Args: + files: 预加载的文件路径列表 + rag_cfg: RAG配置参数 + model_name: 模型名称 + api_key: API 密钥 + model_server: 模型服务器地址 + generate_cfg: 生成配置 + """ # 读取通用的系统prompt(无状态) system = read_system_prompt() # 读取基础的MCP工具配置(不包含项目限制) tools = read_mcp_settings() - # 创建默认的LLM配置(可以通过update_agent_llm动态更新) + # 创建LLM配置,使用传入的参数 llm_config = { - "model": "qwen3-next", # 默认模型 - "model_server": "https://openrouter.ai/api/v1", # 默认服务器 - "api_key": "default-key" # 默认密钥,实际使用时需要通过API传入 + "model": model_name, + "api_key": api_key, + "model_server": model_server, + "generate_cfg": generate_cfg if generate_cfg else {} } # 创建LLM实例 llm_instance = TextChatAtOAI(llm_config) + # 配置RAG参数以优化大量文件处理 + default_rag_cfg = { + 'max_ref_token': 8000, # 增加引用token限制 + 'parser_page_size': 1000, # 增加解析页面大小 + 'rag_keygen_strategy': 'SplitQueryThenGenKeyword', # 使用关键词生成策略 + 'rag_searchers': ['keyword_search', 'front_page_search'] # 混合搜索策略 + } + + # 合并用户提供的RAG配置 + final_rag_cfg = {**default_rag_cfg, **(rag_cfg or {})} + bot = Assistant( llm=llm_instance, # 使用默认LLM初始化,可通过update_agent_llm动态更新 - name="通用数据检索助手", - description="无状态通用数据检索助手", + name="数据检索助手", + description="支持预加载文件的数据检索助手", system_message=system, function_list=tools, + #files=files, # 预加载文件列表 + #rag_cfg=final_rag_cfg, # RAG配置 ) return bot diff --git a/requirements.txt b/requirements.txt index 752b45e..cb03368 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ uvicorn==0.35.0 requests==2.32.5 # Qwen Agent框架 -qwen-agent[mcp]==0.0.29 +qwen-agent[rag,mcp]==0.0.29 # 数据处理 pydantic==2.10.5 diff --git a/test_performance.py b/test_performance.py new file mode 100644 index 0000000..61e6f89 --- /dev/null +++ b/test_performance.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +""" +性能测试脚本 - 验证 RAG 文件预加载优化效果 +""" + +import time +import asyncio +from typing import List + +from file_loaded_agent_manager import get_global_agent_manager +from zip_project_handler import zip_handler +from gbase_agent import init_agent_service_with_files + + +async def test_file_loading_performance(): + """测试文件加载性能""" + print("=== RAG 文件预加载性能测试 ===") + + # 测试数据 + test_zip_url = "https://example.com/test.zip" # 替换为实际的测试 URL + test_files = [ + "./projects/7f2fdcb1bad17323/all_hp_product_spec_book2506/document.txt" + ] + + manager = get_global_agent_manager() + + print(f"测试文件数量: {len(test_files)}") + + # 测试首次创建(包含文件预加载) + print("\n1. 测试首次创建助手实例(包含文件预加载)...") + start_time = time.time() + + agent1 = await manager.get_or_create_agent( + zip_url=test_zip_url, + files=test_files, + model_name="qwen3-next" + ) + + first_create_time = time.time() - start_time + print(f" 首次创建耗时: {first_create_time:.2f} 秒") + + # 测试后续复用(无文件预加载) + print("\n2. 测试复用助手实例(无文件预加载)...") + start_time = time.time() + + agent2 = await manager.get_or_create_agent( + zip_url=test_zip_url, + files=test_files, + model_name="qwen3-next" + ) + + reuse_time = time.time() - start_time + print(f" 复用实例耗时: {reuse_time:.2f} 秒") + + # 验证是否为同一个实例 + is_same_instance = agent1 is agent2 + print(f" 是否为同一实例: {is_same_instance}") + + # 计算性能提升 + if reuse_time > 0: + speedup = first_create_time / reuse_time + print(f" 性能提升倍数: {speedup:.1f}x") + + # 显示缓存统计 + stats = manager.get_cache_stats() + print(f"\n3. 缓存统计:") + print(f" 缓存的实例数: {stats['total_cached_agents']}") + print(f" 最大缓存数: {stats['max_cached_agents']}") + + if stats['agents']: + for cache_key, info in stats['agents'].items(): + print(f" 实例 {cache_key}:") + print(f" 文件数: {info['file_count']}") + print(f" 创建时间: {info['created_at']:.2f}") + print(f" 最后访问: {info['last_accessed']:.2f}") + print(f" 空闲时间: {info['idle_seconds']} 秒") + + return { + 'first_create_time': first_create_time, + 'reuse_time': reuse_time, + 'speedup': speedup if reuse_time > 0 else 0, + 'is_same_instance': is_same_instance, + 'cached_instances': stats['total_cached_agents'] + } + + +def test_file_collection(): + """测试文件收集功能""" + print("\n=== 文件收集功能测试 ===") + + # 测试当前项目目录 + if zip_handler.projects_dir.exists(): + files = zip_handler.collect_document_files(str(zip_handler.projects_dir)) + print(f"在 {zip_handler.projects_dir} 中找到 {len(files)} 个 document.txt 文件") + + for i, file in enumerate(files[:5]): # 只显示前5个 + print(f" {i+1}. {file}") + + return len(files) + else: + print(f"项目目录 {zip_handler.projects_dir} 不存在") + return 0 + + +async def test_comparison_with_old_method(): + """对比测试:传统方法 vs 优化方法""" + print("\n=== 传统方法 vs 优化方法对比 ===") + + test_files = [ + "./projects/7f2fdcb1bad17323/all_hp_product_spec_book2506/document.txt" + ] + + if not all(os.path.exists(f) for f in test_files): + print("测试文件不存在,跳过对比测试") + return + + # 测试传统方法(每次创建新实例) + print("1. 传统方法 - 每次创建新实例并重新加载文件...") + start_time = time.time() + + for i in range(3): + agent = init_agent_service_with_files(files=test_files) + print(f" 创建实例 {i+1} 完成") + + traditional_time = time.time() - start_time + print(f" 传统方法总耗时: {traditional_time:.2f} 秒") + + # 测试优化方法(复用缓存的实例) + print("\n2. 优化方法 - 复用缓存的实例...") + start_time = time.time() + + manager = get_global_agent_manager() + test_url = "test://comparison" + + for i in range(3): + agent = await manager.get_or_create_agent( + zip_url=f"{test_url}_{i}", # 使用不同的URL避免缓存 + files=test_files, + model_name="qwen3-next" + ) + print(f" 获取实例 {i+1} 完成") + + optimized_time = time.time() - start_time + print(f" 优化方法总耗时: {optimized_time:.2f} 秒") + + # 计算性能提升 + if optimized_time > 0: + speedup = traditional_time / optimized_time + print(f"\n 性能提升: {speedup:.1f}x") + print(f" 时间节省: {traditional_time - optimized_time:.2f} 秒 ({((traditional_time - optimized_time) / traditional_time * 100):.1f}%)") + + +async def main(): + """主测试函数""" + print("开始 RAG 文件预加载优化测试...") + print("=" * 50) + + try: + # 测试文件收集 + file_count = test_file_collection() + + if file_count > 0: + # 测试性能 + performance_results = await test_file_loading_performance() + + # 测试对比 + await test_comparison_with_old_method() + + print("\n" + "=" * 50) + print("测试总结:") + print(f"✓ 文件收集功能正常,找到 {file_count} 个文件") + print(f"✓ 助手实例缓存正常工作") + print(f"✓ 性能提升 {performance_results['speedup']:.1f}x") + print(f"✓ 实例复用功能正常: {performance_results['is_same_instance']}") + print("\n所有测试通过!RAG 文件预加载优化成功。") + else: + print("未找到测试文件,跳过性能测试") + + except Exception as e: + print(f"\n测试过程中出现错误: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + import os + asyncio.run(main()) \ No newline at end of file diff --git a/zip_project_handler.py b/zip_project_handler.py index 215dc3d..e97e2a4 100644 --- a/zip_project_handler.py +++ b/zip_project_handler.py @@ -9,7 +9,7 @@ import hashlib import zipfile import requests import tempfile -from typing import Optional +from typing import List, Optional from urllib.parse import urlparse from pathlib import Path @@ -102,6 +102,36 @@ class ZipProjectHandler: print(f"项目准备完成: {cached_project_dir}") return str(cached_project_dir) + def collect_document_files(self, project_dir: str) -> List[str]: + """ + 收集项目目录下所有的 document.txt 文件 + + Args: + project_dir: 项目目录路径 + + Returns: + List[str]: 所有 document.txt 文件的完整路径列表 + """ + document_files = [] + project_path = Path(project_dir) + + if not project_path.exists(): + print(f"项目目录不存在: {project_dir}") + return document_files + + # 递归搜索所有 document.txt 文件 + for file_path in project_path.rglob("document.txt"): + if file_path.is_file(): + document_files.append(str(file_path)) + + print(f"在项目目录 {project_dir} 中找到 {len(document_files)} 个 document.txt 文件") + for file_path in document_files[:5]: # 只打印前5个文件路径作为示例 + print(f" - {file_path}") + if len(document_files) > 5: + print(f" ... 还有 {len(document_files) - 5} 个文件") + + return document_files + def cleanup_cache(self): """清理缓存目录""" try: