diff --git a/fastapi_app.py b/fastapi_app.py index 5bc2d2d..b1fdf28 100644 --- a/fastapi_app.py +++ b/fastapi_app.py @@ -1,5 +1,8 @@ import json import os +import aiofiles +import aiohttp +import hashlib from typing import AsyncGenerator, Dict, List, Optional, Union import uvicorn @@ -41,19 +44,345 @@ def get_content_from_messages(messages: List[dict]) -> str: from file_loaded_agent_manager import get_global_agent_manager, init_global_agent_manager from gbase_agent import update_agent_llm -from zip_project_handler import zip_handler -def get_zip_url_from_unique_id(unique_id: str) -> Optional[str]: - """从unique_map.json中读取zip_url""" +async def download_file(url: str, destination_path: str) -> bool: + """Download file from URL to destination path""" try: - with open('unique_map.json', 'r', encoding='utf-8') as f: - unique_map = json.load(f) - return unique_map.get(unique_id) + os.makedirs(os.path.dirname(destination_path), exist_ok=True) + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + async with aiofiles.open(destination_path, 'wb') as f: + async for chunk in response.content.iter_chunked(8192): + await f.write(chunk) + return True + else: + print(f"Failed to download file from {url}, status: {response.status}") + return False except Exception as e: - print(f"Error reading unique_map.json: {e}") + print(f"Error downloading file from {url}: {str(e)}") + return False + + +def get_file_hash(file_path: str) -> str: + """Generate MD5 hash for a file path/URL""" + return hashlib.md5(file_path.encode('utf-8')).hexdigest() + +def load_processed_files_log(unique_id: str) -> Dict[str, Dict]: + """Load processed files log for a project""" + log_file = os.path.join("projects", unique_id, "processed_files.json") + if os.path.exists(log_file): + try: + with open(log_file, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + print(f"Error loading processed files log: {str(e)}") + return {} + +def save_processed_files_log(unique_id: str, processed_log: Dict[str, Dict]): + """Save processed files log for a project""" + log_file = os.path.join("projects", unique_id, "processed_files.json") + try: + os.makedirs(os.path.dirname(log_file), exist_ok=True) + with open(log_file, 'w', encoding='utf-8') as f: + json.dump(processed_log, f, ensure_ascii=False, indent=2) + except Exception as e: + print(f"Error saving processed files log: {str(e)}") + +def remove_file_or_directory(path: str): + """Remove file or directory if it exists""" + if os.path.exists(path): + try: + if os.path.isdir(path): + import shutil + shutil.rmtree(path) + print(f"Removed directory: {path}") + else: + os.remove(path) + print(f"Removed file: {path}") + return True + except Exception as e: + print(f"Error removing {path}: {str(e)}") + return False + +def remove_dataset_directory(unique_id: str, filename_without_ext: str): + """Remove the entire dataset directory for a specific file""" + dataset_dir = os.path.join("projects", unique_id, "dataset", filename_without_ext) + if remove_file_or_directory(dataset_dir): + print(f"Removed dataset directory: {dataset_dir}") + return True + return False + +def get_document_preview(document_path: str, max_lines: int = 10) -> str: + """Get preview of document content (first max_lines lines)""" + try: + with open(document_path, 'r', encoding='utf-8') as f: + lines = [] + for i, line in enumerate(f): + if i >= max_lines: + break + lines.append(line.rstrip()) + return '\n'.join(lines) + except Exception as e: + print(f"Error reading document preview from {document_path}: {str(e)}") + return f"Error reading document: {str(e)}" + +def generate_dataset_structure(unique_id: str) -> str: + """Generate dataset directory structure as a string""" + dataset_dir = os.path.join("projects", unique_id, "dataset") + structure_lines = [] + + def build_tree(path: str, prefix: str = "", is_last: bool = True): + try: + items = sorted(os.listdir(path)) + items = [item for item in items if not item.startswith('.')] # Hide hidden files + + for i, item in enumerate(items): + item_path = os.path.join(path, item) + is_dir = os.path.isdir(item_path) + + # Determine tree symbols + if i == len(items) - 1: + current_prefix = "└── " if is_last else "├── " + next_prefix = " " if is_last else "│ " + else: + current_prefix = "├── " + next_prefix = "│ " + + line = prefix + current_prefix + item + if is_dir: + line += "/" + structure_lines.append(line) + + # Recursively process subdirectories + if is_dir: + build_tree(item_path, prefix + next_prefix, i == len(items) - 1) + + except Exception as e: + print(f"Error building tree for {path}: {str(e)}") + + structure_lines.append("dataset/") + if os.path.exists(dataset_dir): + build_tree(dataset_dir) + else: + structure_lines.append(" (empty)") + + return '\n'.join(structure_lines) + +def generate_project_readme(unique_id: str) -> str: + """Generate README.md content for a project""" + project_dir = os.path.join("projects", unique_id) + dataset_dir = os.path.join(project_dir, "dataset") + + readme_content = f"""# Project: {unique_id} + +## Dataset Structure + +``` +{generate_dataset_structure(unique_id)} +``` + +## Files Description + +""" + + if not os.path.exists(dataset_dir): + readme_content += "No dataset files available.\n" + else: + # Get all document directories + doc_dirs = [] + try: + for item in sorted(os.listdir(dataset_dir)): + item_path = os.path.join(dataset_dir, item) + if os.path.isdir(item_path): + doc_dirs.append(item) + except Exception as e: + print(f"Error listing dataset directories: {str(e)}") + + if not doc_dirs: + readme_content += "No document directories found.\n" + else: + for doc_dir in doc_dirs: + doc_path = os.path.join(dataset_dir, doc_dir) + document_file = os.path.join(doc_path, "document.txt") + pagination_file = os.path.join(doc_path, "pagination.txt") + embeddings_file = os.path.join(doc_path, "document_embeddings.pkl") + + readme_content += f"### {doc_dir}\n\n" + readme_content += f"**Files:**\n" + readme_content += f"- `document.txt`" + if os.path.exists(document_file): + readme_content += " ✓" + readme_content += "\n" + + readme_content += f"- `pagination.txt`" + if os.path.exists(pagination_file): + readme_content += " ✓" + readme_content += "\n" + + readme_content += f"- `document_embeddings.pkl`" + if os.path.exists(embeddings_file): + readme_content += " ✓" + readme_content += "\n\n" + + # Add document preview + if os.path.exists(document_file): + readme_content += f"**Content Preview (first 10 lines):**\n\n```\n" + preview = get_document_preview(document_file, 10) + readme_content += preview + readme_content += "\n```\n\n" + else: + readme_content += f"**Content Preview:** Not available\n\n" + + readme_content += f"""--- +*Generated on {__import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M:%S')}* +""" + + return readme_content + +def save_project_readme(unique_id: str): + """Generate and save README.md for a project""" + try: + readme_content = generate_project_readme(unique_id) + readme_path = os.path.join("projects", unique_id, "README.md") + + with open(readme_path, 'w', encoding='utf-8') as f: + f.write(readme_content) + + print(f"Generated README.md for project {unique_id}") + return readme_path + except Exception as e: + print(f"Error generating README for project {unique_id}: {str(e)}") return None +async def download_dataset_files(unique_id: str, files: List[str]) -> List[str]: + """Download or copy dataset files to projects/{unique_id}/files directory with processing state management""" + if not files: + return [] + + # Load existing processed files log + processed_log = load_processed_files_log(unique_id) + files_dir = os.path.join("projects", unique_id, "files") + + # Convert files list to a set for easy comparison + new_files_hashes = {get_file_hash(file_path): file_path for file_path in files} + existing_files_hashes = set(processed_log.keys()) + + # Files to process (new or modified) + files_to_process = [] + # Files to remove (no longer in the list) + files_to_remove = existing_files_hashes - set(new_files_hashes.keys()) + + processed_files = [] + + # Remove files that are no longer in the list + for file_hash in files_to_remove: + file_info = processed_log[file_hash] + + # Remove local file in files directory + if 'local_path' in file_info: + remove_file_or_directory(file_info['local_path']) + + # Remove the entire dataset directory for this file + if 'filename' in file_info: + filename_without_ext = os.path.splitext(file_info['filename'])[0] + remove_dataset_directory(unique_id, filename_without_ext) + + # Also remove any specific dataset path if exists (fallback) + if 'dataset_path' in file_info: + remove_file_or_directory(file_info['dataset_path']) + + # Remove from log + del processed_log[file_hash] + print(f"Removed file from processing: {file_info.get('original_path', 'unknown')}") + + # Process new files + for file_path in files: + file_hash = get_file_hash(file_path) + + # Check if file was already processed + if file_hash in processed_log: + file_info = processed_log[file_hash] + if 'local_path' in file_info and os.path.exists(file_info['local_path']): + processed_files.append(file_info['local_path']) + print(f"Skipped already processed file: {file_path}") + continue + + # Extract filename from URL or path + filename = file_path.split("/")[-1] + if not filename: + filename = f"file_{len(processed_files)}" + + destination_path = os.path.join(files_dir, filename) + + # Check if it's a URL (remote file) or local file + success = False + if file_path.startswith(('http://', 'https://')): + # Download remote file + success = await download_file(file_path, destination_path) + else: + # Copy local file + try: + import shutil + os.makedirs(files_dir, exist_ok=True) + shutil.copy2(file_path, destination_path) + success = True + print(f"Copied local file: {file_path} -> {destination_path}") + except Exception as e: + print(f"Failed to copy local file {file_path}: {str(e)}") + + if success: + processed_files.append(destination_path) + # Update processed log + processed_log[file_hash] = { + 'original_path': file_path, + 'local_path': destination_path, + 'filename': filename, + 'processed_at': str(__import__('datetime').datetime.now()), + 'file_type': 'remote' if file_path.startswith(('http://', 'https://')) else 'local' + } + print(f"Successfully processed file: {file_path}") + else: + print(f"Failed to process file: {file_path}") + + # After downloading/copying files, organize them into dataset structure + if processed_files: + try: + from organize_dataset_files import organize_single_project_files + + # Update dataset paths in the log after organization + old_processed_log = processed_log.copy() + organize_single_project_files(unique_id, skip_processed=True) + + # Try to update dataset paths in the log + for file_hash, file_info in old_processed_log.items(): + if 'local_path' in file_info and os.path.exists(file_info['local_path']): + # Construct expected dataset path based on known structure + filename_without_ext = os.path.splitext(file_info['filename'])[0] + dataset_path = os.path.join("projects", unique_id, "dataset", filename_without_ext, "document.txt") + if os.path.exists(dataset_path): + processed_log[file_hash]['dataset_path'] = dataset_path + + print(f"Organized files for project {unique_id} into dataset structure (skipping already processed files)") + except Exception as e: + print(f"Failed to organize files for project {unique_id}: {str(e)}") + + # Save the updated processed log + save_processed_files_log(unique_id, processed_log) + + # Generate README.md after processing files + try: + save_project_readme(unique_id) + except Exception as e: + print(f"Failed to generate README for project {unique_id}: {str(e)}") + + return processed_files + + + + + # 全局助手管理器配置 max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "20")) @@ -80,11 +409,17 @@ class Message(BaseModel): content: str +class DatasetRequest(BaseModel): + system_prompt: Optional[str] = None + mcp_settings: Optional[List[Dict]] = None + files: Optional[List[str]] = None + unique_id: Optional[str] = None + + class ChatRequest(BaseModel): messages: List[Message] model: str = "qwen3-next" model_server: str = "" - zip_url: Optional[str] = None unique_id: Optional[str] = None stream: Optional[bool] = False @@ -170,13 +505,102 @@ async def generate_stream_response(agent, messages, request) -> AsyncGenerator[s yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n" +class FileProcessRequest(BaseModel): + unique_id: str + files: Optional[List[str]] = None + system_prompt: Optional[str] = None + mcp_settings: Optional[List[Dict]] = None + + class Config: + extra = 'allow' + + +class FileProcessResponse(BaseModel): + success: bool + message: str + unique_id: str + processed_files: List[str] + + +@app.post("/api/v1/files/process") +async def process_files(request: FileProcessRequest, authorization: Optional[str] = Header(None)): + """ + Process dataset files for a given unique_id + + Args: + request: FileProcessRequest containing unique_id, files, system_prompt, and mcp_settings + authorization: Authorization header containing API key (Bearer ) + + Returns: + FileProcessResponse: Processing result with file list + """ + try: + + unique_id = request.unique_id + if not unique_id: + raise HTTPException(status_code=400, detail="unique_id is required") + + # 处理文件:只使用request.files + processed_files = [] + if request.files: + # 使用请求中的文件 + processed_files = await download_dataset_files(unique_id, request.files) + print(f"Processed {len(processed_files)} dataset files for unique_id: {unique_id}") + else: + print(f"No files provided in request for unique_id: {unique_id}") + + # 使用unique_id获取项目目录 + project_dir = os.path.join("projects", unique_id) + if not os.path.exists(project_dir): + raise HTTPException(status_code=400, detail=f"Project directory not found for unique_id: {unique_id}") + + # 收集项目目录下所有的 document.txt 文件 + document_files = [] + for root, dirs, files in os.walk(project_dir): + for file in files: + if file == "document.txt": + document_files.append(os.path.join(root, file)) + + # 合并所有处理的文件 + all_files = document_files + processed_files + + if not all_files: + print(f"警告: 项目目录 {project_dir} 中未找到任何 document.txt 文件") + + # 保存system_prompt和mcp_settings到项目目录(如果提供) + if request.system_prompt: + system_prompt_file = os.path.join(project_dir, "system_prompt.md") + with open(system_prompt_file, 'w', encoding='utf-8') as f: + f.write(request.system_prompt) + print(f"Saved system_prompt for unique_id: {unique_id}") + + if request.mcp_settings: + mcp_settings_file = os.path.join(project_dir, "mcp_settings.json") + with open(mcp_settings_file, 'w', encoding='utf-8') as f: + json.dump(request.mcp_settings, f, ensure_ascii=False, indent=2) + print(f"Saved mcp_settings for unique_id: {unique_id}") + + return FileProcessResponse( + success=True, + message=f"Successfully processed {len(all_files)} files", + unique_id=unique_id, + processed_files=all_files + ) + + except HTTPException: + raise + except Exception as e: + print(f"Error processing files: {str(e)}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + + @app.post("/api/v1/chat/completions") async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)): """ Chat completions API similar to OpenAI, supports both streaming and non-streaming Args: - request: ChatRequest containing messages, model, zip_url, etc. + request: ChatRequest containing messages, model, dataset with unique_id, system_prompt, mcp_settings, and files authorization: Authorization header containing API key (Bearer ) Returns: @@ -192,39 +616,23 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] = else: api_key = authorization - # 从最外层获取zip_url和unique_id参数 - zip_url = request.zip_url + # 获取unique_id unique_id = request.unique_id + if not unique_id: + raise HTTPException(status_code=400, detail="unique_id is required") - # 如果提供了unique_id,从unique_map.json中读取zip_url - if unique_id: - zip_url = get_zip_url_from_unique_id(unique_id) - if not zip_url: - raise HTTPException(status_code=400, detail=f"No zip_url found for unique_id: {unique_id}") - - if not zip_url: - raise HTTPException(status_code=400, detail="zip_url is required") - - # 使用ZIP URL获取项目数据 - print(f"从ZIP URL加载项目: {zip_url}") - project_dir = zip_handler.get_project_from_zip(zip_url, unique_id if unique_id else None) - if not project_dir: - raise HTTPException(status_code=400, detail=f"Failed to load project from ZIP URL: {zip_url}") - - # 收集项目目录下所有的 document.txt 文件 - document_files = zip_handler.collect_document_files(project_dir) - - if not document_files: - print(f"警告: 项目目录 {project_dir} 中未找到任何 document.txt 文件") + # 使用unique_id获取项目目录 + project_dir = os.path.join("projects", unique_id) + if not os.path.exists(project_dir): + raise HTTPException(status_code=400, detail=f"Project directory not found for unique_id: {unique_id}") # 收集额外参数作为 generate_cfg - exclude_fields = {'messages', 'model', 'model_server', 'zip_url', 'unique_id', 'stream'} + exclude_fields = {'messages', 'model', 'model_server', 'unique_id', 'stream'} generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields} - # 从全局管理器获取或创建文件预加载的助手实例 + # 从全局管理器获取或创建助手实例(配置读取逻辑已在agent_manager内部处理) agent = await agent_manager.get_or_create_agent( - zip_url=zip_url, - files=document_files, + unique_id=unique_id, project_dir=project_dir, model_name=request.model, api_key=api_key, @@ -322,17 +730,13 @@ async def system_status(): @app.post("/system/cleanup-cache") async def cleanup_cache(): - """清理ZIP文件缓存和助手缓存""" + """清理助手缓存""" try: - # 清理ZIP文件缓存 - zip_handler.cleanup_cache() - # 清理助手实例缓存 cleared_count = agent_manager.clear_cache() return { "message": "缓存清理成功", - "cleared_zip_files": True, "cleared_agent_instances": cleared_count } except Exception as e: @@ -356,11 +760,9 @@ async def cleanup_agent_cache(): async def get_cached_projects(): """获取所有缓存的项目信息""" try: - cached_urls = agent_manager.list_cached_zip_urls() cache_stats = agent_manager.get_cache_stats() return { - "cached_projects": cached_urls, "cache_stats": cache_stats } except Exception as e: @@ -368,18 +770,108 @@ async def get_cached_projects(): @app.post("/system/remove-project-cache") -async def remove_project_cache(zip_url: str): +async def remove_project_cache(unique_id: str): """移除特定项目的缓存""" try: - success = agent_manager.remove_cache_by_url(zip_url) + success = agent_manager.remove_cache_by_unique_id(unique_id) if success: - return {"message": f"项目缓存移除成功: {zip_url}"} + return {"message": f"项目缓存移除成功: {unique_id}"} else: - return {"message": f"未找到项目缓存: {zip_url}", "removed": False} + return {"message": f"未找到项目缓存: {unique_id}", "removed": False} except Exception as e: raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}") +@app.get("/api/v1/files/{unique_id}/status") +async def get_files_processing_status(unique_id: str): + """获取项目的文件处理状态""" + try: + # Load processed files log + processed_log = load_processed_files_log(unique_id) + + # Get project directory info + project_dir = os.path.join("projects", unique_id) + project_exists = os.path.exists(project_dir) + + # Collect document.txt files + document_files = [] + if project_exists: + for root, dirs, files in os.walk(project_dir): + for file in files: + if file == "document.txt": + document_files.append(os.path.join(root, file)) + + return { + "unique_id": unique_id, + "project_exists": project_exists, + "processed_files_count": len(processed_log), + "processed_files": processed_log, + "document_files_count": len(document_files), + "document_files": document_files, + "log_file_exists": os.path.exists(os.path.join("projects", unique_id, "processed_files.json")) + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"获取文件处理状态失败: {str(e)}") + + +@app.post("/api/v1/files/{unique_id}/reset") +async def reset_files_processing(unique_id: str): + """重置项目的文件处理状态,删除处理日志和所有文件""" + try: + project_dir = os.path.join("projects", unique_id) + log_file = os.path.join("projects", unique_id, "processed_files.json") + + # Load processed log to know what files to remove + processed_log = load_processed_files_log(unique_id) + + removed_files = [] + # Remove all processed files and their dataset directories + for file_hash, file_info in processed_log.items(): + # Remove local file in files directory + if 'local_path' in file_info: + if remove_file_or_directory(file_info['local_path']): + removed_files.append(file_info['local_path']) + + # Remove the entire dataset directory for this file + if 'filename' in file_info: + filename_without_ext = os.path.splitext(file_info['filename'])[0] + dataset_dir = os.path.join("projects", unique_id, "dataset", filename_without_ext) + if remove_file_or_directory(dataset_dir): + removed_files.append(dataset_dir) + + # Also remove any specific dataset path if exists (fallback) + if 'dataset_path' in file_info: + if remove_file_or_directory(file_info['dataset_path']): + removed_files.append(file_info['dataset_path']) + + # Remove the log file + if remove_file_or_directory(log_file): + removed_files.append(log_file) + + # Remove the entire files directory + files_dir = os.path.join(project_dir, "files") + if remove_file_or_directory(files_dir): + removed_files.append(files_dir) + + # Also remove the entire dataset directory (clean up any remaining files) + dataset_dir = os.path.join(project_dir, "dataset") + if remove_file_or_directory(dataset_dir): + removed_files.append(dataset_dir) + + # Remove README.md if exists + readme_file = os.path.join(project_dir, "README.md") + if remove_file_or_directory(readme_file): + removed_files.append(readme_file) + + return { + "message": f"文件处理状态重置成功: {unique_id}", + "removed_files_count": len(removed_files), + "removed_files": removed_files + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}") + + if __name__ == "__main__": diff --git a/file_loaded_agent_manager.py b/file_loaded_agent_manager.py index 1e69656..e4d838c 100644 --- a/file_loaded_agent_manager.py +++ b/file_loaded_agent_manager.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""文件预加载助手管理器 - 管理基于ZIP URL的助手实例缓存""" +"""文件预加载助手管理器 - 管理基于unique_id的助手实例缓存""" import hashlib import time @@ -27,20 +27,19 @@ from gbase_agent import init_agent_service_with_files, update_agent_llm class FileLoadedAgentManager: """文件预加载助手管理器 - 基于 ZIP URL 缓存助手实例,避免重复创建和文件解析 + 基于 unique_id 缓存助手实例,避免重复创建和文件解析 """ def __init__(self, max_cached_agents: int = 20): - self.agents: Dict[str, Assistant] = {} # {zip_url_hash: assistant_instance} - self.zip_urls: Dict[str, str] = {} # {zip_url_hash: original_zip_url} + self.agents: Dict[str, Assistant] = {} # {unique_id: assistant_instance} + self.unique_ids: Dict[str, str] = {} # {cache_key: unique_id} self.access_times: Dict[str, float] = {} # LRU 访问时间管理 self.creation_times: Dict[str, float] = {} # 创建时间记录 - self.file_counts: Dict[str, int] = {} # 缓存的文件数量 self.max_cached_agents = max_cached_agents - def _get_zip_url_hash(self, zip_url: str) -> str: - """获取 ZIP URL 的哈希值作为缓存键""" - return hashlib.md5(zip_url.encode('utf-8')).hexdigest()[:16] + def _get_cache_key(self, unique_id: str) -> str: + """获取 unique_id 的哈希值作为缓存键""" + return hashlib.md5(unique_id.encode('utf-8')).hexdigest()[:16] def _update_access_time(self, cache_key: str): """更新访问时间(LRU 管理)""" @@ -60,10 +59,9 @@ class FileLoadedAgentManager: for cache_key in keys_to_remove: try: del self.agents[cache_key] - del self.zip_urls[cache_key] + del self.unique_ids[cache_key] del self.access_times[cache_key] del self.creation_times[cache_key] - del self.file_counts[cache_key] removed_count += 1 logger.info(f"清理过期的助手实例缓存: {cache_key}") except KeyError: @@ -73,8 +71,7 @@ class FileLoadedAgentManager: logger.info(f"已清理 {removed_count} 个过期的助手实例缓存") async def get_or_create_agent(self, - zip_url: str, - files: List[str], + unique_id: str, project_dir: str, model_name: str = "qwen3-next", api_key: Optional[str] = None, @@ -83,7 +80,7 @@ class FileLoadedAgentManager: """获取或创建文件预加载的助手实例 Args: - zip_url: ZIP 文件的 URL + unique_id: 项目的唯一标识符 files: 需要预加载的文件路径列表 project_dir: 项目目录路径,用于读取system_prompt.md和mcp_settings.json model_name: 模型名称 @@ -97,12 +94,16 @@ class FileLoadedAgentManager: import os import json - # 从项目目录读取system_prompt.md和mcp_settings.json + # 读取system_prompt:优先从项目目录读取,然后降级到全局配置 + # 降级到全局配置 system_prompt_template = "" - system_prompt_path = os.path.join(project_dir, "system_prompt.md") - if os.path.exists(system_prompt_path): - with open(system_prompt_path, "r", encoding="utf-8") as f: - system_prompt_template = f.read().strip() + # 尝试从项目目录读取 + system_prompt_file = os.path.join(project_dir, "system_prompt.md") + if not os.path.exists(system_prompt_file): + system_prompt_file = "./system_prompt.md" + + with open(system_prompt_file, "r", encoding="utf-8") as f: + system_prompt_template = f.read().strip() readme = "" readme_path = os.path.join(project_dir, "README.md") @@ -110,55 +111,69 @@ class FileLoadedAgentManager: with open(readme_path, "r", encoding="utf-8") as f: readme = f.read().strip() dataset_dir = os.path.join(project_dir, "dataset") - - system_prompt = system_prompt_template.replace("{dataset_dir}", str(dataset_dir)).replace("{readme}", str(readme)) + + final_system_prompt = system_prompt_template.replace("{dataset_dir}", str(dataset_dir)).replace("{readme}", str(readme)) + logger.info(f"Loaded global system_prompt for unique_id: {unique_id}") + if not final_system_prompt: + logger.info(f"No system_prompt found for unique_id: {unique_id}") - mcp_settings = {} - mcp_settings_path = os.path.join(project_dir, "mcp_settings.json") - if os.path.exists(mcp_settings_path): - with open(mcp_settings_path, "r", encoding="utf-8") as f: - mcp_settings = json.load(f) + # 读取mcp_settings:优先从项目目录读取,然后降级到全局配置 + final_mcp_settings = None - cache_key = self._get_zip_url_hash(zip_url) + # 尝试从项目目录读取 + mcp_settings_file = os.path.join(project_dir, "mcp_settings.json") + if os.path.exists(mcp_settings_file): + with open(mcp_settings_file, 'r', encoding='utf-8') as f: + final_mcp_settings = json.load(f) + logger.info(f"Loaded mcp_settings from project directory for unique_id: {unique_id}") + else: + # 降级到全局配置 + mcp_settings_path = "./mcp/mcp_settings.json" + if os.path.exists(mcp_settings_path): + with open(mcp_settings_path, "r", encoding="utf-8") as f: + final_mcp_settings = json.load(f) + logger.info(f"Loaded global mcp_settings for unique_id: {unique_id}") + else: + final_mcp_settings = [] + logger.info(f"No mcp_settings found for unique_id: {unique_id}") + + if final_mcp_settings is None: + final_mcp_settings = [] + + cache_key = self._get_cache_key(unique_id) # 检查是否已存在该助手实例 if cache_key in self.agents: self._update_access_time(cache_key) agent = self.agents[cache_key] - # 动态更新 LLM 配置(如果参数有变化) - update_agent_llm(agent, model_name, api_key, model_server, generate_cfg) + # 动态更新 LLM 配置和系统设置(如果参数有变化) + update_agent_llm(agent, model_name, api_key, model_server, generate_cfg, final_system_prompt, final_mcp_settings) - # 如果从项目目录读取到了system_prompt,更新agent的系统消息 - if system_prompt: - agent.system_message = system_prompt - - logger.info(f"复用现有的助手实例缓存: {cache_key} (文件数: {len(files)})") + logger.info(f"复用现有的助手实例缓存: {cache_key} (unique_id: {unique_id}") return agent # 清理过期实例 self._cleanup_old_agents() # 创建新的助手实例,预加载文件 - logger.info(f"创建新的助手实例缓存: {cache_key}, 预加载文件数: {len(files)}") + logger.info(f"创建新的助手实例缓存: {cache_key}, unique_id: {unique_id}") current_time = time.time() agent = init_agent_service_with_files( - files=files, model_name=model_name, api_key=api_key, model_server=model_server, generate_cfg=generate_cfg, - system_prompt=system_prompt, - mcp=mcp_settings + system_prompt=final_system_prompt, + mcp=final_mcp_settings ) # 缓存实例 self.agents[cache_key] = agent - self.zip_urls[cache_key] = zip_url + self.unique_ids[cache_key] = unique_id self.access_times[cache_key] = current_time self.creation_times[cache_key] = current_time - self.file_counts[cache_key] = len(files) logger.info(f"助手实例缓存创建完成: {cache_key}") return agent @@ -174,8 +189,7 @@ class FileLoadedAgentManager: for cache_key, agent in self.agents.items(): stats["agents"][cache_key] = { - "zip_url": self.zip_urls.get(cache_key, "unknown"), - "file_count": self.file_counts.get(cache_key, 0), + "unique_id": self.unique_ids.get(cache_key, "unknown"), "created_at": self.creation_times.get(cache_key, 0), "last_accessed": self.access_times.get(cache_key, 0), "age_seconds": int(current_time - self.creation_times.get(cache_key, current_time)), @@ -193,40 +207,34 @@ class FileLoadedAgentManager: cache_count = len(self.agents) self.agents.clear() - self.zip_urls.clear() + self.unique_ids.clear() self.access_times.clear() self.creation_times.clear() - self.file_counts.clear() logger.info(f"已清空所有助手实例缓存,共清理 {cache_count} 个实例") return cache_count - def remove_cache_by_url(self, zip_url: str) -> bool: - """根据 ZIP URL 移除特定的缓存 + def remove_cache_by_unique_id(self, unique_id: str) -> bool: + """根据 unique_id 移除特定的缓存 Args: - zip_url: ZIP 文件 URL + unique_id: 项目的唯一标识符 Returns: bool: 是否成功移除 """ - cache_key = self._get_zip_url_hash(zip_url) + cache_key = self._get_cache_key(unique_id) if cache_key in self.agents: del self.agents[cache_key] - del self.zip_urls[cache_key] + del self.unique_ids[cache_key] del self.access_times[cache_key] del self.creation_times[cache_key] - del self.file_counts[cache_key] - logger.info(f"已移除特定 ZIP URL 的助手实例缓存: {zip_url}") + logger.info(f"已移除特定 unique_id 的助手实例缓存: {unique_id}") return True return False - - def list_cached_zip_urls(self) -> List[str]: - """列出所有缓存的 ZIP URL""" - return list(self.zip_urls.values()) # 全局文件预加载助手管理器实例 diff --git a/gbase_agent.py b/gbase_agent.py index dd0de33..6740722 100644 --- a/gbase_agent.py +++ b/gbase_agent.py @@ -103,7 +103,7 @@ def init_agent_service_universal(): return init_agent_service_with_files(files=None) -def init_agent_service_with_files(files: Optional[List[str]] = None, rag_cfg: Optional[Dict] = None, +def init_agent_service_with_files(rag_cfg: Optional[Dict] = None, model_name: str = "qwen3-next", api_key: Optional[str] = None, model_server: Optional[str] = None, generate_cfg: Optional[Dict] = None, system_prompt: Optional[str] = None, mcp: Optional[List[Dict]] = None): @@ -160,8 +160,8 @@ def init_agent_service_with_files(files: Optional[List[str]] = None, rag_cfg: Op return bot -def update_agent_llm(agent, model_name: str, api_key: str = None, model_server: str = None,generate_cfg: Dict = None): - """动态更新助手实例的LLM,支持从接口传入参数""" +def update_agent_llm(agent, model_name: str, api_key: str = None, model_server: str = None, generate_cfg: Dict = None, system_prompt: str = None, mcp_settings: List[Dict] = None): + """动态更新助手实例的LLM和配置,支持从接口传入参数""" # 获取基础配置 llm_config = { @@ -181,6 +181,14 @@ def update_agent_llm(agent, model_name: str, api_key: str = None, model_server: # 动态设置LLM agent.llm = llm_instance + # 更新系统消息(如果提供) + if system_prompt: + agent.system_message = system_prompt + + # 更新MCP设置(如果提供) + if mcp_settings: + agent.function_list = mcp_settings + return agent diff --git a/mcp/mcp_settings.json b/mcp/mcp_settings.json index aa319d9..7cb1a07 100644 --- a/mcp/mcp_settings.json +++ b/mcp/mcp_settings.json @@ -5,10 +5,10 @@ "command": "mcp-ripgrep", "args": [] }, - "json-reader": { + "semantic_search": { "command": "python", "args": [ - "./mcp/json_reader_server.py" + "./mcp/semantic_search_server.py" ] }, "multi-keyword-search": { diff --git a/mcp/multi_keyword_search_server.py b/mcp/multi_keyword_search_server.py index 616a6c4..c45eb9f 100644 --- a/mcp/multi_keyword_search_server.py +++ b/mcp/multi_keyword_search_server.py @@ -125,13 +125,20 @@ def multi_keyword_search(keywords: List[str], file_paths: List[str], try: # 解析相对路径 if not os.path.isabs(file_path): + # 移除 projects/ 前缀(如果存在) + clean_path = file_path + if clean_path.startswith('projects/'): + clean_path = clean_path[9:] # 移除 'projects/' 前缀 + elif clean_path.startswith('./projects/'): + clean_path = clean_path[11:] # 移除 './projects/' 前缀 + # 尝试在项目目录中查找文件 - full_path = os.path.join(project_data_dir, file_path.lstrip('./')) + full_path = os.path.join(project_data_dir, clean_path.lstrip('./')) if os.path.exists(full_path): valid_paths.append(full_path) else: # 如果直接路径不存在,尝试递归查找 - found = find_file_in_project(file_path, project_data_dir) + found = find_file_in_project(clean_path, project_data_dir) if found: valid_paths.append(found) else: diff --git a/mcp/semantic_search_server.py b/mcp/semantic_search_server.py index ef90174..64ca075 100644 --- a/mcp/semantic_search_server.py +++ b/mcp/semantic_search_server.py @@ -87,11 +87,18 @@ def semantic_search(query: str, embeddings_file: str, top_k: int = 20) -> Dict[s try: # 解析相对路径 if not os.path.isabs(embeddings_file): + # 移除 projects/ 前缀(如果存在) + clean_path = embeddings_file + if clean_path.startswith('projects/'): + clean_path = clean_path[9:] # 移除 'projects/' 前缀 + elif clean_path.startswith('./projects/'): + clean_path = clean_path[11:] # 移除 './projects/' 前缀 + # 尝试在项目目录中查找文件 - full_path = os.path.join(project_data_dir, embeddings_file.lstrip('./')) + full_path = os.path.join(project_data_dir, clean_path.lstrip('./')) if not os.path.exists(full_path): # 如果直接路径不存在,尝试递归查找 - found = find_file_in_project(embeddings_file, project_data_dir) + found = find_file_in_project(clean_path, project_data_dir) if found: embeddings_file = found else: diff --git a/organize_dataset_files.py b/organize_dataset_files.py new file mode 100644 index 0000000..638492b --- /dev/null +++ b/organize_dataset_files.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +import os +import shutil +from pathlib import Path + +def is_file_already_processed(target_file: Path, pagination_file: Path, embeddings_file: Path) -> bool: + """Check if a file has already been processed (document.txt, pagination.txt, and embeddings exist)""" + if not target_file.exists(): + return False + + # Check if pagination and embeddings files exist and are not empty + if pagination_file.exists() and embeddings_file.exists(): + # Check file sizes to ensure they're not empty + if pagination_file.stat().st_size > 0 and embeddings_file.stat().st_size > 0: + return True + + return False + +def organize_single_project_files(unique_id: str, skip_processed=True): + """Organize files for a single project from projects/{unique_id}/files to projects/{unique_id}/dataset/{file_name}/document.txt""" + + project_dir = Path("projects") / unique_id + + if not project_dir.exists(): + print(f"Project directory not found: {project_dir}") + return + + print(f"Organizing files for project: {unique_id} (skip_processed={skip_processed})") + + files_dir = project_dir / "files" + dataset_dir = project_dir / "dataset" + + # Check if files directory exists and has files + if not files_dir.exists(): + print(f" No files directory found, skipping...") + return + + files = list(files_dir.glob("*")) + if not files: + print(f" Files directory is empty, skipping...") + return + + # Create dataset directory if it doesn't exist + dataset_dir.mkdir(exist_ok=True) + + # Copy each file to its own directory + for file_path in files: + if file_path.is_file(): + # Get filename without extension as directory name + file_name_without_ext = file_path.stem + target_dir = dataset_dir / file_name_without_ext + target_file = target_dir / "document.txt" + pagination_file = target_dir / "pagination.txt" + embeddings_file = target_dir / "document_embeddings.pkl" + + # Check if file is already processed + if skip_processed and is_file_already_processed(target_file, pagination_file, embeddings_file): + print(f" Skipping already processed file: {file_path.name}") + continue + + print(f" Copying {file_path.name} -> {target_file.relative_to(project_dir)}") + + # Create target directory + target_dir.mkdir(exist_ok=True) + + # Copy and rename file + shutil.copy2(str(file_path), str(target_file)) + + print(f" Files remain in original location (copied to dataset structure)") + + # Process each document.txt file: split pages and generate embeddings + if not skip_processed: + import sys + sys.path.append(os.path.join(os.path.dirname(__file__), 'embedding')) + + from embedding import split_document_by_pages, embed_document + + for file_path in files: + if file_path.is_file(): + file_name_without_ext = file_path.stem + target_dir = dataset_dir / file_name_without_ext + document_file = target_dir / "document.txt" + pagination_file = target_dir / "pagination.txt" + embeddings_file = target_dir / "document_embeddings.pkl" + + # Skip if already processed + if is_file_already_processed(document_file, pagination_file, embeddings_file): + print(f" Skipping document processing for already processed file: {file_path.name}") + continue + + # Split document by pages + print(f" Splitting pages for {document_file.name}") + try: + pages = split_document_by_pages(str(document_file), str(pagination_file)) + print(f" Generated {len(pages)} pages") + except Exception as e: + print(f" Failed to split pages: {e}") + continue + + # Generate embeddings + print(f" Generating embeddings for {document_file.name}") + try: + # Set local model path for faster processing + local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2" + if not os.path.exists(local_model_path): + local_model_path = None # Fallback to HuggingFace model + + embedding_data = embed_document( + str(document_file), + str(embeddings_file), + chunking_strategy='smart', + model_path=local_model_path, + max_chunk_size=800, + overlap=100 + ) + + if embedding_data: + print(f" Generated embeddings for {len(embedding_data['chunks'])} chunks") + else: + print(f" Failed to generate embeddings") + except Exception as e: + print(f" Failed to generate embeddings: {e}") + + print(f" Document processing completed for project {unique_id}") + else: + print(f" Skipping document processing (skip_processed=True)") + + +def organize_dataset_files(): + """Move files from projects/{unique_id}/files to projects/{unique_id}/dataset/{file_name}/document.txt""" + + projects_dir = Path("projects") + + if not projects_dir.exists(): + print("Projects directory not found") + return + + # Get all project directories (exclude cache and other non-project dirs) + project_dirs = [d for d in projects_dir.iterdir() + if d.is_dir() and d.name != "_cache" and not d.name.startswith(".")] + + for project_dir in project_dirs: + print(f"\nProcessing project: {project_dir.name}") + + files_dir = project_dir / "files" + dataset_dir = project_dir / "dataset" + + # Check if files directory exists and has files + if not files_dir.exists(): + print(f" No files directory found, skipping...") + continue + + files = list(files_dir.glob("*")) + if not files: + print(f" Files directory is empty, skipping...") + continue + + # Create dataset directory if it doesn't exist + dataset_dir.mkdir(exist_ok=True) + + # Move each file to its own directory + for file_path in files: + if file_path.is_file(): + # Get filename without extension as directory name + file_name_without_ext = file_path.stem + target_dir = dataset_dir / file_name_without_ext + target_file = target_dir / "document.txt" + + print(f" Copying {file_path.name} -> {target_file.relative_to(project_dir)}") + + # Create target directory + target_dir.mkdir(exist_ok=True) + + # Copy and rename file + shutil.copy2(str(file_path), str(target_file)) + + print(f" Files remain in original location (copied to dataset structure)") + + print("\nFile organization complete!") + +if __name__ == "__main__": + organize_dataset_files() \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index b667448..9b248e6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,17 @@ # This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +[[package]] +name = "aiofiles" +version = "25.1.0" +description = "File support for asyncio." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "aiofiles-25.1.0-py3-none-any.whl", hash = "sha256:abe311e527c862958650f9438e859c1fa7568a141b22abcd015e120e86a85695"}, + {file = "aiofiles-25.1.0.tar.gz", hash = "sha256:a8d728f0a29de45dc521f18f07297428d56992a742f0cd2701ba86e44d23d5b2"}, +] + [[package]] name = "aiohappyeyeballs" version = "2.6.1" @@ -3949,4 +3961,4 @@ propcache = ">=0.2.1" [metadata] lock-version = "2.1" python-versions = "3.12.0" -content-hash = "e3237e20c799a13a0854f786a34a7c7cb2d5020902281c92ebff2e497492edd1" +content-hash = "06c3b78c8107692eb5944b144ae4df02862fa5e4e8a198f6ccfa07c6743a49cf" diff --git a/pyproject.toml b/pyproject.toml index f48a15b..a6d4049 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,8 @@ dependencies = [ "transformers", "sentence-transformers", "numpy<2", + "aiohttp", + "aiofiles", ] diff --git a/unique_map.json b/unique_map.json deleted file mode 100644 index 88438e9..0000000 --- a/unique_map.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "b743ccc3-13be-43ea-8ec9-4ce9c86103b3": [ - "public/all_hp_product_spec_book2506.txt" - ] -}