diff --git a/fastapi_app.py b/fastapi_app.py index bf6551c..b1479dd 100644 --- a/fastapi_app.py +++ b/fastapi_app.py @@ -20,7 +20,7 @@ from pydantic import BaseModel, Field # Import utility modules from utils import ( # Models - Message, DatasetRequest, ChatRequest, ChatResponse, QueueTaskRequest, QueueTaskResponse, + Message, DatasetRequest, ChatRequest, ChatResponse, QueueTaskRequest, IncrementalTaskRequest, QueueTaskResponse, QueueStatusResponse, TaskStatusResponse, # File utilities @@ -47,7 +47,7 @@ from modified_assistant import update_agent_llm # Import queue manager from task_queue.manager import queue_manager -from task_queue.integration_tasks import process_files_async, cleanup_project_async +from task_queue.integration_tasks import process_files_async, process_files_incremental_async, cleanup_project_async from task_queue.task_status import task_status_store import re @@ -264,6 +264,69 @@ async def process_files_async_endpoint(request: QueueTaskRequest, authorization: raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") +@app.post("/api/v1/files/process/incremental") +async def process_files_incremental_endpoint(request: IncrementalTaskRequest, authorization: Optional[str] = Header(None)): + """ + 增量处理文件的队列版本API - 支持添加和删除文件 + + Args: + request: IncrementalTaskRequest containing dataset_id, files_to_add, files_to_remove, system_prompt, mcp_settings, and queue options + authorization: Authorization header containing API key (Bearer ) + + Returns: + QueueTaskResponse: Processing result with task ID for tracking + """ + try: + dataset_id = request.dataset_id + if not dataset_id: + raise HTTPException(status_code=400, detail="dataset_id is required") + + # 验证至少有添加或删除操作 + if not request.files_to_add and not request.files_to_remove: + raise HTTPException(status_code=400, detail="At least one of files_to_add or files_to_remove must be provided") + + # 估算处理时间(基于文件数量) + estimated_time = 0 + total_add_files = sum(len(file_list) for file_list in (request.files_to_add or {}).values()) + total_remove_files = sum(len(file_list) for file_list in (request.files_to_remove or {}).values()) + total_files = total_add_files + total_remove_files + estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒 + + # 创建任务状态记录 + import uuid + task_id = str(uuid.uuid4()) + task_status_store.set_status( + task_id=task_id, + unique_id=dataset_id, + status="pending" + ) + + # 提交增量异步任务 + task = process_files_incremental_async( + dataset_id=dataset_id, + files_to_add=request.files_to_add, + files_to_remove=request.files_to_remove, + system_prompt=request.system_prompt, + mcp_settings=request.mcp_settings, + task_id=task_id + ) + + return QueueTaskResponse( + success=True, + message=f"增量文件处理任务已提交到队列 - 添加 {total_add_files} 个文件,删除 {total_remove_files} 个文件,项目ID: {dataset_id}", + unique_id=dataset_id, + task_id=task_id, + task_status="pending", + estimated_processing_time=estimated_time + ) + + except HTTPException: + raise + except Exception as e: + print(f"Error submitting incremental file processing task: {str(e)}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + + @app.get("/api/v1/task/{task_id}/status") async def get_task_status(task_id: str): """获取任务状态 - 简单可靠""" diff --git a/task_queue/integration_tasks.py b/task_queue/integration_tasks.py index 3104701..1bb59cc 100644 --- a/task_queue/integration_tasks.py +++ b/task_queue/integration_tasks.py @@ -11,7 +11,8 @@ from typing import Dict, List, Optional, Any from task_queue.config import huey from task_queue.manager import queue_manager from task_queue.task_status import task_status_store -from utils import download_dataset_files, save_processed_files_log +from utils import download_dataset_files, save_processed_files_log, load_processed_files_log +from utils.dataset_manager import remove_dataset_directory_by_key @huey.task() @@ -154,6 +155,196 @@ def process_files_async( } +@huey.task() +def process_files_incremental_async( + dataset_id: str, + files_to_add: Optional[Dict[str, List[str]]] = None, + files_to_remove: Optional[Dict[str, List[str]]] = None, + system_prompt: Optional[str] = None, + mcp_settings: Optional[List[Dict]] = None, + task_id: Optional[str] = None +) -> Dict[str, Any]: + """ + 增量处理文件任务 - 支持添加和删除文件 + + Args: + dataset_id: 项目唯一ID + files_to_add: 按key分组的要添加的文件路径字典 + files_to_remove: 按key分组的要删除的文件路径字典 + system_prompt: 系统提示词 + mcp_settings: MCP设置 + task_id: 任务ID(用于状态跟踪) + + Returns: + 处理结果字典 + """ + try: + print(f"开始增量处理文件任务,项目ID: {dataset_id}") + + # 如果有task_id,设置初始状态 + if task_id: + task_status_store.set_status( + task_id=task_id, + unique_id=dataset_id, + status="running" + ) + + # 确保项目目录存在 + project_dir = os.path.join("projects", "data", dataset_id) + if not os.path.exists(project_dir): + os.makedirs(project_dir, exist_ok=True) + + # 加载现有的处理日志 + processed_log = load_processed_files_log(dataset_id) + print(f"加载现有处理日志,包含 {len(processed_log)} 个文件记录") + + removed_files = [] + added_files = [] + + # 1. 处理删除操作 + if files_to_remove: + print(f"开始处理删除操作,涉及 {len(files_to_remove)} 个key分组") + for key, file_list in files_to_remove.items(): + if not file_list: # 如果文件列表为空,删除整个key分组 + print(f"删除整个key分组: {key}") + if remove_dataset_directory_by_key(dataset_id, key): + removed_files.append(f"dataset/{key}") + + # 从处理日志中移除该key的所有记录 + keys_to_remove = [file_hash for file_hash, file_info in processed_log.items() + if file_info.get('key') == key] + for file_hash in keys_to_remove: + del processed_log[file_hash] + removed_files.append(f"log_entry:{file_hash}") + else: + # 删除特定文件 + for file_path in file_list: + print(f"删除特定文件: {key}/{file_path}") + # 计算文件hash以在日志中查找 + import hashlib + file_hash = hashlib.md5(file_path.encode('utf-8')).hexdigest() + + # 从处理日志中移除 + if file_hash in processed_log: + del processed_log[file_hash] + removed_files.append(f"log_entry:{file_hash}") + + # 2. 处理添加操作 + processed_files_by_key = {} + if files_to_add: + print(f"开始处理添加操作,涉及 {len(files_to_add)} 个key分组") + # 使用异步处理下载文件 + import asyncio + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + processed_files_by_key = loop.run_until_complete(download_dataset_files(dataset_id, files_to_add)) + total_added_files = sum(len(files_list) for files_list in processed_files_by_key.values()) + print(f"异步处理了 {total_added_files} 个数据集文件,涉及 {len(processed_files_by_key)} 个key,项目ID: {dataset_id}") + + # 记录添加的文件 + for key, files_list in processed_files_by_key.items(): + for file_path in files_list: + added_files.append(f"{key}/{file_path}") + else: + print(f"请求中未提供要添加的文件,项目ID: {dataset_id}") + + # 保存更新后的处理日志 + save_processed_files_log(dataset_id, processed_log) + print(f"已更新处理日志,当前包含 {len(processed_log)} 个文件记录") + + # 保存system_prompt和mcp_settings到项目目录(如果提供) + if system_prompt: + system_prompt_file = os.path.join(project_dir, "system_prompt.md") + with open(system_prompt_file, 'w', encoding='utf-8') as f: + f.write(system_prompt) + print(f"已保存system_prompt,项目ID: {dataset_id}") + + if mcp_settings: + mcp_settings_file = os.path.join(project_dir, "mcp_settings.json") + with open(mcp_settings_file, 'w', encoding='utf-8') as f: + json.dump(mcp_settings, f, ensure_ascii=False, indent=2) + print(f"已保存mcp_settings,项目ID: {dataset_id}") + + # 生成项目README.md文件 + try: + from utils.project_manager import save_project_readme + save_project_readme(dataset_id) + print(f"已生成README.md文件,项目ID: {dataset_id}") + except Exception as e: + print(f"生成README.md失败,项目ID: {dataset_id}, 错误: {str(e)}") + # 不影响主要处理流程,继续执行 + + # 收集项目目录下所有的 document.txt 文件 + document_files = [] + for root, dirs, files_list in os.walk(project_dir): + for file in files_list: + if file == "document.txt": + document_files.append(os.path.join(root, file)) + + # 构建结果文件列表 + result_files = [] + for key in processed_files_by_key.keys(): + # 添加对应的dataset document.txt路径 + document_path = os.path.join("projects", "data", dataset_id, "dataset", key, "document.txt") + if os.path.exists(document_path): + result_files.append(document_path) + + # 对于没有在processed_files_by_key中但存在的document.txt文件,也添加到结果中 + existing_document_paths = set(result_files) # 避免重复 + for doc_file in document_files: + if doc_file not in existing_document_paths: + result_files.append(doc_file) + + result = { + "status": "success", + "message": f"增量处理完成 - 添加了 {len(added_files)} 个文件,删除了 {len(removed_files)} 个文件,最终保留 {len(result_files)} 个文档文件", + "dataset_id": dataset_id, + "removed_files": removed_files, + "added_files": added_files, + "processed_files": result_files, + "processed_files_by_key": processed_files_by_key, + "document_files": document_files, + "total_files_added": sum(len(files_list) for files_list in processed_files_by_key.values()), + "total_files_removed": len(removed_files), + "final_files_count": len(result_files), + "processing_time": time.time() + } + + # 更新任务状态为完成 + if task_id: + task_status_store.update_status( + task_id=task_id, + status="completed", + result=result + ) + + print(f"增量文件处理任务完成: {dataset_id}") + return result + + except Exception as e: + error_msg = f"增量处理文件时发生错误: {str(e)}" + print(error_msg) + + # 更新任务状态为错误 + if task_id: + task_status_store.update_status( + task_id=task_id, + status="failed", + error=error_msg + ) + + return { + "status": "error", + "message": error_msg, + "dataset_id": dataset_id, + "error": str(e) + } + + @huey.task() def cleanup_project_async( unique_id: str, diff --git a/utils/__init__.py b/utils/__init__.py index 51ea8ab..fd14a47 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -68,6 +68,7 @@ from .api_models import ( ProjectStatsResponse, ProjectActionResponse, QueueTaskRequest, + IncrementalTaskRequest, QueueTaskResponse, QueueStatusResponse, TaskStatusResponse, diff --git a/utils/api_models.py b/utils/api_models.py index 12a8555..e15a9c9 100644 --- a/utils/api_models.py +++ b/utils/api_models.py @@ -245,6 +245,57 @@ class QueueTaskRequest(BaseModel): raise ValueError(f"Files must be a dict with key groups, got {type(v)}") +class IncrementalTaskRequest(BaseModel): + """增量文件处理请求模型""" + dataset_id: str = Field(..., description="Dataset ID for the project") + files_to_add: Optional[Dict[str, List[str]]] = Field(default=None, description="Files to add organized by key groups") + files_to_remove: Optional[Dict[str, List[str]]] = Field(default=None, description="Files to remove organized by key groups") + system_prompt: Optional[str] = None + mcp_settings: Optional[List[Dict]] = None + priority: Optional[int] = Field(default=0, description="Task priority (higher number = higher priority)") + delay: Optional[int] = Field(default=0, description="Delay execution by N seconds") + + model_config = ConfigDict(extra='allow') + + @field_validator('files_to_add', mode='before') + @classmethod + def validate_files_to_add(cls, v): + """Validate files_to_add dict format""" + if v is None: + return None + if isinstance(v, dict): + for key, value in v.items(): + if not isinstance(key, str): + raise ValueError(f"Key in files_to_add dict must be string, got {type(key)}") + if not isinstance(value, list): + raise ValueError(f"Value in files_to_add dict must be list, got {type(value)} for key '{key}'") + for item in value: + if not isinstance(item, str): + raise ValueError(f"File paths must be strings, got {type(item)} in key '{key}'") + return v + else: + raise ValueError(f"files_to_add must be a dict with key groups, got {type(v)}") + + @field_validator('files_to_remove', mode='before') + @classmethod + def validate_files_to_remove(cls, v): + """Validate files_to_remove dict format""" + if v is None: + return None + if isinstance(v, dict): + for key, value in v.items(): + if not isinstance(key, str): + raise ValueError(f"Key in files_to_remove dict must be string, got {type(key)}") + if not isinstance(value, list): + raise ValueError(f"Value in files_to_remove dict must be list, got {type(value)} for key '{key}'") + for item in value: + if not isinstance(item, str): + raise ValueError(f"File paths must be strings, got {type(item)} in key '{key}'") + return v + else: + raise ValueError(f"files_to_remove must be a dict with key groups, got {type(v)}") + + class QueueTaskResponse(BaseModel): """队列任务响应模型""" success: bool