add /api/v1/files/process/incremental
This commit is contained in:
parent
40aa71b966
commit
bff5817520
@ -20,7 +20,7 @@ from pydantic import BaseModel, Field
|
||||
# Import utility modules
|
||||
from utils import (
|
||||
# Models
|
||||
Message, DatasetRequest, ChatRequest, ChatResponse, QueueTaskRequest, QueueTaskResponse,
|
||||
Message, DatasetRequest, ChatRequest, ChatResponse, QueueTaskRequest, IncrementalTaskRequest, QueueTaskResponse,
|
||||
QueueStatusResponse, TaskStatusResponse,
|
||||
|
||||
# File utilities
|
||||
@ -47,7 +47,7 @@ from modified_assistant import update_agent_llm
|
||||
|
||||
# Import queue manager
|
||||
from task_queue.manager import queue_manager
|
||||
from task_queue.integration_tasks import process_files_async, cleanup_project_async
|
||||
from task_queue.integration_tasks import process_files_async, process_files_incremental_async, cleanup_project_async
|
||||
from task_queue.task_status import task_status_store
|
||||
import re
|
||||
|
||||
@ -264,6 +264,69 @@ async def process_files_async_endpoint(request: QueueTaskRequest, authorization:
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/api/v1/files/process/incremental")
|
||||
async def process_files_incremental_endpoint(request: IncrementalTaskRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
增量处理文件的队列版本API - 支持添加和删除文件
|
||||
|
||||
Args:
|
||||
request: IncrementalTaskRequest containing dataset_id, files_to_add, files_to_remove, system_prompt, mcp_settings, and queue options
|
||||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||||
|
||||
Returns:
|
||||
QueueTaskResponse: Processing result with task ID for tracking
|
||||
"""
|
||||
try:
|
||||
dataset_id = request.dataset_id
|
||||
if not dataset_id:
|
||||
raise HTTPException(status_code=400, detail="dataset_id is required")
|
||||
|
||||
# 验证至少有添加或删除操作
|
||||
if not request.files_to_add and not request.files_to_remove:
|
||||
raise HTTPException(status_code=400, detail="At least one of files_to_add or files_to_remove must be provided")
|
||||
|
||||
# 估算处理时间(基于文件数量)
|
||||
estimated_time = 0
|
||||
total_add_files = sum(len(file_list) for file_list in (request.files_to_add or {}).values())
|
||||
total_remove_files = sum(len(file_list) for file_list in (request.files_to_remove or {}).values())
|
||||
total_files = total_add_files + total_remove_files
|
||||
estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒
|
||||
|
||||
# 创建任务状态记录
|
||||
import uuid
|
||||
task_id = str(uuid.uuid4())
|
||||
task_status_store.set_status(
|
||||
task_id=task_id,
|
||||
unique_id=dataset_id,
|
||||
status="pending"
|
||||
)
|
||||
|
||||
# 提交增量异步任务
|
||||
task = process_files_incremental_async(
|
||||
dataset_id=dataset_id,
|
||||
files_to_add=request.files_to_add,
|
||||
files_to_remove=request.files_to_remove,
|
||||
system_prompt=request.system_prompt,
|
||||
mcp_settings=request.mcp_settings,
|
||||
task_id=task_id
|
||||
)
|
||||
|
||||
return QueueTaskResponse(
|
||||
success=True,
|
||||
message=f"增量文件处理任务已提交到队列 - 添加 {total_add_files} 个文件,删除 {total_remove_files} 个文件,项目ID: {dataset_id}",
|
||||
unique_id=dataset_id,
|
||||
task_id=task_id,
|
||||
task_status="pending",
|
||||
estimated_processing_time=estimated_time
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"Error submitting incremental file processing task: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/api/v1/task/{task_id}/status")
|
||||
async def get_task_status(task_id: str):
|
||||
"""获取任务状态 - 简单可靠"""
|
||||
|
||||
@ -11,7 +11,8 @@ from typing import Dict, List, Optional, Any
|
||||
from task_queue.config import huey
|
||||
from task_queue.manager import queue_manager
|
||||
from task_queue.task_status import task_status_store
|
||||
from utils import download_dataset_files, save_processed_files_log
|
||||
from utils import download_dataset_files, save_processed_files_log, load_processed_files_log
|
||||
from utils.dataset_manager import remove_dataset_directory_by_key
|
||||
|
||||
|
||||
@huey.task()
|
||||
@ -154,6 +155,196 @@ def process_files_async(
|
||||
}
|
||||
|
||||
|
||||
@huey.task()
|
||||
def process_files_incremental_async(
|
||||
dataset_id: str,
|
||||
files_to_add: Optional[Dict[str, List[str]]] = None,
|
||||
files_to_remove: Optional[Dict[str, List[str]]] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
mcp_settings: Optional[List[Dict]] = None,
|
||||
task_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
增量处理文件任务 - 支持添加和删除文件
|
||||
|
||||
Args:
|
||||
dataset_id: 项目唯一ID
|
||||
files_to_add: 按key分组的要添加的文件路径字典
|
||||
files_to_remove: 按key分组的要删除的文件路径字典
|
||||
system_prompt: 系统提示词
|
||||
mcp_settings: MCP设置
|
||||
task_id: 任务ID(用于状态跟踪)
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
try:
|
||||
print(f"开始增量处理文件任务,项目ID: {dataset_id}")
|
||||
|
||||
# 如果有task_id,设置初始状态
|
||||
if task_id:
|
||||
task_status_store.set_status(
|
||||
task_id=task_id,
|
||||
unique_id=dataset_id,
|
||||
status="running"
|
||||
)
|
||||
|
||||
# 确保项目目录存在
|
||||
project_dir = os.path.join("projects", "data", dataset_id)
|
||||
if not os.path.exists(project_dir):
|
||||
os.makedirs(project_dir, exist_ok=True)
|
||||
|
||||
# 加载现有的处理日志
|
||||
processed_log = load_processed_files_log(dataset_id)
|
||||
print(f"加载现有处理日志,包含 {len(processed_log)} 个文件记录")
|
||||
|
||||
removed_files = []
|
||||
added_files = []
|
||||
|
||||
# 1. 处理删除操作
|
||||
if files_to_remove:
|
||||
print(f"开始处理删除操作,涉及 {len(files_to_remove)} 个key分组")
|
||||
for key, file_list in files_to_remove.items():
|
||||
if not file_list: # 如果文件列表为空,删除整个key分组
|
||||
print(f"删除整个key分组: {key}")
|
||||
if remove_dataset_directory_by_key(dataset_id, key):
|
||||
removed_files.append(f"dataset/{key}")
|
||||
|
||||
# 从处理日志中移除该key的所有记录
|
||||
keys_to_remove = [file_hash for file_hash, file_info in processed_log.items()
|
||||
if file_info.get('key') == key]
|
||||
for file_hash in keys_to_remove:
|
||||
del processed_log[file_hash]
|
||||
removed_files.append(f"log_entry:{file_hash}")
|
||||
else:
|
||||
# 删除特定文件
|
||||
for file_path in file_list:
|
||||
print(f"删除特定文件: {key}/{file_path}")
|
||||
# 计算文件hash以在日志中查找
|
||||
import hashlib
|
||||
file_hash = hashlib.md5(file_path.encode('utf-8')).hexdigest()
|
||||
|
||||
# 从处理日志中移除
|
||||
if file_hash in processed_log:
|
||||
del processed_log[file_hash]
|
||||
removed_files.append(f"log_entry:{file_hash}")
|
||||
|
||||
# 2. 处理添加操作
|
||||
processed_files_by_key = {}
|
||||
if files_to_add:
|
||||
print(f"开始处理添加操作,涉及 {len(files_to_add)} 个key分组")
|
||||
# 使用异步处理下载文件
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
processed_files_by_key = loop.run_until_complete(download_dataset_files(dataset_id, files_to_add))
|
||||
total_added_files = sum(len(files_list) for files_list in processed_files_by_key.values())
|
||||
print(f"异步处理了 {total_added_files} 个数据集文件,涉及 {len(processed_files_by_key)} 个key,项目ID: {dataset_id}")
|
||||
|
||||
# 记录添加的文件
|
||||
for key, files_list in processed_files_by_key.items():
|
||||
for file_path in files_list:
|
||||
added_files.append(f"{key}/{file_path}")
|
||||
else:
|
||||
print(f"请求中未提供要添加的文件,项目ID: {dataset_id}")
|
||||
|
||||
# 保存更新后的处理日志
|
||||
save_processed_files_log(dataset_id, processed_log)
|
||||
print(f"已更新处理日志,当前包含 {len(processed_log)} 个文件记录")
|
||||
|
||||
# 保存system_prompt和mcp_settings到项目目录(如果提供)
|
||||
if system_prompt:
|
||||
system_prompt_file = os.path.join(project_dir, "system_prompt.md")
|
||||
with open(system_prompt_file, 'w', encoding='utf-8') as f:
|
||||
f.write(system_prompt)
|
||||
print(f"已保存system_prompt,项目ID: {dataset_id}")
|
||||
|
||||
if mcp_settings:
|
||||
mcp_settings_file = os.path.join(project_dir, "mcp_settings.json")
|
||||
with open(mcp_settings_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(mcp_settings, f, ensure_ascii=False, indent=2)
|
||||
print(f"已保存mcp_settings,项目ID: {dataset_id}")
|
||||
|
||||
# 生成项目README.md文件
|
||||
try:
|
||||
from utils.project_manager import save_project_readme
|
||||
save_project_readme(dataset_id)
|
||||
print(f"已生成README.md文件,项目ID: {dataset_id}")
|
||||
except Exception as e:
|
||||
print(f"生成README.md失败,项目ID: {dataset_id}, 错误: {str(e)}")
|
||||
# 不影响主要处理流程,继续执行
|
||||
|
||||
# 收集项目目录下所有的 document.txt 文件
|
||||
document_files = []
|
||||
for root, dirs, files_list in os.walk(project_dir):
|
||||
for file in files_list:
|
||||
if file == "document.txt":
|
||||
document_files.append(os.path.join(root, file))
|
||||
|
||||
# 构建结果文件列表
|
||||
result_files = []
|
||||
for key in processed_files_by_key.keys():
|
||||
# 添加对应的dataset document.txt路径
|
||||
document_path = os.path.join("projects", "data", dataset_id, "dataset", key, "document.txt")
|
||||
if os.path.exists(document_path):
|
||||
result_files.append(document_path)
|
||||
|
||||
# 对于没有在processed_files_by_key中但存在的document.txt文件,也添加到结果中
|
||||
existing_document_paths = set(result_files) # 避免重复
|
||||
for doc_file in document_files:
|
||||
if doc_file not in existing_document_paths:
|
||||
result_files.append(doc_file)
|
||||
|
||||
result = {
|
||||
"status": "success",
|
||||
"message": f"增量处理完成 - 添加了 {len(added_files)} 个文件,删除了 {len(removed_files)} 个文件,最终保留 {len(result_files)} 个文档文件",
|
||||
"dataset_id": dataset_id,
|
||||
"removed_files": removed_files,
|
||||
"added_files": added_files,
|
||||
"processed_files": result_files,
|
||||
"processed_files_by_key": processed_files_by_key,
|
||||
"document_files": document_files,
|
||||
"total_files_added": sum(len(files_list) for files_list in processed_files_by_key.values()),
|
||||
"total_files_removed": len(removed_files),
|
||||
"final_files_count": len(result_files),
|
||||
"processing_time": time.time()
|
||||
}
|
||||
|
||||
# 更新任务状态为完成
|
||||
if task_id:
|
||||
task_status_store.update_status(
|
||||
task_id=task_id,
|
||||
status="completed",
|
||||
result=result
|
||||
)
|
||||
|
||||
print(f"增量文件处理任务完成: {dataset_id}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"增量处理文件时发生错误: {str(e)}"
|
||||
print(error_msg)
|
||||
|
||||
# 更新任务状态为错误
|
||||
if task_id:
|
||||
task_status_store.update_status(
|
||||
task_id=task_id,
|
||||
status="failed",
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
"dataset_id": dataset_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
@huey.task()
|
||||
def cleanup_project_async(
|
||||
unique_id: str,
|
||||
|
||||
@ -68,6 +68,7 @@ from .api_models import (
|
||||
ProjectStatsResponse,
|
||||
ProjectActionResponse,
|
||||
QueueTaskRequest,
|
||||
IncrementalTaskRequest,
|
||||
QueueTaskResponse,
|
||||
QueueStatusResponse,
|
||||
TaskStatusResponse,
|
||||
|
||||
@ -245,6 +245,57 @@ class QueueTaskRequest(BaseModel):
|
||||
raise ValueError(f"Files must be a dict with key groups, got {type(v)}")
|
||||
|
||||
|
||||
class IncrementalTaskRequest(BaseModel):
|
||||
"""增量文件处理请求模型"""
|
||||
dataset_id: str = Field(..., description="Dataset ID for the project")
|
||||
files_to_add: Optional[Dict[str, List[str]]] = Field(default=None, description="Files to add organized by key groups")
|
||||
files_to_remove: Optional[Dict[str, List[str]]] = Field(default=None, description="Files to remove organized by key groups")
|
||||
system_prompt: Optional[str] = None
|
||||
mcp_settings: Optional[List[Dict]] = None
|
||||
priority: Optional[int] = Field(default=0, description="Task priority (higher number = higher priority)")
|
||||
delay: Optional[int] = Field(default=0, description="Delay execution by N seconds")
|
||||
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
@field_validator('files_to_add', mode='before')
|
||||
@classmethod
|
||||
def validate_files_to_add(cls, v):
|
||||
"""Validate files_to_add dict format"""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, dict):
|
||||
for key, value in v.items():
|
||||
if not isinstance(key, str):
|
||||
raise ValueError(f"Key in files_to_add dict must be string, got {type(key)}")
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f"Value in files_to_add dict must be list, got {type(value)} for key '{key}'")
|
||||
for item in value:
|
||||
if not isinstance(item, str):
|
||||
raise ValueError(f"File paths must be strings, got {type(item)} in key '{key}'")
|
||||
return v
|
||||
else:
|
||||
raise ValueError(f"files_to_add must be a dict with key groups, got {type(v)}")
|
||||
|
||||
@field_validator('files_to_remove', mode='before')
|
||||
@classmethod
|
||||
def validate_files_to_remove(cls, v):
|
||||
"""Validate files_to_remove dict format"""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, dict):
|
||||
for key, value in v.items():
|
||||
if not isinstance(key, str):
|
||||
raise ValueError(f"Key in files_to_remove dict must be string, got {type(key)}")
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f"Value in files_to_remove dict must be list, got {type(value)} for key '{key}'")
|
||||
for item in value:
|
||||
if not isinstance(item, str):
|
||||
raise ValueError(f"File paths must be strings, got {type(item)} in key '{key}'")
|
||||
return v
|
||||
else:
|
||||
raise ValueError(f"files_to_remove must be a dict with key groups, got {type(v)}")
|
||||
|
||||
|
||||
class QueueTaskResponse(BaseModel):
|
||||
"""队列任务响应模型"""
|
||||
success: bool
|
||||
|
||||
Loading…
Reference in New Issue
Block a user