add /api/v1/files/process/incremental
This commit is contained in:
parent
40aa71b966
commit
bff5817520
@ -20,7 +20,7 @@ from pydantic import BaseModel, Field
|
|||||||
# Import utility modules
|
# Import utility modules
|
||||||
from utils import (
|
from utils import (
|
||||||
# Models
|
# Models
|
||||||
Message, DatasetRequest, ChatRequest, ChatResponse, QueueTaskRequest, QueueTaskResponse,
|
Message, DatasetRequest, ChatRequest, ChatResponse, QueueTaskRequest, IncrementalTaskRequest, QueueTaskResponse,
|
||||||
QueueStatusResponse, TaskStatusResponse,
|
QueueStatusResponse, TaskStatusResponse,
|
||||||
|
|
||||||
# File utilities
|
# File utilities
|
||||||
@ -47,7 +47,7 @@ from modified_assistant import update_agent_llm
|
|||||||
|
|
||||||
# Import queue manager
|
# Import queue manager
|
||||||
from task_queue.manager import queue_manager
|
from task_queue.manager import queue_manager
|
||||||
from task_queue.integration_tasks import process_files_async, cleanup_project_async
|
from task_queue.integration_tasks import process_files_async, process_files_incremental_async, cleanup_project_async
|
||||||
from task_queue.task_status import task_status_store
|
from task_queue.task_status import task_status_store
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -264,6 +264,69 @@ async def process_files_async_endpoint(request: QueueTaskRequest, authorization:
|
|||||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/v1/files/process/incremental")
|
||||||
|
async def process_files_incremental_endpoint(request: IncrementalTaskRequest, authorization: Optional[str] = Header(None)):
|
||||||
|
"""
|
||||||
|
增量处理文件的队列版本API - 支持添加和删除文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: IncrementalTaskRequest containing dataset_id, files_to_add, files_to_remove, system_prompt, mcp_settings, and queue options
|
||||||
|
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QueueTaskResponse: Processing result with task ID for tracking
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
dataset_id = request.dataset_id
|
||||||
|
if not dataset_id:
|
||||||
|
raise HTTPException(status_code=400, detail="dataset_id is required")
|
||||||
|
|
||||||
|
# 验证至少有添加或删除操作
|
||||||
|
if not request.files_to_add and not request.files_to_remove:
|
||||||
|
raise HTTPException(status_code=400, detail="At least one of files_to_add or files_to_remove must be provided")
|
||||||
|
|
||||||
|
# 估算处理时间(基于文件数量)
|
||||||
|
estimated_time = 0
|
||||||
|
total_add_files = sum(len(file_list) for file_list in (request.files_to_add or {}).values())
|
||||||
|
total_remove_files = sum(len(file_list) for file_list in (request.files_to_remove or {}).values())
|
||||||
|
total_files = total_add_files + total_remove_files
|
||||||
|
estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒
|
||||||
|
|
||||||
|
# 创建任务状态记录
|
||||||
|
import uuid
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
task_status_store.set_status(
|
||||||
|
task_id=task_id,
|
||||||
|
unique_id=dataset_id,
|
||||||
|
status="pending"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提交增量异步任务
|
||||||
|
task = process_files_incremental_async(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
files_to_add=request.files_to_add,
|
||||||
|
files_to_remove=request.files_to_remove,
|
||||||
|
system_prompt=request.system_prompt,
|
||||||
|
mcp_settings=request.mcp_settings,
|
||||||
|
task_id=task_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return QueueTaskResponse(
|
||||||
|
success=True,
|
||||||
|
message=f"增量文件处理任务已提交到队列 - 添加 {total_add_files} 个文件,删除 {total_remove_files} 个文件,项目ID: {dataset_id}",
|
||||||
|
unique_id=dataset_id,
|
||||||
|
task_id=task_id,
|
||||||
|
task_status="pending",
|
||||||
|
estimated_processing_time=estimated_time
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error submitting incremental file processing task: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/v1/task/{task_id}/status")
|
@app.get("/api/v1/task/{task_id}/status")
|
||||||
async def get_task_status(task_id: str):
|
async def get_task_status(task_id: str):
|
||||||
"""获取任务状态 - 简单可靠"""
|
"""获取任务状态 - 简单可靠"""
|
||||||
|
|||||||
@ -11,7 +11,8 @@ from typing import Dict, List, Optional, Any
|
|||||||
from task_queue.config import huey
|
from task_queue.config import huey
|
||||||
from task_queue.manager import queue_manager
|
from task_queue.manager import queue_manager
|
||||||
from task_queue.task_status import task_status_store
|
from task_queue.task_status import task_status_store
|
||||||
from utils import download_dataset_files, save_processed_files_log
|
from utils import download_dataset_files, save_processed_files_log, load_processed_files_log
|
||||||
|
from utils.dataset_manager import remove_dataset_directory_by_key
|
||||||
|
|
||||||
|
|
||||||
@huey.task()
|
@huey.task()
|
||||||
@ -154,6 +155,196 @@ def process_files_async(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@huey.task()
|
||||||
|
def process_files_incremental_async(
|
||||||
|
dataset_id: str,
|
||||||
|
files_to_add: Optional[Dict[str, List[str]]] = None,
|
||||||
|
files_to_remove: Optional[Dict[str, List[str]]] = None,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
mcp_settings: Optional[List[Dict]] = None,
|
||||||
|
task_id: Optional[str] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
增量处理文件任务 - 支持添加和删除文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: 项目唯一ID
|
||||||
|
files_to_add: 按key分组的要添加的文件路径字典
|
||||||
|
files_to_remove: 按key分组的要删除的文件路径字典
|
||||||
|
system_prompt: 系统提示词
|
||||||
|
mcp_settings: MCP设置
|
||||||
|
task_id: 任务ID(用于状态跟踪)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理结果字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
print(f"开始增量处理文件任务,项目ID: {dataset_id}")
|
||||||
|
|
||||||
|
# 如果有task_id,设置初始状态
|
||||||
|
if task_id:
|
||||||
|
task_status_store.set_status(
|
||||||
|
task_id=task_id,
|
||||||
|
unique_id=dataset_id,
|
||||||
|
status="running"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 确保项目目录存在
|
||||||
|
project_dir = os.path.join("projects", "data", dataset_id)
|
||||||
|
if not os.path.exists(project_dir):
|
||||||
|
os.makedirs(project_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 加载现有的处理日志
|
||||||
|
processed_log = load_processed_files_log(dataset_id)
|
||||||
|
print(f"加载现有处理日志,包含 {len(processed_log)} 个文件记录")
|
||||||
|
|
||||||
|
removed_files = []
|
||||||
|
added_files = []
|
||||||
|
|
||||||
|
# 1. 处理删除操作
|
||||||
|
if files_to_remove:
|
||||||
|
print(f"开始处理删除操作,涉及 {len(files_to_remove)} 个key分组")
|
||||||
|
for key, file_list in files_to_remove.items():
|
||||||
|
if not file_list: # 如果文件列表为空,删除整个key分组
|
||||||
|
print(f"删除整个key分组: {key}")
|
||||||
|
if remove_dataset_directory_by_key(dataset_id, key):
|
||||||
|
removed_files.append(f"dataset/{key}")
|
||||||
|
|
||||||
|
# 从处理日志中移除该key的所有记录
|
||||||
|
keys_to_remove = [file_hash for file_hash, file_info in processed_log.items()
|
||||||
|
if file_info.get('key') == key]
|
||||||
|
for file_hash in keys_to_remove:
|
||||||
|
del processed_log[file_hash]
|
||||||
|
removed_files.append(f"log_entry:{file_hash}")
|
||||||
|
else:
|
||||||
|
# 删除特定文件
|
||||||
|
for file_path in file_list:
|
||||||
|
print(f"删除特定文件: {key}/{file_path}")
|
||||||
|
# 计算文件hash以在日志中查找
|
||||||
|
import hashlib
|
||||||
|
file_hash = hashlib.md5(file_path.encode('utf-8')).hexdigest()
|
||||||
|
|
||||||
|
# 从处理日志中移除
|
||||||
|
if file_hash in processed_log:
|
||||||
|
del processed_log[file_hash]
|
||||||
|
removed_files.append(f"log_entry:{file_hash}")
|
||||||
|
|
||||||
|
# 2. 处理添加操作
|
||||||
|
processed_files_by_key = {}
|
||||||
|
if files_to_add:
|
||||||
|
print(f"开始处理添加操作,涉及 {len(files_to_add)} 个key分组")
|
||||||
|
# 使用异步处理下载文件
|
||||||
|
import asyncio
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
processed_files_by_key = loop.run_until_complete(download_dataset_files(dataset_id, files_to_add))
|
||||||
|
total_added_files = sum(len(files_list) for files_list in processed_files_by_key.values())
|
||||||
|
print(f"异步处理了 {total_added_files} 个数据集文件,涉及 {len(processed_files_by_key)} 个key,项目ID: {dataset_id}")
|
||||||
|
|
||||||
|
# 记录添加的文件
|
||||||
|
for key, files_list in processed_files_by_key.items():
|
||||||
|
for file_path in files_list:
|
||||||
|
added_files.append(f"{key}/{file_path}")
|
||||||
|
else:
|
||||||
|
print(f"请求中未提供要添加的文件,项目ID: {dataset_id}")
|
||||||
|
|
||||||
|
# 保存更新后的处理日志
|
||||||
|
save_processed_files_log(dataset_id, processed_log)
|
||||||
|
print(f"已更新处理日志,当前包含 {len(processed_log)} 个文件记录")
|
||||||
|
|
||||||
|
# 保存system_prompt和mcp_settings到项目目录(如果提供)
|
||||||
|
if system_prompt:
|
||||||
|
system_prompt_file = os.path.join(project_dir, "system_prompt.md")
|
||||||
|
with open(system_prompt_file, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(system_prompt)
|
||||||
|
print(f"已保存system_prompt,项目ID: {dataset_id}")
|
||||||
|
|
||||||
|
if mcp_settings:
|
||||||
|
mcp_settings_file = os.path.join(project_dir, "mcp_settings.json")
|
||||||
|
with open(mcp_settings_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(mcp_settings, f, ensure_ascii=False, indent=2)
|
||||||
|
print(f"已保存mcp_settings,项目ID: {dataset_id}")
|
||||||
|
|
||||||
|
# 生成项目README.md文件
|
||||||
|
try:
|
||||||
|
from utils.project_manager import save_project_readme
|
||||||
|
save_project_readme(dataset_id)
|
||||||
|
print(f"已生成README.md文件,项目ID: {dataset_id}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"生成README.md失败,项目ID: {dataset_id}, 错误: {str(e)}")
|
||||||
|
# 不影响主要处理流程,继续执行
|
||||||
|
|
||||||
|
# 收集项目目录下所有的 document.txt 文件
|
||||||
|
document_files = []
|
||||||
|
for root, dirs, files_list in os.walk(project_dir):
|
||||||
|
for file in files_list:
|
||||||
|
if file == "document.txt":
|
||||||
|
document_files.append(os.path.join(root, file))
|
||||||
|
|
||||||
|
# 构建结果文件列表
|
||||||
|
result_files = []
|
||||||
|
for key in processed_files_by_key.keys():
|
||||||
|
# 添加对应的dataset document.txt路径
|
||||||
|
document_path = os.path.join("projects", "data", dataset_id, "dataset", key, "document.txt")
|
||||||
|
if os.path.exists(document_path):
|
||||||
|
result_files.append(document_path)
|
||||||
|
|
||||||
|
# 对于没有在processed_files_by_key中但存在的document.txt文件,也添加到结果中
|
||||||
|
existing_document_paths = set(result_files) # 避免重复
|
||||||
|
for doc_file in document_files:
|
||||||
|
if doc_file not in existing_document_paths:
|
||||||
|
result_files.append(doc_file)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"status": "success",
|
||||||
|
"message": f"增量处理完成 - 添加了 {len(added_files)} 个文件,删除了 {len(removed_files)} 个文件,最终保留 {len(result_files)} 个文档文件",
|
||||||
|
"dataset_id": dataset_id,
|
||||||
|
"removed_files": removed_files,
|
||||||
|
"added_files": added_files,
|
||||||
|
"processed_files": result_files,
|
||||||
|
"processed_files_by_key": processed_files_by_key,
|
||||||
|
"document_files": document_files,
|
||||||
|
"total_files_added": sum(len(files_list) for files_list in processed_files_by_key.values()),
|
||||||
|
"total_files_removed": len(removed_files),
|
||||||
|
"final_files_count": len(result_files),
|
||||||
|
"processing_time": time.time()
|
||||||
|
}
|
||||||
|
|
||||||
|
# 更新任务状态为完成
|
||||||
|
if task_id:
|
||||||
|
task_status_store.update_status(
|
||||||
|
task_id=task_id,
|
||||||
|
status="completed",
|
||||||
|
result=result
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"增量文件处理任务完成: {dataset_id}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"增量处理文件时发生错误: {str(e)}"
|
||||||
|
print(error_msg)
|
||||||
|
|
||||||
|
# 更新任务状态为错误
|
||||||
|
if task_id:
|
||||||
|
task_status_store.update_status(
|
||||||
|
task_id=task_id,
|
||||||
|
status="failed",
|
||||||
|
error=error_msg
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": error_msg,
|
||||||
|
"dataset_id": dataset_id,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@huey.task()
|
@huey.task()
|
||||||
def cleanup_project_async(
|
def cleanup_project_async(
|
||||||
unique_id: str,
|
unique_id: str,
|
||||||
|
|||||||
@ -68,6 +68,7 @@ from .api_models import (
|
|||||||
ProjectStatsResponse,
|
ProjectStatsResponse,
|
||||||
ProjectActionResponse,
|
ProjectActionResponse,
|
||||||
QueueTaskRequest,
|
QueueTaskRequest,
|
||||||
|
IncrementalTaskRequest,
|
||||||
QueueTaskResponse,
|
QueueTaskResponse,
|
||||||
QueueStatusResponse,
|
QueueStatusResponse,
|
||||||
TaskStatusResponse,
|
TaskStatusResponse,
|
||||||
|
|||||||
@ -245,6 +245,57 @@ class QueueTaskRequest(BaseModel):
|
|||||||
raise ValueError(f"Files must be a dict with key groups, got {type(v)}")
|
raise ValueError(f"Files must be a dict with key groups, got {type(v)}")
|
||||||
|
|
||||||
|
|
||||||
|
class IncrementalTaskRequest(BaseModel):
|
||||||
|
"""增量文件处理请求模型"""
|
||||||
|
dataset_id: str = Field(..., description="Dataset ID for the project")
|
||||||
|
files_to_add: Optional[Dict[str, List[str]]] = Field(default=None, description="Files to add organized by key groups")
|
||||||
|
files_to_remove: Optional[Dict[str, List[str]]] = Field(default=None, description="Files to remove organized by key groups")
|
||||||
|
system_prompt: Optional[str] = None
|
||||||
|
mcp_settings: Optional[List[Dict]] = None
|
||||||
|
priority: Optional[int] = Field(default=0, description="Task priority (higher number = higher priority)")
|
||||||
|
delay: Optional[int] = Field(default=0, description="Delay execution by N seconds")
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra='allow')
|
||||||
|
|
||||||
|
@field_validator('files_to_add', mode='before')
|
||||||
|
@classmethod
|
||||||
|
def validate_files_to_add(cls, v):
|
||||||
|
"""Validate files_to_add dict format"""
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
if isinstance(v, dict):
|
||||||
|
for key, value in v.items():
|
||||||
|
if not isinstance(key, str):
|
||||||
|
raise ValueError(f"Key in files_to_add dict must be string, got {type(key)}")
|
||||||
|
if not isinstance(value, list):
|
||||||
|
raise ValueError(f"Value in files_to_add dict must be list, got {type(value)} for key '{key}'")
|
||||||
|
for item in value:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
raise ValueError(f"File paths must be strings, got {type(item)} in key '{key}'")
|
||||||
|
return v
|
||||||
|
else:
|
||||||
|
raise ValueError(f"files_to_add must be a dict with key groups, got {type(v)}")
|
||||||
|
|
||||||
|
@field_validator('files_to_remove', mode='before')
|
||||||
|
@classmethod
|
||||||
|
def validate_files_to_remove(cls, v):
|
||||||
|
"""Validate files_to_remove dict format"""
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
if isinstance(v, dict):
|
||||||
|
for key, value in v.items():
|
||||||
|
if not isinstance(key, str):
|
||||||
|
raise ValueError(f"Key in files_to_remove dict must be string, got {type(key)}")
|
||||||
|
if not isinstance(value, list):
|
||||||
|
raise ValueError(f"Value in files_to_remove dict must be list, got {type(value)} for key '{key}'")
|
||||||
|
for item in value:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
raise ValueError(f"File paths must be strings, got {type(item)} in key '{key}'")
|
||||||
|
return v
|
||||||
|
else:
|
||||||
|
raise ValueError(f"files_to_remove must be a dict with key groups, got {type(v)}")
|
||||||
|
|
||||||
|
|
||||||
class QueueTaskResponse(BaseModel):
|
class QueueTaskResponse(BaseModel):
|
||||||
"""队列任务响应模型"""
|
"""队列任务响应模型"""
|
||||||
success: bool
|
success: bool
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user