785 lines
29 KiB
Python
785 lines
29 KiB
Python
import json
|
||
import os
|
||
import tempfile
|
||
import shutil
|
||
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
|
||
from datetime import datetime
|
||
|
||
import uvicorn
|
||
from fastapi import FastAPI, HTTPException, Depends, Header
|
||
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
|
||
from pydantic import BaseModel, Field
|
||
|
||
# Import utility modules
|
||
from utils import (
|
||
# Models
|
||
Message, DatasetRequest, ChatRequest, FileProcessRequest,
|
||
FileProcessResponse, ChatResponse, QueueTaskRequest, QueueTaskResponse,
|
||
QueueStatusResponse, TaskStatusResponse,
|
||
|
||
# File utilities
|
||
download_file, remove_file_or_directory, get_document_preview,
|
||
load_processed_files_log, save_processed_files_log, get_file_hash,
|
||
|
||
# Dataset management
|
||
download_dataset_files, generate_dataset_structure,
|
||
remove_dataset_directory, remove_dataset_directory_by_key,
|
||
|
||
# Project management
|
||
generate_project_readme, save_project_readme, get_project_status,
|
||
remove_project, list_projects, get_project_stats,
|
||
|
||
# Agent management
|
||
get_global_agent_manager, init_global_agent_manager
|
||
)
|
||
|
||
# Import gbase_agent
|
||
from gbase_agent import update_agent_llm
|
||
|
||
# Import queue manager
|
||
from task_queue.manager import queue_manager
|
||
from task_queue.integration_tasks import process_files_async, cleanup_project_async
|
||
from task_queue.task_status import task_status_store
|
||
|
||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||
|
||
# Custom version for qwen-agent messages - keep this function as it's specific to this app
|
||
def get_content_from_messages(messages: List[dict]) -> str:
|
||
"""Extract content from qwen-agent messages with special formatting"""
|
||
full_text = ''
|
||
content = []
|
||
TOOL_CALL_S = '[TOOL_CALL]'
|
||
TOOL_RESULT_S = '[TOOL_RESPONSE]'
|
||
THOUGHT_S = '[THINK]'
|
||
ANSWER_S = '[ANSWER]'
|
||
|
||
for msg in messages:
|
||
if msg['role'] == ASSISTANT:
|
||
if msg.get('reasoning_content'):
|
||
assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
|
||
content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
|
||
if msg.get('content'):
|
||
assert isinstance(msg['content'], str), 'Now only supports text messages'
|
||
content.append(f'{ANSWER_S}\n{msg["content"]}')
|
||
if msg.get('function_call'):
|
||
content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{msg["function_call"]["arguments"]}')
|
||
elif msg['role'] == FUNCTION:
|
||
content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
|
||
else:
|
||
raise TypeError
|
||
if content:
|
||
full_text = '\n'.join(content)
|
||
|
||
return full_text
|
||
|
||
|
||
# Helper functions are now imported from utils module
|
||
|
||
|
||
|
||
|
||
|
||
# 全局助手管理器配置
|
||
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "20"))
|
||
|
||
# 初始化全局助手管理器
|
||
agent_manager = init_global_agent_manager(max_cached_agents=max_cached_agents)
|
||
|
||
app = FastAPI(title="Database Assistant API", version="1.0.0")
|
||
|
||
# 挂载public文件夹为静态文件服务
|
||
app.mount("/public", StaticFiles(directory="public"), name="static")
|
||
|
||
# 添加CORS中间件,支持前端页面
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 在生产环境中应该设置为具体的前端域名
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
# Models are now imported from utils module
|
||
|
||
|
||
async def generate_stream_response(agent, messages, request) -> AsyncGenerator[str, None]:
|
||
"""生成流式响应"""
|
||
accumulated_content = ""
|
||
chunk_id = 0
|
||
try:
|
||
for response in agent.run(messages=messages):
|
||
previous_content = accumulated_content
|
||
accumulated_content = get_content_from_messages(response)
|
||
|
||
# 计算新增的内容
|
||
if accumulated_content.startswith(previous_content):
|
||
new_content = accumulated_content[len(previous_content):]
|
||
else:
|
||
new_content = accumulated_content
|
||
previous_content = ""
|
||
|
||
# 只有当有新内容时才发送chunk
|
||
if new_content:
|
||
chunk_id += 1
|
||
# 构造OpenAI格式的流式响应
|
||
chunk_data = {
|
||
"id": f"chatcmpl-{chunk_id}",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": request.model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {
|
||
"content": new_content
|
||
},
|
||
"finish_reason": None
|
||
}]
|
||
}
|
||
|
||
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
||
|
||
# 发送最终完成标记
|
||
final_chunk = {
|
||
"id": f"chatcmpl-{chunk_id + 1}",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": request.model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {},
|
||
"finish_reason": "stop"
|
||
}]
|
||
}
|
||
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
|
||
|
||
# 发送结束标记
|
||
yield "data: [DONE]\n\n"
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"Error in generate_stream_response: {str(e)}")
|
||
print(f"Full traceback: {error_details}")
|
||
|
||
error_data = {
|
||
"error": {
|
||
"message": f"Stream error: {str(e)}",
|
||
"type": "internal_error"
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
|
||
|
||
# Models are now imported from utils module
|
||
|
||
|
||
@app.post("/api/v1/files/process")
|
||
async def process_files(request: FileProcessRequest, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
Process dataset files for a given unique_id.
|
||
Files are organized by key groups, and each group is combined into a single document.txt file.
|
||
Supports zip files which will be extracted and their txt/md contents combined.
|
||
|
||
Args:
|
||
request: FileProcessRequest containing unique_id, files (key-grouped dict), system_prompt, and mcp_settings
|
||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||
|
||
Returns:
|
||
FileProcessResponse: Processing result with file list
|
||
"""
|
||
try:
|
||
unique_id = request.unique_id
|
||
if not unique_id:
|
||
raise HTTPException(status_code=400, detail="unique_id is required")
|
||
|
||
# 处理文件:使用按key分组格式
|
||
processed_files_by_key = {}
|
||
if request.files:
|
||
# 使用请求中的文件(按key分组)
|
||
processed_files_by_key = await download_dataset_files(unique_id, request.files)
|
||
total_files = sum(len(files) for files in processed_files_by_key.values())
|
||
print(f"Processed {total_files} dataset files across {len(processed_files_by_key)} keys for unique_id: {unique_id}")
|
||
else:
|
||
print(f"No files provided in request for unique_id: {unique_id}")
|
||
|
||
# 使用unique_id获取项目目录
|
||
project_dir = os.path.join("projects", unique_id)
|
||
if not os.path.exists(project_dir):
|
||
raise HTTPException(status_code=400, detail=f"Project directory not found for unique_id: {unique_id}")
|
||
|
||
# 收集项目目录下所有的 document.txt 文件
|
||
document_files = []
|
||
for root, dirs, files in os.walk(project_dir):
|
||
for file in files:
|
||
if file == "document.txt":
|
||
document_files.append(os.path.join(root, file))
|
||
|
||
# 合并所有处理的文件(包含新按key分组的文件)
|
||
all_files = document_files.copy()
|
||
for key, files in processed_files_by_key.items():
|
||
all_files.extend(files)
|
||
|
||
if not all_files:
|
||
print(f"警告: 项目目录 {project_dir} 中未找到任何 document.txt 文件")
|
||
|
||
# 保存system_prompt和mcp_settings到项目目录(如果提供)
|
||
if request.system_prompt:
|
||
system_prompt_file = os.path.join(project_dir, "system_prompt.md")
|
||
with open(system_prompt_file, 'w', encoding='utf-8') as f:
|
||
f.write(request.system_prompt)
|
||
print(f"Saved system_prompt for unique_id: {unique_id}")
|
||
|
||
if request.mcp_settings:
|
||
mcp_settings_file = os.path.join(project_dir, "mcp_settings.json")
|
||
with open(mcp_settings_file, 'w', encoding='utf-8') as f:
|
||
json.dump(request.mcp_settings, f, ensure_ascii=False, indent=2)
|
||
print(f"Saved mcp_settings for unique_id: {unique_id}")
|
||
|
||
# 生成项目README.md文件
|
||
try:
|
||
save_project_readme(unique_id)
|
||
print(f"Generated README.md for unique_id: {unique_id}")
|
||
except Exception as e:
|
||
print(f"Failed to generate README.md for unique_id: {unique_id}, error: {str(e)}")
|
||
# 不影响主要处理流程,继续执行
|
||
|
||
# 返回结果包含按key分组的文件信息
|
||
result_files = []
|
||
for key in processed_files_by_key.keys():
|
||
# 添加对应的dataset document.txt路径
|
||
document_path = os.path.join("projects", unique_id, "dataset", key, "document.txt")
|
||
if os.path.exists(document_path):
|
||
result_files.append(document_path)
|
||
|
||
# 对于没有在processed_files_by_key中但存在的document.txt文件,也添加到结果中
|
||
existing_document_paths = set(result_files) # 避免重复
|
||
for doc_file in document_files:
|
||
if doc_file not in existing_document_paths:
|
||
result_files.append(doc_file)
|
||
|
||
return FileProcessResponse(
|
||
success=True,
|
||
message=f"Successfully processed {len(result_files)} document files across {len(processed_files_by_key)} keys",
|
||
unique_id=unique_id,
|
||
processed_files=result_files
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
print(f"Error processing files: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/files/process/async")
|
||
async def process_files_async_endpoint(request: QueueTaskRequest, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
异步处理文件的队列版本API
|
||
与 /api/v1/files/process 功能相同,但使用队列异步处理
|
||
|
||
Args:
|
||
request: QueueTaskRequest containing unique_id, files, system_prompt, mcp_settings, and queue options
|
||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||
|
||
Returns:
|
||
QueueTaskResponse: Processing result with task ID for tracking
|
||
"""
|
||
try:
|
||
unique_id = request.unique_id
|
||
if not unique_id:
|
||
raise HTTPException(status_code=400, detail="unique_id is required")
|
||
|
||
# 估算处理时间(基于文件数量)
|
||
estimated_time = 0
|
||
if request.files:
|
||
total_files = sum(len(file_list) for file_list in request.files.values())
|
||
estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒
|
||
|
||
# 提交异步任务
|
||
task_id = queue_manager.enqueue_multiple_files(
|
||
project_id=unique_id,
|
||
file_paths=[],
|
||
original_filenames=[]
|
||
)
|
||
|
||
# 创建任务状态记录
|
||
import uuid
|
||
task_id = str(uuid.uuid4())
|
||
task_status_store.set_status(
|
||
task_id=task_id,
|
||
unique_id=unique_id,
|
||
status="pending"
|
||
)
|
||
|
||
# 提交异步任务
|
||
task = process_files_async(
|
||
unique_id=unique_id,
|
||
files=request.files,
|
||
system_prompt=request.system_prompt,
|
||
mcp_settings=request.mcp_settings,
|
||
task_id=task_id
|
||
)
|
||
|
||
return QueueTaskResponse(
|
||
success=True,
|
||
message=f"文件处理任务已提交到队列,项目ID: {unique_id}",
|
||
unique_id=unique_id,
|
||
task_id=task_id, # 使用我们自己的task_id
|
||
task_status="pending",
|
||
estimated_processing_time=estimated_time
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
print(f"Error submitting async file processing task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/task/{task_id}/status")
|
||
async def get_task_status(task_id: str):
|
||
"""获取任务状态 - 简单可靠"""
|
||
try:
|
||
status_data = task_status_store.get_status(task_id)
|
||
|
||
if not status_data:
|
||
return {
|
||
"success": False,
|
||
"message": "任务不存在或已过期",
|
||
"task_id": task_id,
|
||
"status": "not_found"
|
||
}
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "任务状态获取成功",
|
||
"task_id": task_id,
|
||
"status": status_data["status"],
|
||
"unique_id": status_data["unique_id"],
|
||
"created_at": status_data["created_at"],
|
||
"updated_at": status_data["updated_at"],
|
||
"result": status_data.get("result"),
|
||
"error": status_data.get("error")
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting task status: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取任务状态失败: {str(e)}")
|
||
|
||
|
||
@app.delete("/api/v1/task/{task_id}")
|
||
async def delete_task(task_id: str):
|
||
"""删除任务记录"""
|
||
try:
|
||
success = task_status_store.delete_status(task_id)
|
||
if success:
|
||
return {
|
||
"success": True,
|
||
"message": f"任务记录已删除: {task_id}",
|
||
"task_id": task_id
|
||
}
|
||
else:
|
||
return {
|
||
"success": False,
|
||
"message": f"任务记录不存在: {task_id}",
|
||
"task_id": task_id
|
||
}
|
||
except Exception as e:
|
||
print(f"Error deleting task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"删除任务记录失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/tasks")
|
||
async def list_tasks(status: Optional[str] = None, unique_id: Optional[str] = None, limit: int = 100):
|
||
"""列出任务,支持筛选"""
|
||
try:
|
||
if status or unique_id:
|
||
# 使用搜索功能
|
||
tasks = task_status_store.search_tasks(status=status, unique_id=unique_id, limit=limit)
|
||
else:
|
||
# 获取所有任务
|
||
all_tasks = task_status_store.list_all()
|
||
tasks = list(all_tasks.values())[:limit]
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "任务列表获取成功",
|
||
"total_tasks": len(tasks),
|
||
"tasks": tasks,
|
||
"filters": {
|
||
"status": status,
|
||
"unique_id": unique_id,
|
||
"limit": limit
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error listing tasks: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取任务列表失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/tasks/statistics")
|
||
async def get_task_statistics():
|
||
"""获取任务统计信息"""
|
||
try:
|
||
stats = task_status_store.get_statistics()
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "统计信息获取成功",
|
||
"statistics": stats
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting statistics: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/tasks/cleanup")
|
||
async def cleanup_tasks(older_than_days: int = 7):
|
||
"""清理旧任务记录"""
|
||
try:
|
||
deleted_count = task_status_store.cleanup_old_tasks(older_than_days=older_than_days)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"已清理 {deleted_count} 条旧任务记录",
|
||
"deleted_count": deleted_count,
|
||
"older_than_days": older_than_days
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error cleaning up tasks: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"清理任务记录失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/projects/{unique_id}/tasks")
|
||
async def get_project_tasks(unique_id: str):
|
||
"""获取指定项目的所有任务"""
|
||
try:
|
||
tasks = task_status_store.get_by_unique_id(unique_id)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "项目任务获取成功",
|
||
"unique_id": unique_id,
|
||
"total_tasks": len(tasks),
|
||
"tasks": tasks
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting project tasks: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取项目任务失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/files/{unique_id}/cleanup/async")
|
||
async def cleanup_project_async_endpoint(unique_id: str, remove_all: bool = False):
|
||
"""异步清理项目文件"""
|
||
try:
|
||
task = cleanup_project_async(unique_id=unique_id, remove_all=remove_all)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"项目清理任务已提交到队列,项目ID: {unique_id}",
|
||
"unique_id": unique_id,
|
||
"task_id": task.id,
|
||
"action": "remove_all" if remove_all else "cleanup_logs"
|
||
}
|
||
except Exception as e:
|
||
print(f"Error submitting cleanup task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"提交清理任务失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/chat/completions")
|
||
async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
Chat completions API similar to OpenAI, supports both streaming and non-streaming
|
||
|
||
Args:
|
||
request: ChatRequest containing messages, model, dataset with unique_id, system_prompt, mcp_settings, and files
|
||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||
|
||
Returns:
|
||
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
||
"""
|
||
try:
|
||
# 从Authorization header中提取API key
|
||
api_key = None
|
||
if authorization:
|
||
# 移除 "Bearer " 前缀
|
||
if authorization.startswith("Bearer "):
|
||
api_key = authorization[7:]
|
||
else:
|
||
api_key = authorization
|
||
|
||
# 获取unique_id
|
||
unique_id = request.unique_id
|
||
if not unique_id:
|
||
raise HTTPException(status_code=400, detail="unique_id is required")
|
||
|
||
# 使用unique_id获取项目目录
|
||
project_dir = os.path.join("projects", unique_id)
|
||
if not os.path.exists(project_dir):
|
||
raise HTTPException(status_code=400, detail=f"Project directory not found for unique_id: {unique_id}")
|
||
|
||
# 收集额外参数作为 generate_cfg
|
||
exclude_fields = {'messages', 'model', 'model_server', 'unique_id', 'stream'}
|
||
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
|
||
|
||
# 从全局管理器获取或创建助手实例(配置读取逻辑已在agent_manager内部处理)
|
||
agent = await agent_manager.get_or_create_agent(
|
||
unique_id=unique_id,
|
||
project_dir=project_dir,
|
||
model_name=request.model,
|
||
api_key=api_key,
|
||
model_server=request.model_server,
|
||
generate_cfg=generate_cfg
|
||
)
|
||
# 构建包含项目信息的消息上下文
|
||
messages = []
|
||
for msg in request.messages:
|
||
if msg.role == "assistant":
|
||
# 对assistant消息进行[ANSWER]分割处理,只保留最后一段
|
||
content_parts = msg.content.split("[ANSWER]")
|
||
if content_parts:
|
||
# 取最后一段非空文本
|
||
last_part = content_parts[-1].strip()
|
||
messages.append({"role": msg.role, "content": last_part})
|
||
else:
|
||
messages.append({"role": msg.role, "content": msg.content})
|
||
else:
|
||
messages.append({"role": msg.role, "content": msg.content})
|
||
|
||
# 根据stream参数决定返回流式还是非流式响应
|
||
if request.stream:
|
||
return StreamingResponse(
|
||
generate_stream_response(agent, messages, request),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||
)
|
||
else:
|
||
# 非流式响应
|
||
final_responses = agent.run_nonstream(messages)
|
||
|
||
if final_responses and len(final_responses) > 0:
|
||
# 取最后一个响应
|
||
final_response = final_responses[-1]
|
||
|
||
# 如果返回的是Message对象,需要转换为字典
|
||
if hasattr(final_response, 'model_dump'):
|
||
final_response = final_response.model_dump()
|
||
elif hasattr(final_response, 'dict'):
|
||
final_response = final_response.dict()
|
||
|
||
content = final_response.get("content", "")
|
||
|
||
# 构造OpenAI格式的响应
|
||
return ChatResponse(
|
||
choices=[{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": content
|
||
},
|
||
"finish_reason": "stop"
|
||
}],
|
||
usage={
|
||
"prompt_tokens": sum(len(msg.content) for msg in request.messages),
|
||
"completion_tokens": len(content),
|
||
"total_tokens": sum(len(msg.content) for msg in request.messages) + len(content)
|
||
}
|
||
)
|
||
else:
|
||
raise HTTPException(status_code=500, detail="No response from agent")
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"Error in chat_completions: {str(e)}")
|
||
print(f"Full traceback: {error_details}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@app.get("/api/health")
|
||
async def health_check():
|
||
"""Health check endpoint"""
|
||
return {"message": "Database Assistant API is running"}
|
||
|
||
|
||
@app.get("/system/status")
|
||
async def system_status():
|
||
"""获取系统状态信息"""
|
||
# 获取助手缓存统计
|
||
cache_stats = agent_manager.get_cache_stats()
|
||
|
||
return {
|
||
"status": "running",
|
||
"storage_type": "File-Loaded Agent Manager",
|
||
"max_cached_agents": max_cached_agents,
|
||
"agent_cache": {
|
||
"total_cached_agents": cache_stats["total_cached_agents"],
|
||
"max_cached_agents": cache_stats["max_cached_agents"],
|
||
"cached_agents": cache_stats["agents"]
|
||
}
|
||
}
|
||
|
||
|
||
@app.post("/system/cleanup-cache")
|
||
async def cleanup_cache():
|
||
"""清理助手缓存"""
|
||
try:
|
||
# 清理助手实例缓存
|
||
cleared_count = agent_manager.clear_cache()
|
||
|
||
return {
|
||
"message": "缓存清理成功",
|
||
"cleared_agent_instances": cleared_count
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"缓存清理失败: {str(e)}")
|
||
|
||
|
||
@app.post("/system/cleanup-agent-cache")
|
||
async def cleanup_agent_cache():
|
||
"""仅清理助手实例缓存"""
|
||
try:
|
||
cleared_count = agent_manager.clear_cache()
|
||
return {
|
||
"message": "助手实例缓存清理成功",
|
||
"cleared_agent_instances": cleared_count
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"助手实例缓存清理失败: {str(e)}")
|
||
|
||
|
||
@app.get("/system/cached-projects")
|
||
async def get_cached_projects():
|
||
"""获取所有缓存的项目信息"""
|
||
try:
|
||
cache_stats = agent_manager.get_cache_stats()
|
||
|
||
return {
|
||
"cache_stats": cache_stats
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取缓存项目信息失败: {str(e)}")
|
||
|
||
|
||
@app.post("/system/remove-project-cache")
|
||
async def remove_project_cache(unique_id: str):
|
||
"""移除特定项目的缓存"""
|
||
try:
|
||
success = agent_manager.remove_cache_by_unique_id(unique_id)
|
||
if success:
|
||
return {"message": f"项目缓存移除成功: {unique_id}"}
|
||
else:
|
||
return {"message": f"未找到项目缓存: {unique_id}", "removed": False}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/files/{unique_id}/status")
|
||
async def get_files_processing_status(unique_id: str):
|
||
"""获取项目的文件处理状态"""
|
||
try:
|
||
# Load processed files log
|
||
processed_log = load_processed_files_log(unique_id)
|
||
|
||
# Get project directory info
|
||
project_dir = os.path.join("projects", unique_id)
|
||
project_exists = os.path.exists(project_dir)
|
||
|
||
# Collect document.txt files
|
||
document_files = []
|
||
if project_exists:
|
||
for root, dirs, files in os.walk(project_dir):
|
||
for file in files:
|
||
if file == "document.txt":
|
||
document_files.append(os.path.join(root, file))
|
||
|
||
return {
|
||
"unique_id": unique_id,
|
||
"project_exists": project_exists,
|
||
"processed_files_count": len(processed_log),
|
||
"processed_files": processed_log,
|
||
"document_files_count": len(document_files),
|
||
"document_files": document_files,
|
||
"log_file_exists": os.path.exists(os.path.join("projects", unique_id, "processed_files.json"))
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取文件处理状态失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/files/{unique_id}/reset")
|
||
async def reset_files_processing(unique_id: str):
|
||
"""重置项目的文件处理状态,删除处理日志和所有文件"""
|
||
try:
|
||
project_dir = os.path.join("projects", unique_id)
|
||
log_file = os.path.join("projects", unique_id, "processed_files.json")
|
||
|
||
# Load processed log to know what files to remove
|
||
processed_log = load_processed_files_log(unique_id)
|
||
|
||
removed_files = []
|
||
# Remove all processed files and their dataset directories
|
||
for file_hash, file_info in processed_log.items():
|
||
# Remove local file in files directory
|
||
if 'local_path' in file_info:
|
||
if remove_file_or_directory(file_info['local_path']):
|
||
removed_files.append(file_info['local_path'])
|
||
|
||
# Handle new key-based structure first
|
||
if 'key' in file_info:
|
||
# Remove dataset directory by key
|
||
key = file_info['key']
|
||
if remove_dataset_directory_by_key(unique_id, key):
|
||
removed_files.append(f"dataset/{key}")
|
||
elif 'filename' in file_info:
|
||
# Fallback to old filename-based structure
|
||
filename_without_ext = os.path.splitext(file_info['filename'])[0]
|
||
dataset_dir = os.path.join("projects", unique_id, "dataset", filename_without_ext)
|
||
if remove_file_or_directory(dataset_dir):
|
||
removed_files.append(dataset_dir)
|
||
|
||
# Also remove any specific dataset path if exists (fallback)
|
||
if 'dataset_path' in file_info:
|
||
if remove_file_or_directory(file_info['dataset_path']):
|
||
removed_files.append(file_info['dataset_path'])
|
||
|
||
# Remove the log file
|
||
if remove_file_or_directory(log_file):
|
||
removed_files.append(log_file)
|
||
|
||
# Remove the entire files directory
|
||
files_dir = os.path.join(project_dir, "files")
|
||
if remove_file_or_directory(files_dir):
|
||
removed_files.append(files_dir)
|
||
|
||
# Also remove the entire dataset directory (clean up any remaining files)
|
||
dataset_dir = os.path.join(project_dir, "dataset")
|
||
if remove_file_or_directory(dataset_dir):
|
||
removed_files.append(dataset_dir)
|
||
|
||
# Remove README.md if exists
|
||
readme_file = os.path.join(project_dir, "README.md")
|
||
if remove_file_or_directory(readme_file):
|
||
removed_files.append(readme_file)
|
||
|
||
return {
|
||
"message": f"文件处理状态重置成功: {unique_id}",
|
||
"removed_files_count": len(removed_files),
|
||
"removed_files": removed_files
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}")
|
||
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(app, host="0.0.0.0", port=8001)
|