1268 lines
48 KiB
Python
1268 lines
48 KiB
Python
import json
|
||
import os
|
||
import tempfile
|
||
import shutil
|
||
import uuid
|
||
import hashlib
|
||
import requests
|
||
import aiohttp
|
||
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
|
||
from datetime import datetime
|
||
import re
|
||
|
||
import uvicorn
|
||
from fastapi import FastAPI, HTTPException, Depends, Header, UploadFile, File, Form
|
||
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from file_manager_api import router as file_manager_router
|
||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
|
||
from pydantic import BaseModel, Field
|
||
|
||
|
||
# Import utility modules
|
||
from utils import (
|
||
# Models
|
||
Message, DatasetRequest, ChatRequest, ChatResponse, QueueTaskRequest, IncrementalTaskRequest, QueueTaskResponse,
|
||
QueueStatusResponse, TaskStatusResponse,
|
||
|
||
# File utilities
|
||
download_file, remove_file_or_directory, get_document_preview,
|
||
load_processed_files_log, save_processed_files_log, get_file_hash,
|
||
|
||
# Dataset management
|
||
download_dataset_files, generate_dataset_structure,
|
||
remove_dataset_directory, remove_dataset_directory_by_key,
|
||
|
||
# Project management
|
||
generate_project_readme, save_project_readme, get_project_status,
|
||
remove_project, list_projects, get_project_stats,
|
||
|
||
# Agent management
|
||
get_global_agent_manager, init_global_agent_manager
|
||
)
|
||
|
||
# Import ChatRequestV2 directly from api_models
|
||
from utils.api_models import ChatRequestV2
|
||
|
||
# Import modified_assistant
|
||
from modified_assistant import update_agent_llm
|
||
|
||
# Import queue manager
|
||
from task_queue.manager import queue_manager
|
||
from task_queue.integration_tasks import process_files_async, process_files_incremental_async, cleanup_project_async
|
||
from task_queue.task_status import task_status_store
|
||
|
||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||
|
||
|
||
def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extension: str) -> tuple[str, int]:
|
||
"""
|
||
获取带版本号的文件名,自动处理文件删除和版本递增
|
||
|
||
Args:
|
||
upload_dir: 上传目录路径
|
||
name_without_ext: 不含扩展名的文件名
|
||
file_extension: 文件扩展名(包含点号)
|
||
|
||
Returns:
|
||
tuple[str, int]: (最终文件名, 版本号)
|
||
"""
|
||
# 检查原始文件是否存在
|
||
original_file = os.path.join(upload_dir, name_without_ext + file_extension)
|
||
original_exists = os.path.exists(original_file)
|
||
|
||
# 查找所有相关的版本化文件
|
||
pattern = re.compile(re.escape(name_without_ext) + r'_(\d+)' + re.escape(file_extension) + r'$')
|
||
existing_versions = []
|
||
files_to_delete = []
|
||
|
||
for filename in os.listdir(upload_dir):
|
||
# 检查是否是原始文件
|
||
if filename == name_without_ext + file_extension:
|
||
files_to_delete.append(filename)
|
||
continue
|
||
|
||
# 检查是否是版本化文件
|
||
match = pattern.match(filename)
|
||
if match:
|
||
version_num = int(match.group(1))
|
||
existing_versions.append(version_num)
|
||
files_to_delete.append(filename)
|
||
|
||
# 如果没有任何相关文件存在,使用原始文件名(版本1)
|
||
if not original_exists and not existing_versions:
|
||
return name_without_ext + file_extension, 1
|
||
|
||
# 删除所有现有文件(原始文件和版本化文件)
|
||
for filename in files_to_delete:
|
||
file_to_delete = os.path.join(upload_dir, filename)
|
||
try:
|
||
os.remove(file_to_delete)
|
||
print(f"已删除文件: {file_to_delete}")
|
||
except OSError as e:
|
||
print(f"删除文件失败 {file_to_delete}: {e}")
|
||
|
||
# 确定下一个版本号
|
||
if existing_versions:
|
||
next_version = max(existing_versions) + 1
|
||
else:
|
||
next_version = 2
|
||
|
||
# 生成带版本号的文件名
|
||
versioned_filename = f"{name_without_ext}_{next_version}{file_extension}"
|
||
|
||
return versioned_filename, next_version
|
||
|
||
|
||
# Custom version for qwen-agent messages - keep this function as it's specific to this app
|
||
def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str:
|
||
"""Extract content from qwen-agent messages with special formatting"""
|
||
full_text = ''
|
||
content = []
|
||
TOOL_CALL_S = '[TOOL_CALL]'
|
||
TOOL_RESULT_S = '[TOOL_RESPONSE]'
|
||
THOUGHT_S = '[THINK]'
|
||
ANSWER_S = '[ANSWER]'
|
||
|
||
for msg in messages:
|
||
|
||
if msg['role'] == ASSISTANT:
|
||
if msg.get('reasoning_content'):
|
||
assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
|
||
content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
|
||
if msg.get('content'):
|
||
assert isinstance(msg['content'], str), 'Now only supports text messages'
|
||
# 过滤掉流式输出中的不完整 tool_call 文本
|
||
content_text = msg["content"]
|
||
|
||
# 使用正则表达式替换不完整的 tool_call 模式为空字符串
|
||
|
||
# 匹配并替换不完整的 tool_call 模式
|
||
content_text = re.sub(r'<t?o?o?l?_?c?a?l?l?$', '', content_text)
|
||
# 只有在处理后内容不为空时才添加
|
||
if content_text.strip():
|
||
content.append(f'{ANSWER_S}\n{content_text}')
|
||
if msg.get('function_call'):
|
||
content_text = msg["function_call"]["arguments"]
|
||
content_text = re.sub(r'}\n<\/?t?o?o?l?_?c?a?l?l?$', '', content_text)
|
||
if content_text.strip():
|
||
content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{content_text}')
|
||
elif msg['role'] == FUNCTION:
|
||
if tool_response:
|
||
content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
|
||
else:
|
||
raise TypeError
|
||
|
||
if content:
|
||
full_text = '\n'.join(content)
|
||
|
||
return full_text
|
||
|
||
|
||
# Helper functions are now imported from utils module
|
||
|
||
|
||
|
||
|
||
|
||
# 全局助手管理器配置
|
||
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "20"))
|
||
|
||
# 初始化全局助手管理器
|
||
agent_manager = init_global_agent_manager(max_cached_agents=max_cached_agents)
|
||
|
||
app = FastAPI(title="Database Assistant API", version="1.0.0")
|
||
|
||
# 挂载public文件夹为静态文件服务
|
||
app.mount("/public", StaticFiles(directory="public"), name="static")
|
||
|
||
# 添加CORS中间件,支持前端页面
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 在生产环境中应该设置为具体的前端域名
|
||
allow_credentials=True,
|
||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
|
||
allow_headers=[
|
||
"Authorization", "Content-Type", "Accept", "Origin", "User-Agent",
|
||
"DNT", "Cache-Control", "Range", "X-Requested-With"
|
||
],
|
||
)
|
||
|
||
|
||
# Models are now imported from utils module
|
||
|
||
|
||
async def generate_stream_response(agent, messages, tool_response: bool, model: str) -> AsyncGenerator[str, None]:
|
||
"""生成流式响应"""
|
||
accumulated_content = ""
|
||
chunk_id = 0
|
||
try:
|
||
for response in agent.run(messages=messages):
|
||
previous_content = accumulated_content
|
||
accumulated_content = get_content_from_messages(response, tool_response=tool_response)
|
||
|
||
# 计算新增的内容
|
||
if accumulated_content.startswith(previous_content):
|
||
new_content = accumulated_content[len(previous_content):]
|
||
else:
|
||
new_content = accumulated_content
|
||
previous_content = ""
|
||
|
||
# 只有当有新内容时才发送chunk
|
||
if new_content:
|
||
chunk_id += 1
|
||
# 构造OpenAI格式的流式响应
|
||
chunk_data = {
|
||
"id": f"chatcmpl-{chunk_id}",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {
|
||
"content": new_content
|
||
},
|
||
"finish_reason": None
|
||
}]
|
||
}
|
||
|
||
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
||
|
||
# 发送最终完成标记
|
||
final_chunk = {
|
||
"id": f"chatcmpl-{chunk_id + 1}",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {},
|
||
"finish_reason": "stop"
|
||
}]
|
||
}
|
||
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
|
||
|
||
# 发送结束标记
|
||
yield "data: [DONE]\n\n"
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"Error in generate_stream_response: {str(e)}")
|
||
print(f"Full traceback: {error_details}")
|
||
|
||
error_data = {
|
||
"error": {
|
||
"message": f"Stream error: {str(e)}",
|
||
"type": "internal_error"
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
|
||
|
||
# Models are now imported from utils module
|
||
|
||
|
||
@app.post("/api/v1/files/process/async")
|
||
async def process_files_async_endpoint(request: QueueTaskRequest, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
异步处理文件的队列版本API
|
||
与 /api/v1/files/process 功能相同,但使用队列异步处理
|
||
|
||
Args:
|
||
request: QueueTaskRequest containing dataset_id, files, system_prompt, mcp_settings, and queue options
|
||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||
|
||
Returns:
|
||
QueueTaskResponse: Processing result with task ID for tracking
|
||
"""
|
||
try:
|
||
dataset_id = request.dataset_id
|
||
if not dataset_id:
|
||
raise HTTPException(status_code=400, detail="dataset_id is required")
|
||
|
||
# 估算处理时间(基于文件数量)
|
||
estimated_time = 0
|
||
if request.upload_folder:
|
||
# 对于upload_folder,无法预先估算文件数量,使用默认时间
|
||
estimated_time = 120 # 默认2分钟
|
||
elif request.files:
|
||
total_files = sum(len(file_list) for file_list in request.files.values())
|
||
estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒
|
||
|
||
# 提交异步任务
|
||
task_id = queue_manager.enqueue_multiple_files(
|
||
project_id=dataset_id,
|
||
file_paths=[],
|
||
original_filenames=[]
|
||
)
|
||
|
||
# 创建任务状态记录
|
||
import uuid
|
||
task_id = str(uuid.uuid4())
|
||
task_status_store.set_status(
|
||
task_id=task_id,
|
||
unique_id=dataset_id,
|
||
status="pending"
|
||
)
|
||
|
||
# 提交异步任务
|
||
task = process_files_async(
|
||
dataset_id=dataset_id,
|
||
files=request.files,
|
||
upload_folder=request.upload_folder,
|
||
task_id=task_id
|
||
)
|
||
|
||
# 构建更详细的消息
|
||
message = f"文件处理任务已提交到队列,项目ID: {dataset_id}"
|
||
if request.upload_folder:
|
||
group_count = len(request.upload_folder)
|
||
message += f",将从 {group_count} 个上传文件夹自动扫描文件"
|
||
elif request.files:
|
||
total_files = sum(len(file_list) for file_list in request.files.values())
|
||
message += f",包含 {total_files} 个文件"
|
||
|
||
return QueueTaskResponse(
|
||
success=True,
|
||
message=message,
|
||
dataset_id=dataset_id,
|
||
task_id=task_id, # 使用我们自己的task_id
|
||
task_status="pending",
|
||
estimated_processing_time=estimated_time
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
print(f"Error submitting async file processing task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/files/process/incremental")
|
||
async def process_files_incremental_endpoint(request: IncrementalTaskRequest, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
增量处理文件的队列版本API - 支持添加和删除文件
|
||
|
||
Args:
|
||
request: IncrementalTaskRequest containing dataset_id, files_to_add, files_to_remove, system_prompt, mcp_settings, and queue options
|
||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||
|
||
Returns:
|
||
QueueTaskResponse: Processing result with task ID for tracking
|
||
"""
|
||
try:
|
||
dataset_id = request.dataset_id
|
||
if not dataset_id:
|
||
raise HTTPException(status_code=400, detail="dataset_id is required")
|
||
|
||
# 验证至少有添加或删除操作
|
||
if not request.files_to_add and not request.files_to_remove:
|
||
raise HTTPException(status_code=400, detail="At least one of files_to_add or files_to_remove must be provided")
|
||
|
||
# 估算处理时间(基于文件数量)
|
||
estimated_time = 0
|
||
total_add_files = sum(len(file_list) for file_list in (request.files_to_add or {}).values())
|
||
total_remove_files = sum(len(file_list) for file_list in (request.files_to_remove or {}).values())
|
||
total_files = total_add_files + total_remove_files
|
||
estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒
|
||
|
||
# 创建任务状态记录
|
||
import uuid
|
||
task_id = str(uuid.uuid4())
|
||
task_status_store.set_status(
|
||
task_id=task_id,
|
||
unique_id=dataset_id,
|
||
status="pending"
|
||
)
|
||
|
||
# 提交增量异步任务
|
||
task = process_files_incremental_async(
|
||
dataset_id=dataset_id,
|
||
files_to_add=request.files_to_add,
|
||
files_to_remove=request.files_to_remove,
|
||
system_prompt=request.system_prompt,
|
||
mcp_settings=request.mcp_settings,
|
||
task_id=task_id
|
||
)
|
||
|
||
return QueueTaskResponse(
|
||
success=True,
|
||
message=f"增量文件处理任务已提交到队列 - 添加 {total_add_files} 个文件,删除 {total_remove_files} 个文件,项目ID: {dataset_id}",
|
||
dataset_id=dataset_id,
|
||
task_id=task_id,
|
||
task_status="pending",
|
||
estimated_processing_time=estimated_time
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
print(f"Error submitting incremental file processing task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/task/{task_id}/status")
|
||
async def get_task_status(task_id: str):
|
||
"""获取任务状态 - 简单可靠"""
|
||
try:
|
||
status_data = task_status_store.get_status(task_id)
|
||
|
||
if not status_data:
|
||
return {
|
||
"success": False,
|
||
"message": "任务不存在或已过期",
|
||
"task_id": task_id,
|
||
"status": "not_found"
|
||
}
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "任务状态获取成功",
|
||
"task_id": task_id,
|
||
"status": status_data["status"],
|
||
"unique_id": status_data["unique_id"],
|
||
"created_at": status_data["created_at"],
|
||
"updated_at": status_data["updated_at"],
|
||
"result": status_data.get("result"),
|
||
"error": status_data.get("error")
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting task status: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取任务状态失败: {str(e)}")
|
||
|
||
|
||
@app.delete("/api/v1/task/{task_id}")
|
||
async def delete_task(task_id: str):
|
||
"""删除任务记录"""
|
||
try:
|
||
success = task_status_store.delete_status(task_id)
|
||
if success:
|
||
return {
|
||
"success": True,
|
||
"message": f"任务记录已删除: {task_id}",
|
||
"task_id": task_id
|
||
}
|
||
else:
|
||
return {
|
||
"success": False,
|
||
"message": f"任务记录不存在: {task_id}",
|
||
"task_id": task_id
|
||
}
|
||
except Exception as e:
|
||
print(f"Error deleting task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"删除任务记录失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/tasks")
|
||
async def list_tasks(status: Optional[str] = None, dataset_id: Optional[str] = None, limit: int = 100):
|
||
"""列出任务,支持筛选"""
|
||
try:
|
||
if status or dataset_id:
|
||
# 使用搜索功能
|
||
tasks = task_status_store.search_tasks(status=status, unique_id=dataset_id, limit=limit)
|
||
else:
|
||
# 获取所有任务
|
||
all_tasks = task_status_store.list_all()
|
||
tasks = list(all_tasks.values())[:limit]
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "任务列表获取成功",
|
||
"total_tasks": len(tasks),
|
||
"tasks": tasks,
|
||
"filters": {
|
||
"status": status,
|
||
"dataset_id": dataset_id,
|
||
"limit": limit
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error listing tasks: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取任务列表失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/tasks/statistics")
|
||
async def get_task_statistics():
|
||
"""获取任务统计信息"""
|
||
try:
|
||
stats = task_status_store.get_statistics()
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "统计信息获取成功",
|
||
"statistics": stats
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting statistics: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/tasks/cleanup")
|
||
async def cleanup_tasks(older_than_days: int = 7):
|
||
"""清理旧任务记录"""
|
||
try:
|
||
deleted_count = task_status_store.cleanup_old_tasks(older_than_days=older_than_days)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"已清理 {deleted_count} 条旧任务记录",
|
||
"deleted_count": deleted_count,
|
||
"older_than_days": older_than_days
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error cleaning up tasks: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"清理任务记录失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/projects")
|
||
async def list_all_projects():
|
||
"""获取所有项目列表"""
|
||
try:
|
||
# 获取机器人项目(projects/robot)
|
||
robot_dir = "projects/robot"
|
||
robot_projects = []
|
||
|
||
if os.path.exists(robot_dir):
|
||
for item in os.listdir(robot_dir):
|
||
item_path = os.path.join(robot_dir, item)
|
||
if os.path.isdir(item_path):
|
||
try:
|
||
# 读取机器人配置文件
|
||
config_path = os.path.join(item_path, "robot_config.json")
|
||
config_data = {}
|
||
if os.path.exists(config_path):
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config_data = json.load(f)
|
||
|
||
# 统计文件数量
|
||
file_count = 0
|
||
if os.path.exists(os.path.join(item_path, "dataset")):
|
||
for root, dirs, files in os.walk(os.path.join(item_path, "dataset")):
|
||
file_count += len(files)
|
||
|
||
robot_projects.append({
|
||
"id": item,
|
||
"name": config_data.get("name", item),
|
||
"type": "robot",
|
||
"status": config_data.get("status", "active"),
|
||
"file_count": file_count,
|
||
"config": config_data,
|
||
"created_at": os.path.getctime(item_path),
|
||
"updated_at": os.path.getmtime(item_path)
|
||
})
|
||
except Exception as e:
|
||
print(f"Error reading robot project {item}: {str(e)}")
|
||
robot_projects.append({
|
||
"id": item,
|
||
"name": item,
|
||
"type": "robot",
|
||
"status": "unknown",
|
||
"file_count": 0,
|
||
"created_at": os.path.getctime(item_path),
|
||
"updated_at": os.path.getmtime(item_path)
|
||
})
|
||
|
||
# 获取数据集(projects/data)
|
||
data_dir = "projects/data"
|
||
datasets = []
|
||
|
||
if os.path.exists(data_dir):
|
||
for item in os.listdir(data_dir):
|
||
item_path = os.path.join(data_dir, item)
|
||
if os.path.isdir(item_path):
|
||
try:
|
||
# 读取处理日志
|
||
log_path = os.path.join(item_path, "processing_log.json")
|
||
log_data = {}
|
||
if os.path.exists(log_path):
|
||
with open(log_path, 'r', encoding='utf-8') as f:
|
||
log_data = json.load(f)
|
||
|
||
# 统计文件数量
|
||
file_count = 0
|
||
for root, dirs, files in os.walk(item_path):
|
||
file_count += len([f for f in files if not f.endswith('.pkl')])
|
||
|
||
# 获取状态
|
||
status = "active"
|
||
if log_data.get("status"):
|
||
status = log_data["status"]
|
||
elif os.path.exists(os.path.join(item_path, "processed")):
|
||
status = "completed"
|
||
|
||
datasets.append({
|
||
"id": item,
|
||
"name": f"数据集 - {item[:8]}...",
|
||
"type": "dataset",
|
||
"status": status,
|
||
"file_count": file_count,
|
||
"log_data": log_data,
|
||
"created_at": os.path.getctime(item_path),
|
||
"updated_at": os.path.getmtime(item_path)
|
||
})
|
||
except Exception as e:
|
||
print(f"Error reading dataset {item}: {str(e)}")
|
||
datasets.append({
|
||
"id": item,
|
||
"name": f"数据集 - {item[:8]}...",
|
||
"type": "dataset",
|
||
"status": "unknown",
|
||
"file_count": 0,
|
||
"created_at": os.path.getctime(item_path),
|
||
"updated_at": os.path.getmtime(item_path)
|
||
})
|
||
|
||
all_projects = robot_projects + datasets
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "项目列表获取成功",
|
||
"total_projects": len(all_projects),
|
||
"robot_projects": robot_projects,
|
||
"datasets": datasets,
|
||
"projects": all_projects # 保持向后兼容
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error listing projects: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取项目列表失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/projects/robot")
|
||
async def list_robot_projects():
|
||
"""获取机器人项目列表"""
|
||
try:
|
||
response = await list_all_projects()
|
||
return {
|
||
"success": True,
|
||
"message": "机器人项目列表获取成功",
|
||
"total_projects": len(response["robot_projects"]),
|
||
"projects": response["robot_projects"]
|
||
}
|
||
except Exception as e:
|
||
print(f"Error listing robot projects: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取机器人项目列表失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/projects/datasets")
|
||
async def list_datasets():
|
||
"""获取数据集列表"""
|
||
try:
|
||
response = await list_all_projects()
|
||
return {
|
||
"success": True,
|
||
"message": "数据集列表获取成功",
|
||
"total_projects": len(response["datasets"]),
|
||
"projects": response["datasets"]
|
||
}
|
||
except Exception as e:
|
||
print(f"Error listing datasets: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取数据集列表失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/projects/{dataset_id}/tasks")
|
||
async def get_project_tasks(dataset_id: str):
|
||
"""获取指定项目的所有任务"""
|
||
try:
|
||
tasks = task_status_store.get_by_unique_id(dataset_id)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "项目任务获取成功",
|
||
"dataset_id": dataset_id,
|
||
"total_tasks": len(tasks),
|
||
"tasks": tasks
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting project tasks: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取项目任务失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/files/{dataset_id}/cleanup/async")
|
||
async def cleanup_project_async_endpoint(dataset_id: str, remove_all: bool = False):
|
||
"""异步清理项目文件"""
|
||
try:
|
||
task = cleanup_project_async(dataset_id=dataset_id, remove_all=remove_all)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"项目清理任务已提交到队列,项目ID: {dataset_id}",
|
||
"dataset_id": dataset_id,
|
||
"task_id": task.id,
|
||
"action": "remove_all" if remove_all else "cleanup_logs"
|
||
}
|
||
except Exception as e:
|
||
print(f"Error submitting cleanup task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"提交清理任务失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/chat/completions")
|
||
async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
Chat completions API similar to OpenAI, supports both streaming and non-streaming
|
||
|
||
Args:
|
||
request: ChatRequest containing messages, model, optional dataset_ids list, required bot_id, system_prompt, mcp_settings, and files
|
||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||
|
||
Returns:
|
||
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
||
|
||
Notes:
|
||
- dataset_ids: 可选参数,当提供时必须是项目ID列表(单个项目也使用数组格式)
|
||
- bot_id: 必需参数,机器人ID
|
||
- 只有当 robot_type == "catalog_agent" 且 dataset_ids 为非空数组时才会创建机器人项目目录:projects/robot/{bot_id}/
|
||
- robot_type 为其他值(包括默认的 "agent")时不创建任何目录
|
||
- dataset_ids 为空数组 []、None 或未提供时不创建任何目录
|
||
- 支持多知识库合并,自动处理文件夹重名冲突
|
||
|
||
Required Parameters:
|
||
- bot_id: str - 目标机器人ID
|
||
- messages: List[Message] - 对话消息列表
|
||
Optional Parameters:
|
||
- dataset_ids: List[str] - 源知识库项目ID列表(单个项目也使用数组格式)
|
||
- robot_type: str - 机器人类型,默认为 "agent"
|
||
|
||
Example:
|
||
{"bot_id": "my-bot-001", "messages": [{"role": "user", "content": "Hello"}]}
|
||
{"dataset_ids": ["project-123"], "bot_id": "my-bot-001", "messages": [{"role": "user", "content": "Hello"}]}
|
||
{"dataset_ids": ["project-123", "project-456"], "bot_id": "my-bot-002", "messages": [{"role": "user", "content": "Hello"}]}
|
||
{"dataset_ids": ["project-123"], "bot_id": "my-catalog-bot", "robot_type": "catalog_agent", "messages": [{"role": "user", "content": "Hello"}]}
|
||
"""
|
||
try:
|
||
# v1接口:从Authorization header中提取API key作为模型API密钥
|
||
api_key = extract_api_key_from_auth(authorization)
|
||
|
||
# 获取bot_id(必需参数)
|
||
bot_id = request.bot_id
|
||
if not bot_id:
|
||
raise HTTPException(status_code=400, detail="bot_id is required")
|
||
|
||
# 创建项目目录(如果有dataset_ids且不是agent类型)
|
||
project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type)
|
||
|
||
# 收集额外参数作为 generate_cfg
|
||
exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id'}
|
||
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
|
||
|
||
# 处理消息
|
||
messages = process_messages(request.messages, request.language)
|
||
|
||
# 调用公共的agent创建和响应生成逻辑
|
||
return await create_agent_and_generate_response(
|
||
bot_id=bot_id,
|
||
api_key=api_key,
|
||
messages=messages,
|
||
stream=request.stream,
|
||
tool_response=True,
|
||
model_name=request.model,
|
||
model_server=request.model_server,
|
||
language=request.language,
|
||
system_prompt=request.system_prompt,
|
||
mcp_settings=request.mcp_settings,
|
||
robot_type=request.robot_type,
|
||
project_dir=project_dir,
|
||
generate_cfg=generate_cfg
|
||
)
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"Error in chat_completions: {str(e)}")
|
||
print(f"Full traceback: {error_details}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
|
||
"""获取机器人配置从后端API"""
|
||
try:
|
||
backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
||
url = f"{backend_host}/v1/agent_bot_config/{bot_id}"
|
||
|
||
auth_token = generate_v2_auth_token(bot_id)
|
||
headers = {
|
||
"content-type": "application/json",
|
||
"authorization": f"Bearer {auth_token}"
|
||
}
|
||
print(url,headers)
|
||
# 使用异步HTTP请求
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(url, headers=headers, timeout=30) as response:
|
||
if response.status != 200:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Failed to fetch bot config: API returned status code {response.status}"
|
||
)
|
||
|
||
# 解析响应
|
||
response_data = await response.json()
|
||
|
||
if not response_data.get("success"):
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Failed to fetch bot config: {response_data.get('message', 'Unknown error')}"
|
||
)
|
||
|
||
return response_data.get("data", {})
|
||
|
||
except aiohttp.ClientError as e:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Failed to connect to backend API: {str(e)}"
|
||
)
|
||
except Exception as e:
|
||
if isinstance(e, HTTPException):
|
||
raise
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Failed to fetch bot config: {str(e)}"
|
||
)
|
||
|
||
|
||
def process_messages(messages: List[Message], language: Optional[str] = None) -> List[Dict[str, str]]:
|
||
"""处理消息列表,包括[ANSWER]分割和语言指令添加"""
|
||
processed_messages = []
|
||
|
||
# 处理每条消息
|
||
for msg in messages:
|
||
if msg.role == "assistant":
|
||
# 对assistant消息进行[ANSWER]分割处理,只保留最后一段
|
||
content_parts = msg.content.split("[ANSWER]")
|
||
if content_parts:
|
||
# 取最后一段非空文本
|
||
last_part = content_parts[-1].strip()
|
||
processed_messages.append({"role": msg.role, "content": last_part})
|
||
else:
|
||
processed_messages.append({"role": msg.role, "content": msg.content})
|
||
else:
|
||
processed_messages.append({"role": msg.role, "content": msg.content})
|
||
|
||
# 在最后一条消息的末尾追加回复语言
|
||
if processed_messages and language:
|
||
language_map = {
|
||
'zh': '请用中文回复',
|
||
'en': 'Please reply in English',
|
||
'ja': '日本語で回答してください',
|
||
'jp': '日本語で回答してください'
|
||
}
|
||
language_instruction = language_map.get(language.lower(), '')
|
||
if language_instruction:
|
||
# 在最后一条消息末尾追加语言指令
|
||
processed_messages[-1]['content'] = processed_messages[-1]['content'] + f"\n\n{language_instruction}。"
|
||
|
||
return processed_messages
|
||
|
||
|
||
async def create_agent_and_generate_response(
|
||
bot_id: str,
|
||
api_key: str,
|
||
messages: List[Dict[str, str]],
|
||
stream: bool,
|
||
tool_response: bool,
|
||
model_name: str,
|
||
model_server: str,
|
||
language: str,
|
||
system_prompt: Optional[str],
|
||
mcp_settings: Optional[List[Dict]],
|
||
robot_type: str,
|
||
project_dir: Optional[str] = None,
|
||
generate_cfg: Optional[Dict] = None
|
||
) -> Union[ChatResponse, StreamingResponse]:
|
||
"""创建agent并生成响应的公共逻辑"""
|
||
if generate_cfg is None:
|
||
generate_cfg = {}
|
||
|
||
# 从全局管理器获取或创建助手实例
|
||
agent = await agent_manager.get_or_create_agent(
|
||
bot_id=bot_id,
|
||
project_dir=project_dir,
|
||
model_name=model_name,
|
||
api_key=api_key,
|
||
model_server=model_server,
|
||
generate_cfg=generate_cfg,
|
||
language=language,
|
||
system_prompt=system_prompt,
|
||
mcp_settings=mcp_settings,
|
||
robot_type=robot_type
|
||
)
|
||
|
||
# 根据stream参数决定返回流式还是非流式响应
|
||
if stream:
|
||
return StreamingResponse(
|
||
generate_stream_response(agent, messages, tool_response, model_name),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||
)
|
||
else:
|
||
# 非流式响应
|
||
final_responses = agent.run_nonstream(messages)
|
||
|
||
if final_responses and len(final_responses) > 0:
|
||
# 使用 get_content_from_messages 处理响应,支持 tool_response 参数
|
||
content = get_content_from_messages(final_responses, tool_response=tool_response)
|
||
|
||
# 构造OpenAI格式的响应
|
||
return ChatResponse(
|
||
choices=[{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": content
|
||
},
|
||
"finish_reason": "stop"
|
||
}],
|
||
usage={
|
||
"prompt_tokens": sum(len(msg.get("content", "")) for msg in messages),
|
||
"completion_tokens": len(content),
|
||
"total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(content)
|
||
}
|
||
)
|
||
else:
|
||
raise HTTPException(status_code=500, detail="No response from agent")
|
||
|
||
|
||
def create_project_directory(dataset_ids: Optional[List[str]], bot_id: str, robot_type: str = "agent") -> Optional[str]:
|
||
"""创建项目目录的公共逻辑"""
|
||
# 只有当 robot_type == "catalog_agent" 且 dataset_ids 不为空时才创建目录
|
||
if robot_type != "catalog_agent" or not dataset_ids or len(dataset_ids) == 0:
|
||
return None
|
||
|
||
try:
|
||
from utils.multi_project_manager import create_robot_project
|
||
return create_robot_project(dataset_ids, bot_id)
|
||
except Exception as e:
|
||
print(f"Error creating project directory: {e}")
|
||
return None
|
||
|
||
|
||
def extract_api_key_from_auth(authorization: Optional[str]) -> Optional[str]:
|
||
"""从Authorization header中提取API key"""
|
||
if not authorization:
|
||
return None
|
||
|
||
# 移除 "Bearer " 前缀
|
||
if authorization.startswith("Bearer "):
|
||
return authorization[7:]
|
||
else:
|
||
return authorization
|
||
|
||
|
||
def generate_v2_auth_token(bot_id: str) -> str:
|
||
"""生成v2接口的认证token"""
|
||
masterkey = os.getenv("MASTERKEY", "master")
|
||
token_input = f"{masterkey}:{bot_id}"
|
||
return hashlib.md5(token_input.encode()).hexdigest()
|
||
|
||
|
||
@app.post("/api/v2/chat/completions")
|
||
async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
Chat completions API v2 with simplified parameters.
|
||
Only requires messages, stream, tool_response, bot_id, and language parameters.
|
||
Other parameters are fetched from the backend bot configuration API.
|
||
|
||
Args:
|
||
request: ChatRequestV2 containing only essential parameters
|
||
authorization: Authorization header for authentication (different from v1)
|
||
|
||
Returns:
|
||
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
||
|
||
Required Parameters:
|
||
- bot_id: str - 目标机器人ID
|
||
- messages: List[Message] - 对话消息列表
|
||
|
||
Optional Parameters:
|
||
- stream: bool - 是否流式输出,默认false
|
||
- tool_response: bool - 是否包含工具响应,默认false
|
||
- language: str - 回复语言,默认"ja"
|
||
|
||
Authentication:
|
||
- Requires valid MD5 hash token: MD5(MASTERKEY:bot_id)
|
||
- Authorization header should contain: Bearer {token}
|
||
- Uses MD5 hash of MASTERKEY:bot_id for backend API authentication
|
||
- Optionally uses API key from bot config for model access
|
||
"""
|
||
try:
|
||
# 获取bot_id(必需参数)
|
||
bot_id = request.bot_id
|
||
if not bot_id:
|
||
raise HTTPException(status_code=400, detail="bot_id is required")
|
||
|
||
# v2接口鉴权验证
|
||
expected_token = generate_v2_auth_token(bot_id)
|
||
provided_token = extract_api_key_from_auth(authorization)
|
||
|
||
if not provided_token:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Authorization header is required for v2 API"
|
||
)
|
||
|
||
if provided_token != expected_token:
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail=f"Invalid authentication token. Expected: {expected_token[:8]}..., Provided: {provided_token[:8]}..."
|
||
)
|
||
|
||
# 从后端API获取机器人配置(使用v2的鉴权方式)
|
||
bot_config = await fetch_bot_config(bot_id)
|
||
|
||
# v2接口:API密钥优先从后端配置获取,其次才从Authorization header获取
|
||
# 注意:这里的Authorization header已经用于鉴权,不再作为API key使用
|
||
api_key = bot_config.get("api_key")
|
||
|
||
# 创建项目目录(从后端配置获取dataset_ids)
|
||
project_dir = create_project_directory(
|
||
bot_config.get("dataset_ids", []),
|
||
bot_id,
|
||
bot_config.get("robot_type", "agent")
|
||
)
|
||
|
||
# 处理消息
|
||
messages = process_messages(request.messages, request.language)
|
||
|
||
# 调用公共的agent创建和响应生成逻辑
|
||
return await create_agent_and_generate_response(
|
||
bot_id=bot_id,
|
||
api_key=api_key,
|
||
messages=messages,
|
||
stream=request.stream,
|
||
tool_response=request.tool_response,
|
||
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
|
||
model_server=bot_config.get("model_server", ""),
|
||
language=request.language or bot_config.get("language", "ja"),
|
||
system_prompt=bot_config.get("system_prompt"),
|
||
mcp_settings=bot_config.get("mcp_settings", []),
|
||
robot_type=bot_config.get("robot_type", "agent"),
|
||
project_dir=project_dir,
|
||
generate_cfg={} # v2接口不传递额外的generate_cfg
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"Error in chat_completions_v2: {str(e)}")
|
||
print(f"Full traceback: {error_details}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/upload")
|
||
async def upload_file(file: UploadFile = File(...), folder: Optional[str] = Form(None)):
|
||
"""
|
||
文件上传API接口,上传文件到 ./projects/uploads/ 目录下
|
||
|
||
可以指定自定义文件夹名,如果不指定则使用日期文件夹
|
||
指定文件夹时使用原始文件名并支持版本控制
|
||
|
||
Args:
|
||
file: 上传的文件
|
||
folder: 可选的自定义文件夹名
|
||
|
||
Returns:
|
||
dict: 包含文件路径和文件夹信息的响应
|
||
"""
|
||
try:
|
||
# 调试信息
|
||
print(f"Received folder parameter: {folder}")
|
||
print(f"File received: {file.filename if file else 'None'}")
|
||
|
||
# 确定上传文件夹
|
||
if folder:
|
||
# 使用指定的自定义文件夹
|
||
target_folder = folder
|
||
# 安全性检查:防止路径遍历攻击
|
||
target_folder = os.path.basename(target_folder)
|
||
else:
|
||
# 获取当前日期并格式化为年月日
|
||
current_date = datetime.now()
|
||
target_folder = current_date.strftime("%Y%m%d")
|
||
|
||
# 创建上传目录
|
||
upload_dir = os.path.join("projects", "uploads", target_folder)
|
||
os.makedirs(upload_dir, exist_ok=True)
|
||
|
||
# 处理文件名
|
||
if not file.filename:
|
||
raise HTTPException(status_code=400, detail="文件名不能为空")
|
||
|
||
# 解析文件名和扩展名
|
||
original_filename = file.filename
|
||
name_without_ext, file_extension = os.path.splitext(original_filename)
|
||
|
||
# 根据是否指定文件夹决定命名策略
|
||
if folder:
|
||
# 使用原始文件名,支持版本控制
|
||
final_filename, version = get_versioned_filename(upload_dir, name_without_ext, file_extension)
|
||
file_path = os.path.join(upload_dir, final_filename)
|
||
|
||
# 保存文件
|
||
with open(file_path, "wb") as buffer:
|
||
shutil.copyfileobj(file.file, buffer)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"文件上传成功{' (版本: ' + str(version) + ')' if version > 1 else ''}",
|
||
"file_path": file_path,
|
||
"folder": target_folder,
|
||
"original_filename": original_filename,
|
||
"version": version
|
||
}
|
||
else:
|
||
# 使用UUID唯一文件名(原有逻辑)
|
||
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||
file_path = os.path.join(upload_dir, unique_filename)
|
||
|
||
# 保存文件
|
||
with open(file_path, "wb") as buffer:
|
||
shutil.copyfileobj(file.file, buffer)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "文件上传成功",
|
||
"file_path": file_path,
|
||
"folder": target_folder,
|
||
"original_filename": original_filename
|
||
}
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
print(f"Error uploading file: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"文件上传失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/health")
|
||
async def health_check():
|
||
"""Health check endpoint"""
|
||
return {"message": "Database Assistant API is running"}
|
||
|
||
|
||
@app.post("/system/remove-project-cache")
|
||
async def remove_project_cache(dataset_id: str):
|
||
"""移除特定项目的缓存"""
|
||
try:
|
||
removed_count = agent_manager.remove_cache_by_unique_id(dataset_id)
|
||
if removed_count > 0:
|
||
return {"message": f"项目缓存移除成功: {dataset_id}", "removed_count": removed_count}
|
||
else:
|
||
return {"message": f"未找到项目缓存: {dataset_id}", "removed_count": 0}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/files/{dataset_id}/status")
|
||
async def get_files_processing_status(dataset_id: str):
|
||
"""获取项目的文件处理状态"""
|
||
try:
|
||
# Load processed files log
|
||
processed_log = load_processed_files_log(dataset_id)
|
||
|
||
# Get project directory info
|
||
project_dir = os.path.join("projects", "data", dataset_id)
|
||
project_exists = os.path.exists(project_dir)
|
||
|
||
# Collect document.txt files
|
||
document_files = []
|
||
if project_exists:
|
||
for root, dirs, files in os.walk(project_dir):
|
||
for file in files:
|
||
if file == "document.txt":
|
||
document_files.append(os.path.join(root, file))
|
||
|
||
return {
|
||
"dataset_id": dataset_id,
|
||
"project_exists": project_exists,
|
||
"processed_files_count": len(processed_log),
|
||
"processed_files": processed_log,
|
||
"document_files_count": len(document_files),
|
||
"document_files": document_files,
|
||
"log_file_exists": os.path.exists(os.path.join("projects", "data", dataset_id, "processed_files.json"))
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取文件处理状态失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/files/{dataset_id}/reset")
|
||
async def reset_files_processing(dataset_id: str):
|
||
"""重置项目的文件处理状态,删除处理日志和所有文件"""
|
||
try:
|
||
project_dir = os.path.join("projects", "data", dataset_id)
|
||
log_file = os.path.join("projects", "data", dataset_id, "processed_files.json")
|
||
|
||
# Load processed log to know what files to remove
|
||
processed_log = load_processed_files_log(dataset_id)
|
||
|
||
removed_files = []
|
||
# Remove all processed files and their dataset directories
|
||
for file_hash, file_info in processed_log.items():
|
||
# Remove local file in files directory
|
||
if 'local_path' in file_info:
|
||
if remove_file_or_directory(file_info['local_path']):
|
||
removed_files.append(file_info['local_path'])
|
||
|
||
# Handle new key-based structure first
|
||
if 'key' in file_info:
|
||
# Remove dataset directory by key
|
||
key = file_info['key']
|
||
if remove_dataset_directory_by_key(dataset_id, key):
|
||
removed_files.append(f"dataset/{key}")
|
||
elif 'filename' in file_info:
|
||
# Fallback to old filename-based structure
|
||
filename_without_ext = os.path.splitext(file_info['filename'])[0]
|
||
dataset_dir = os.path.join("projects", "data", dataset_id, "dataset", filename_without_ext)
|
||
if remove_file_or_directory(dataset_dir):
|
||
removed_files.append(dataset_dir)
|
||
|
||
# Also remove any specific dataset path if exists (fallback)
|
||
if 'dataset_path' in file_info:
|
||
if remove_file_or_directory(file_info['dataset_path']):
|
||
removed_files.append(file_info['dataset_path'])
|
||
|
||
# Remove the log file
|
||
if remove_file_or_directory(log_file):
|
||
removed_files.append(log_file)
|
||
|
||
# Remove the entire files directory
|
||
files_dir = os.path.join(project_dir, "files")
|
||
if remove_file_or_directory(files_dir):
|
||
removed_files.append(files_dir)
|
||
|
||
# Also remove the entire dataset directory (clean up any remaining files)
|
||
dataset_dir = os.path.join(project_dir, "dataset")
|
||
if remove_file_or_directory(dataset_dir):
|
||
removed_files.append(dataset_dir)
|
||
|
||
# Remove README.md if exists
|
||
readme_file = os.path.join(project_dir, "README.md")
|
||
if remove_file_or_directory(readme_file):
|
||
removed_files.append(readme_file)
|
||
|
||
return {
|
||
"message": f"文件处理状态重置成功: {dataset_id}",
|
||
"removed_files_count": len(removed_files),
|
||
"removed_files": removed_files
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}")
|
||
|
||
# 注册文件管理API路由
|
||
app.include_router(file_manager_router)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 启动 FastAPI 应用
|
||
print("Starting FastAPI server...")
|
||
print("File Manager API available at: http://localhost:8001/api/v1/files")
|
||
print("Web Interface available at: http://localhost:8001/public/file-manager.html")
|
||
uvicorn.run(app, host="0.0.0.0", port=8001)
|