1192 lines
42 KiB
Python
1192 lines
42 KiB
Python
import json
|
||
import os
|
||
import tempfile
|
||
import shutil
|
||
import uuid
|
||
import hashlib
|
||
import requests
|
||
import aiohttp
|
||
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
|
||
from datetime import datetime
|
||
|
||
import uvicorn
|
||
from fastapi import FastAPI, HTTPException, Depends, Header, UploadFile, File
|
||
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
|
||
from pydantic import BaseModel, Field
|
||
|
||
# Import utility modules
|
||
from utils import (
|
||
# Models
|
||
Message, DatasetRequest, ChatRequest, ChatResponse, QueueTaskRequest, IncrementalTaskRequest, QueueTaskResponse,
|
||
QueueStatusResponse, TaskStatusResponse,
|
||
|
||
# File utilities
|
||
download_file, remove_file_or_directory, get_document_preview,
|
||
load_processed_files_log, save_processed_files_log, get_file_hash,
|
||
|
||
# Dataset management
|
||
download_dataset_files, generate_dataset_structure,
|
||
remove_dataset_directory, remove_dataset_directory_by_key,
|
||
|
||
# Project management
|
||
generate_project_readme, save_project_readme, get_project_status,
|
||
remove_project, list_projects, get_project_stats,
|
||
|
||
# Agent management
|
||
get_global_agent_manager, init_global_agent_manager
|
||
)
|
||
|
||
# Import ChatRequestV2 directly from api_models
|
||
from utils.api_models import ChatRequestV2
|
||
|
||
# Import modified_assistant
|
||
from modified_assistant import update_agent_llm
|
||
|
||
# Import queue manager
|
||
from task_queue.manager import queue_manager
|
||
from task_queue.integration_tasks import process_files_async, process_files_incremental_async, cleanup_project_async
|
||
from task_queue.task_status import task_status_store
|
||
import re
|
||
|
||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||
|
||
# Custom version for qwen-agent messages - keep this function as it's specific to this app
|
||
def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str:
|
||
"""Extract content from qwen-agent messages with special formatting"""
|
||
full_text = ''
|
||
content = []
|
||
TOOL_CALL_S = '[TOOL_CALL]'
|
||
TOOL_RESULT_S = '[TOOL_RESPONSE]'
|
||
THOUGHT_S = '[THINK]'
|
||
ANSWER_S = '[ANSWER]'
|
||
|
||
for msg in messages:
|
||
|
||
if msg['role'] == ASSISTANT:
|
||
if msg.get('reasoning_content'):
|
||
assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
|
||
content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
|
||
if msg.get('content'):
|
||
assert isinstance(msg['content'], str), 'Now only supports text messages'
|
||
# 过滤掉流式输出中的不完整 tool_call 文本
|
||
content_text = msg["content"]
|
||
|
||
# 使用正则表达式替换不完整的 tool_call 模式为空字符串
|
||
|
||
# 匹配并替换不完整的 tool_call 模式
|
||
content_text = re.sub(r'<t?o?o?l?_?c?a?l?l?$', '', content_text)
|
||
# 只有在处理后内容不为空时才添加
|
||
if content_text.strip():
|
||
content.append(f'{ANSWER_S}\n{content_text}')
|
||
if msg.get('function_call'):
|
||
content_text = msg["function_call"]["arguments"]
|
||
content_text = re.sub(r'}\n<\/?t?o?o?l?_?c?a?l?l?$', '', content_text)
|
||
if content_text.strip():
|
||
content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{content_text}')
|
||
elif msg['role'] == FUNCTION:
|
||
if tool_response:
|
||
content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
|
||
else:
|
||
raise TypeError
|
||
|
||
if content:
|
||
full_text = '\n'.join(content)
|
||
|
||
return full_text
|
||
|
||
|
||
# Helper functions are now imported from utils module
|
||
|
||
|
||
|
||
|
||
|
||
# 全局助手管理器配置
|
||
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "20"))
|
||
|
||
# 初始化全局助手管理器
|
||
agent_manager = init_global_agent_manager(max_cached_agents=max_cached_agents)
|
||
|
||
app = FastAPI(title="Database Assistant API", version="1.0.0")
|
||
|
||
# 挂载public文件夹为静态文件服务
|
||
app.mount("/public", StaticFiles(directory="public"), name="static")
|
||
|
||
# 添加CORS中间件,支持前端页面
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 在生产环境中应该设置为具体的前端域名
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
# Models are now imported from utils module
|
||
|
||
|
||
async def generate_stream_response(agent, messages, tool_response: bool, model: str) -> AsyncGenerator[str, None]:
|
||
"""生成流式响应"""
|
||
accumulated_content = ""
|
||
chunk_id = 0
|
||
try:
|
||
for response in agent.run(messages=messages):
|
||
previous_content = accumulated_content
|
||
accumulated_content = get_content_from_messages(response, tool_response=tool_response)
|
||
|
||
# 计算新增的内容
|
||
if accumulated_content.startswith(previous_content):
|
||
new_content = accumulated_content[len(previous_content):]
|
||
else:
|
||
new_content = accumulated_content
|
||
previous_content = ""
|
||
|
||
# 只有当有新内容时才发送chunk
|
||
if new_content:
|
||
chunk_id += 1
|
||
# 构造OpenAI格式的流式响应
|
||
chunk_data = {
|
||
"id": f"chatcmpl-{chunk_id}",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {
|
||
"content": new_content
|
||
},
|
||
"finish_reason": None
|
||
}]
|
||
}
|
||
|
||
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
||
|
||
# 发送最终完成标记
|
||
final_chunk = {
|
||
"id": f"chatcmpl-{chunk_id + 1}",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(__import__('time').time()),
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {},
|
||
"finish_reason": "stop"
|
||
}]
|
||
}
|
||
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
|
||
|
||
# 发送结束标记
|
||
yield "data: [DONE]\n\n"
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"Error in generate_stream_response: {str(e)}")
|
||
print(f"Full traceback: {error_details}")
|
||
|
||
error_data = {
|
||
"error": {
|
||
"message": f"Stream error: {str(e)}",
|
||
"type": "internal_error"
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||
|
||
|
||
# Models are now imported from utils module
|
||
|
||
|
||
@app.post("/api/v1/files/process/async")
|
||
async def process_files_async_endpoint(request: QueueTaskRequest, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
异步处理文件的队列版本API
|
||
与 /api/v1/files/process 功能相同,但使用队列异步处理
|
||
|
||
Args:
|
||
request: QueueTaskRequest containing dataset_id, files, system_prompt, mcp_settings, and queue options
|
||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||
|
||
Returns:
|
||
QueueTaskResponse: Processing result with task ID for tracking
|
||
"""
|
||
try:
|
||
dataset_id = request.dataset_id
|
||
if not dataset_id:
|
||
raise HTTPException(status_code=400, detail="dataset_id is required")
|
||
|
||
# 估算处理时间(基于文件数量)
|
||
estimated_time = 0
|
||
if request.files:
|
||
total_files = sum(len(file_list) for file_list in request.files.values())
|
||
estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒
|
||
|
||
# 提交异步任务
|
||
task_id = queue_manager.enqueue_multiple_files(
|
||
project_id=dataset_id,
|
||
file_paths=[],
|
||
original_filenames=[]
|
||
)
|
||
|
||
# 创建任务状态记录
|
||
import uuid
|
||
task_id = str(uuid.uuid4())
|
||
task_status_store.set_status(
|
||
task_id=task_id,
|
||
unique_id=dataset_id,
|
||
status="pending"
|
||
)
|
||
|
||
# 提交异步任务
|
||
task = process_files_async(
|
||
unique_id=dataset_id,
|
||
files=request.files,
|
||
system_prompt=request.system_prompt,
|
||
mcp_settings=request.mcp_settings,
|
||
task_id=task_id
|
||
)
|
||
|
||
return QueueTaskResponse(
|
||
success=True,
|
||
message=f"文件处理任务已提交到队列,项目ID: {dataset_id}",
|
||
unique_id=dataset_id,
|
||
task_id=task_id, # 使用我们自己的task_id
|
||
task_status="pending",
|
||
estimated_processing_time=estimated_time
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
print(f"Error submitting async file processing task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/files/process/incremental")
|
||
async def process_files_incremental_endpoint(request: IncrementalTaskRequest, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
增量处理文件的队列版本API - 支持添加和删除文件
|
||
|
||
Args:
|
||
request: IncrementalTaskRequest containing dataset_id, files_to_add, files_to_remove, system_prompt, mcp_settings, and queue options
|
||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||
|
||
Returns:
|
||
QueueTaskResponse: Processing result with task ID for tracking
|
||
"""
|
||
try:
|
||
dataset_id = request.dataset_id
|
||
if not dataset_id:
|
||
raise HTTPException(status_code=400, detail="dataset_id is required")
|
||
|
||
# 验证至少有添加或删除操作
|
||
if not request.files_to_add and not request.files_to_remove:
|
||
raise HTTPException(status_code=400, detail="At least one of files_to_add or files_to_remove must be provided")
|
||
|
||
# 估算处理时间(基于文件数量)
|
||
estimated_time = 0
|
||
total_add_files = sum(len(file_list) for file_list in (request.files_to_add or {}).values())
|
||
total_remove_files = sum(len(file_list) for file_list in (request.files_to_remove or {}).values())
|
||
total_files = total_add_files + total_remove_files
|
||
estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒
|
||
|
||
# 创建任务状态记录
|
||
import uuid
|
||
task_id = str(uuid.uuid4())
|
||
task_status_store.set_status(
|
||
task_id=task_id,
|
||
unique_id=dataset_id,
|
||
status="pending"
|
||
)
|
||
|
||
# 提交增量异步任务
|
||
task = process_files_incremental_async(
|
||
dataset_id=dataset_id,
|
||
files_to_add=request.files_to_add,
|
||
files_to_remove=request.files_to_remove,
|
||
system_prompt=request.system_prompt,
|
||
mcp_settings=request.mcp_settings,
|
||
task_id=task_id
|
||
)
|
||
|
||
return QueueTaskResponse(
|
||
success=True,
|
||
message=f"增量文件处理任务已提交到队列 - 添加 {total_add_files} 个文件,删除 {total_remove_files} 个文件,项目ID: {dataset_id}",
|
||
unique_id=dataset_id,
|
||
task_id=task_id,
|
||
task_status="pending",
|
||
estimated_processing_time=estimated_time
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
print(f"Error submitting incremental file processing task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/task/{task_id}/status")
|
||
async def get_task_status(task_id: str):
|
||
"""获取任务状态 - 简单可靠"""
|
||
try:
|
||
status_data = task_status_store.get_status(task_id)
|
||
|
||
if not status_data:
|
||
return {
|
||
"success": False,
|
||
"message": "任务不存在或已过期",
|
||
"task_id": task_id,
|
||
"status": "not_found"
|
||
}
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "任务状态获取成功",
|
||
"task_id": task_id,
|
||
"status": status_data["status"],
|
||
"unique_id": status_data["unique_id"],
|
||
"created_at": status_data["created_at"],
|
||
"updated_at": status_data["updated_at"],
|
||
"result": status_data.get("result"),
|
||
"error": status_data.get("error")
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting task status: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取任务状态失败: {str(e)}")
|
||
|
||
|
||
@app.delete("/api/v1/task/{task_id}")
|
||
async def delete_task(task_id: str):
|
||
"""删除任务记录"""
|
||
try:
|
||
success = task_status_store.delete_status(task_id)
|
||
if success:
|
||
return {
|
||
"success": True,
|
||
"message": f"任务记录已删除: {task_id}",
|
||
"task_id": task_id
|
||
}
|
||
else:
|
||
return {
|
||
"success": False,
|
||
"message": f"任务记录不存在: {task_id}",
|
||
"task_id": task_id
|
||
}
|
||
except Exception as e:
|
||
print(f"Error deleting task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"删除任务记录失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/tasks")
|
||
async def list_tasks(status: Optional[str] = None, dataset_id: Optional[str] = None, limit: int = 100):
|
||
"""列出任务,支持筛选"""
|
||
try:
|
||
if status or dataset_id:
|
||
# 使用搜索功能
|
||
tasks = task_status_store.search_tasks(status=status, unique_id=dataset_id, limit=limit)
|
||
else:
|
||
# 获取所有任务
|
||
all_tasks = task_status_store.list_all()
|
||
tasks = list(all_tasks.values())[:limit]
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "任务列表获取成功",
|
||
"total_tasks": len(tasks),
|
||
"tasks": tasks,
|
||
"filters": {
|
||
"status": status,
|
||
"dataset_id": dataset_id,
|
||
"limit": limit
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error listing tasks: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取任务列表失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/tasks/statistics")
|
||
async def get_task_statistics():
|
||
"""获取任务统计信息"""
|
||
try:
|
||
stats = task_status_store.get_statistics()
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "统计信息获取成功",
|
||
"statistics": stats
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting statistics: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/tasks/cleanup")
|
||
async def cleanup_tasks(older_than_days: int = 7):
|
||
"""清理旧任务记录"""
|
||
try:
|
||
deleted_count = task_status_store.cleanup_old_tasks(older_than_days=older_than_days)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"已清理 {deleted_count} 条旧任务记录",
|
||
"deleted_count": deleted_count,
|
||
"older_than_days": older_than_days
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error cleaning up tasks: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"清理任务记录失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/projects/{dataset_id}/tasks")
|
||
async def get_project_tasks(dataset_id: str):
|
||
"""获取指定项目的所有任务"""
|
||
try:
|
||
tasks = task_status_store.get_by_unique_id(dataset_id)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "项目任务获取成功",
|
||
"dataset_id": dataset_id,
|
||
"total_tasks": len(tasks),
|
||
"tasks": tasks
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting project tasks: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取项目任务失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/files/{dataset_id}/cleanup/async")
|
||
async def cleanup_project_async_endpoint(dataset_id: str, remove_all: bool = False):
|
||
"""异步清理项目文件"""
|
||
try:
|
||
task = cleanup_project_async(unique_id=dataset_id, remove_all=remove_all)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"项目清理任务已提交到队列,项目ID: {dataset_id}",
|
||
"dataset_id": dataset_id,
|
||
"task_id": task.id,
|
||
"action": "remove_all" if remove_all else "cleanup_logs"
|
||
}
|
||
except Exception as e:
|
||
print(f"Error submitting cleanup task: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"提交清理任务失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/chat/completions")
|
||
async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
Chat completions API similar to OpenAI, supports both streaming and non-streaming
|
||
|
||
Args:
|
||
request: ChatRequest containing messages, model, dataset_ids (optional list), required bot_id, system_prompt, mcp_settings, and files
|
||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||
|
||
Returns:
|
||
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
||
|
||
Notes:
|
||
- dataset_ids: 可选参数,当提供时必须是项目ID列表(单个项目也使用数组格式)
|
||
- bot_id: 必需参数,机器人ID,用于创建项目目录
|
||
- 只有当提供 dataset_ids 时才会创建机器人项目目录:projects/robot/{bot_id}/
|
||
- 支持多知识库合并,自动处理文件夹重名冲突
|
||
|
||
Required Parameters:
|
||
- bot_id: str - 目标机器人项目ID
|
||
Optional Parameters:
|
||
- dataset_ids: List[str] - 源知识库项目ID列表(单个项目也使用数组格式)
|
||
|
||
Example:
|
||
{"bot_id": "my-bot-001"}
|
||
{"dataset_ids": ["project-123"], "bot_id": "my-bot-001"}
|
||
{"dataset_ids": ["project-123", "project-456"], "bot_id": "my-bot-002"}
|
||
"""
|
||
try:
|
||
# v1接口:从Authorization header中提取API key作为模型API密钥
|
||
api_key = extract_api_key_from_auth(authorization)
|
||
|
||
# 获取bot_id(必需参数)
|
||
bot_id = request.bot_id
|
||
if not bot_id:
|
||
raise HTTPException(status_code=400, detail="bot_id is required")
|
||
|
||
# 创建项目目录(如果有dataset_ids)
|
||
project_dir = create_project_directory(request.dataset_ids, bot_id)
|
||
|
||
# 收集额外参数作为 generate_cfg
|
||
exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id'}
|
||
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
|
||
|
||
# 处理消息
|
||
messages = process_messages(request.messages, request.language)
|
||
|
||
# 调用公共的agent创建和响应生成逻辑
|
||
return await create_agent_and_generate_response(
|
||
bot_id=bot_id,
|
||
api_key=api_key,
|
||
messages=messages,
|
||
stream=request.stream,
|
||
tool_response=request.tool_response,
|
||
model_name=request.model,
|
||
model_server=request.model_server,
|
||
language=request.language,
|
||
system_prompt=request.system_prompt,
|
||
mcp_settings=request.mcp_settings,
|
||
robot_type=request.robot_type,
|
||
project_dir=project_dir,
|
||
generate_cfg=generate_cfg
|
||
)
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"Error in chat_completions: {str(e)}")
|
||
print(f"Full traceback: {error_details}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
|
||
"""获取机器人配置从后端API"""
|
||
try:
|
||
backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
||
url = f"{backend_host}/v1/agent_bot_config/{bot_id}"
|
||
|
||
auth_token = generate_v2_auth_token(bot_id)
|
||
headers = {
|
||
"content-type": "application/json",
|
||
"authorization": f"Bearer {auth_token}"
|
||
}
|
||
print(url,headers)
|
||
# 使用异步HTTP请求
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(url, headers=headers, timeout=30) as response:
|
||
if response.status != 200:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Failed to fetch bot config: API returned status code {response.status}"
|
||
)
|
||
|
||
# 解析响应
|
||
response_data = await response.json()
|
||
|
||
if not response_data.get("success"):
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Failed to fetch bot config: {response_data.get('message', 'Unknown error')}"
|
||
)
|
||
|
||
return response_data.get("data", {})
|
||
|
||
except aiohttp.ClientError as e:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Failed to connect to backend API: {str(e)}"
|
||
)
|
||
except Exception as e:
|
||
if isinstance(e, HTTPException):
|
||
raise
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Failed to fetch bot config: {str(e)}"
|
||
)
|
||
|
||
|
||
def process_messages(messages: List[Message], language: Optional[str] = None) -> List[Dict[str, str]]:
|
||
"""处理消息列表,包括[ANSWER]分割和语言指令添加"""
|
||
processed_messages = []
|
||
|
||
# 处理每条消息
|
||
for msg in messages:
|
||
if msg.role == "assistant":
|
||
# 对assistant消息进行[ANSWER]分割处理,只保留最后一段
|
||
content_parts = msg.content.split("[ANSWER]")
|
||
if content_parts:
|
||
# 取最后一段非空文本
|
||
last_part = content_parts[-1].strip()
|
||
processed_messages.append({"role": msg.role, "content": last_part})
|
||
else:
|
||
processed_messages.append({"role": msg.role, "content": msg.content})
|
||
else:
|
||
processed_messages.append({"role": msg.role, "content": msg.content})
|
||
|
||
# 在最后一条消息的末尾追加回复语言
|
||
if processed_messages and language:
|
||
language_map = {
|
||
'zh': '请用中文回复',
|
||
'en': 'Please reply in English',
|
||
'ja': '日本語で回答してください',
|
||
'jp': '日本語で回答してください'
|
||
}
|
||
language_instruction = language_map.get(language.lower(), '')
|
||
if language_instruction:
|
||
# 在最后一条消息末尾追加语言指令
|
||
processed_messages[-1]['content'] = processed_messages[-1]['content'] + f"\n\n{language_instruction}。"
|
||
|
||
return processed_messages
|
||
|
||
|
||
async def create_agent_and_generate_response(
|
||
bot_id: str,
|
||
api_key: str,
|
||
messages: List[Dict[str, str]],
|
||
stream: bool,
|
||
tool_response: bool,
|
||
model_name: str,
|
||
model_server: str,
|
||
language: str,
|
||
system_prompt: Optional[str],
|
||
mcp_settings: Optional[List[Dict]],
|
||
robot_type: str,
|
||
project_dir: Optional[str] = None,
|
||
generate_cfg: Optional[Dict] = None
|
||
) -> Union[ChatResponse, StreamingResponse]:
|
||
"""创建agent并生成响应的公共逻辑"""
|
||
if generate_cfg is None:
|
||
generate_cfg = {}
|
||
|
||
# 从全局管理器获取或创建助手实例
|
||
agent = await agent_manager.get_or_create_agent(
|
||
bot_id=bot_id,
|
||
project_dir=project_dir,
|
||
model_name=model_name,
|
||
api_key=api_key,
|
||
model_server=model_server,
|
||
generate_cfg=generate_cfg,
|
||
language=language,
|
||
system_prompt=system_prompt,
|
||
mcp_settings=mcp_settings,
|
||
robot_type=robot_type
|
||
)
|
||
|
||
# 根据stream参数决定返回流式还是非流式响应
|
||
if stream:
|
||
return StreamingResponse(
|
||
generate_stream_response(agent, messages, tool_response, model_name),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||
)
|
||
else:
|
||
# 非流式响应
|
||
final_responses = agent.run_nonstream(messages)
|
||
|
||
if final_responses and len(final_responses) > 0:
|
||
# 使用 get_content_from_messages 处理响应,支持 tool_response 参数
|
||
content = get_content_from_messages(final_responses, tool_response=tool_response)
|
||
|
||
# 构造OpenAI格式的响应
|
||
return ChatResponse(
|
||
choices=[{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": content
|
||
},
|
||
"finish_reason": "stop"
|
||
}],
|
||
usage={
|
||
"prompt_tokens": sum(len(msg.get("content", "")) for msg in messages),
|
||
"completion_tokens": len(content),
|
||
"total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(content)
|
||
}
|
||
)
|
||
else:
|
||
raise HTTPException(status_code=500, detail="No response from agent")
|
||
|
||
|
||
def create_project_directory(dataset_ids: List[str], bot_id: str) -> Optional[str]:
|
||
"""创建项目目录的公共逻辑"""
|
||
if not dataset_ids:
|
||
return None
|
||
|
||
try:
|
||
from utils.multi_project_manager import create_robot_project
|
||
return create_robot_project(dataset_ids, bot_id)
|
||
except Exception as e:
|
||
print(f"Error creating project directory: {e}")
|
||
return None
|
||
|
||
|
||
def extract_api_key_from_auth(authorization: Optional[str]) -> Optional[str]:
|
||
"""从Authorization header中提取API key"""
|
||
if not authorization:
|
||
return None
|
||
|
||
# 移除 "Bearer " 前缀
|
||
if authorization.startswith("Bearer "):
|
||
return authorization[7:]
|
||
else:
|
||
return authorization
|
||
|
||
|
||
def generate_v2_auth_token(bot_id: str) -> str:
|
||
"""生成v2接口的认证token"""
|
||
masterkey = os.getenv("MASTERKEY", "master")
|
||
token_input = f"{masterkey}:{bot_id}"
|
||
return hashlib.md5(token_input.encode()).hexdigest()
|
||
|
||
|
||
@app.post("/api/v2/chat/completions")
|
||
async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
Chat completions API v2 with simplified parameters.
|
||
Only requires messages, stream, tool_response, bot_id, and language parameters.
|
||
Other parameters are fetched from the backend bot configuration API.
|
||
|
||
Args:
|
||
request: ChatRequestV2 containing only essential parameters
|
||
authorization: Authorization header for authentication (different from v1)
|
||
|
||
Returns:
|
||
Union[ChatResponse, StreamingResponse]: Chat completion response or stream
|
||
|
||
Required Parameters:
|
||
- bot_id: str - 目标机器人ID
|
||
- messages: List[Message] - 对话消息列表
|
||
|
||
Optional Parameters:
|
||
- stream: bool - 是否流式输出,默认false
|
||
- tool_response: bool - 是否包含工具响应,默认false
|
||
- language: str - 回复语言,默认"ja"
|
||
|
||
Authentication:
|
||
- Requires valid MD5 hash token: MD5(MASTERKEY:bot_id)
|
||
- Authorization header should contain: Bearer {token}
|
||
- Uses MD5 hash of MASTERKEY:bot_id for backend API authentication
|
||
- Optionally uses API key from bot config for model access
|
||
"""
|
||
try:
|
||
# 获取bot_id(必需参数)
|
||
bot_id = request.bot_id
|
||
if not bot_id:
|
||
raise HTTPException(status_code=400, detail="bot_id is required")
|
||
|
||
# v2接口鉴权验证
|
||
expected_token = generate_v2_auth_token(bot_id)
|
||
provided_token = extract_api_key_from_auth(authorization)
|
||
|
||
if not provided_token:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Authorization header is required for v2 API"
|
||
)
|
||
|
||
if provided_token != expected_token:
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail=f"Invalid authentication token. Expected: {expected_token[:8]}..., Provided: {provided_token[:8]}..."
|
||
)
|
||
|
||
# 从后端API获取机器人配置(使用v2的鉴权方式)
|
||
bot_config = await fetch_bot_config(bot_id)
|
||
|
||
# v2接口:API密钥优先从后端配置获取,其次才从Authorization header获取
|
||
# 注意:这里的Authorization header已经用于鉴权,不再作为API key使用
|
||
api_key = bot_config.get("api_key")
|
||
|
||
# 创建项目目录(从后端配置获取dataset_ids)
|
||
project_dir = create_project_directory(bot_config.get("dataset_ids", []), bot_id)
|
||
|
||
# 处理消息
|
||
messages = process_messages(request.messages, request.language)
|
||
|
||
# 调用公共的agent创建和响应生成逻辑
|
||
return await create_agent_and_generate_response(
|
||
bot_id=bot_id,
|
||
api_key=api_key,
|
||
messages=messages,
|
||
stream=request.stream,
|
||
tool_response=request.tool_response,
|
||
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
|
||
model_server=bot_config.get("model_server", ""),
|
||
language=request.language or bot_config.get("language", "ja"),
|
||
system_prompt=bot_config.get("system_prompt"),
|
||
mcp_settings=bot_config.get("mcp_settings", []),
|
||
robot_type=bot_config.get("robot_type", "agent"),
|
||
project_dir=project_dir,
|
||
generate_cfg={} # v2接口不传递额外的generate_cfg
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"Error in chat_completions_v2: {str(e)}")
|
||
print(f"Full traceback: {error_details}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/upload")
|
||
async def upload_file(file: UploadFile = File(...)):
|
||
"""
|
||
文件上传API接口,上传文件到 ./projects/uploads 目录
|
||
|
||
Args:
|
||
file: 上传的文件
|
||
|
||
Returns:
|
||
dict: 包含文件路径和文件名的响应
|
||
"""
|
||
try:
|
||
# 确保上传目录存在
|
||
upload_dir = os.path.join("projects", "uploads")
|
||
os.makedirs(upload_dir, exist_ok=True)
|
||
|
||
# 生成唯一文件名
|
||
file_extension = os.path.splitext(file.filename)[1] if file.filename else ""
|
||
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||
file_path = os.path.join(upload_dir, unique_filename)
|
||
|
||
# 保存文件
|
||
with open(file_path, "wb") as buffer:
|
||
shutil.copyfileobj(file.file, buffer)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "文件上传成功",
|
||
"filename": unique_filename,
|
||
"original_filename": file.filename,
|
||
"file_path": file_path
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error uploading file: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"文件上传失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/health")
|
||
async def health_check():
|
||
"""Health check endpoint"""
|
||
return {"message": "Database Assistant API is running"}
|
||
|
||
|
||
@app.post("/system/remove-project-cache")
|
||
async def remove_project_cache(dataset_id: str):
|
||
"""移除特定项目的缓存"""
|
||
try:
|
||
removed_count = agent_manager.remove_cache_by_unique_id(dataset_id)
|
||
if removed_count > 0:
|
||
return {"message": f"项目缓存移除成功: {dataset_id}", "removed_count": removed_count}
|
||
else:
|
||
return {"message": f"未找到项目缓存: {dataset_id}", "removed_count": 0}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}")
|
||
|
||
|
||
@app.get("/api/v1/files/{dataset_id}/status")
|
||
async def get_files_processing_status(dataset_id: str):
|
||
"""获取项目的文件处理状态"""
|
||
try:
|
||
# Load processed files log
|
||
processed_log = load_processed_files_log(dataset_id)
|
||
|
||
# Get project directory info
|
||
project_dir = os.path.join("projects", "data", dataset_id)
|
||
project_exists = os.path.exists(project_dir)
|
||
|
||
# Collect document.txt files
|
||
document_files = []
|
||
if project_exists:
|
||
for root, dirs, files in os.walk(project_dir):
|
||
for file in files:
|
||
if file == "document.txt":
|
||
document_files.append(os.path.join(root, file))
|
||
|
||
return {
|
||
"dataset_id": dataset_id,
|
||
"project_exists": project_exists,
|
||
"processed_files_count": len(processed_log),
|
||
"processed_files": processed_log,
|
||
"document_files_count": len(document_files),
|
||
"document_files": document_files,
|
||
"log_file_exists": os.path.exists(os.path.join("projects", "data", dataset_id, "processed_files.json"))
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取文件处理状态失败: {str(e)}")
|
||
|
||
|
||
@app.post("/api/v1/files/{dataset_id}/reset")
|
||
async def reset_files_processing(dataset_id: str):
|
||
"""重置项目的文件处理状态,删除处理日志和所有文件"""
|
||
try:
|
||
project_dir = os.path.join("projects", "data", dataset_id)
|
||
log_file = os.path.join("projects", "data", dataset_id, "processed_files.json")
|
||
|
||
# Load processed log to know what files to remove
|
||
processed_log = load_processed_files_log(dataset_id)
|
||
|
||
removed_files = []
|
||
# Remove all processed files and their dataset directories
|
||
for file_hash, file_info in processed_log.items():
|
||
# Remove local file in files directory
|
||
if 'local_path' in file_info:
|
||
if remove_file_or_directory(file_info['local_path']):
|
||
removed_files.append(file_info['local_path'])
|
||
|
||
# Handle new key-based structure first
|
||
if 'key' in file_info:
|
||
# Remove dataset directory by key
|
||
key = file_info['key']
|
||
if remove_dataset_directory_by_key(dataset_id, key):
|
||
removed_files.append(f"dataset/{key}")
|
||
elif 'filename' in file_info:
|
||
# Fallback to old filename-based structure
|
||
filename_without_ext = os.path.splitext(file_info['filename'])[0]
|
||
dataset_dir = os.path.join("projects", "data", dataset_id, "dataset", filename_without_ext)
|
||
if remove_file_or_directory(dataset_dir):
|
||
removed_files.append(dataset_dir)
|
||
|
||
# Also remove any specific dataset path if exists (fallback)
|
||
if 'dataset_path' in file_info:
|
||
if remove_file_or_directory(file_info['dataset_path']):
|
||
removed_files.append(file_info['dataset_path'])
|
||
|
||
# Remove the log file
|
||
if remove_file_or_directory(log_file):
|
||
removed_files.append(log_file)
|
||
|
||
# Remove the entire files directory
|
||
files_dir = os.path.join(project_dir, "files")
|
||
if remove_file_or_directory(files_dir):
|
||
removed_files.append(files_dir)
|
||
|
||
# Also remove the entire dataset directory (clean up any remaining files)
|
||
dataset_dir = os.path.join(project_dir, "dataset")
|
||
if remove_file_or_directory(dataset_dir):
|
||
removed_files.append(dataset_dir)
|
||
|
||
# Remove README.md if exists
|
||
readme_file = os.path.join(project_dir, "README.md")
|
||
if remove_file_or_directory(readme_file):
|
||
removed_files.append(readme_file)
|
||
|
||
return {
|
||
"message": f"文件处理状态重置成功: {dataset_id}",
|
||
"removed_files_count": len(removed_files),
|
||
"removed_files": removed_files
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}")
|
||
|
||
|
||
def build_directory_tree(path: str, relative_path: str = "") -> dict:
|
||
"""构建目录树结构"""
|
||
import os
|
||
|
||
if not os.path.exists(path):
|
||
return {}
|
||
|
||
tree = {
|
||
"name": os.path.basename(path) or "projects",
|
||
"path": relative_path,
|
||
"type": "directory",
|
||
"children": [],
|
||
"size": 0,
|
||
"modified_time": os.path.getmtime(path)
|
||
}
|
||
|
||
try:
|
||
entries = os.listdir(path)
|
||
entries.sort()
|
||
|
||
for entry in entries:
|
||
entry_path = os.path.join(path, entry)
|
||
entry_relative_path = os.path.join(relative_path, entry) if relative_path else entry
|
||
|
||
if os.path.isdir(entry_path):
|
||
tree["children"].append(build_directory_tree(entry_path, entry_relative_path))
|
||
else:
|
||
try:
|
||
file_size = os.path.getsize(entry_path)
|
||
file_modified = os.path.getmtime(entry_path)
|
||
tree["children"].append({
|
||
"name": entry,
|
||
"path": entry_relative_path,
|
||
"type": "file",
|
||
"size": file_size,
|
||
"modified_time": file_modified
|
||
})
|
||
tree["size"] += file_size
|
||
except (OSError, IOError):
|
||
tree["children"].append({
|
||
"name": entry,
|
||
"path": entry_relative_path,
|
||
"type": "file",
|
||
"size": 0,
|
||
"modified_time": 0
|
||
})
|
||
except (OSError, IOError) as e:
|
||
print(f"Error reading directory {path}: {e}")
|
||
|
||
return tree
|
||
|
||
|
||
@app.get("/api/v1/projects/tree")
|
||
async def get_projects_tree(
|
||
include_files: bool = True,
|
||
max_depth: int = 10,
|
||
filter_type: Optional[str] = None
|
||
):
|
||
"""
|
||
获取projects文件夹的目录树结构
|
||
|
||
Args:
|
||
include_files: 是否包含文件,false时只显示目录
|
||
max_depth: 最大深度限制
|
||
filter_type: 过滤类型 ('data', 'robot', 'uploads')
|
||
|
||
Returns:
|
||
dict: 包含目录树结构的响应
|
||
"""
|
||
try:
|
||
projects_dir = "projects"
|
||
|
||
if not os.path.exists(projects_dir):
|
||
return {
|
||
"success": False,
|
||
"message": "projects目录不存在",
|
||
"tree": {}
|
||
}
|
||
|
||
tree = build_directory_tree(projects_dir)
|
||
|
||
# 根据filter_type过滤
|
||
if filter_type and filter_type in ['data', 'robot', 'uploads']:
|
||
filtered_children = []
|
||
for child in tree.get("children", []):
|
||
if child["name"] == filter_type:
|
||
filtered_children.append(child)
|
||
tree["children"] = filtered_children
|
||
|
||
# 如果不包含文件,移除所有文件节点
|
||
if not include_files:
|
||
tree = filter_directories_only(tree)
|
||
|
||
# 计算统计信息
|
||
stats = calculate_tree_stats(tree)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "目录树获取成功",
|
||
"tree": tree,
|
||
"stats": stats,
|
||
"filters": {
|
||
"include_files": include_files,
|
||
"max_depth": max_depth,
|
||
"filter_type": filter_type
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting projects tree: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取目录树失败: {str(e)}")
|
||
|
||
|
||
def filter_directories_only(tree: dict) -> dict:
|
||
"""过滤掉文件,只保留目录"""
|
||
if tree["type"] != "directory":
|
||
return tree
|
||
|
||
filtered_children = []
|
||
for child in tree.get("children", []):
|
||
if child["type"] == "directory":
|
||
filtered_children.append(filter_directories_only(child))
|
||
|
||
tree["children"] = filtered_children
|
||
return tree
|
||
|
||
|
||
def calculate_tree_stats(tree: dict) -> dict:
|
||
"""计算目录树统计信息"""
|
||
stats = {
|
||
"total_directories": 0,
|
||
"total_files": 0,
|
||
"total_size": 0
|
||
}
|
||
|
||
def traverse(node):
|
||
if node["type"] == "directory":
|
||
stats["total_directories"] += 1
|
||
for child in node.get("children", []):
|
||
traverse(child)
|
||
else:
|
||
stats["total_files"] += 1
|
||
stats["total_size"] += node.get("size", 0)
|
||
|
||
traverse(tree)
|
||
return stats
|
||
|
||
|
||
@app.get("/api/v1/projects/subtree/{sub_path:path}")
|
||
async def get_projects_subtree(
|
||
sub_path: str,
|
||
include_files: bool = True,
|
||
max_depth: int = 5
|
||
):
|
||
"""
|
||
获取projects子目录的树结构
|
||
|
||
Args:
|
||
sub_path: 子目录路径,如 'data/1624be71-5432-40bf-9758-f4aecffd4e9c'
|
||
include_files: 是否包含文件
|
||
max_depth: 最大深度
|
||
|
||
Returns:
|
||
dict: 包含子目录树结构的响应
|
||
"""
|
||
try:
|
||
full_path = os.path.join("projects", sub_path)
|
||
|
||
if not os.path.exists(full_path):
|
||
return {
|
||
"success": False,
|
||
"message": f"路径不存在: {sub_path}",
|
||
"tree": {}
|
||
}
|
||
|
||
if not os.path.isdir(full_path):
|
||
return {
|
||
"success": False,
|
||
"message": f"路径不是目录: {sub_path}",
|
||
"tree": {}
|
||
}
|
||
|
||
tree = build_directory_tree(full_path, sub_path)
|
||
|
||
# 如果不包含文件,移除所有文件节点
|
||
if not include_files:
|
||
tree = filter_directories_only(tree)
|
||
|
||
# 计算统计信息
|
||
stats = calculate_tree_stats(tree)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "子目录树获取成功",
|
||
"sub_path": sub_path,
|
||
"tree": tree,
|
||
"stats": stats,
|
||
"filters": {
|
||
"include_files": include_files,
|
||
"max_depth": max_depth
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Error getting projects subtree: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取子目录树失败: {str(e)}")
|
||
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(app, host="0.0.0.0", port=8001)
|