qwen_agent/fastapi_app.py
2025-11-20 13:29:44 +08:00

1800 lines
68 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import tempfile
import shutil
import uuid
import hashlib
import requests
import aiohttp
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
from datetime import datetime
import re
import multiprocessing
import time
import psutil
import uvicorn
from fastapi import FastAPI, HTTPException, Depends, Header, UploadFile, File, Form
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from utils.logger import logger
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from file_manager_api import router as file_manager_router
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
from pydantic import BaseModel, Field
# 导入语义检索服务
from embedding import get_search_service
# Import utility modules
from utils import (
# Models
Message, DatasetRequest, ChatRequest, ChatResponse, QueueTaskRequest, IncrementalTaskRequest, QueueTaskResponse,
QueueStatusResponse, TaskStatusResponse,
# File utilities
download_file, remove_file_or_directory, get_document_preview,
load_processed_files_log, save_processed_files_log, get_file_hash,
# Dataset management
download_dataset_files, generate_dataset_structure,
remove_dataset_directory, remove_dataset_directory_by_key,
# Project management
generate_project_readme, save_project_readme, get_project_status,
remove_project, list_projects, get_project_stats,
# Agent management
get_global_agent_manager, init_global_agent_manager,
# Optimization modules
get_global_sharded_agent_manager, init_global_sharded_agent_manager,
get_global_connection_pool, init_global_connection_pool,
get_global_file_cache, init_global_file_cache,
setup_system_optimizations, get_optimized_worker_config
)
# Import ChatRequestV2 directly from api_models
from utils.api_models import ChatRequestV2
# Import modified_assistant
from modified_assistant import update_agent_llm
# Import queue manager
from task_queue.manager import queue_manager
from task_queue.integration_tasks import process_files_async, process_files_incremental_async, cleanup_project_async
from task_queue.task_status import task_status_store
# 系统优化设置
# os.environ["TOKENIZERS_PARALLELISM"] = "false" # 注释掉,使用优化的配置
def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extension: str) -> tuple[str, int]:
"""
获取带版本号的文件名,自动处理文件删除和版本递增
Args:
upload_dir: 上传目录路径
name_without_ext: 不含扩展名的文件名
file_extension: 文件扩展名(包含点号)
Returns:
tuple[str, int]: (最终文件名, 版本号)
"""
# 检查原始文件是否存在
original_file = os.path.join(upload_dir, name_without_ext + file_extension)
original_exists = os.path.exists(original_file)
# 查找所有相关的版本化文件
pattern = re.compile(re.escape(name_without_ext) + r'_(\d+)' + re.escape(file_extension) + r'$')
existing_versions = []
files_to_delete = []
for filename in os.listdir(upload_dir):
# 检查是否是原始文件
if filename == name_without_ext + file_extension:
files_to_delete.append(filename)
continue
# 检查是否是版本化文件
match = pattern.match(filename)
if match:
version_num = int(match.group(1))
existing_versions.append(version_num)
files_to_delete.append(filename)
# 如果没有任何相关文件存在使用原始文件名版本1
if not original_exists and not existing_versions:
return name_without_ext + file_extension, 1
# 删除所有现有文件(原始文件和版本化文件)
for filename in files_to_delete:
file_to_delete = os.path.join(upload_dir, filename)
try:
os.remove(file_to_delete)
print(f"已删除文件: {file_to_delete}")
except OSError as e:
print(f"删除文件失败 {file_to_delete}: {e}")
# 确定下一个版本号
if existing_versions:
next_version = max(existing_versions) + 1
else:
next_version = 2
# 生成带版本号的文件名
versioned_filename = f"{name_without_ext}_{next_version}{file_extension}"
return versioned_filename, next_version
# Models are now imported from utils module
# 语义检索请求模型
class SemanticSearchRequest(BaseModel):
embedding_file: str = Field(..., description="embedding.pkl 文件路径")
query: str = Field(..., description="搜索关键词")
top_k: int = Field(default=20, description="返回结果数量", ge=1, le=100)
min_score: float = Field(default=0.0, description="最小相似度阈值", ge=0.0, le=1.0)
class BatchSearchRequest(BaseModel):
requests: List[SemanticSearchRequest] = Field(..., description="搜索请求列表")
# 语义检索响应模型
class SearchResult(BaseModel):
rank: int = Field(..., description="排名")
score: float = Field(..., description="相似度分数")
content: str = Field(..., description="匹配的内容")
content_preview: str = Field(..., description="内容预览")
class SemanticSearchResponse(BaseModel):
success: bool = Field(..., description="是否成功")
query: str = Field(..., description="查询关键词")
embedding_file: str = Field(..., description="embedding 文件路径")
processing_time: float = Field(..., description="处理时间(秒)")
total_chunks: int = Field(..., description="总文档块数")
chunking_strategy: str = Field(..., description="分块策略")
results: List[SearchResult] = Field(..., description="搜索结果")
cache_stats: Optional[Dict[str, Any]] = Field(None, description="缓存统计")
error: Optional[str] = Field(None, description="错误信息")
# 编码请求和响应模型
class EncodeRequest(BaseModel):
texts: List[str] = Field(..., description="要编码的文本列表")
batch_size: int = Field(default=32, description="批次大小", ge=1, le=128)
class EncodeResponse(BaseModel):
success: bool = Field(..., description="是否成功")
embeddings: List[List[float]] = Field(..., description="编码结果")
shape: List[int] = Field(..., description="embeddings 形状")
processing_time: float = Field(..., description="处理时间(秒)")
total_texts: int = Field(..., description="总文本数量")
error: Optional[str] = Field(None, description="错误信息")
# Custom version for qwen-agent messages - keep this function as it's specific to this app
def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str:
"""Extract content from qwen-agent messages with special formatting"""
full_text = ''
content = []
TOOL_CALL_S = '[TOOL_CALL]'
TOOL_RESULT_S = '[TOOL_RESPONSE]'
THOUGHT_S = '[THINK]'
ANSWER_S = '[ANSWER]'
for msg in messages:
if msg['role'] == ASSISTANT:
if msg.get('reasoning_content'):
assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
if msg.get('content'):
assert isinstance(msg['content'], str), 'Now only supports text messages'
# 过滤掉流式输出中的不完整 tool_call 文本
content_text = msg["content"]
# 使用正则表达式替换不完整的 tool_call 模式为空字符串
# 匹配并替换不完整的 tool_call 模式
content_text = re.sub(r'<t?o?o?l?_?c?a?l?l?$', '', content_text)
# 只有在处理后内容不为空时才添加
if content_text.strip():
content.append(f'{ANSWER_S}\n{content_text}')
if msg.get('function_call'):
content_text = msg["function_call"]["arguments"]
content_text = re.sub(r'}\n<\/?t?o?o?l?_?c?a?l?l?$', '', content_text)
if content_text.strip():
content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{content_text}')
elif msg['role'] == FUNCTION:
if tool_response:
content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
else:
raise TypeError
if content:
full_text = '\n'.join(content)
return full_text
# Helper functions are now imported from utils module
# 初始化系统优化
print("正在初始化系统优化...")
system_optimizer = setup_system_optimizations()
# 全局助手管理器配置(使用优化后的配置)
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "50")) # 增加缓存大小
shard_count = int(os.getenv("SHARD_COUNT", "16")) # 分片数量
# 初始化优化的全局助手管理器
agent_manager = init_global_sharded_agent_manager(
max_cached_agents=max_cached_agents,
shard_count=shard_count
)
# 初始化连接池
connection_pool = init_global_connection_pool(
max_connections_per_host=int(os.getenv("MAX_CONNECTIONS_PER_HOST", "100")),
max_connections_total=int(os.getenv("MAX_CONNECTIONS_TOTAL", "500")),
keepalive_timeout=int(os.getenv("KEEPALIVE_TIMEOUT", "30")),
connect_timeout=int(os.getenv("CONNECT_TIMEOUT", "10")),
total_timeout=int(os.getenv("TOTAL_TIMEOUT", "60"))
)
# 初始化文件缓存
file_cache = init_global_file_cache(
cache_size=int(os.getenv("FILE_CACHE_SIZE", "1000")),
ttl=int(os.getenv("FILE_CACHE_TTL", "300"))
)
print("系统优化初始化完成")
print(f"- 分片Agent管理器: {shard_count} 个分片,最多缓存 {max_cached_agents} 个agent")
print(f"- 连接池: 每主机100连接总计500连接")
print(f"- 文件缓存: 1000个文件TTL 300秒")
app = FastAPI(title="Database Assistant API", version="1.0.0")
# 挂载public文件夹为静态文件服务
app.mount("/public", StaticFiles(directory="public"), name="static")
# 添加CORS中间件支持前端页面
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 在生产环境中应该设置为具体的前端域名
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
allow_headers=[
"Authorization", "Content-Type", "Accept", "Origin", "User-Agent",
"DNT", "Cache-Control", "Range", "X-Requested-With"
],
)
# Models are now imported from utils module
async def generate_stream_response(agent, messages, tool_response: bool, model: str) -> AsyncGenerator[str, None]:
"""生成流式响应"""
accumulated_content = ""
chunk_id = 0
try:
for response in agent.run(messages=messages):
previous_content = accumulated_content
accumulated_content = get_content_from_messages(response, tool_response=tool_response)
# 计算新增的内容
if accumulated_content.startswith(previous_content):
new_content = accumulated_content[len(previous_content):]
else:
new_content = accumulated_content
previous_content = ""
# 只有当有新内容时才发送chunk
if new_content:
chunk_id += 1
# 构造OpenAI格式的流式响应
chunk_data = {
"id": f"chatcmpl-{chunk_id}",
"object": "chat.completion.chunk",
"created": int(__import__('time').time()),
"model": model,
"choices": [{
"index": 0,
"delta": {
"content": new_content
},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
# 发送最终完成标记
final_chunk = {
"id": f"chatcmpl-{chunk_id + 1}",
"object": "chat.completion.chunk",
"created": int(__import__('time').time()),
"model": model,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
# 发送结束标记
yield "data: [DONE]\n\n"
except Exception as e:
import traceback
error_details = traceback.format_exc()
logger.error(f"Error in generate_stream_response: {str(e)}")
logger.error(f"Full traceback: {error_details}")
error_data = {
"error": {
"message": f"Stream error: {str(e)}",
"type": "internal_error"
}
}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
# Models are now imported from utils module
@app.post("/api/v1/files/process/async")
async def process_files_async_endpoint(request: QueueTaskRequest, authorization: Optional[str] = Header(None)):
"""
异步处理文件的队列版本API
与 /api/v1/files/process 功能相同,但使用队列异步处理
Args:
request: QueueTaskRequest containing dataset_id, files, system_prompt, mcp_settings, and queue options
authorization: Authorization header containing API key (Bearer <API_KEY>)
Returns:
QueueTaskResponse: Processing result with task ID for tracking
"""
try:
dataset_id = request.dataset_id
if not dataset_id:
raise HTTPException(status_code=400, detail="dataset_id is required")
# 估算处理时间(基于文件数量)
estimated_time = 0
if request.upload_folder:
# 对于upload_folder无法预先估算文件数量使用默认时间
estimated_time = 120 # 默认2分钟
elif request.files:
total_files = sum(len(file_list) for file_list in request.files.values())
estimated_time = max(30, total_files * 10) # 每个文件预估10秒最少30秒
# 提交异步任务
task_id = queue_manager.enqueue_multiple_files(
project_id=dataset_id,
file_paths=[],
original_filenames=[]
)
# 创建任务状态记录
import uuid
task_id = str(uuid.uuid4())
task_status_store.set_status(
task_id=task_id,
unique_id=dataset_id,
status="pending"
)
# 提交异步任务
task = process_files_async(
dataset_id=dataset_id,
files=request.files,
upload_folder=request.upload_folder,
task_id=task_id
)
# 构建更详细的消息
message = f"文件处理任务已提交到队列项目ID: {dataset_id}"
if request.upload_folder:
group_count = len(request.upload_folder)
message += f",将从 {group_count} 个上传文件夹自动扫描文件"
elif request.files:
total_files = sum(len(file_list) for file_list in request.files.values())
message += f",包含 {total_files} 个文件"
return QueueTaskResponse(
success=True,
message=message,
dataset_id=dataset_id,
task_id=task_id, # 使用我们自己的task_id
task_status="pending",
estimated_processing_time=estimated_time
)
except HTTPException:
raise
except Exception as e:
print(f"Error submitting async file processing task: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post("/api/v1/files/process/incremental")
async def process_files_incremental_endpoint(request: IncrementalTaskRequest, authorization: Optional[str] = Header(None)):
"""
增量处理文件的队列版本API - 支持添加和删除文件
Args:
request: IncrementalTaskRequest containing dataset_id, files_to_add, files_to_remove, system_prompt, mcp_settings, and queue options
authorization: Authorization header containing API key (Bearer <API_KEY>)
Returns:
QueueTaskResponse: Processing result with task ID for tracking
"""
try:
dataset_id = request.dataset_id
if not dataset_id:
raise HTTPException(status_code=400, detail="dataset_id is required")
# 验证至少有添加或删除操作
if not request.files_to_add and not request.files_to_remove:
raise HTTPException(status_code=400, detail="At least one of files_to_add or files_to_remove must be provided")
# 估算处理时间(基于文件数量)
estimated_time = 0
total_add_files = sum(len(file_list) for file_list in (request.files_to_add or {}).values())
total_remove_files = sum(len(file_list) for file_list in (request.files_to_remove or {}).values())
total_files = total_add_files + total_remove_files
estimated_time = max(30, total_files * 10) # 每个文件预估10秒最少30秒
# 创建任务状态记录
import uuid
task_id = str(uuid.uuid4())
task_status_store.set_status(
task_id=task_id,
unique_id=dataset_id,
status="pending"
)
# 提交增量异步任务
task = process_files_incremental_async(
dataset_id=dataset_id,
files_to_add=request.files_to_add,
files_to_remove=request.files_to_remove,
system_prompt=request.system_prompt,
mcp_settings=request.mcp_settings,
task_id=task_id
)
return QueueTaskResponse(
success=True,
message=f"增量文件处理任务已提交到队列 - 添加 {total_add_files} 个文件,删除 {total_remove_files} 个文件项目ID: {dataset_id}",
dataset_id=dataset_id,
task_id=task_id,
task_status="pending",
estimated_processing_time=estimated_time
)
except HTTPException:
raise
except Exception as e:
print(f"Error submitting incremental file processing task: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.get("/api/v1/task/{task_id}/status")
async def get_task_status(task_id: str):
"""获取任务状态 - 简单可靠"""
try:
status_data = task_status_store.get_status(task_id)
if not status_data:
return {
"success": False,
"message": "任务不存在或已过期",
"task_id": task_id,
"status": "not_found"
}
return {
"success": True,
"message": "任务状态获取成功",
"task_id": task_id,
"status": status_data["status"],
"unique_id": status_data["unique_id"],
"created_at": status_data["created_at"],
"updated_at": status_data["updated_at"],
"result": status_data.get("result"),
"error": status_data.get("error")
}
except Exception as e:
print(f"Error getting task status: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取任务状态失败: {str(e)}")
@app.delete("/api/v1/task/{task_id}")
async def delete_task(task_id: str):
"""删除任务记录"""
try:
success = task_status_store.delete_status(task_id)
if success:
return {
"success": True,
"message": f"任务记录已删除: {task_id}",
"task_id": task_id
}
else:
return {
"success": False,
"message": f"任务记录不存在: {task_id}",
"task_id": task_id
}
except Exception as e:
print(f"Error deleting task: {str(e)}")
raise HTTPException(status_code=500, detail=f"删除任务记录失败: {str(e)}")
@app.get("/api/v1/tasks")
async def list_tasks(status: Optional[str] = None, dataset_id: Optional[str] = None, limit: int = 100):
"""列出任务,支持筛选"""
try:
if status or dataset_id:
# 使用搜索功能
tasks = task_status_store.search_tasks(status=status, unique_id=dataset_id, limit=limit)
else:
# 获取所有任务
all_tasks = task_status_store.list_all()
tasks = list(all_tasks.values())[:limit]
return {
"success": True,
"message": "任务列表获取成功",
"total_tasks": len(tasks),
"tasks": tasks,
"filters": {
"status": status,
"dataset_id": dataset_id,
"limit": limit
}
}
except Exception as e:
print(f"Error listing tasks: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取任务列表失败: {str(e)}")
@app.get("/api/v1/tasks/statistics")
async def get_task_statistics():
"""获取任务统计信息"""
try:
stats = task_status_store.get_statistics()
return {
"success": True,
"message": "统计信息获取成功",
"statistics": stats
}
except Exception as e:
print(f"Error getting statistics: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}")
@app.post("/api/v1/tasks/cleanup")
async def cleanup_tasks(older_than_days: int = 7):
"""清理旧任务记录"""
try:
deleted_count = task_status_store.cleanup_old_tasks(older_than_days=older_than_days)
return {
"success": True,
"message": f"已清理 {deleted_count} 条旧任务记录",
"deleted_count": deleted_count,
"older_than_days": older_than_days
}
except Exception as e:
print(f"Error cleaning up tasks: {str(e)}")
raise HTTPException(status_code=500, detail=f"清理任务记录失败: {str(e)}")
@app.get("/api/v1/projects")
async def list_all_projects():
"""获取所有项目列表"""
try:
# 获取机器人项目projects/robot
robot_dir = "projects/robot"
robot_projects = []
if os.path.exists(robot_dir):
for item in os.listdir(robot_dir):
item_path = os.path.join(robot_dir, item)
if os.path.isdir(item_path):
try:
# 读取机器人配置文件
config_path = os.path.join(item_path, "robot_config.json")
config_data = {}
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
config_data = json.load(f)
# 统计文件数量
file_count = 0
if os.path.exists(os.path.join(item_path, "dataset")):
for root, dirs, files in os.walk(os.path.join(item_path, "dataset")):
file_count += len(files)
robot_projects.append({
"id": item,
"name": config_data.get("name", item),
"type": "robot",
"status": config_data.get("status", "active"),
"file_count": file_count,
"config": config_data,
"created_at": os.path.getctime(item_path),
"updated_at": os.path.getmtime(item_path)
})
except Exception as e:
print(f"Error reading robot project {item}: {str(e)}")
robot_projects.append({
"id": item,
"name": item,
"type": "robot",
"status": "unknown",
"file_count": 0,
"created_at": os.path.getctime(item_path),
"updated_at": os.path.getmtime(item_path)
})
# 获取数据集projects/data
data_dir = "projects/data"
datasets = []
if os.path.exists(data_dir):
for item in os.listdir(data_dir):
item_path = os.path.join(data_dir, item)
if os.path.isdir(item_path):
try:
# 读取处理日志
log_path = os.path.join(item_path, "processing_log.json")
log_data = {}
if os.path.exists(log_path):
with open(log_path, 'r', encoding='utf-8') as f:
log_data = json.load(f)
# 统计文件数量
file_count = 0
for root, dirs, files in os.walk(item_path):
file_count += len([f for f in files if not f.endswith('.pkl')])
# 获取状态
status = "active"
if log_data.get("status"):
status = log_data["status"]
elif os.path.exists(os.path.join(item_path, "processed")):
status = "completed"
datasets.append({
"id": item,
"name": f"数据集 - {item[:8]}...",
"type": "dataset",
"status": status,
"file_count": file_count,
"log_data": log_data,
"created_at": os.path.getctime(item_path),
"updated_at": os.path.getmtime(item_path)
})
except Exception as e:
print(f"Error reading dataset {item}: {str(e)}")
datasets.append({
"id": item,
"name": f"数据集 - {item[:8]}...",
"type": "dataset",
"status": "unknown",
"file_count": 0,
"created_at": os.path.getctime(item_path),
"updated_at": os.path.getmtime(item_path)
})
all_projects = robot_projects + datasets
return {
"success": True,
"message": "项目列表获取成功",
"total_projects": len(all_projects),
"robot_projects": robot_projects,
"datasets": datasets,
"projects": all_projects # 保持向后兼容
}
except Exception as e:
print(f"Error listing projects: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取项目列表失败: {str(e)}")
@app.get("/api/v1/projects/robot")
async def list_robot_projects():
"""获取机器人项目列表"""
try:
response = await list_all_projects()
return {
"success": True,
"message": "机器人项目列表获取成功",
"total_projects": len(response["robot_projects"]),
"projects": response["robot_projects"]
}
except Exception as e:
print(f"Error listing robot projects: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取机器人项目列表失败: {str(e)}")
@app.get("/api/v1/projects/datasets")
async def list_datasets():
"""获取数据集列表"""
try:
response = await list_all_projects()
return {
"success": True,
"message": "数据集列表获取成功",
"total_projects": len(response["datasets"]),
"projects": response["datasets"]
}
except Exception as e:
print(f"Error listing datasets: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取数据集列表失败: {str(e)}")
@app.get("/api/v1/projects/{dataset_id}/tasks")
async def get_project_tasks(dataset_id: str):
"""获取指定项目的所有任务"""
try:
tasks = task_status_store.get_by_unique_id(dataset_id)
return {
"success": True,
"message": "项目任务获取成功",
"dataset_id": dataset_id,
"total_tasks": len(tasks),
"tasks": tasks
}
except Exception as e:
print(f"Error getting project tasks: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取项目任务失败: {str(e)}")
@app.post("/api/v1/files/{dataset_id}/cleanup/async")
async def cleanup_project_async_endpoint(dataset_id: str, remove_all: bool = False):
"""异步清理项目文件"""
try:
task = cleanup_project_async(dataset_id=dataset_id, remove_all=remove_all)
return {
"success": True,
"message": f"项目清理任务已提交到队列项目ID: {dataset_id}",
"dataset_id": dataset_id,
"task_id": task.id,
"action": "remove_all" if remove_all else "cleanup_logs"
}
except Exception as e:
print(f"Error submitting cleanup task: {str(e)}")
raise HTTPException(status_code=500, detail=f"提交清理任务失败: {str(e)}")
@app.post("/api/v1/chat/completions")
async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)):
"""
Chat completions API similar to OpenAI, supports both streaming and non-streaming
Args:
request: ChatRequest containing messages, model, optional dataset_ids list, required bot_id, system_prompt, mcp_settings, and files
authorization: Authorization header containing API key (Bearer <API_KEY>)
Returns:
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
Notes:
- dataset_ids: 可选参数当提供时必须是项目ID列表单个项目也使用数组格式
- bot_id: 必需参数机器人ID
- 只有当 robot_type == "catalog_agent" 且 dataset_ids 为非空数组时才会创建机器人项目目录projects/robot/{bot_id}/
- robot_type 为其他值(包括默认的 "agent")时不创建任何目录
- dataset_ids 为空数组 []、None 或未提供时不创建任何目录
- 支持多知识库合并,自动处理文件夹重名冲突
Required Parameters:
- bot_id: str - 目标机器人ID
- messages: List[Message] - 对话消息列表
Optional Parameters:
- dataset_ids: List[str] - 源知识库项目ID列表单个项目也使用数组格式
- robot_type: str - 机器人类型,默认为 "agent"
Example:
{"bot_id": "my-bot-001", "messages": [{"role": "user", "content": "Hello"}]}
{"dataset_ids": ["project-123"], "bot_id": "my-bot-001", "messages": [{"role": "user", "content": "Hello"}]}
{"dataset_ids": ["project-123", "project-456"], "bot_id": "my-bot-002", "messages": [{"role": "user", "content": "Hello"}]}
{"dataset_ids": ["project-123"], "bot_id": "my-catalog-bot", "robot_type": "catalog_agent", "messages": [{"role": "user", "content": "Hello"}]}
"""
try:
# v1接口从Authorization header中提取API key作为模型API密钥
api_key = extract_api_key_from_auth(authorization)
# 获取bot_id必需参数
bot_id = request.bot_id
if not bot_id:
raise HTTPException(status_code=400, detail="bot_id is required")
# 创建项目目录如果有dataset_ids且不是agent类型
project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type)
# 收集额外参数作为 generate_cfg
exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier'}
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
# 处理消息
messages = process_messages(request.messages, request.language)
# 调用公共的agent创建和响应生成逻辑
return await create_agent_and_generate_response(
bot_id=bot_id,
api_key=api_key,
messages=messages,
stream=request.stream,
tool_response=True,
model_name=request.model,
model_server=request.model_server,
language=request.language,
system_prompt=request.system_prompt,
mcp_settings=request.mcp_settings,
robot_type=request.robot_type,
project_dir=project_dir,
generate_cfg=generate_cfg,
user_identifier=request.user_identifier
)
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error in chat_completions: {str(e)}")
print(f"Full traceback: {error_details}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
"""获取机器人配置从后端API"""
try:
backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
url = f"{backend_host}/v1/agent_bot_config/{bot_id}"
auth_token = generate_v2_auth_token(bot_id)
headers = {
"content-type": "application/json",
"authorization": f"Bearer {auth_token}"
}
# 使用异步HTTP请求
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers, timeout=30) as response:
if response.status != 200:
raise HTTPException(
status_code=400,
detail=f"Failed to fetch bot config: API returned status code {response.status}"
)
# 解析响应
response_data = await response.json()
if not response_data.get("success"):
raise HTTPException(
status_code=400,
detail=f"Failed to fetch bot config: {response_data.get('message', 'Unknown error')}"
)
return response_data.get("data", {})
except aiohttp.ClientError as e:
raise HTTPException(
status_code=500,
detail=f"Failed to connect to backend API: {str(e)}"
)
except Exception as e:
if isinstance(e, HTTPException):
raise
raise HTTPException(
status_code=500,
detail=f"Failed to fetch bot config: {str(e)}"
)
def process_messages(messages: List[Message], language: Optional[str] = None) -> List[Dict[str, str]]:
"""处理消息列表,包括[TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER]分割和语言指令添加
这是 get_content_from_messages 的逆运算,将包含 [TOOL_RESPONSE] 的消息重新组装回
msg['role'] == 'function' 和 msg.get('function_call') 的格式。
"""
processed_messages = []
# 收集所有ASSISTANT消息的索引
assistant_indices = [i for i, msg in enumerate(messages) if msg.role == "assistant"]
total_assistant_messages = len(assistant_indices)
cutoff_point = max(0, total_assistant_messages - 5)
# 处理每条消息
for i, msg in enumerate(messages):
if msg.role == "assistant":
# 确定当前ASSISTANT消息在所有ASSISTANT消息中的位置从0开始
assistant_position = assistant_indices.index(i)
# 使用正则表达式按照 [TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER] 进行切割
parts = re.split(r'\[(TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg.content)
# 重新组装内容,根据消息位置决定处理方式
filtered_content = ""
current_tag = None
is_recent_message = assistant_position >= cutoff_point # 最近10条消息
for i in range(0, len(parts)):
if i % 2 == 0: # 文本内容
text = parts[i].strip()
if not text:
continue
if current_tag == "TOOL_RESPONSE":
if is_recent_message:
# 最近10条ASSISTANT消息保留完整TOOL_RESPONSE信息使用简略模式
if len(text) <= 500:
filtered_content += f"[TOOL_RESPONSE]\n{text}\n"
else:
# 截取前中后3段内容每段250字
first_part = text[:250]
middle_start = len(text) // 2 - 125
middle_part = text[middle_start:middle_start + 250]
last_part = text[-250:]
# 计算省略的字数
omitted_count = len(text) - 750
omitted_text = f"...此处省略{omitted_count}字..."
# 拼接内容
truncated_text = f"{first_part}\n{omitted_text}\n{middle_part}\n{omitted_text}\n{last_part}"
filtered_content += f"[TOOL_RESPONSE]\n{truncated_text}\n"
# 10条以上的消息不保留TOOL_RESPONSE数据完全跳过
elif current_tag == "TOOL_CALL":
if is_recent_message:
# 最近10条ASSISTANT消息保留TOOL_CALL信息
filtered_content += f"[TOOL_CALL]\n{text}\n"
# 10条以上的消息不保留TOOL_CALL数据完全跳过
elif current_tag == "ANSWER":
# 所有ASSISTANT消息都保留ANSWER数据
filtered_content += f"[ANSWER]\n{text}\n"
else:
# 第一个标签之前的内容
filtered_content += text + "\n"
else: # 标签
current_tag = parts[i]
# 取最终处理后的内容,去除首尾空白
final_content = filtered_content.strip()
if final_content:
processed_messages.append({"role": msg.role, "content": final_content})
else:
# 如果处理后为空,使用原内容
processed_messages.append({"role": msg.role, "content": msg.content})
else:
processed_messages.append({"role": msg.role, "content": msg.content})
# 逆运算:将包含 [TOOL_RESPONSE] 的消息重新组装回 msg['role'] == 'function' 和 msg.get('function_call')
# 这是 get_content_from_messages 的逆运算
final_messages = []
for msg in processed_messages:
if msg["role"] == ASSISTANT and "[TOOL_RESPONSE]" in msg["content"]:
# 分割消息内容
parts = re.split(r'\[(TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"])
current_tag = None
assistant_content = ""
function_calls = []
tool_responses = []
for i in range(0, len(parts)):
if i % 2 == 0: # 文本内容
text = parts[i].strip()
if not text:
continue
if current_tag == "TOOL_RESPONSE":
# 解析 TOOL_RESPONSE 格式:[TOOL_RESPONSE] function_name\ncontent
lines = text.split('\n', 1)
function_name = lines[0].strip() if lines else ""
response_content = lines[1].strip() if len(lines) > 1 else ""
tool_responses.append({
"role": FUNCTION,
"name": function_name,
"content": response_content
})
elif current_tag == "TOOL_CALL":
# 解析 TOOL_CALL 格式:[TOOL_CALL] function_name\narguments
lines = text.split('\n', 1)
function_name = lines[0].strip() if lines else ""
arguments = lines[1].strip() if len(lines) > 1 else ""
function_calls.append({
"name": function_name,
"arguments": arguments
})
elif current_tag == "ANSWER":
assistant_content += text + "\n"
else:
# 第一个标签之前的内容也属于 assistant
assistant_content += text + "\n"
else: # 标签
current_tag = parts[i]
# 添加 assistant 消息(如果有内容)
if assistant_content.strip() or function_calls:
assistant_msg = {"role": ASSISTANT}
if assistant_content.strip():
assistant_msg["content"] = assistant_content.strip()
if function_calls:
# 如果有多个 function_call只取第一个兼容原有逻辑
assistant_msg["function_call"] = function_calls[0]
final_messages.append(assistant_msg)
# 添加所有 tool_responses 作为 function 消息
final_messages.extend(tool_responses)
else:
# 非 assistant 消息或不包含 [TOOL_RESPONSE] 的消息直接添加
final_messages.append(msg)
# 在最后一条消息的末尾追加回复语言
if final_messages and language:
language_map = {
'zh': '请用中文回复',
'en': 'Please reply in English',
'ja': '日本語で回答してください',
'jp': '日本語で回答してください'
}
language_instruction = language_map.get(language.lower(), '')
if language_instruction:
# 在最后一条消息末尾追加语言指令
final_messages[-1]['content'] = final_messages[-1]['content'] + f"\n\n{language_instruction}"
return final_messages
async def create_agent_and_generate_response(
bot_id: str,
api_key: str,
messages: List[Dict[str, str]],
stream: bool,
tool_response: bool,
model_name: str,
model_server: str,
language: str,
system_prompt: Optional[str],
mcp_settings: Optional[List[Dict]],
robot_type: str,
project_dir: Optional[str] = None,
generate_cfg: Optional[Dict] = None,
user_identifier: Optional[str] = None
) -> Union[ChatResponse, StreamingResponse]:
"""创建agent并生成响应的公共逻辑"""
if generate_cfg is None:
generate_cfg = {}
# 从全局管理器获取或创建助手实例
agent = await agent_manager.get_or_create_agent(
bot_id=bot_id,
project_dir=project_dir,
model_name=model_name,
api_key=api_key,
model_server=model_server,
generate_cfg=generate_cfg,
language=language,
system_prompt=system_prompt,
mcp_settings=mcp_settings,
robot_type=robot_type,
user_identifier=user_identifier
)
# 根据stream参数决定返回流式还是非流式响应
if stream:
return StreamingResponse(
generate_stream_response(agent, messages, tool_response, model_name),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
)
else:
# 非流式响应
final_responses = agent.run_nonstream(messages)
if final_responses and len(final_responses) > 0:
# 使用 get_content_from_messages 处理响应,支持 tool_response 参数
content = get_content_from_messages(final_responses, tool_response=tool_response)
# 构造OpenAI格式的响应
return ChatResponse(
choices=[{
"index": 0,
"message": {
"role": "assistant",
"content": content
},
"finish_reason": "stop"
}],
usage={
"prompt_tokens": sum(len(msg.get("content", "")) for msg in messages),
"completion_tokens": len(content),
"total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(content)
}
)
else:
raise HTTPException(status_code=500, detail="No response from agent")
def create_project_directory(dataset_ids: Optional[List[str]], bot_id: str, robot_type: str = "general_agent") -> Optional[str]:
"""创建项目目录的公共逻辑"""
# 只有当 robot_type == "catalog_agent" 且 dataset_ids 不为空时才创建目录
if robot_type != "catalog_agent" or not dataset_ids or len(dataset_ids) == 0:
return None
try:
from utils.multi_project_manager import create_robot_project
return create_robot_project(dataset_ids, bot_id)
except Exception as e:
print(f"Error creating project directory: {e}")
return None
def extract_api_key_from_auth(authorization: Optional[str]) -> Optional[str]:
"""从Authorization header中提取API key"""
if not authorization:
return None
# 移除 "Bearer " 前缀
if authorization.startswith("Bearer "):
return authorization[7:]
else:
return authorization
def generate_v2_auth_token(bot_id: str) -> str:
"""生成v2接口的认证token"""
masterkey = os.getenv("MASTERKEY", "master")
token_input = f"{masterkey}:{bot_id}"
return hashlib.md5(token_input.encode()).hexdigest()
@app.post("/api/v2/chat/completions")
async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[str] = Header(None)):
"""
Chat completions API v2 with simplified parameters.
Only requires messages, stream, tool_response, bot_id, and language parameters.
Other parameters are fetched from the backend bot configuration API.
Args:
request: ChatRequestV2 containing only essential parameters
authorization: Authorization header for authentication (different from v1)
Returns:
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
Required Parameters:
- bot_id: str - 目标机器人ID
- messages: List[Message] - 对话消息列表
Optional Parameters:
- stream: bool - 是否流式输出默认false
- tool_response: bool - 是否包含工具响应默认false
- language: str - 回复语言,默认"ja"
Authentication:
- Requires valid MD5 hash token: MD5(MASTERKEY:bot_id)
- Authorization header should contain: Bearer {token}
- Uses MD5 hash of MASTERKEY:bot_id for backend API authentication
- Optionally uses API key from bot config for model access
"""
try:
# 获取bot_id必需参数
bot_id = request.bot_id
if not bot_id:
raise HTTPException(status_code=400, detail="bot_id is required")
# v2接口鉴权验证
expected_token = generate_v2_auth_token(bot_id)
provided_token = extract_api_key_from_auth(authorization)
if not provided_token:
raise HTTPException(
status_code=401,
detail="Authorization header is required for v2 API"
)
if provided_token != expected_token:
raise HTTPException(
status_code=403,
detail=f"Invalid authentication token. Expected: {expected_token[:8]}..., Provided: {provided_token[:8]}..."
)
# 从后端API获取机器人配置使用v2的鉴权方式
bot_config = await fetch_bot_config(bot_id)
# v2接口API密钥优先从后端配置获取其次才从Authorization header获取
# 注意这里的Authorization header已经用于鉴权不再作为API key使用
api_key = bot_config.get("api_key")
# 创建项目目录从后端配置获取dataset_ids
project_dir = create_project_directory(
bot_config.get("dataset_ids", []),
bot_id,
bot_config.get("robot_type", "general_agent")
)
# 处理消息
messages = process_messages(request.messages, request.language)
# 调用公共的agent创建和响应生成逻辑
return await create_agent_and_generate_response(
bot_id=bot_id,
api_key=api_key,
messages=messages,
stream=request.stream,
tool_response=request.tool_response,
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
model_server=bot_config.get("model_server", ""),
language=request.language or bot_config.get("language", "ja"),
system_prompt=bot_config.get("system_prompt"),
mcp_settings=bot_config.get("mcp_settings", []),
robot_type=bot_config.get("robot_type", "general_agent"),
project_dir=project_dir,
generate_cfg={}, # v2接口不传递额外的generate_cfg
user_identifier=request.user_identifier
)
except HTTPException:
raise
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error in chat_completions_v2: {str(e)}")
print(f"Full traceback: {error_details}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post("/api/v1/upload")
async def upload_file(file: UploadFile = File(...), folder: Optional[str] = Form(None)):
"""
文件上传API接口上传文件到 ./projects/uploads/ 目录下
可以指定自定义文件夹名,如果不指定则使用日期文件夹
指定文件夹时使用原始文件名并支持版本控制
Args:
file: 上传的文件
folder: 可选的自定义文件夹名
Returns:
dict: 包含文件路径和文件夹信息的响应
"""
try:
# 调试信息
print(f"Received folder parameter: {folder}")
print(f"File received: {file.filename if file else 'None'}")
# 确定上传文件夹
if folder:
# 使用指定的自定义文件夹
target_folder = folder
# 安全性检查:防止路径遍历攻击
target_folder = os.path.basename(target_folder)
else:
# 获取当前日期并格式化为年月日
current_date = datetime.now()
target_folder = current_date.strftime("%Y%m%d")
# 创建上传目录
upload_dir = os.path.join("projects", "uploads", target_folder)
os.makedirs(upload_dir, exist_ok=True)
# 处理文件名
if not file.filename:
raise HTTPException(status_code=400, detail="文件名不能为空")
# 解析文件名和扩展名
original_filename = file.filename
name_without_ext, file_extension = os.path.splitext(original_filename)
# 根据是否指定文件夹决定命名策略
if folder:
# 使用原始文件名,支持版本控制
final_filename, version = get_versioned_filename(upload_dir, name_without_ext, file_extension)
file_path = os.path.join(upload_dir, final_filename)
# 保存文件
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
return {
"success": True,
"message": f"文件上传成功{' (版本: ' + str(version) + ')' if version > 1 else ''}",
"file_path": file_path,
"folder": target_folder,
"original_filename": original_filename,
"version": version
}
else:
# 使用UUID唯一文件名原有逻辑
unique_filename = f"{uuid.uuid4()}{file_extension}"
file_path = os.path.join(upload_dir, unique_filename)
# 保存文件
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
return {
"success": True,
"message": "文件上传成功",
"file_path": file_path,
"folder": target_folder,
"original_filename": original_filename
}
except HTTPException:
raise
except Exception as e:
print(f"Error uploading file: {str(e)}")
raise HTTPException(status_code=500, detail=f"文件上传失败: {str(e)}")
@app.get("/api/health")
async def health_check():
"""Health check endpoint"""
return {"message": "Database Assistant API is running"}
@app.get("/api/v1/system/performance")
async def get_performance_stats():
"""获取系统性能统计信息"""
try:
# 获取agent管理器统计
agent_stats = agent_manager.get_cache_stats()
# 获取连接池统计(简化版)
pool_stats = {
"connection_pool": "active",
"max_connections_per_host": 100,
"max_connections_total": 500,
"keepalive_timeout": 30
}
# 获取文件缓存统计
file_cache_stats = {
"cache_size": len(file_cache._cache) if hasattr(file_cache, '_cache') else 0,
"max_cache_size": file_cache.cache_size if hasattr(file_cache, 'cache_size') else 1000,
"ttl": file_cache.ttl if hasattr(file_cache, 'ttl') else 300
}
# 系统资源信息
import psutil
system_stats = {
"cpu_count": multiprocessing.cpu_count(),
"memory_total_gb": round(psutil.virtual_memory().total / (1024**3), 2),
"memory_available_gb": round(psutil.virtual_memory().available / (1024**3), 2),
"memory_percent": psutil.virtual_memory().percent,
"disk_usage_percent": psutil.disk_usage('/').percent
}
return {
"success": True,
"timestamp": int(time.time()),
"performance": {
"agent_manager": agent_stats,
"connection_pool": pool_stats,
"file_cache": file_cache_stats,
"system": system_stats
}
}
except Exception as e:
print(f"Error getting performance stats: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取性能统计失败: {str(e)}")
@app.post("/api/v1/system/optimize")
async def optimize_system(profile: str = "balanced"):
"""应用系统优化配置"""
try:
# 应用优化配置
config = apply_optimization_profile(profile)
return {
"success": True,
"message": f"已应用 {profile} 优化配置",
"config": config
}
except Exception as e:
print(f"Error applying optimization profile: {str(e)}")
raise HTTPException(status_code=500, detail=f"应用优化配置失败: {str(e)}")
@app.post("/api/v1/system/clear-cache")
async def clear_system_cache(cache_type: Optional[str] = None):
"""清理系统缓存"""
try:
cleared_counts = {}
if cache_type is None or cache_type == "agent":
# 清理agent缓存
agent_count = agent_manager.clear_cache()
cleared_counts["agent_cache"] = agent_count
if cache_type is None or cache_type == "file":
# 清理文件缓存
if hasattr(file_cache, '_cache'):
file_count = len(file_cache._cache)
file_cache._cache.clear()
cleared_counts["file_cache"] = file_count
return {
"success": True,
"message": f"已清理指定类型的缓存",
"cleared_counts": cleared_counts
}
except Exception as e:
print(f"Error clearing cache: {str(e)}")
raise HTTPException(status_code=500, detail=f"清理缓存失败: {str(e)}")
@app.get("/api/v1/system/config")
async def get_system_config():
"""获取当前系统配置"""
try:
return {
"success": True,
"config": {
"max_cached_agents": max_cached_agents,
"shard_count": shard_count,
"tokenizer_parallelism": os.getenv("TOKENIZERS_PARALLELISM", "true"),
"max_connections_per_host": os.getenv("MAX_CONNECTIONS_PER_HOST", "100"),
"max_connections_total": os.getenv("MAX_CONNECTIONS_TOTAL", "500"),
"file_cache_size": os.getenv("FILE_CACHE_SIZE", "1000"),
"file_cache_ttl": os.getenv("FILE_CACHE_TTL", "300")
}
}
except Exception as e:
print(f"Error getting system config: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取系统配置失败: {str(e)}")
@app.post("/system/remove-project-cache")
async def remove_project_cache(dataset_id: str):
"""移除特定项目的缓存"""
try:
removed_count = agent_manager.remove_cache_by_unique_id(dataset_id)
if removed_count > 0:
return {"message": f"项目缓存移除成功: {dataset_id}", "removed_count": removed_count}
else:
return {"message": f"未找到项目缓存: {dataset_id}", "removed_count": 0}
except Exception as e:
raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}")
@app.get("/api/v1/files/{dataset_id}/status")
async def get_files_processing_status(dataset_id: str):
"""获取项目的文件处理状态"""
try:
# Load processed files log
processed_log = load_processed_files_log(dataset_id)
# Get project directory info
project_dir = os.path.join("projects", "data", dataset_id)
project_exists = os.path.exists(project_dir)
# Collect document.txt files
document_files = []
if project_exists:
for root, dirs, files in os.walk(project_dir):
for file in files:
if file == "document.txt":
document_files.append(os.path.join(root, file))
return {
"dataset_id": dataset_id,
"project_exists": project_exists,
"processed_files_count": len(processed_log),
"processed_files": processed_log,
"document_files_count": len(document_files),
"document_files": document_files,
"log_file_exists": os.path.exists(os.path.join("projects", "data", dataset_id, "processed_files.json"))
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取文件处理状态失败: {str(e)}")
@app.post("/api/v1/files/{dataset_id}/reset")
async def reset_files_processing(dataset_id: str):
"""重置项目的文件处理状态,删除处理日志和所有文件"""
try:
project_dir = os.path.join("projects", "data", dataset_id)
log_file = os.path.join("projects", "data", dataset_id, "processed_files.json")
# Load processed log to know what files to remove
processed_log = load_processed_files_log(dataset_id)
removed_files = []
# Remove all processed files and their dataset directories
for file_hash, file_info in processed_log.items():
# Remove local file in files directory
if 'local_path' in file_info:
if remove_file_or_directory(file_info['local_path']):
removed_files.append(file_info['local_path'])
# Handle new key-based structure first
if 'key' in file_info:
# Remove dataset directory by key
key = file_info['key']
if remove_dataset_directory_by_key(dataset_id, key):
removed_files.append(f"dataset/{key}")
elif 'filename' in file_info:
# Fallback to old filename-based structure
filename_without_ext = os.path.splitext(file_info['filename'])[0]
dataset_dir = os.path.join("projects", "data", dataset_id, "dataset", filename_without_ext)
if remove_file_or_directory(dataset_dir):
removed_files.append(dataset_dir)
# Also remove any specific dataset path if exists (fallback)
if 'dataset_path' in file_info:
if remove_file_or_directory(file_info['dataset_path']):
removed_files.append(file_info['dataset_path'])
# Remove the log file
if remove_file_or_directory(log_file):
removed_files.append(log_file)
# Remove the entire files directory
files_dir = os.path.join(project_dir, "files")
if remove_file_or_directory(files_dir):
removed_files.append(files_dir)
# Also remove the entire dataset directory (clean up any remaining files)
dataset_dir = os.path.join(project_dir, "dataset")
if remove_file_or_directory(dataset_dir):
removed_files.append(dataset_dir)
# Remove README.md if exists
readme_file = os.path.join(project_dir, "README.md")
if remove_file_or_directory(readme_file):
removed_files.append(readme_file)
return {
"message": f"文件处理状态重置成功: {dataset_id}",
"removed_files_count": len(removed_files),
"removed_files": removed_files
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}")
# ============ 语义检索 API 端点 ============
@app.post("/api/v1/semantic-search", response_model=SemanticSearchResponse)
async def semantic_search(request: SemanticSearchRequest):
"""
语义搜索 API
Args:
request: 包含 embedding_file 和 query 的搜索请求
Returns:
语义搜索结果
"""
try:
search_service = get_search_service()
result = await search_service.semantic_search(
embedding_file=request.embedding_file,
query=request.query,
top_k=request.top_k,
min_score=request.min_score
)
if result["success"]:
return SemanticSearchResponse(
success=True,
query=result["query"],
embedding_file=result["embedding_file"],
processing_time=result["processing_time"],
total_chunks=result["total_chunks"],
chunking_strategy=result["chunking_strategy"],
results=[
SearchResult(
rank=r["rank"],
score=r["score"],
content=r["content"],
content_preview=r["content_preview"]
)
for r in result["results"]
],
cache_stats=result.get("cache_stats")
)
else:
return SemanticSearchResponse(
success=False,
query=request.query,
embedding_file=request.embedding_file,
processing_time=0.0,
total_chunks=0,
chunking_strategy="",
results=[],
error=result.get("error", "未知错误")
)
except Exception as e:
logger.error(f"语义搜索 API 错误: {e}")
raise HTTPException(status_code=500, detail=f"语义搜索失败: {str(e)}")
@app.post("/api/v1/semantic-search/batch")
async def batch_semantic_search(request: BatchSearchRequest):
"""
批量语义搜索 API
Args:
request: 包含多个搜索请求的批量请求
Returns:
批量搜索结果
"""
try:
search_service = get_search_service()
# 转换请求格式
search_requests = [
{
"embedding_file": req.embedding_file,
"query": req.query,
"top_k": req.top_k,
"min_score": req.min_score
}
for req in request.requests
]
results = await search_service.batch_search(search_requests)
return {
"success": True,
"total_requests": len(request.requests),
"results": results
}
except Exception as e:
logger.error(f"批量语义搜索 API 错误: {e}")
raise HTTPException(status_code=500, detail=f"批量语义搜索失败: {str(e)}")
@app.get("/api/v1/semantic-search/stats")
async def get_semantic_search_stats():
"""
获取语义搜索服务统计信息
Returns:
服务统计信息
"""
try:
search_service = get_search_service()
stats = search_service.get_service_stats()
return {
"success": True,
"timestamp": int(time.time()),
"stats": stats
}
except Exception as e:
logger.error(f"获取语义搜索统计信息失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}")
@app.post("/api/v1/semantic-search/clear-cache")
async def clear_semantic_search_cache():
"""
清空语义搜索缓存
Returns:
清理结果
"""
try:
from manager import get_cache_manager
cache_manager = get_cache_manager()
cache_manager.clear_cache()
return {
"success": True,
"message": "缓存已清空"
}
except Exception as e:
logger.error(f"清空语义搜索缓存失败: {e}")
raise HTTPException(status_code=500, detail=f"清空缓存失败: {str(e)}")
@app.post("/api/v1/embedding/encode", response_model=EncodeResponse)
async def encode_texts(request: EncodeRequest):
"""
文本编码 API
Args:
request: 包含 texts 和 batch_size 的编码请求
Returns:
编码结果
"""
try:
search_service = get_search_service()
if not request.texts:
return EncodeResponse(
success=False,
embeddings=[],
shape=[0, 0],
processing_time=0.0,
total_texts=0,
error="texts 不能为空"
)
start_time = time.time()
# 使用模型管理器编码文本
embeddings = await search_service.model_manager.encode_texts(
request.texts,
batch_size=request.batch_size
)
processing_time = time.time() - start_time
# 转换为列表格式
embeddings_list = embeddings.tolist()
return EncodeResponse(
success=True,
embeddings=embeddings_list,
shape=list(embeddings.shape),
processing_time=processing_time,
total_texts=len(request.texts)
)
except Exception as e:
logger.error(f"文本编码 API 错误: {e}")
return EncodeResponse(
success=False,
embeddings=[],
shape=[0, 0],
processing_time=0.0,
total_texts=len(request.texts) if request else 0,
error=f"编码失败: {str(e)}"
)
# 注册文件管理API路由
app.include_router(file_manager_router)
if __name__ == "__main__":
# 启动 FastAPI 应用
print("Starting FastAPI server...")
print("File Manager API available at: http://localhost:8001/api/v1/files")
print("Web Interface available at: http://localhost:8001/public/file-manager.html")
uvicorn.run(app, host="0.0.0.0", port=8001)