467 lines
18 KiB
Python
467 lines
18 KiB
Python
import os
|
||
import uuid
|
||
import shutil
|
||
from datetime import datetime
|
||
from typing import Optional, List
|
||
from fastapi import APIRouter, HTTPException, Header, UploadFile, File, Form
|
||
from pydantic import BaseModel
|
||
|
||
from utils import (
|
||
DatasetRequest, QueueTaskRequest, IncrementalTaskRequest, QueueTaskResponse,
|
||
load_processed_files_log, remove_file_or_directory, remove_dataset_directory_by_key
|
||
)
|
||
from utils.fastapi_utils import get_versioned_filename
|
||
from task_queue.manager import queue_manager
|
||
from task_queue.integration_tasks import process_files_async, process_files_incremental_async, cleanup_project_async
|
||
from task_queue.task_status import task_status_store
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
@router.post("/api/v1/files/process/async")
|
||
async def process_files_async_endpoint(request: QueueTaskRequest, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
异步处理文件的队列版本API
|
||
与 /api/v1/files/process 功能相同,但使用队列异步处理
|
||
|
||
Args:
|
||
request: QueueTaskRequest containing dataset_id, files, system_prompt, mcp_settings, and queue options
|
||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||
|
||
Returns:
|
||
QueueTaskResponse: Processing result with task ID for tracking
|
||
"""
|
||
try:
|
||
dataset_id = request.dataset_id
|
||
if not dataset_id:
|
||
raise HTTPException(status_code=400, detail="dataset_id is required")
|
||
|
||
# 估算处理时间(基于文件数量)
|
||
estimated_time = 0
|
||
if request.upload_folder:
|
||
# 对于upload_folder,无法预先估算文件数量,使用默认时间
|
||
estimated_time = 120 # 默认2分钟
|
||
elif request.files:
|
||
total_files = sum(len(file_list) for file_list in request.files.values())
|
||
estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒
|
||
|
||
# 创建任务状态记录
|
||
import uuid
|
||
task_id = str(uuid.uuid4())
|
||
task_status_store.set_status(
|
||
task_id=task_id,
|
||
unique_id=dataset_id,
|
||
status="pending"
|
||
)
|
||
|
||
# 提交异步任务
|
||
task = process_files_async(
|
||
dataset_id=dataset_id,
|
||
files=request.files,
|
||
upload_folder=request.upload_folder,
|
||
task_id=task_id
|
||
)
|
||
|
||
# 构建更详细的消息
|
||
message = f"文件处理任务已提交到队列,项目ID: {dataset_id}"
|
||
if request.upload_folder:
|
||
group_count = len(request.upload_folder)
|
||
message += f",将从 {group_count} 个上传文件夹自动扫描文件"
|
||
elif request.files:
|
||
total_files = sum(len(file_list) for file_list in request.files.values())
|
||
message += f",包含 {total_files} 个文件"
|
||
|
||
return QueueTaskResponse(
|
||
success=True,
|
||
message=message,
|
||
dataset_id=dataset_id,
|
||
task_id=task_id, # 使用我们自己的task_id
|
||
task_status="pending",
|
||
estimated_processing_time=estimated_time
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
print(f"Error submitting async file processing task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@router.post("/api/v1/files/process/incremental")
|
||
async def process_files_incremental_endpoint(request: IncrementalTaskRequest, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
增量处理文件的队列版本API - 支持添加和删除文件
|
||
|
||
Args:
|
||
request: IncrementalTaskRequest containing dataset_id, files_to_add, files_to_remove, system_prompt, mcp_settings, and queue options
|
||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||
|
||
Returns:
|
||
QueueTaskResponse: Processing result with task ID for tracking
|
||
"""
|
||
try:
|
||
dataset_id = request.dataset_id
|
||
if not dataset_id:
|
||
raise HTTPException(status_code=400, detail="dataset_id is required")
|
||
|
||
# 验证至少有添加或删除操作
|
||
if not request.files_to_add and not request.files_to_remove:
|
||
raise HTTPException(status_code=400, detail="At least one of files_to_add or files_to_remove must be provided")
|
||
|
||
# 估算处理时间(基于文件数量)
|
||
estimated_time = 0
|
||
total_add_files = sum(len(file_list) for file_list in (request.files_to_add or {}).values())
|
||
total_remove_files = sum(len(file_list) for file_list in (request.files_to_remove or {}).values())
|
||
total_files = total_add_files + total_remove_files
|
||
estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒
|
||
|
||
# 创建任务状态记录
|
||
import uuid
|
||
task_id = str(uuid.uuid4())
|
||
task_status_store.set_status(
|
||
task_id=task_id,
|
||
unique_id=dataset_id,
|
||
status="pending"
|
||
)
|
||
|
||
# 提交增量异步任务
|
||
task = process_files_incremental_async(
|
||
dataset_id=dataset_id,
|
||
files_to_add=request.files_to_add,
|
||
files_to_remove=request.files_to_remove,
|
||
system_prompt=request.system_prompt,
|
||
mcp_settings=request.mcp_settings,
|
||
task_id=task_id
|
||
)
|
||
|
||
return QueueTaskResponse(
|
||
success=True,
|
||
message=f"增量文件处理任务已提交到队列 - 添加 {total_add_files} 个文件,删除 {total_remove_files} 个文件,项目ID: {dataset_id}",
|
||
dataset_id=dataset_id,
|
||
task_id=task_id,
|
||
task_status="pending",
|
||
estimated_processing_time=estimated_time
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
print(f"Error submitting incremental file processing task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@router.get("/api/v1/files/{dataset_id}/status")
|
||
async def get_files_processing_status(dataset_id: str):
|
||
"""获取项目的文件处理状态"""
|
||
try:
|
||
# Load processed files log
|
||
processed_log = load_processed_files_log(dataset_id)
|
||
|
||
# Get project directory info
|
||
project_dir = os.path.join("projects", "data", dataset_id)
|
||
project_exists = os.path.exists(project_dir)
|
||
|
||
# Collect document.txt files
|
||
document_files = []
|
||
if project_exists:
|
||
for root, dirs, files in os.walk(project_dir):
|
||
for file in files:
|
||
if file == "document.txt":
|
||
document_files.append(os.path.join(root, file))
|
||
|
||
return {
|
||
"dataset_id": dataset_id,
|
||
"project_exists": project_exists,
|
||
"processed_files_count": len(processed_log),
|
||
"processed_files": processed_log,
|
||
"document_files_count": len(document_files),
|
||
"document_files": document_files,
|
||
"log_file_exists": os.path.exists(os.path.join("projects", "data", dataset_id, "processed_files.json"))
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取文件处理状态失败: {str(e)}")
|
||
|
||
|
||
@router.post("/api/v1/files/{dataset_id}/reset")
|
||
async def reset_files_processing(dataset_id: str):
|
||
"""重置项目的文件处理状态,删除处理日志和所有文件"""
|
||
try:
|
||
project_dir = os.path.join("projects", "data", dataset_id)
|
||
log_file = os.path.join("projects", "data", dataset_id, "processed_files.json")
|
||
|
||
# Load processed log to know what files to remove
|
||
processed_log = load_processed_files_log(dataset_id)
|
||
|
||
removed_files = []
|
||
# Remove all processed files and their dataset directories
|
||
for file_hash, file_info in processed_log.items():
|
||
# Remove local file in files directory
|
||
if 'local_path' in file_info:
|
||
if remove_file_or_directory(file_info['local_path']):
|
||
removed_files.append(file_info['local_path'])
|
||
|
||
# Handle new key-based structure first
|
||
if 'key' in file_info:
|
||
# Remove dataset directory by key
|
||
key = file_info['key']
|
||
if remove_dataset_directory_by_key(dataset_id, key):
|
||
removed_files.append(f"dataset/{key}")
|
||
elif 'filename' in file_info:
|
||
# Fallback to old filename-based structure
|
||
filename_without_ext = os.path.splitext(file_info['filename'])[0]
|
||
dataset_dir = os.path.join("projects", "data", dataset_id, "dataset", filename_without_ext)
|
||
if remove_file_or_directory(dataset_dir):
|
||
removed_files.append(dataset_dir)
|
||
|
||
# Also remove any specific dataset path if exists (fallback)
|
||
if 'dataset_path' in file_info:
|
||
if remove_file_or_directory(file_info['dataset_path']):
|
||
removed_files.append(file_info['dataset_path'])
|
||
|
||
# Remove the log file
|
||
if remove_file_or_directory(log_file):
|
||
removed_files.append(log_file)
|
||
|
||
# Remove the entire files directory
|
||
files_dir = os.path.join(project_dir, "files")
|
||
if remove_file_or_directory(files_dir):
|
||
removed_files.append(files_dir)
|
||
|
||
# Also remove the entire dataset directory (clean up any remaining files)
|
||
dataset_dir = os.path.join(project_dir, "dataset")
|
||
if remove_file_or_directory(dataset_dir):
|
||
removed_files.append(dataset_dir)
|
||
|
||
# Remove README.md if exists
|
||
readme_file = os.path.join(project_dir, "README.md")
|
||
if remove_file_or_directory(readme_file):
|
||
removed_files.append(readme_file)
|
||
|
||
return {
|
||
"message": f"文件处理状态重置成功: {dataset_id}",
|
||
"removed_files_count": len(removed_files),
|
||
"removed_files": removed_files
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}")
|
||
|
||
|
||
@router.post("/api/v1/files/{dataset_id}/cleanup/async")
|
||
async def cleanup_project_async_endpoint(dataset_id: str, remove_all: bool = False):
|
||
"""异步清理项目文件"""
|
||
try:
|
||
task = cleanup_project_async(dataset_id=dataset_id, remove_all=remove_all)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"项目清理任务已提交到队列,项目ID: {dataset_id}",
|
||
"dataset_id": dataset_id,
|
||
"task_id": task.id,
|
||
"action": "remove_all" if remove_all else "cleanup_logs"
|
||
}
|
||
except Exception as e:
|
||
print(f"Error submitting cleanup task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"提交清理任务失败: {str(e)}")
|
||
|
||
|
||
@router.post("/api/v1/upload")
|
||
async def upload_file(file: UploadFile = File(...), folder: Optional[str] = Form(None)):
|
||
"""
|
||
文件上传API接口,上传文件到 ./projects/uploads/ 目录下
|
||
|
||
可以指定自定义文件夹名,如果不指定则使用日期文件夹
|
||
指定文件夹时使用原始文件名并支持版本控制
|
||
|
||
Args:
|
||
file: 上传的文件
|
||
folder: 可选的自定义文件夹名
|
||
|
||
Returns:
|
||
dict: 包含文件路径和文件夹信息的响应
|
||
"""
|
||
try:
|
||
# 调试信息
|
||
print(f"Received folder parameter: {folder}")
|
||
print(f"File received: {file.filename if file else 'None'}")
|
||
|
||
# 确定上传文件夹
|
||
if folder:
|
||
# 使用指定的自定义文件夹
|
||
target_folder = folder
|
||
# 安全性检查:防止路径遍历攻击
|
||
target_folder = os.path.basename(target_folder)
|
||
else:
|
||
# 获取当前日期并格式化为年月日
|
||
current_date = datetime.now()
|
||
target_folder = current_date.strftime("%Y%m%d")
|
||
|
||
# 创建上传目录
|
||
upload_dir = os.path.join("projects", "uploads", target_folder)
|
||
os.makedirs(upload_dir, exist_ok=True)
|
||
|
||
# 处理文件名
|
||
if not file.filename:
|
||
raise HTTPException(status_code=400, detail="文件名不能为空")
|
||
|
||
# 解析文件名和扩展名
|
||
original_filename = file.filename
|
||
name_without_ext, file_extension = os.path.splitext(original_filename)
|
||
|
||
# 根据是否指定文件夹决定命名策略
|
||
if folder:
|
||
# 使用原始文件名,支持版本控制
|
||
final_filename, version = get_versioned_filename(upload_dir, name_without_ext, file_extension)
|
||
file_path = os.path.join(upload_dir, final_filename)
|
||
|
||
# 保存文件
|
||
with open(file_path, "wb") as buffer:
|
||
shutil.copyfileobj(file.file, buffer)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"文件上传成功{' (版本: ' + str(version) + ')' if version > 1 else ''}",
|
||
"file_path": file_path,
|
||
"folder": target_folder,
|
||
"original_filename": original_filename,
|
||
"version": version
|
||
}
|
||
else:
|
||
# 使用UUID唯一文件名(原有逻辑)
|
||
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||
file_path = os.path.join(upload_dir, unique_filename)
|
||
|
||
# 保存文件
|
||
with open(file_path, "wb") as buffer:
|
||
shutil.copyfileobj(file.file, buffer)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "文件上传成功",
|
||
"file_path": file_path,
|
||
"folder": target_folder,
|
||
"original_filename": original_filename
|
||
}
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
print(f"Error uploading file: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"文件上传失败: {str(e)}")
|
||
|
||
|
||
# Task management routes that are related to file processing
|
||
@router.get("/api/v1/task/{task_id}/status")
|
||
async def get_task_status(task_id: str):
|
||
"""获取任务状态 - 简单可靠"""
|
||
try:
|
||
status_data = task_status_store.get_status(task_id)
|
||
|
||
if not status_data:
|
||
return {
|
||
"success": False,
|
||
"message": "任务不存在或已过期",
|
||
"task_id": task_id,
|
||
"status": "not_found"
|
||
}
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "任务状态获取成功",
|
||
"task_id": task_id,
|
||
"status": status_data["status"],
|
||
"unique_id": status_data["unique_id"],
|
||
"created_at": status_data["created_at"],
|
||
"updated_at": status_data["updated_at"],
|
||
"result": status_data.get("result"),
|
||
"error": status_data.get("error")
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting task status: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取任务状态失败: {str(e)}")
|
||
|
||
|
||
@router.delete("/api/v1/task/{task_id}")
|
||
async def delete_task(task_id: str):
|
||
"""删除任务记录"""
|
||
try:
|
||
success = task_status_store.delete_status(task_id)
|
||
if success:
|
||
return {
|
||
"success": True,
|
||
"message": f"任务记录已删除: {task_id}",
|
||
"task_id": task_id
|
||
}
|
||
else:
|
||
return {
|
||
"success": False,
|
||
"message": f"任务记录不存在: {task_id}",
|
||
"task_id": task_id
|
||
}
|
||
except Exception as e:
|
||
print(f"Error deleting task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"删除任务记录失败: {str(e)}")
|
||
|
||
|
||
@router.get("/api/v1/tasks")
|
||
async def list_tasks(status: Optional[str] = None, dataset_id: Optional[str] = None, limit: int = 100):
|
||
"""列出任务,支持筛选"""
|
||
try:
|
||
if status or dataset_id:
|
||
# 使用搜索功能
|
||
tasks = task_status_store.search_tasks(status=status, unique_id=dataset_id, limit=limit)
|
||
else:
|
||
# 获取所有任务
|
||
all_tasks = task_status_store.list_all()
|
||
tasks = list(all_tasks.values())[:limit]
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "任务列表获取成功",
|
||
"total_tasks": len(tasks),
|
||
"tasks": tasks,
|
||
"filters": {
|
||
"status": status,
|
||
"dataset_id": dataset_id,
|
||
"limit": limit
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error listing tasks: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取任务列表失败: {str(e)}")
|
||
|
||
|
||
@router.get("/api/v1/tasks/statistics")
|
||
async def get_task_statistics():
|
||
"""获取任务统计信息"""
|
||
try:
|
||
stats = task_status_store.get_statistics()
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "统计信息获取成功",
|
||
"statistics": stats
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting statistics: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}")
|
||
|
||
|
||
@router.post("/api/v1/tasks/cleanup")
|
||
async def cleanup_tasks(older_than_days: int = 7):
|
||
"""清理旧任务记录"""
|
||
try:
|
||
deleted_count = task_status_store.cleanup_old_tasks(older_than_days=older_than_days)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"已清理 {deleted_count} 条旧任务记录",
|
||
"deleted_count": deleted_count,
|
||
"older_than_days": older_than_days
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error cleaning up tasks: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"清理任务记录失败: {str(e)}") |