catalog-agent/fastapi_app.py
2025-10-18 09:20:59 +08:00

785 lines
29 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import tempfile
import shutil
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
from datetime import datetime
import uvicorn
from fastapi import FastAPI, HTTPException, Depends, Header
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
from pydantic import BaseModel, Field
# Import utility modules
from utils import (
# Models
Message, DatasetRequest, ChatRequest, FileProcessRequest,
FileProcessResponse, ChatResponse, QueueTaskRequest, QueueTaskResponse,
QueueStatusResponse, TaskStatusResponse,
# File utilities
download_file, remove_file_or_directory, get_document_preview,
load_processed_files_log, save_processed_files_log, get_file_hash,
# Dataset management
download_dataset_files, generate_dataset_structure,
remove_dataset_directory, remove_dataset_directory_by_key,
# Project management
generate_project_readme, save_project_readme, get_project_status,
remove_project, list_projects, get_project_stats,
# Agent management
get_global_agent_manager, init_global_agent_manager
)
# Import gbase_agent
from gbase_agent import update_agent_llm
# Import queue manager
from task_queue.manager import queue_manager
from task_queue.integration_tasks import process_files_async, cleanup_project_async
from task_queue.task_status import task_status_store
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Custom version for qwen-agent messages - keep this function as it's specific to this app
def get_content_from_messages(messages: List[dict]) -> str:
"""Extract content from qwen-agent messages with special formatting"""
full_text = ''
content = []
TOOL_CALL_S = '[TOOL_CALL]'
TOOL_RESULT_S = '[TOOL_RESPONSE]'
THOUGHT_S = '[THINK]'
ANSWER_S = '[ANSWER]'
for msg in messages:
if msg['role'] == ASSISTANT:
if msg.get('reasoning_content'):
assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
if msg.get('content'):
assert isinstance(msg['content'], str), 'Now only supports text messages'
content.append(f'{ANSWER_S}\n{msg["content"]}')
if msg.get('function_call'):
content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{msg["function_call"]["arguments"]}')
elif msg['role'] == FUNCTION:
content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
else:
raise TypeError
if content:
full_text = '\n'.join(content)
return full_text
# Helper functions are now imported from utils module
# 全局助手管理器配置
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "20"))
# 初始化全局助手管理器
agent_manager = init_global_agent_manager(max_cached_agents=max_cached_agents)
app = FastAPI(title="Database Assistant API", version="1.0.0")
# 挂载public文件夹为静态文件服务
app.mount("/public", StaticFiles(directory="public"), name="static")
# 添加CORS中间件支持前端页面
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 在生产环境中应该设置为具体的前端域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Models are now imported from utils module
async def generate_stream_response(agent, messages, request) -> AsyncGenerator[str, None]:
"""生成流式响应"""
accumulated_content = ""
chunk_id = 0
try:
for response in agent.run(messages=messages):
previous_content = accumulated_content
accumulated_content = get_content_from_messages(response)
# 计算新增的内容
if accumulated_content.startswith(previous_content):
new_content = accumulated_content[len(previous_content):]
else:
new_content = accumulated_content
previous_content = ""
# 只有当有新内容时才发送chunk
if new_content:
chunk_id += 1
# 构造OpenAI格式的流式响应
chunk_data = {
"id": f"chatcmpl-{chunk_id}",
"object": "chat.completion.chunk",
"created": int(__import__('time').time()),
"model": request.model,
"choices": [{
"index": 0,
"delta": {
"content": new_content
},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
# 发送最终完成标记
final_chunk = {
"id": f"chatcmpl-{chunk_id + 1}",
"object": "chat.completion.chunk",
"created": int(__import__('time').time()),
"model": request.model,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
# 发送结束标记
yield "data: [DONE]\n\n"
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error in generate_stream_response: {str(e)}")
print(f"Full traceback: {error_details}")
error_data = {
"error": {
"message": f"Stream error: {str(e)}",
"type": "internal_error"
}
}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
# Models are now imported from utils module
@app.post("/api/v1/files/process")
async def process_files(request: FileProcessRequest, authorization: Optional[str] = Header(None)):
"""
Process dataset files for a given unique_id.
Files are organized by key groups, and each group is combined into a single document.txt file.
Supports zip files which will be extracted and their txt/md contents combined.
Args:
request: FileProcessRequest containing unique_id, files (key-grouped dict), system_prompt, and mcp_settings
authorization: Authorization header containing API key (Bearer <API_KEY>)
Returns:
FileProcessResponse: Processing result with file list
"""
try:
unique_id = request.unique_id
if not unique_id:
raise HTTPException(status_code=400, detail="unique_id is required")
# 处理文件使用按key分组格式
processed_files_by_key = {}
if request.files:
# 使用请求中的文件按key分组
processed_files_by_key = await download_dataset_files(unique_id, request.files)
total_files = sum(len(files) for files in processed_files_by_key.values())
print(f"Processed {total_files} dataset files across {len(processed_files_by_key)} keys for unique_id: {unique_id}")
else:
print(f"No files provided in request for unique_id: {unique_id}")
# 使用unique_id获取项目目录
project_dir = os.path.join("projects", unique_id)
if not os.path.exists(project_dir):
raise HTTPException(status_code=400, detail=f"Project directory not found for unique_id: {unique_id}")
# 收集项目目录下所有的 document.txt 文件
document_files = []
for root, dirs, files in os.walk(project_dir):
for file in files:
if file == "document.txt":
document_files.append(os.path.join(root, file))
# 合并所有处理的文件包含新按key分组的文件
all_files = document_files.copy()
for key, files in processed_files_by_key.items():
all_files.extend(files)
if not all_files:
print(f"警告: 项目目录 {project_dir} 中未找到任何 document.txt 文件")
# 保存system_prompt和mcp_settings到项目目录如果提供
if request.system_prompt:
system_prompt_file = os.path.join(project_dir, "system_prompt.md")
with open(system_prompt_file, 'w', encoding='utf-8') as f:
f.write(request.system_prompt)
print(f"Saved system_prompt for unique_id: {unique_id}")
if request.mcp_settings:
mcp_settings_file = os.path.join(project_dir, "mcp_settings.json")
with open(mcp_settings_file, 'w', encoding='utf-8') as f:
json.dump(request.mcp_settings, f, ensure_ascii=False, indent=2)
print(f"Saved mcp_settings for unique_id: {unique_id}")
# 生成项目README.md文件
try:
save_project_readme(unique_id)
print(f"Generated README.md for unique_id: {unique_id}")
except Exception as e:
print(f"Failed to generate README.md for unique_id: {unique_id}, error: {str(e)}")
# 不影响主要处理流程,继续执行
# 返回结果包含按key分组的文件信息
result_files = []
for key in processed_files_by_key.keys():
# 添加对应的dataset document.txt路径
document_path = os.path.join("projects", unique_id, "dataset", key, "document.txt")
if os.path.exists(document_path):
result_files.append(document_path)
# 对于没有在processed_files_by_key中但存在的document.txt文件也添加到结果中
existing_document_paths = set(result_files) # 避免重复
for doc_file in document_files:
if doc_file not in existing_document_paths:
result_files.append(doc_file)
return FileProcessResponse(
success=True,
message=f"Successfully processed {len(result_files)} document files across {len(processed_files_by_key)} keys",
unique_id=unique_id,
processed_files=result_files
)
except HTTPException:
raise
except Exception as e:
print(f"Error processing files: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post("/api/v1/files/process/async")
async def process_files_async_endpoint(request: QueueTaskRequest, authorization: Optional[str] = Header(None)):
"""
异步处理文件的队列版本API
与 /api/v1/files/process 功能相同,但使用队列异步处理
Args:
request: QueueTaskRequest containing unique_id, files, system_prompt, mcp_settings, and queue options
authorization: Authorization header containing API key (Bearer <API_KEY>)
Returns:
QueueTaskResponse: Processing result with task ID for tracking
"""
try:
unique_id = request.unique_id
if not unique_id:
raise HTTPException(status_code=400, detail="unique_id is required")
# 估算处理时间(基于文件数量)
estimated_time = 0
if request.files:
total_files = sum(len(file_list) for file_list in request.files.values())
estimated_time = max(30, total_files * 10) # 每个文件预估10秒最少30秒
# 提交异步任务
task_id = queue_manager.enqueue_multiple_files(
project_id=unique_id,
file_paths=[],
original_filenames=[]
)
# 创建任务状态记录
import uuid
task_id = str(uuid.uuid4())
task_status_store.set_status(
task_id=task_id,
unique_id=unique_id,
status="pending"
)
# 提交异步任务
task = process_files_async(
unique_id=unique_id,
files=request.files,
system_prompt=request.system_prompt,
mcp_settings=request.mcp_settings,
task_id=task_id
)
return QueueTaskResponse(
success=True,
message=f"文件处理任务已提交到队列项目ID: {unique_id}",
unique_id=unique_id,
task_id=task_id, # 使用我们自己的task_id
task_status="pending",
estimated_processing_time=estimated_time
)
except HTTPException:
raise
except Exception as e:
print(f"Error submitting async file processing task: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.get("/api/v1/task/{task_id}/status")
async def get_task_status(task_id: str):
"""获取任务状态 - 简单可靠"""
try:
status_data = task_status_store.get_status(task_id)
if not status_data:
return {
"success": False,
"message": "任务不存在或已过期",
"task_id": task_id,
"status": "not_found"
}
return {
"success": True,
"message": "任务状态获取成功",
"task_id": task_id,
"status": status_data["status"],
"unique_id": status_data["unique_id"],
"created_at": status_data["created_at"],
"updated_at": status_data["updated_at"],
"result": status_data.get("result"),
"error": status_data.get("error")
}
except Exception as e:
print(f"Error getting task status: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取任务状态失败: {str(e)}")
@app.delete("/api/v1/task/{task_id}")
async def delete_task(task_id: str):
"""删除任务记录"""
try:
success = task_status_store.delete_status(task_id)
if success:
return {
"success": True,
"message": f"任务记录已删除: {task_id}",
"task_id": task_id
}
else:
return {
"success": False,
"message": f"任务记录不存在: {task_id}",
"task_id": task_id
}
except Exception as e:
print(f"Error deleting task: {str(e)}")
raise HTTPException(status_code=500, detail=f"删除任务记录失败: {str(e)}")
@app.get("/api/v1/tasks")
async def list_tasks(status: Optional[str] = None, unique_id: Optional[str] = None, limit: int = 100):
"""列出任务,支持筛选"""
try:
if status or unique_id:
# 使用搜索功能
tasks = task_status_store.search_tasks(status=status, unique_id=unique_id, limit=limit)
else:
# 获取所有任务
all_tasks = task_status_store.list_all()
tasks = list(all_tasks.values())[:limit]
return {
"success": True,
"message": "任务列表获取成功",
"total_tasks": len(tasks),
"tasks": tasks,
"filters": {
"status": status,
"unique_id": unique_id,
"limit": limit
}
}
except Exception as e:
print(f"Error listing tasks: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取任务列表失败: {str(e)}")
@app.get("/api/v1/tasks/statistics")
async def get_task_statistics():
"""获取任务统计信息"""
try:
stats = task_status_store.get_statistics()
return {
"success": True,
"message": "统计信息获取成功",
"statistics": stats
}
except Exception as e:
print(f"Error getting statistics: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}")
@app.post("/api/v1/tasks/cleanup")
async def cleanup_tasks(older_than_days: int = 7):
"""清理旧任务记录"""
try:
deleted_count = task_status_store.cleanup_old_tasks(older_than_days=older_than_days)
return {
"success": True,
"message": f"已清理 {deleted_count} 条旧任务记录",
"deleted_count": deleted_count,
"older_than_days": older_than_days
}
except Exception as e:
print(f"Error cleaning up tasks: {str(e)}")
raise HTTPException(status_code=500, detail=f"清理任务记录失败: {str(e)}")
@app.get("/api/v1/projects/{unique_id}/tasks")
async def get_project_tasks(unique_id: str):
"""获取指定项目的所有任务"""
try:
tasks = task_status_store.get_by_unique_id(unique_id)
return {
"success": True,
"message": "项目任务获取成功",
"unique_id": unique_id,
"total_tasks": len(tasks),
"tasks": tasks
}
except Exception as e:
print(f"Error getting project tasks: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取项目任务失败: {str(e)}")
@app.post("/api/v1/files/{unique_id}/cleanup/async")
async def cleanup_project_async_endpoint(unique_id: str, remove_all: bool = False):
"""异步清理项目文件"""
try:
task = cleanup_project_async(unique_id=unique_id, remove_all=remove_all)
return {
"success": True,
"message": f"项目清理任务已提交到队列项目ID: {unique_id}",
"unique_id": unique_id,
"task_id": task.id,
"action": "remove_all" if remove_all else "cleanup_logs"
}
except Exception as e:
print(f"Error submitting cleanup task: {str(e)}")
raise HTTPException(status_code=500, detail=f"提交清理任务失败: {str(e)}")
@app.post("/api/v1/chat/completions")
async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)):
"""
Chat completions API similar to OpenAI, supports both streaming and non-streaming
Args:
request: ChatRequest containing messages, model, dataset with unique_id, system_prompt, mcp_settings, and files
authorization: Authorization header containing API key (Bearer <API_KEY>)
Returns:
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
"""
try:
# 从Authorization header中提取API key
api_key = None
if authorization:
# 移除 "Bearer " 前缀
if authorization.startswith("Bearer "):
api_key = authorization[7:]
else:
api_key = authorization
# 获取unique_id
unique_id = request.unique_id
if not unique_id:
raise HTTPException(status_code=400, detail="unique_id is required")
# 使用unique_id获取项目目录
project_dir = os.path.join("projects", unique_id)
if not os.path.exists(project_dir):
raise HTTPException(status_code=400, detail=f"Project directory not found for unique_id: {unique_id}")
# 收集额外参数作为 generate_cfg
exclude_fields = {'messages', 'model', 'model_server', 'unique_id', 'stream'}
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
# 从全局管理器获取或创建助手实例配置读取逻辑已在agent_manager内部处理
agent = await agent_manager.get_or_create_agent(
unique_id=unique_id,
project_dir=project_dir,
model_name=request.model,
api_key=api_key,
model_server=request.model_server,
generate_cfg=generate_cfg
)
# 构建包含项目信息的消息上下文
messages = []
for msg in request.messages:
if msg.role == "assistant":
# 对assistant消息进行[ANSWER]分割处理,只保留最后一段
content_parts = msg.content.split("[ANSWER]")
if content_parts:
# 取最后一段非空文本
last_part = content_parts[-1].strip()
messages.append({"role": msg.role, "content": last_part})
else:
messages.append({"role": msg.role, "content": msg.content})
else:
messages.append({"role": msg.role, "content": msg.content})
# 根据stream参数决定返回流式还是非流式响应
if request.stream:
return StreamingResponse(
generate_stream_response(agent, messages, request),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
)
else:
# 非流式响应
final_responses = agent.run_nonstream(messages)
if final_responses and len(final_responses) > 0:
# 取最后一个响应
final_response = final_responses[-1]
# 如果返回的是Message对象需要转换为字典
if hasattr(final_response, 'model_dump'):
final_response = final_response.model_dump()
elif hasattr(final_response, 'dict'):
final_response = final_response.dict()
content = final_response.get("content", "")
# 构造OpenAI格式的响应
return ChatResponse(
choices=[{
"index": 0,
"message": {
"role": "assistant",
"content": content
},
"finish_reason": "stop"
}],
usage={
"prompt_tokens": sum(len(msg.content) for msg in request.messages),
"completion_tokens": len(content),
"total_tokens": sum(len(msg.content) for msg in request.messages) + len(content)
}
)
else:
raise HTTPException(status_code=500, detail="No response from agent")
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error in chat_completions: {str(e)}")
print(f"Full traceback: {error_details}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.get("/api/health")
async def health_check():
"""Health check endpoint"""
return {"message": "Database Assistant API is running"}
@app.get("/system/status")
async def system_status():
"""获取系统状态信息"""
# 获取助手缓存统计
cache_stats = agent_manager.get_cache_stats()
return {
"status": "running",
"storage_type": "File-Loaded Agent Manager",
"max_cached_agents": max_cached_agents,
"agent_cache": {
"total_cached_agents": cache_stats["total_cached_agents"],
"max_cached_agents": cache_stats["max_cached_agents"],
"cached_agents": cache_stats["agents"]
}
}
@app.post("/system/cleanup-cache")
async def cleanup_cache():
"""清理助手缓存"""
try:
# 清理助手实例缓存
cleared_count = agent_manager.clear_cache()
return {
"message": "缓存清理成功",
"cleared_agent_instances": cleared_count
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"缓存清理失败: {str(e)}")
@app.post("/system/cleanup-agent-cache")
async def cleanup_agent_cache():
"""仅清理助手实例缓存"""
try:
cleared_count = agent_manager.clear_cache()
return {
"message": "助手实例缓存清理成功",
"cleared_agent_instances": cleared_count
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"助手实例缓存清理失败: {str(e)}")
@app.get("/system/cached-projects")
async def get_cached_projects():
"""获取所有缓存的项目信息"""
try:
cache_stats = agent_manager.get_cache_stats()
return {
"cache_stats": cache_stats
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取缓存项目信息失败: {str(e)}")
@app.post("/system/remove-project-cache")
async def remove_project_cache(unique_id: str):
"""移除特定项目的缓存"""
try:
success = agent_manager.remove_cache_by_unique_id(unique_id)
if success:
return {"message": f"项目缓存移除成功: {unique_id}"}
else:
return {"message": f"未找到项目缓存: {unique_id}", "removed": False}
except Exception as e:
raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}")
@app.get("/api/v1/files/{unique_id}/status")
async def get_files_processing_status(unique_id: str):
"""获取项目的文件处理状态"""
try:
# Load processed files log
processed_log = load_processed_files_log(unique_id)
# Get project directory info
project_dir = os.path.join("projects", unique_id)
project_exists = os.path.exists(project_dir)
# Collect document.txt files
document_files = []
if project_exists:
for root, dirs, files in os.walk(project_dir):
for file in files:
if file == "document.txt":
document_files.append(os.path.join(root, file))
return {
"unique_id": unique_id,
"project_exists": project_exists,
"processed_files_count": len(processed_log),
"processed_files": processed_log,
"document_files_count": len(document_files),
"document_files": document_files,
"log_file_exists": os.path.exists(os.path.join("projects", unique_id, "processed_files.json"))
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取文件处理状态失败: {str(e)}")
@app.post("/api/v1/files/{unique_id}/reset")
async def reset_files_processing(unique_id: str):
"""重置项目的文件处理状态,删除处理日志和所有文件"""
try:
project_dir = os.path.join("projects", unique_id)
log_file = os.path.join("projects", unique_id, "processed_files.json")
# Load processed log to know what files to remove
processed_log = load_processed_files_log(unique_id)
removed_files = []
# Remove all processed files and their dataset directories
for file_hash, file_info in processed_log.items():
# Remove local file in files directory
if 'local_path' in file_info:
if remove_file_or_directory(file_info['local_path']):
removed_files.append(file_info['local_path'])
# Handle new key-based structure first
if 'key' in file_info:
# Remove dataset directory by key
key = file_info['key']
if remove_dataset_directory_by_key(unique_id, key):
removed_files.append(f"dataset/{key}")
elif 'filename' in file_info:
# Fallback to old filename-based structure
filename_without_ext = os.path.splitext(file_info['filename'])[0]
dataset_dir = os.path.join("projects", unique_id, "dataset", filename_without_ext)
if remove_file_or_directory(dataset_dir):
removed_files.append(dataset_dir)
# Also remove any specific dataset path if exists (fallback)
if 'dataset_path' in file_info:
if remove_file_or_directory(file_info['dataset_path']):
removed_files.append(file_info['dataset_path'])
# Remove the log file
if remove_file_or_directory(log_file):
removed_files.append(log_file)
# Remove the entire files directory
files_dir = os.path.join(project_dir, "files")
if remove_file_or_directory(files_dir):
removed_files.append(files_dir)
# Also remove the entire dataset directory (clean up any remaining files)
dataset_dir = os.path.join(project_dir, "dataset")
if remove_file_or_directory(dataset_dir):
removed_files.append(dataset_dir)
# Remove README.md if exists
readme_file = os.path.join(project_dir, "README.md")
if remove_file_or_directory(readme_file):
removed_files.append(readme_file)
return {
"message": f"文件处理状态重置成功: {unique_id}",
"removed_files_count": len(removed_files),
"removed_files": removed_files
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8001)