1629 lines
63 KiB
Python
1629 lines
63 KiB
Python
import json
|
||
import os
|
||
import tempfile
|
||
import shutil
|
||
import uuid
|
||
import hashlib
|
||
import requests
|
||
import aiohttp
|
||
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
|
||
from datetime import datetime
|
||
import re
|
||
import multiprocessing
|
||
import time
|
||
import psutil
|
||
|
||
import uvicorn
|
||
from fastapi import FastAPI, HTTPException, Depends, Header, UploadFile, File, Form
|
||
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||
from utils.logger import logger
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from file_manager_api import router as file_manager_router
|
||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
|
||
from pydantic import BaseModel, Field
|
||
|
||
# 导入语义检索服务
|
||
from embedding import get_search_service
|
||
|
||
# Import utility modules
|
||
from utils import (
|
||
# Models
|
||
Message, DatasetRequest, ChatRequest, ChatResponse, QueueTaskRequest, IncrementalTaskRequest, QueueTaskResponse,
|
||
QueueStatusResponse, TaskStatusResponse,
|
||
|
||
# File utilities
|
||
download_file, remove_file_or_directory, get_document_preview,
|
||
load_processed_files_log, save_processed_files_log, get_file_hash,
|
||
|
||
# Dataset management
|
||
download_dataset_files, generate_dataset_structure,
|
||
remove_dataset_directory, remove_dataset_directory_by_key,
|
||
|
||
# Project management
|
||
generate_project_readme, save_project_readme, get_project_status,
|
||
remove_project, list_projects, get_project_stats,
|
||
|
||
# Agent management
|
||
get_global_agent_manager, init_global_agent_manager,
|
||
|
||
# Optimization modules
|
||
get_global_sharded_agent_manager, init_global_sharded_agent_manager,
|
||
get_global_connection_pool, init_global_connection_pool,
|
||
get_global_file_cache, init_global_file_cache,
|
||
setup_system_optimizations, get_optimized_worker_config
|
||
)
|
||
|
||
# Import ChatRequestV2 directly from api_models
|
||
from utils.api_models import ChatRequestV2
|
||
|
||
# Import modified_assistant
|
||
from modified_assistant import update_agent_llm
|
||
|
||
# Import queue manager
|
||
from task_queue.manager import queue_manager
|
||
from task_queue.integration_tasks import process_files_async, process_files_incremental_async, cleanup_project_async
|
||
from task_queue.task_status import task_status_store
|
||
|
||
# 系统优化设置
|
||
# os.environ["TOKENIZERS_PARALLELISM"] = "false" # 注释掉,使用优化的配置
|
||
|
||
|
||
def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extension: str) -> tuple[str, int]:
|
||
"""
|
||
获取带版本号的文件名,自动处理文件删除和版本递增
|
||
|
||
Args:
|
||
upload_dir: 上传目录路径
|
||
name_without_ext: 不含扩展名的文件名
|
||
file_extension: 文件扩展名(包含点号)
|
||
|
||
Returns:
|
||
tuple[str, int]: (最终文件名, 版本号)
|
||
"""
|
||
# 检查原始文件是否存在
|
||
original_file = os.path.join(upload_dir, name_without_ext + file_extension)
|
||
original_exists = os.path.exists(original_file)
|
||
|
||
# 查找所有相关的版本化文件
|
||
pattern = re.compile(re.escape(name_without_ext) + r'_(\d+)' + re.escape(file_extension) + r'$')
|
||
existing_versions = []
|
||
files_to_delete = []
|
||
|
||
for filename in os.listdir(upload_dir):
|
||
# 检查是否是原始文件
|
||
if filename == name_without_ext + file_extension:
|
||
files_to_delete.append(filename)
|
||
continue
|
||
|
||
# 检查是否是版本化文件
|
||
match = pattern.match(filename)
|
||
if match:
|
||
version_num = int(match.group(1))
|
||
existing_versions.append(version_num)
|
||
files_to_delete.append(filename)
|
||
|
||
# 如果没有任何相关文件存在,使用原始文件名(版本1)
|
||
if not original_exists and not existing_versions:
|
||
return name_without_ext + file_extension, 1
|
||
|
||
# 删除所有现有文件(原始文件和版本化文件)
|
||
for filename in files_to_delete:
|
||
file_to_delete = os.path.join(upload_dir, filename)
|
||
try:
|
||
os.remove(file_to_delete)
|
||
print(f"已删除文件: {file_to_delete}")
|
||
except OSError as e:
|
||
print(f"删除文件失败 {file_to_delete}: {e}")
|
||
|
||
# 确定下一个版本号
|
||
if existing_versions:
|
||
next_version = max(existing_versions) + 1
|
||
else:
|
||
next_version = 2
|
||
|
||
# 生成带版本号的文件名
|
||
versioned_filename = f"{name_without_ext}_{next_version}{file_extension}"
|
||
|
||
return versioned_filename, next_version
|
||
|
||
|
||
# Models are now imported from utils module
|
||
|
||
|
||
# 语义检索请求模型
|
||
|
||
|
||
# 编码请求和响应模型
|
||
class EncodeRequest(BaseModel):
|
||
texts: List[str] = Field(..., description="要编码的文本列表")
|
||
batch_size: int = Field(default=32, description="批次大小", ge=1, le=128)
|
||
|
||
|
||
class EncodeResponse(BaseModel):
|
||
success: bool = Field(..., description="是否成功")
|
||
embeddings: List[List[float]] = Field(..., description="编码结果")
|
||
shape: List[int] = Field(..., description="embeddings 形状")
|
||
processing_time: float = Field(..., description="处理时间(秒)")
|
||
total_texts: int = Field(..., description="总文本数量")
|
||
error: Optional[str] = Field(None, description="错误信息")
|
||
|
||
|
||
# Custom version for qwen-agent messages - keep this function as it's specific to this app
|
||
def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str:
|
||
"""Extract content from qwen-agent messages with special formatting"""
|
||
full_text = ''
|
||
content = []
|
||
TOOL_CALL_S = '[TOOL_CALL]'
|
||
TOOL_RESULT_S = '[TOOL_RESPONSE]'
|
||
THOUGHT_S = '[THINK]'
|
||
ANSWER_S = '[ANSWER]'
|
||
|
||
for msg in messages:
|
||
|
||
if msg['role'] == ASSISTANT:
|
||
if msg.get('reasoning_content'):
|
||
assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
|
||
content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
|
||
if msg.get('content'):
|
||
assert isinstance(msg['content'], str), 'Now only supports text messages'
|
||
# 过滤掉流式输出中的不完整 tool_call 文本
|
||
content_text = msg["content"]
|
||
|
||
# 使用正则表达式替换不完整的 tool_call 模式为空字符串
|
||
|
||
# 匹配并替换不完整的 tool_call 模式
|
||
content_text = re.sub(r'<t?o?o?l?_?c?a?l?l?$', '', content_text)
|
||
# 只有在处理后内容不为空时才添加
|
||
if content_text.strip():
|
||
content.append(f'{ANSWER_S}\n{content_text}')
|
||
if msg.get('function_call'):
|
||
content_text = msg["function_call"]["arguments"]
|
||
content_text = re.sub(r'}\n<\/?t?o?o?l?_?c?a?l?l?$', '', content_text)
|
||
if content_text.strip():
|
||
content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{content_text}')
|
||
elif msg['role'] == FUNCTION:
|
||
if tool_response:
|
||
content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
|
||
else:
|
||
raise TypeError
|
||
|
||
if content:
|
||
full_text = '\n'.join(content)
|
||
|
||
return full_text
|
||
|
||
|
||
# Helper functions are now imported from utils module
|
||
|
||
|
||
|
||
|
||
|
||
# 初始化系统优化
|
||
print("正在初始化系统优化...")
|
||
system_optimizer = setup_system_optimizations()
|
||
|
||
# 全局助手管理器配置(使用优化后的配置)
|
||
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "50")) # 增加缓存大小
|
||
shard_count = int(os.getenv("SHARD_COUNT", "16")) # 分片数量
|
||
|
||
# 初始化优化的全局助手管理器
|
||
agent_manager = init_global_sharded_agent_manager(
|
||
max_cached_agents=max_cached_agents,
|
||
shard_count=shard_count
|
||
)
|
||
|
||
# 初始化连接池
|
||
connection_pool = init_global_connection_pool(
|
||
max_connections_per_host=int(os.getenv("MAX_CONNECTIONS_PER_HOST", "100")),
|
||
max_connections_total=int(os.getenv("MAX_CONNECTIONS_TOTAL", "500")),
|
||
keepalive_timeout=int(os.getenv("KEEPALIVE_TIMEOUT", "30")),
|
||
connect_timeout=int(os.getenv("CONNECT_TIMEOUT", "10")),
|
||
total_timeout=int(os.getenv("TOTAL_TIMEOUT", "60"))
|
||
)
|
||
|
||
# 初始化文件缓存
|
||
file_cache = init_global_file_cache(
|
||
cache_size=int(os.getenv("FILE_CACHE_SIZE", "1000")),
|
||
ttl=int(os.getenv("FILE_CACHE_TTL", "300"))
|
||
)
|
||
|
||
print("系统优化初始化完成")
|
||
print(f"- 分片Agent管理器: {shard_count} 个分片,最多缓存 {max_cached_agents} 个agent")
|
||
print(f"- 连接池: 每主机100连接,总计500连接")
|
||
print(f"- 文件缓存: 1000个文件,TTL 300秒")
|
||
|
||
app = FastAPI(title="Database Assistant API", version="1.0.0")
|
||
|
||
# 挂载public文件夹为静态文件服务
|
||
app.mount("/public", StaticFiles(directory="public"), name="static")
|
||
|
||
# 添加CORS中间件,支持前端页面
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 在生产环境中应该设置为具体的前端域名
|
||
allow_credentials=True,
|
||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
|
||
allow_headers=[
|
||
"Authorization", "Content-Type", "Accept", "Origin", "User-Agent",
|
||
"DNT", "Cache-Control", "Range", "X-Requested-With"
|
||
],
|
||
)
|
||
|
||
|
||
# Models are now imported from utils module
|
||
|
||
|
||
async def generate_stream_response(agent, messages, tool_response: bool, model: str) -> AsyncGenerator[str, None]:
|
||
"""生成流式响应"""
|
||
accumulated_content = ""
|
||
chunk_id = 0
|
||
try:
|
||
for response in agent.run(messages=messages):
|
||
previous_content = accumulated_content
|
||
accumulated_content = get_content_from_messages(response, tool_response=tool_response)
|
||
|
||
# 计算新增的内容
|
||
if accumulated_content.startswith(previous_content):
|
||
new_content = accumulated_content[len(previous_content):]
|
||
else:
|
||
new_content = accumulated_content
|
||
previous_content = ""
|
||
|
||
# 只有当有新内容时才发送chunk
|
||
if new_content:
|
||
chunk_id += 1
|
||
# 构造OpenAI格式的流式响应
|
||
chunk_data = {
|
||
"id": f"chatcmpl-{chunk_id}",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {
|
||
"content": new_content
|
||
},
|
||
"finish_reason": None
|
||
}]
|
||
}
|
||
|
||
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
||
|
||
# 发送最终完成标记
|
||
final_chunk = {
|
||
"id": f"chatcmpl-{chunk_id + 1}",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {},
|
||
"finish_reason": "stop"
|
||
}]
|
||
}
|
||
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
|
||
|
||
# 发送结束标记
|
||
yield "data: [DONE]\n\n"
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
logger.error(f"Error in generate_stream_response: {str(e)}")
|
||
logger.error(f"Full traceback: {error_details}")
|
||
|
||
error_data = {
|
||
"error": {
|
||
"message": f"Stream error: {str(e)}",
|
||
"type": "internal_error"
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
|
||
|
||
# Models are now imported from utils module
|
||
|
||
|
||
@app.post("/api/v1/files/process/async")
|
||
async def process_files_async_endpoint(request: QueueTaskRequest, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
异步处理文件的队列版本API
|
||
与 /api/v1/files/process 功能相同,但使用队列异步处理
|
||
|
||
Args:
|
||
request: QueueTaskRequest containing dataset_id, files, system_prompt, mcp_settings, and queue options
|
||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||
|
||
Returns:
|
||
QueueTaskResponse: Processing result with task ID for tracking
|
||
"""
|
||
try:
|
||
dataset_id = request.dataset_id
|
||
if not dataset_id:
|
||
raise HTTPException(status_code=400, detail="dataset_id is required")
|
||
|
||
# 估算处理时间(基于文件数量)
|
||
estimated_time = 0
|
||
if request.upload_folder:
|
||
# 对于upload_folder,无法预先估算文件数量,使用默认时间
|
||
estimated_time = 120 # 默认2分钟
|
||
elif request.files:
|
||
total_files = sum(len(file_list) for file_list in request.files.values())
|
||
estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒
|
||
|
||
# 提交异步任务
|
||
task_id = queue_manager.enqueue_multiple_files(
|
||
project_id=dataset_id,
|
||
file_paths=[],
|
||
original_filenames=[]
|
||
)
|
||
|
||
# 创建任务状态记录
|
||
import uuid
|
||
task_id = str(uuid.uuid4())
|
||
task_status_store.set_status(
|
||
task_id=task_id,
|
||
unique_id=dataset_id,
|
||
status="pending"
|
||
)
|
||
|
||
# 提交异步任务
|
||
task = process_files_async(
|
||
dataset_id=dataset_id,
|
||
files=request.files,
|
||
upload_folder=request.upload_folder,
|
||
task_id=task_id
|
||
)
|
||
|
||
# 构建更详细的消息
|
||
message = f"文件处理任务已提交到队列,项目ID: {dataset_id}"
|
||
if request.upload_folder:
|
||
group_count = len(request.upload_folder)
|
||
message += f",将从 {group_count} 个上传文件夹自动扫描文件"
|
||
elif request.files:
|
||
total_files = sum(len(file_list) for file_list in request.files.values())
|
||
message += f",包含 {total_files} 个文件"
|
||
|
||
return QueueTaskResponse(
|
||
success=True,
|
||
message=message,
|
||
dataset_id=dataset_id,
|
||
task_id=task_id, # 使用我们自己的task_id
|
||
task_status="pending",
|
||
estimated_processing_time=estimated_time
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
print(f"Error submitting async file processing task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/files/process/incremental")
|
||
async def process_files_incremental_endpoint(request: IncrementalTaskRequest, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
增量处理文件的队列版本API - 支持添加和删除文件
|
||
|
||
Args:
|
||
request: IncrementalTaskRequest containing dataset_id, files_to_add, files_to_remove, system_prompt, mcp_settings, and queue options
|
||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||
|
||
Returns:
|
||
QueueTaskResponse: Processing result with task ID for tracking
|
||
"""
|
||
try:
|
||
dataset_id = request.dataset_id
|
||
if not dataset_id:
|
||
raise HTTPException(status_code=400, detail="dataset_id is required")
|
||
|
||
# 验证至少有添加或删除操作
|
||
if not request.files_to_add and not request.files_to_remove:
|
||
raise HTTPException(status_code=400, detail="At least one of files_to_add or files_to_remove must be provided")
|
||
|
||
# 估算处理时间(基于文件数量)
|
||
estimated_time = 0
|
||
total_add_files = sum(len(file_list) for file_list in (request.files_to_add or {}).values())
|
||
total_remove_files = sum(len(file_list) for file_list in (request.files_to_remove or {}).values())
|
||
total_files = total_add_files + total_remove_files
|
||
estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒
|
||
|
||
# 创建任务状态记录
|
||
import uuid
|
||
task_id = str(uuid.uuid4())
|
||
task_status_store.set_status(
|
||
task_id=task_id,
|
||
unique_id=dataset_id,
|
||
status="pending"
|
||
)
|
||
|
||
# 提交增量异步任务
|
||
task = process_files_incremental_async(
|
||
dataset_id=dataset_id,
|
||
files_to_add=request.files_to_add,
|
||
files_to_remove=request.files_to_remove,
|
||
system_prompt=request.system_prompt,
|
||
mcp_settings=request.mcp_settings,
|
||
task_id=task_id
|
||
)
|
||
|
||
return QueueTaskResponse(
|
||
success=True,
|
||
message=f"增量文件处理任务已提交到队列 - 添加 {total_add_files} 个文件,删除 {total_remove_files} 个文件,项目ID: {dataset_id}",
|
||
dataset_id=dataset_id,
|
||
task_id=task_id,
|
||
task_status="pending",
|
||
estimated_processing_time=estimated_time
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
print(f"Error submitting incremental file processing task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/task/{task_id}/status")
|
||
async def get_task_status(task_id: str):
|
||
"""获取任务状态 - 简单可靠"""
|
||
try:
|
||
status_data = task_status_store.get_status(task_id)
|
||
|
||
if not status_data:
|
||
return {
|
||
"success": False,
|
||
"message": "任务不存在或已过期",
|
||
"task_id": task_id,
|
||
"status": "not_found"
|
||
}
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "任务状态获取成功",
|
||
"task_id": task_id,
|
||
"status": status_data["status"],
|
||
"unique_id": status_data["unique_id"],
|
||
"created_at": status_data["created_at"],
|
||
"updated_at": status_data["updated_at"],
|
||
"result": status_data.get("result"),
|
||
"error": status_data.get("error")
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting task status: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取任务状态失败: {str(e)}")
|
||
|
||
|
||
@app.delete("/api/v1/task/{task_id}")
|
||
async def delete_task(task_id: str):
|
||
"""删除任务记录"""
|
||
try:
|
||
success = task_status_store.delete_status(task_id)
|
||
if success:
|
||
return {
|
||
"success": True,
|
||
"message": f"任务记录已删除: {task_id}",
|
||
"task_id": task_id
|
||
}
|
||
else:
|
||
return {
|
||
"success": False,
|
||
"message": f"任务记录不存在: {task_id}",
|
||
"task_id": task_id
|
||
}
|
||
except Exception as e:
|
||
print(f"Error deleting task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"删除任务记录失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/tasks")
|
||
async def list_tasks(status: Optional[str] = None, dataset_id: Optional[str] = None, limit: int = 100):
|
||
"""列出任务,支持筛选"""
|
||
try:
|
||
if status or dataset_id:
|
||
# 使用搜索功能
|
||
tasks = task_status_store.search_tasks(status=status, unique_id=dataset_id, limit=limit)
|
||
else:
|
||
# 获取所有任务
|
||
all_tasks = task_status_store.list_all()
|
||
tasks = list(all_tasks.values())[:limit]
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "任务列表获取成功",
|
||
"total_tasks": len(tasks),
|
||
"tasks": tasks,
|
||
"filters": {
|
||
"status": status,
|
||
"dataset_id": dataset_id,
|
||
"limit": limit
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error listing tasks: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取任务列表失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/tasks/statistics")
|
||
async def get_task_statistics():
|
||
"""获取任务统计信息"""
|
||
try:
|
||
stats = task_status_store.get_statistics()
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "统计信息获取成功",
|
||
"statistics": stats
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting statistics: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/tasks/cleanup")
|
||
async def cleanup_tasks(older_than_days: int = 7):
|
||
"""清理旧任务记录"""
|
||
try:
|
||
deleted_count = task_status_store.cleanup_old_tasks(older_than_days=older_than_days)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"已清理 {deleted_count} 条旧任务记录",
|
||
"deleted_count": deleted_count,
|
||
"older_than_days": older_than_days
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error cleaning up tasks: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"清理任务记录失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/projects")
|
||
async def list_all_projects():
|
||
"""获取所有项目列表"""
|
||
try:
|
||
# 获取机器人项目(projects/robot)
|
||
robot_dir = "projects/robot"
|
||
robot_projects = []
|
||
|
||
if os.path.exists(robot_dir):
|
||
for item in os.listdir(robot_dir):
|
||
item_path = os.path.join(robot_dir, item)
|
||
if os.path.isdir(item_path):
|
||
try:
|
||
# 读取机器人配置文件
|
||
config_path = os.path.join(item_path, "robot_config.json")
|
||
config_data = {}
|
||
if os.path.exists(config_path):
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config_data = json.load(f)
|
||
|
||
# 统计文件数量
|
||
file_count = 0
|
||
if os.path.exists(os.path.join(item_path, "dataset")):
|
||
for root, dirs, files in os.walk(os.path.join(item_path, "dataset")):
|
||
file_count += len(files)
|
||
|
||
robot_projects.append({
|
||
"id": item,
|
||
"name": config_data.get("name", item),
|
||
"type": "robot",
|
||
"status": config_data.get("status", "active"),
|
||
"file_count": file_count,
|
||
"config": config_data,
|
||
"created_at": os.path.getctime(item_path),
|
||
"updated_at": os.path.getmtime(item_path)
|
||
})
|
||
except Exception as e:
|
||
print(f"Error reading robot project {item}: {str(e)}")
|
||
robot_projects.append({
|
||
"id": item,
|
||
"name": item,
|
||
"type": "robot",
|
||
"status": "unknown",
|
||
"file_count": 0,
|
||
"created_at": os.path.getctime(item_path),
|
||
"updated_at": os.path.getmtime(item_path)
|
||
})
|
||
|
||
# 获取数据集(projects/data)
|
||
data_dir = "projects/data"
|
||
datasets = []
|
||
|
||
if os.path.exists(data_dir):
|
||
for item in os.listdir(data_dir):
|
||
item_path = os.path.join(data_dir, item)
|
||
if os.path.isdir(item_path):
|
||
try:
|
||
# 读取处理日志
|
||
log_path = os.path.join(item_path, "processing_log.json")
|
||
log_data = {}
|
||
if os.path.exists(log_path):
|
||
with open(log_path, 'r', encoding='utf-8') as f:
|
||
log_data = json.load(f)
|
||
|
||
# 统计文件数量
|
||
file_count = 0
|
||
for root, dirs, files in os.walk(item_path):
|
||
file_count += len([f for f in files if not f.endswith('.pkl')])
|
||
|
||
# 获取状态
|
||
status = "active"
|
||
if log_data.get("status"):
|
||
status = log_data["status"]
|
||
elif os.path.exists(os.path.join(item_path, "processed")):
|
||
status = "completed"
|
||
|
||
datasets.append({
|
||
"id": item,
|
||
"name": f"数据集 - {item[:8]}...",
|
||
"type": "dataset",
|
||
"status": status,
|
||
"file_count": file_count,
|
||
"log_data": log_data,
|
||
"created_at": os.path.getctime(item_path),
|
||
"updated_at": os.path.getmtime(item_path)
|
||
})
|
||
except Exception as e:
|
||
print(f"Error reading dataset {item}: {str(e)}")
|
||
datasets.append({
|
||
"id": item,
|
||
"name": f"数据集 - {item[:8]}...",
|
||
"type": "dataset",
|
||
"status": "unknown",
|
||
"file_count": 0,
|
||
"created_at": os.path.getctime(item_path),
|
||
"updated_at": os.path.getmtime(item_path)
|
||
})
|
||
|
||
all_projects = robot_projects + datasets
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "项目列表获取成功",
|
||
"total_projects": len(all_projects),
|
||
"robot_projects": robot_projects,
|
||
"datasets": datasets,
|
||
"projects": all_projects # 保持向后兼容
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error listing projects: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取项目列表失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/projects/robot")
|
||
async def list_robot_projects():
|
||
"""获取机器人项目列表"""
|
||
try:
|
||
response = await list_all_projects()
|
||
return {
|
||
"success": True,
|
||
"message": "机器人项目列表获取成功",
|
||
"total_projects": len(response["robot_projects"]),
|
||
"projects": response["robot_projects"]
|
||
}
|
||
except Exception as e:
|
||
print(f"Error listing robot projects: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取机器人项目列表失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/projects/datasets")
|
||
async def list_datasets():
|
||
"""获取数据集列表"""
|
||
try:
|
||
response = await list_all_projects()
|
||
return {
|
||
"success": True,
|
||
"message": "数据集列表获取成功",
|
||
"total_projects": len(response["datasets"]),
|
||
"projects": response["datasets"]
|
||
}
|
||
except Exception as e:
|
||
print(f"Error listing datasets: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取数据集列表失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/projects/{dataset_id}/tasks")
|
||
async def get_project_tasks(dataset_id: str):
|
||
"""获取指定项目的所有任务"""
|
||
try:
|
||
tasks = task_status_store.get_by_unique_id(dataset_id)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "项目任务获取成功",
|
||
"dataset_id": dataset_id,
|
||
"total_tasks": len(tasks),
|
||
"tasks": tasks
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting project tasks: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取项目任务失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/files/{dataset_id}/cleanup/async")
|
||
async def cleanup_project_async_endpoint(dataset_id: str, remove_all: bool = False):
|
||
"""异步清理项目文件"""
|
||
try:
|
||
task = cleanup_project_async(dataset_id=dataset_id, remove_all=remove_all)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"项目清理任务已提交到队列,项目ID: {dataset_id}",
|
||
"dataset_id": dataset_id,
|
||
"task_id": task.id,
|
||
"action": "remove_all" if remove_all else "cleanup_logs"
|
||
}
|
||
except Exception as e:
|
||
print(f"Error submitting cleanup task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"提交清理任务失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/chat/completions")
|
||
async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
Chat completions API similar to OpenAI, supports both streaming and non-streaming
|
||
|
||
Args:
|
||
request: ChatRequest containing messages, model, optional dataset_ids list, required bot_id, system_prompt, mcp_settings, and files
|
||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||
|
||
Returns:
|
||
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
||
|
||
Notes:
|
||
- dataset_ids: 可选参数,当提供时必须是项目ID列表(单个项目也使用数组格式)
|
||
- bot_id: 必需参数,机器人ID
|
||
- 只有当 robot_type == "catalog_agent" 且 dataset_ids 为非空数组时才会创建机器人项目目录:projects/robot/{bot_id}/
|
||
- robot_type 为其他值(包括默认的 "agent")时不创建任何目录
|
||
- dataset_ids 为空数组 []、None 或未提供时不创建任何目录
|
||
- 支持多知识库合并,自动处理文件夹重名冲突
|
||
|
||
Required Parameters:
|
||
- bot_id: str - 目标机器人ID
|
||
- messages: List[Message] - 对话消息列表
|
||
Optional Parameters:
|
||
- dataset_ids: List[str] - 源知识库项目ID列表(单个项目也使用数组格式)
|
||
- robot_type: str - 机器人类型,默认为 "agent"
|
||
|
||
Example:
|
||
{"bot_id": "my-bot-001", "messages": [{"role": "user", "content": "Hello"}]}
|
||
{"dataset_ids": ["project-123"], "bot_id": "my-bot-001", "messages": [{"role": "user", "content": "Hello"}]}
|
||
{"dataset_ids": ["project-123", "project-456"], "bot_id": "my-bot-002", "messages": [{"role": "user", "content": "Hello"}]}
|
||
{"dataset_ids": ["project-123"], "bot_id": "my-catalog-bot", "robot_type": "catalog_agent", "messages": [{"role": "user", "content": "Hello"}]}
|
||
"""
|
||
try:
|
||
# v1接口:从Authorization header中提取API key作为模型API密钥
|
||
api_key = extract_api_key_from_auth(authorization)
|
||
|
||
# 获取bot_id(必需参数)
|
||
bot_id = request.bot_id
|
||
if not bot_id:
|
||
raise HTTPException(status_code=400, detail="bot_id is required")
|
||
|
||
# 创建项目目录(如果有dataset_ids且不是agent类型)
|
||
project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type)
|
||
|
||
# 收集额外参数作为 generate_cfg
|
||
exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier'}
|
||
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
|
||
|
||
# 处理消息
|
||
messages = process_messages(request.messages, request.language)
|
||
|
||
# 调用公共的agent创建和响应生成逻辑
|
||
return await create_agent_and_generate_response(
|
||
bot_id=bot_id,
|
||
api_key=api_key,
|
||
messages=messages,
|
||
stream=request.stream,
|
||
tool_response=True,
|
||
model_name=request.model,
|
||
model_server=request.model_server,
|
||
language=request.language,
|
||
system_prompt=request.system_prompt,
|
||
mcp_settings=request.mcp_settings,
|
||
robot_type=request.robot_type,
|
||
project_dir=project_dir,
|
||
generate_cfg=generate_cfg,
|
||
user_identifier=request.user_identifier
|
||
)
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"Error in chat_completions: {str(e)}")
|
||
print(f"Full traceback: {error_details}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
|
||
"""获取机器人配置从后端API"""
|
||
try:
|
||
backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
||
url = f"{backend_host}/v1/agent_bot_config/{bot_id}"
|
||
|
||
auth_token = generate_v2_auth_token(bot_id)
|
||
headers = {
|
||
"content-type": "application/json",
|
||
"authorization": f"Bearer {auth_token}"
|
||
}
|
||
# 使用异步HTTP请求
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(url, headers=headers, timeout=30) as response:
|
||
if response.status != 200:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Failed to fetch bot config: API returned status code {response.status}"
|
||
)
|
||
|
||
# 解析响应
|
||
response_data = await response.json()
|
||
|
||
if not response_data.get("success"):
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Failed to fetch bot config: {response_data.get('message', 'Unknown error')}"
|
||
)
|
||
|
||
return response_data.get("data", {})
|
||
|
||
except aiohttp.ClientError as e:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Failed to connect to backend API: {str(e)}"
|
||
)
|
||
except Exception as e:
|
||
if isinstance(e, HTTPException):
|
||
raise
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Failed to fetch bot config: {str(e)}"
|
||
)
|
||
|
||
|
||
def process_messages(messages: List[Message], language: Optional[str] = None) -> List[Dict[str, str]]:
|
||
"""处理消息列表,包括[TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER]分割和语言指令添加
|
||
|
||
这是 get_content_from_messages 的逆运算,将包含 [TOOL_RESPONSE] 的消息重新组装回
|
||
msg['role'] == 'function' 和 msg.get('function_call') 的格式。
|
||
"""
|
||
processed_messages = []
|
||
|
||
# 收集所有ASSISTANT消息的索引
|
||
assistant_indices = [i for i, msg in enumerate(messages) if msg.role == "assistant"]
|
||
total_assistant_messages = len(assistant_indices)
|
||
cutoff_point = max(0, total_assistant_messages - 5)
|
||
|
||
# 处理每条消息
|
||
for i, msg in enumerate(messages):
|
||
if msg.role == "assistant":
|
||
# 确定当前ASSISTANT消息在所有ASSISTANT消息中的位置(从0开始)
|
||
assistant_position = assistant_indices.index(i)
|
||
|
||
# 使用正则表达式按照 [TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER] 进行切割
|
||
parts = re.split(r'\[(TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg.content)
|
||
|
||
# 重新组装内容,根据消息位置决定处理方式
|
||
filtered_content = ""
|
||
current_tag = None
|
||
is_recent_message = assistant_position >= cutoff_point # 最近10条消息
|
||
|
||
for i in range(0, len(parts)):
|
||
if i % 2 == 0: # 文本内容
|
||
text = parts[i].strip()
|
||
if not text:
|
||
continue
|
||
|
||
if current_tag == "TOOL_RESPONSE":
|
||
if is_recent_message:
|
||
# 最近10条ASSISTANT消息:保留完整TOOL_RESPONSE信息(使用简略模式)
|
||
if len(text) <= 500:
|
||
filtered_content += f"[TOOL_RESPONSE]\n{text}\n"
|
||
else:
|
||
# 截取前中后3段内容,每段250字
|
||
first_part = text[:250]
|
||
middle_start = len(text) // 2 - 125
|
||
middle_part = text[middle_start:middle_start + 250]
|
||
last_part = text[-250:]
|
||
|
||
# 计算省略的字数
|
||
omitted_count = len(text) - 750
|
||
omitted_text = f"...此处省略{omitted_count}字..."
|
||
|
||
# 拼接内容
|
||
truncated_text = f"{first_part}\n{omitted_text}\n{middle_part}\n{omitted_text}\n{last_part}"
|
||
filtered_content += f"[TOOL_RESPONSE]\n{truncated_text}\n"
|
||
# 10条以上的消息:不保留TOOL_RESPONSE数据(完全跳过)
|
||
elif current_tag == "TOOL_CALL":
|
||
if is_recent_message:
|
||
# 最近10条ASSISTANT消息:保留TOOL_CALL信息
|
||
filtered_content += f"[TOOL_CALL]\n{text}\n"
|
||
# 10条以上的消息:不保留TOOL_CALL数据(完全跳过)
|
||
elif current_tag == "ANSWER":
|
||
# 所有ASSISTANT消息都保留ANSWER数据
|
||
filtered_content += f"[ANSWER]\n{text}\n"
|
||
else:
|
||
# 第一个标签之前的内容
|
||
filtered_content += text + "\n"
|
||
else: # 标签
|
||
current_tag = parts[i]
|
||
|
||
# 取最终处理后的内容,去除首尾空白
|
||
final_content = filtered_content.strip()
|
||
if final_content:
|
||
processed_messages.append({"role": msg.role, "content": final_content})
|
||
else:
|
||
# 如果处理后为空,使用原内容
|
||
processed_messages.append({"role": msg.role, "content": msg.content})
|
||
else:
|
||
processed_messages.append({"role": msg.role, "content": msg.content})
|
||
|
||
# 逆运算:将包含 [TOOL_RESPONSE] 的消息重新组装回 msg['role'] == 'function' 和 msg.get('function_call')
|
||
# 这是 get_content_from_messages 的逆运算
|
||
final_messages = []
|
||
for msg in processed_messages:
|
||
if msg["role"] == ASSISTANT and "[TOOL_RESPONSE]" in msg["content"]:
|
||
# 分割消息内容
|
||
parts = re.split(r'\[(TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"])
|
||
|
||
current_tag = None
|
||
assistant_content = ""
|
||
function_calls = []
|
||
tool_responses = []
|
||
|
||
for i in range(0, len(parts)):
|
||
if i % 2 == 0: # 文本内容
|
||
text = parts[i].strip()
|
||
if not text:
|
||
continue
|
||
|
||
if current_tag == "TOOL_RESPONSE":
|
||
# 解析 TOOL_RESPONSE 格式:[TOOL_RESPONSE] function_name\ncontent
|
||
lines = text.split('\n', 1)
|
||
function_name = lines[0].strip() if lines else ""
|
||
response_content = lines[1].strip() if len(lines) > 1 else ""
|
||
|
||
tool_responses.append({
|
||
"role": FUNCTION,
|
||
"name": function_name,
|
||
"content": response_content
|
||
})
|
||
elif current_tag == "TOOL_CALL":
|
||
# 解析 TOOL_CALL 格式:[TOOL_CALL] function_name\narguments
|
||
lines = text.split('\n', 1)
|
||
function_name = lines[0].strip() if lines else ""
|
||
arguments = lines[1].strip() if len(lines) > 1 else ""
|
||
|
||
function_calls.append({
|
||
"name": function_name,
|
||
"arguments": arguments
|
||
})
|
||
elif current_tag == "ANSWER":
|
||
assistant_content += text + "\n"
|
||
else:
|
||
# 第一个标签之前的内容也属于 assistant
|
||
assistant_content += text + "\n"
|
||
else: # 标签
|
||
current_tag = parts[i]
|
||
|
||
# 添加 assistant 消息(如果有内容)
|
||
if assistant_content.strip() or function_calls:
|
||
assistant_msg = {"role": ASSISTANT}
|
||
if assistant_content.strip():
|
||
assistant_msg["content"] = assistant_content.strip()
|
||
if function_calls:
|
||
# 如果有多个 function_call,只取第一个(兼容原有逻辑)
|
||
assistant_msg["function_call"] = function_calls[0]
|
||
final_messages.append(assistant_msg)
|
||
|
||
# 添加所有 tool_responses 作为 function 消息
|
||
final_messages.extend(tool_responses)
|
||
else:
|
||
# 非 assistant 消息或不包含 [TOOL_RESPONSE] 的消息直接添加
|
||
final_messages.append(msg)
|
||
|
||
# 在最后一条消息的末尾追加回复语言
|
||
if final_messages and language:
|
||
language_map = {
|
||
'zh': '请用中文回复',
|
||
'en': 'Please reply in English',
|
||
'ja': '日本語で回答してください',
|
||
'jp': '日本語で回答してください'
|
||
}
|
||
language_instruction = language_map.get(language.lower(), '')
|
||
if language_instruction:
|
||
# 在最后一条消息末尾追加语言指令
|
||
final_messages[-1]['content'] = final_messages[-1]['content'] + f"\n\n{language_instruction}。"
|
||
|
||
return final_messages
|
||
|
||
|
||
async def create_agent_and_generate_response(
|
||
bot_id: str,
|
||
api_key: str,
|
||
messages: List[Dict[str, str]],
|
||
stream: bool,
|
||
tool_response: bool,
|
||
model_name: str,
|
||
model_server: str,
|
||
language: str,
|
||
system_prompt: Optional[str],
|
||
mcp_settings: Optional[List[Dict]],
|
||
robot_type: str,
|
||
project_dir: Optional[str] = None,
|
||
generate_cfg: Optional[Dict] = None,
|
||
user_identifier: Optional[str] = None
|
||
) -> Union[ChatResponse, StreamingResponse]:
|
||
"""创建agent并生成响应的公共逻辑"""
|
||
if generate_cfg is None:
|
||
generate_cfg = {}
|
||
|
||
# 从全局管理器获取或创建助手实例
|
||
agent = await agent_manager.get_or_create_agent(
|
||
bot_id=bot_id,
|
||
project_dir=project_dir,
|
||
model_name=model_name,
|
||
api_key=api_key,
|
||
model_server=model_server,
|
||
generate_cfg=generate_cfg,
|
||
language=language,
|
||
system_prompt=system_prompt,
|
||
mcp_settings=mcp_settings,
|
||
robot_type=robot_type,
|
||
user_identifier=user_identifier
|
||
)
|
||
|
||
# 根据stream参数决定返回流式还是非流式响应
|
||
if stream:
|
||
return StreamingResponse(
|
||
generate_stream_response(agent, messages, tool_response, model_name),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||
)
|
||
else:
|
||
# 非流式响应
|
||
final_responses = agent.run_nonstream(messages)
|
||
|
||
if final_responses and len(final_responses) > 0:
|
||
# 使用 get_content_from_messages 处理响应,支持 tool_response 参数
|
||
content = get_content_from_messages(final_responses, tool_response=tool_response)
|
||
|
||
# 构造OpenAI格式的响应
|
||
return ChatResponse(
|
||
choices=[{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": content
|
||
},
|
||
"finish_reason": "stop"
|
||
}],
|
||
usage={
|
||
"prompt_tokens": sum(len(msg.get("content", "")) for msg in messages),
|
||
"completion_tokens": len(content),
|
||
"total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(content)
|
||
}
|
||
)
|
||
else:
|
||
raise HTTPException(status_code=500, detail="No response from agent")
|
||
|
||
|
||
def create_project_directory(dataset_ids: Optional[List[str]], bot_id: str, robot_type: str = "general_agent") -> Optional[str]:
|
||
"""创建项目目录的公共逻辑"""
|
||
# 只有当 robot_type == "catalog_agent" 且 dataset_ids 不为空时才创建目录
|
||
if robot_type != "catalog_agent" or not dataset_ids or len(dataset_ids) == 0:
|
||
return None
|
||
|
||
try:
|
||
from utils.multi_project_manager import create_robot_project
|
||
return create_robot_project(dataset_ids, bot_id)
|
||
except Exception as e:
|
||
print(f"Error creating project directory: {e}")
|
||
return None
|
||
|
||
|
||
def extract_api_key_from_auth(authorization: Optional[str]) -> Optional[str]:
|
||
"""从Authorization header中提取API key"""
|
||
if not authorization:
|
||
return None
|
||
|
||
# 移除 "Bearer " 前缀
|
||
if authorization.startswith("Bearer "):
|
||
return authorization[7:]
|
||
else:
|
||
return authorization
|
||
|
||
|
||
def generate_v2_auth_token(bot_id: str) -> str:
|
||
"""生成v2接口的认证token"""
|
||
masterkey = os.getenv("MASTERKEY", "master")
|
||
token_input = f"{masterkey}:{bot_id}"
|
||
return hashlib.md5(token_input.encode()).hexdigest()
|
||
|
||
|
||
@app.post("/api/v2/chat/completions")
|
||
async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
Chat completions API v2 with simplified parameters.
|
||
Only requires messages, stream, tool_response, bot_id, and language parameters.
|
||
Other parameters are fetched from the backend bot configuration API.
|
||
|
||
Args:
|
||
request: ChatRequestV2 containing only essential parameters
|
||
authorization: Authorization header for authentication (different from v1)
|
||
|
||
Returns:
|
||
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
||
|
||
Required Parameters:
|
||
- bot_id: str - 目标机器人ID
|
||
- messages: List[Message] - 对话消息列表
|
||
|
||
Optional Parameters:
|
||
- stream: bool - 是否流式输出,默认false
|
||
- tool_response: bool - 是否包含工具响应,默认false
|
||
- language: str - 回复语言,默认"ja"
|
||
|
||
Authentication:
|
||
- Requires valid MD5 hash token: MD5(MASTERKEY:bot_id)
|
||
- Authorization header should contain: Bearer {token}
|
||
- Uses MD5 hash of MASTERKEY:bot_id for backend API authentication
|
||
- Optionally uses API key from bot config for model access
|
||
"""
|
||
try:
|
||
# 获取bot_id(必需参数)
|
||
bot_id = request.bot_id
|
||
if not bot_id:
|
||
raise HTTPException(status_code=400, detail="bot_id is required")
|
||
|
||
# v2接口鉴权验证
|
||
expected_token = generate_v2_auth_token(bot_id)
|
||
provided_token = extract_api_key_from_auth(authorization)
|
||
|
||
if not provided_token:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Authorization header is required for v2 API"
|
||
)
|
||
|
||
if provided_token != expected_token:
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail=f"Invalid authentication token. Expected: {expected_token[:8]}..., Provided: {provided_token[:8]}..."
|
||
)
|
||
|
||
# 从后端API获取机器人配置(使用v2的鉴权方式)
|
||
bot_config = await fetch_bot_config(bot_id)
|
||
|
||
# v2接口:API密钥优先从后端配置获取,其次才从Authorization header获取
|
||
# 注意:这里的Authorization header已经用于鉴权,不再作为API key使用
|
||
api_key = bot_config.get("api_key")
|
||
|
||
# 创建项目目录(从后端配置获取dataset_ids)
|
||
project_dir = create_project_directory(
|
||
bot_config.get("dataset_ids", []),
|
||
bot_id,
|
||
bot_config.get("robot_type", "general_agent")
|
||
)
|
||
|
||
# 处理消息
|
||
messages = process_messages(request.messages, request.language)
|
||
|
||
# 调用公共的agent创建和响应生成逻辑
|
||
return await create_agent_and_generate_response(
|
||
bot_id=bot_id,
|
||
api_key=api_key,
|
||
messages=messages,
|
||
stream=request.stream,
|
||
tool_response=request.tool_response,
|
||
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
|
||
model_server=bot_config.get("model_server", ""),
|
||
language=request.language or bot_config.get("language", "ja"),
|
||
system_prompt=bot_config.get("system_prompt"),
|
||
mcp_settings=bot_config.get("mcp_settings", []),
|
||
robot_type=bot_config.get("robot_type", "general_agent"),
|
||
project_dir=project_dir,
|
||
generate_cfg={}, # v2接口不传递额外的generate_cfg
|
||
user_identifier=request.user_identifier
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"Error in chat_completions_v2: {str(e)}")
|
||
print(f"Full traceback: {error_details}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/upload")
|
||
async def upload_file(file: UploadFile = File(...), folder: Optional[str] = Form(None)):
|
||
"""
|
||
文件上传API接口,上传文件到 ./projects/uploads/ 目录下
|
||
|
||
可以指定自定义文件夹名,如果不指定则使用日期文件夹
|
||
指定文件夹时使用原始文件名并支持版本控制
|
||
|
||
Args:
|
||
file: 上传的文件
|
||
folder: 可选的自定义文件夹名
|
||
|
||
Returns:
|
||
dict: 包含文件路径和文件夹信息的响应
|
||
"""
|
||
try:
|
||
# 调试信息
|
||
print(f"Received folder parameter: {folder}")
|
||
print(f"File received: {file.filename if file else 'None'}")
|
||
|
||
# 确定上传文件夹
|
||
if folder:
|
||
# 使用指定的自定义文件夹
|
||
target_folder = folder
|
||
# 安全性检查:防止路径遍历攻击
|
||
target_folder = os.path.basename(target_folder)
|
||
else:
|
||
# 获取当前日期并格式化为年月日
|
||
current_date = datetime.now()
|
||
target_folder = current_date.strftime("%Y%m%d")
|
||
|
||
# 创建上传目录
|
||
upload_dir = os.path.join("projects", "uploads", target_folder)
|
||
os.makedirs(upload_dir, exist_ok=True)
|
||
|
||
# 处理文件名
|
||
if not file.filename:
|
||
raise HTTPException(status_code=400, detail="文件名不能为空")
|
||
|
||
# 解析文件名和扩展名
|
||
original_filename = file.filename
|
||
name_without_ext, file_extension = os.path.splitext(original_filename)
|
||
|
||
# 根据是否指定文件夹决定命名策略
|
||
if folder:
|
||
# 使用原始文件名,支持版本控制
|
||
final_filename, version = get_versioned_filename(upload_dir, name_without_ext, file_extension)
|
||
file_path = os.path.join(upload_dir, final_filename)
|
||
|
||
# 保存文件
|
||
with open(file_path, "wb") as buffer:
|
||
shutil.copyfileobj(file.file, buffer)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"文件上传成功{' (版本: ' + str(version) + ')' if version > 1 else ''}",
|
||
"file_path": file_path,
|
||
"folder": target_folder,
|
||
"original_filename": original_filename,
|
||
"version": version
|
||
}
|
||
else:
|
||
# 使用UUID唯一文件名(原有逻辑)
|
||
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||
file_path = os.path.join(upload_dir, unique_filename)
|
||
|
||
# 保存文件
|
||
with open(file_path, "wb") as buffer:
|
||
shutil.copyfileobj(file.file, buffer)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "文件上传成功",
|
||
"file_path": file_path,
|
||
"folder": target_folder,
|
||
"original_filename": original_filename
|
||
}
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
print(f"Error uploading file: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"文件上传失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/health")
|
||
async def health_check():
|
||
"""Health check endpoint"""
|
||
return {"message": "Database Assistant API is running"}
|
||
|
||
|
||
@app.get("/api/v1/system/performance")
|
||
async def get_performance_stats():
|
||
"""获取系统性能统计信息"""
|
||
try:
|
||
# 获取agent管理器统计
|
||
agent_stats = agent_manager.get_cache_stats()
|
||
|
||
# 获取连接池统计(简化版)
|
||
pool_stats = {
|
||
"connection_pool": "active",
|
||
"max_connections_per_host": 100,
|
||
"max_connections_total": 500,
|
||
"keepalive_timeout": 30
|
||
}
|
||
|
||
# 获取文件缓存统计
|
||
file_cache_stats = {
|
||
"cache_size": len(file_cache._cache) if hasattr(file_cache, '_cache') else 0,
|
||
"max_cache_size": file_cache.cache_size if hasattr(file_cache, 'cache_size') else 1000,
|
||
"ttl": file_cache.ttl if hasattr(file_cache, 'ttl') else 300
|
||
}
|
||
|
||
# 系统资源信息
|
||
import psutil
|
||
system_stats = {
|
||
"cpu_count": multiprocessing.cpu_count(),
|
||
"memory_total_gb": round(psutil.virtual_memory().total / (1024**3), 2),
|
||
"memory_available_gb": round(psutil.virtual_memory().available / (1024**3), 2),
|
||
"memory_percent": psutil.virtual_memory().percent,
|
||
"disk_usage_percent": psutil.disk_usage('/').percent
|
||
}
|
||
|
||
return {
|
||
"success": True,
|
||
"timestamp": int(time.time()),
|
||
"performance": {
|
||
"agent_manager": agent_stats,
|
||
"connection_pool": pool_stats,
|
||
"file_cache": file_cache_stats,
|
||
"system": system_stats
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting performance stats: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取性能统计失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/system/optimize")
|
||
async def optimize_system(profile: str = "balanced"):
|
||
"""应用系统优化配置"""
|
||
try:
|
||
# 应用优化配置
|
||
config = apply_optimization_profile(profile)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"已应用 {profile} 优化配置",
|
||
"config": config
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error applying optimization profile: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"应用优化配置失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/system/clear-cache")
|
||
async def clear_system_cache(cache_type: Optional[str] = None):
|
||
"""清理系统缓存"""
|
||
try:
|
||
cleared_counts = {}
|
||
|
||
if cache_type is None or cache_type == "agent":
|
||
# 清理agent缓存
|
||
agent_count = agent_manager.clear_cache()
|
||
cleared_counts["agent_cache"] = agent_count
|
||
|
||
if cache_type is None or cache_type == "file":
|
||
# 清理文件缓存
|
||
if hasattr(file_cache, '_cache'):
|
||
file_count = len(file_cache._cache)
|
||
file_cache._cache.clear()
|
||
cleared_counts["file_cache"] = file_count
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"已清理指定类型的缓存",
|
||
"cleared_counts": cleared_counts
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error clearing cache: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"清理缓存失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/system/config")
|
||
async def get_system_config():
|
||
"""获取当前系统配置"""
|
||
try:
|
||
return {
|
||
"success": True,
|
||
"config": {
|
||
"max_cached_agents": max_cached_agents,
|
||
"shard_count": shard_count,
|
||
"tokenizer_parallelism": os.getenv("TOKENIZERS_PARALLELISM", "true"),
|
||
"max_connections_per_host": os.getenv("MAX_CONNECTIONS_PER_HOST", "100"),
|
||
"max_connections_total": os.getenv("MAX_CONNECTIONS_TOTAL", "500"),
|
||
"file_cache_size": os.getenv("FILE_CACHE_SIZE", "1000"),
|
||
"file_cache_ttl": os.getenv("FILE_CACHE_TTL", "300")
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting system config: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取系统配置失败: {str(e)}")
|
||
|
||
|
||
@app.post("/system/remove-project-cache")
|
||
async def remove_project_cache(dataset_id: str):
|
||
"""移除特定项目的缓存"""
|
||
try:
|
||
removed_count = agent_manager.remove_cache_by_unique_id(dataset_id)
|
||
if removed_count > 0:
|
||
return {"message": f"项目缓存移除成功: {dataset_id}", "removed_count": removed_count}
|
||
else:
|
||
return {"message": f"未找到项目缓存: {dataset_id}", "removed_count": 0}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/files/{dataset_id}/status")
|
||
async def get_files_processing_status(dataset_id: str):
|
||
"""获取项目的文件处理状态"""
|
||
try:
|
||
# Load processed files log
|
||
processed_log = load_processed_files_log(dataset_id)
|
||
|
||
# Get project directory info
|
||
project_dir = os.path.join("projects", "data", dataset_id)
|
||
project_exists = os.path.exists(project_dir)
|
||
|
||
# Collect document.txt files
|
||
document_files = []
|
||
if project_exists:
|
||
for root, dirs, files in os.walk(project_dir):
|
||
for file in files:
|
||
if file == "document.txt":
|
||
document_files.append(os.path.join(root, file))
|
||
|
||
return {
|
||
"dataset_id": dataset_id,
|
||
"project_exists": project_exists,
|
||
"processed_files_count": len(processed_log),
|
||
"processed_files": processed_log,
|
||
"document_files_count": len(document_files),
|
||
"document_files": document_files,
|
||
"log_file_exists": os.path.exists(os.path.join("projects", "data", dataset_id, "processed_files.json"))
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取文件处理状态失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/files/{dataset_id}/reset")
|
||
async def reset_files_processing(dataset_id: str):
|
||
"""重置项目的文件处理状态,删除处理日志和所有文件"""
|
||
try:
|
||
project_dir = os.path.join("projects", "data", dataset_id)
|
||
log_file = os.path.join("projects", "data", dataset_id, "processed_files.json")
|
||
|
||
# Load processed log to know what files to remove
|
||
processed_log = load_processed_files_log(dataset_id)
|
||
|
||
removed_files = []
|
||
# Remove all processed files and their dataset directories
|
||
for file_hash, file_info in processed_log.items():
|
||
# Remove local file in files directory
|
||
if 'local_path' in file_info:
|
||
if remove_file_or_directory(file_info['local_path']):
|
||
removed_files.append(file_info['local_path'])
|
||
|
||
# Handle new key-based structure first
|
||
if 'key' in file_info:
|
||
# Remove dataset directory by key
|
||
key = file_info['key']
|
||
if remove_dataset_directory_by_key(dataset_id, key):
|
||
removed_files.append(f"dataset/{key}")
|
||
elif 'filename' in file_info:
|
||
# Fallback to old filename-based structure
|
||
filename_without_ext = os.path.splitext(file_info['filename'])[0]
|
||
dataset_dir = os.path.join("projects", "data", dataset_id, "dataset", filename_without_ext)
|
||
if remove_file_or_directory(dataset_dir):
|
||
removed_files.append(dataset_dir)
|
||
|
||
# Also remove any specific dataset path if exists (fallback)
|
||
if 'dataset_path' in file_info:
|
||
if remove_file_or_directory(file_info['dataset_path']):
|
||
removed_files.append(file_info['dataset_path'])
|
||
|
||
# Remove the log file
|
||
if remove_file_or_directory(log_file):
|
||
removed_files.append(log_file)
|
||
|
||
# Remove the entire files directory
|
||
files_dir = os.path.join(project_dir, "files")
|
||
if remove_file_or_directory(files_dir):
|
||
removed_files.append(files_dir)
|
||
|
||
# Also remove the entire dataset directory (clean up any remaining files)
|
||
dataset_dir = os.path.join(project_dir, "dataset")
|
||
if remove_file_or_directory(dataset_dir):
|
||
removed_files.append(dataset_dir)
|
||
|
||
# Remove README.md if exists
|
||
readme_file = os.path.join(project_dir, "README.md")
|
||
if remove_file_or_directory(readme_file):
|
||
removed_files.append(readme_file)
|
||
|
||
return {
|
||
"message": f"文件处理状态重置成功: {dataset_id}",
|
||
"removed_files_count": len(removed_files),
|
||
"removed_files": removed_files
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/embedding/encode", response_model=EncodeResponse)
|
||
async def encode_texts(request: EncodeRequest):
|
||
"""
|
||
文本编码 API
|
||
|
||
Args:
|
||
request: 包含 texts 和 batch_size 的编码请求
|
||
|
||
Returns:
|
||
编码结果
|
||
"""
|
||
try:
|
||
search_service = get_search_service()
|
||
|
||
if not request.texts:
|
||
return EncodeResponse(
|
||
success=False,
|
||
embeddings=[],
|
||
shape=[0, 0],
|
||
processing_time=0.0,
|
||
total_texts=0,
|
||
error="texts 不能为空"
|
||
)
|
||
|
||
start_time = time.time()
|
||
|
||
# 使用模型管理器编码文本
|
||
embeddings = await search_service.model_manager.encode_texts(
|
||
request.texts,
|
||
batch_size=request.batch_size
|
||
)
|
||
|
||
processing_time = time.time() - start_time
|
||
|
||
# 转换为列表格式
|
||
embeddings_list = embeddings.tolist()
|
||
|
||
return EncodeResponse(
|
||
success=True,
|
||
embeddings=embeddings_list,
|
||
shape=list(embeddings.shape),
|
||
processing_time=processing_time,
|
||
total_texts=len(request.texts)
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"文本编码 API 错误: {e}")
|
||
return EncodeResponse(
|
||
success=False,
|
||
embeddings=[],
|
||
shape=[0, 0],
|
||
processing_time=0.0,
|
||
total_texts=len(request.texts) if request else 0,
|
||
error=f"编码失败: {str(e)}"
|
||
)
|
||
|
||
# 注册文件管理API路由
|
||
app.include_router(file_manager_router)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 启动 FastAPI 应用
|
||
print("Starting FastAPI server...")
|
||
print("File Manager API available at: http://localhost:8001/api/v1/files")
|
||
print("Web Interface available at: http://localhost:8001/public/file-manager.html")
|
||
uvicorn.run(app, host="0.0.0.0", port=8001)
|