From fe6c4f77d72ad3ff0145c474ab85ece4e2f4d890 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Tue, 25 Nov 2025 22:34:44 +0800 Subject: [PATCH] update fastapi --- fastapi_app.py | 1835 +--------------------------------------- routes/__init__.py | 1 + routes/chat.py | 429 ++++++++++ routes/files.py | 467 ++++++++++ routes/projects.py | 173 ++++ routes/system.py | 272 ++++++ utils/fastapi_utils.py | 488 +++++++++++ 7 files changed, 1843 insertions(+), 1822 deletions(-) create mode 100644 routes/__init__.py create mode 100644 routes/chat.py create mode 100644 routes/files.py create mode 100644 routes/projects.py create mode 100644 routes/system.py create mode 100644 utils/fastapi_utils.py diff --git a/fastapi_app.py b/fastapi_app.py index f6ac710..93c20eb 100644 --- a/fastapi_app.py +++ b/fastapi_app.py @@ -1,239 +1,22 @@ import json import os -import tempfile -import shutil import uuid -import hashlib -import requests -import aiohttp -from typing import AsyncGenerator, Dict, List, Optional, Union, Any -from datetime import datetime -import re -import multiprocessing import time -import psutil +import multiprocessing +import sys import uvicorn -from fastapi import FastAPI, HTTPException, Depends, Header, UploadFile, File, Form -from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse -from utils.logger import logger +from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from file_manager_api import router as file_manager_router -from qwen_agent.llm.schema import ASSISTANT, FUNCTION -from qwen_agent.llm.oai import TextChatAtOAI -from pydantic import BaseModel, Field +from utils.logger import logger -# 导入语义检索服务 -from embedding import get_model_manager +# Import route modules +from routes import chat, files, projects, system -# Import utility modules -from utils import ( - # Models - Message, DatasetRequest, ChatRequest, ChatResponse, QueueTaskRequest, IncrementalTaskRequest, QueueTaskResponse, - QueueStatusResponse, TaskStatusResponse, - - # File utilities - download_file, remove_file_or_directory, get_document_preview, - load_processed_files_log, save_processed_files_log, get_file_hash, - - # Dataset management - download_dataset_files, generate_dataset_structure, - remove_dataset_directory, remove_dataset_directory_by_key, - - # Project management - generate_project_readme, save_project_readme, get_project_status, - remove_project, list_projects, get_project_stats, - - # Agent management - get_global_agent_manager, init_global_agent_manager, - - # Optimization modules - get_global_sharded_agent_manager, init_global_sharded_agent_manager, - get_global_connection_pool, init_global_connection_pool, - get_global_file_cache, init_global_file_cache, - setup_system_optimizations, get_optimized_worker_config -) - -# Import ChatRequestV2 directly from api_models -from utils.api_models import ChatRequestV2 - -# Import modified_assistant -from modified_assistant import update_agent_llm - -# Import queue manager -from task_queue.manager import queue_manager -from task_queue.integration_tasks import process_files_async, process_files_incremental_async, cleanup_project_async -from task_queue.task_status import task_status_store - -# 系统优化设置 -# os.environ["TOKENIZERS_PARALLELISM"] = "false" # 注释掉,使用优化的配置 - - -def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extension: str) -> tuple[str, int]: - """ - 获取带版本号的文件名,自动处理文件删除和版本递增 - - Args: - upload_dir: 上传目录路径 - name_without_ext: 不含扩展名的文件名 - file_extension: 文件扩展名(包含点号) - - Returns: - tuple[str, int]: (最终文件名, 版本号) - """ - # 检查原始文件是否存在 - original_file = os.path.join(upload_dir, name_without_ext + file_extension) - original_exists = os.path.exists(original_file) - - # 查找所有相关的版本化文件 - pattern = re.compile(re.escape(name_without_ext) + r'_(\d+)' + re.escape(file_extension) + r'$') - existing_versions = [] - files_to_delete = [] - - for filename in os.listdir(upload_dir): - # 检查是否是原始文件 - if filename == name_without_ext + file_extension: - files_to_delete.append(filename) - continue - - # 检查是否是版本化文件 - match = pattern.match(filename) - if match: - version_num = int(match.group(1)) - existing_versions.append(version_num) - files_to_delete.append(filename) - - # 如果没有任何相关文件存在,使用原始文件名(版本1) - if not original_exists and not existing_versions: - return name_without_ext + file_extension, 1 - - # 删除所有现有文件(原始文件和版本化文件) - for filename in files_to_delete: - file_to_delete = os.path.join(upload_dir, filename) - try: - os.remove(file_to_delete) - print(f"已删除文件: {file_to_delete}") - except OSError as e: - print(f"删除文件失败 {file_to_delete}: {e}") - - # 确定下一个版本号 - if existing_versions: - next_version = max(existing_versions) + 1 - else: - next_version = 2 - - # 生成带版本号的文件名 - versioned_filename = f"{name_without_ext}_{next_version}{file_extension}" - - return versioned_filename, next_version - - -# Models are now imported from utils module - - -# 语义检索请求模型 - - -# 编码请求和响应模型 -class EncodeRequest(BaseModel): - texts: List[str] = Field(..., description="要编码的文本列表") - batch_size: int = Field(default=32, description="批次大小", ge=1, le=128) - - -class EncodeResponse(BaseModel): - success: bool = Field(..., description="是否成功") - embeddings: List[List[float]] = Field(..., description="编码结果") - shape: List[int] = Field(..., description="embeddings 形状") - processing_time: float = Field(..., description="处理时间(秒)") - total_texts: int = Field(..., description="总文本数量") - error: Optional[str] = Field(None, description="错误信息") - - -# Custom version for qwen-agent messages - keep this function as it's specific to this app -def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str: - """Extract content from qwen-agent messages with special formatting""" - full_text = '' - content = [] - TOOL_CALL_S = '[TOOL_CALL]' - TOOL_RESULT_S = '[TOOL_RESPONSE]' - THOUGHT_S = '[THINK]' - ANSWER_S = '[ANSWER]' - - for msg in messages: - - if msg['role'] == ASSISTANT: - if msg.get('reasoning_content'): - assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages' - content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}') - if msg.get('content'): - assert isinstance(msg['content'], str), 'Now only supports text messages' - # 过滤掉流式输出中的不完整 tool_call 文本 - content_text = msg["content"] - - # 使用正则表达式替换不完整的 tool_call 模式为空字符串 - - # 匹配并替换不完整的 tool_call 模式 - content_text = re.sub(r' AsyncGenerator[str, None]: - """生成流式响应""" - accumulated_content = "" - chunk_id = 0 - try: - for response in agent.run(messages=messages): - previous_content = accumulated_content - accumulated_content = get_content_from_messages(response, tool_response=tool_response) - - # 计算新增的内容 - if accumulated_content.startswith(previous_content): - new_content = accumulated_content[len(previous_content):] - else: - new_content = accumulated_content - previous_content = "" - - # 只有当有新内容时才发送chunk - if new_content: - chunk_id += 1 - # 构造OpenAI格式的流式响应 - chunk_data = { - "id": f"chatcmpl-{chunk_id}", - "object": "chat.completion.chunk", - "created": int(__import__('time').time()), - "model": model, - "choices": [{ - "index": 0, - "delta": { - "content": new_content - }, - "finish_reason": None - }] - } - - yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n" - - # 发送最终完成标记 - final_chunk = { - "id": f"chatcmpl-{chunk_id + 1}", - "object": "chat.completion.chunk", - "created": int(__import__('time').time()), - "model": model, - "choices": [{ - "index": 0, - "delta": {}, - "finish_reason": "stop" - }] - } - yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n" - - # 发送结束标记 - yield "data: [DONE]\n\n" - except Exception as e: - import traceback - error_details = traceback.format_exc() - logger.error(f"Error in generate_stream_response: {str(e)}") - logger.error(f"Full traceback: {error_details}") - - error_data = { - "error": { - "message": f"Stream error: {str(e)}", - "type": "internal_error" - } - } - yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n" - - -# Models are now imported from utils module - - -@app.post("/api/v1/files/process/async") -async def process_files_async_endpoint(request: QueueTaskRequest, authorization: Optional[str] = Header(None)): - """ - 异步处理文件的队列版本API - 与 /api/v1/files/process 功能相同,但使用队列异步处理 - - Args: - request: QueueTaskRequest containing dataset_id, files, system_prompt, mcp_settings, and queue options - authorization: Authorization header containing API key (Bearer ) - - Returns: - QueueTaskResponse: Processing result with task ID for tracking - """ - try: - dataset_id = request.dataset_id - if not dataset_id: - raise HTTPException(status_code=400, detail="dataset_id is required") - - # 估算处理时间(基于文件数量) - estimated_time = 0 - if request.upload_folder: - # 对于upload_folder,无法预先估算文件数量,使用默认时间 - estimated_time = 120 # 默认2分钟 - elif request.files: - total_files = sum(len(file_list) for file_list in request.files.values()) - estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒 - - # 提交异步任务 - task_id = queue_manager.enqueue_multiple_files( - project_id=dataset_id, - file_paths=[], - original_filenames=[] - ) - - # 创建任务状态记录 - import uuid - task_id = str(uuid.uuid4()) - task_status_store.set_status( - task_id=task_id, - unique_id=dataset_id, - status="pending" - ) - - # 提交异步任务 - task = process_files_async( - dataset_id=dataset_id, - files=request.files, - upload_folder=request.upload_folder, - task_id=task_id - ) - - # 构建更详细的消息 - message = f"文件处理任务已提交到队列,项目ID: {dataset_id}" - if request.upload_folder: - group_count = len(request.upload_folder) - message += f",将从 {group_count} 个上传文件夹自动扫描文件" - elif request.files: - total_files = sum(len(file_list) for file_list in request.files.values()) - message += f",包含 {total_files} 个文件" - - return QueueTaskResponse( - success=True, - message=message, - dataset_id=dataset_id, - task_id=task_id, # 使用我们自己的task_id - task_status="pending", - estimated_processing_time=estimated_time - ) - - except HTTPException: - raise - except Exception as e: - print(f"Error submitting async file processing task: {str(e)}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - - -@app.post("/api/v1/files/process/incremental") -async def process_files_incremental_endpoint(request: IncrementalTaskRequest, authorization: Optional[str] = Header(None)): - """ - 增量处理文件的队列版本API - 支持添加和删除文件 - - Args: - request: IncrementalTaskRequest containing dataset_id, files_to_add, files_to_remove, system_prompt, mcp_settings, and queue options - authorization: Authorization header containing API key (Bearer ) - - Returns: - QueueTaskResponse: Processing result with task ID for tracking - """ - try: - dataset_id = request.dataset_id - if not dataset_id: - raise HTTPException(status_code=400, detail="dataset_id is required") - - # 验证至少有添加或删除操作 - if not request.files_to_add and not request.files_to_remove: - raise HTTPException(status_code=400, detail="At least one of files_to_add or files_to_remove must be provided") - - # 估算处理时间(基于文件数量) - estimated_time = 0 - total_add_files = sum(len(file_list) for file_list in (request.files_to_add or {}).values()) - total_remove_files = sum(len(file_list) for file_list in (request.files_to_remove or {}).values()) - total_files = total_add_files + total_remove_files - estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒 - - # 创建任务状态记录 - import uuid - task_id = str(uuid.uuid4()) - task_status_store.set_status( - task_id=task_id, - unique_id=dataset_id, - status="pending" - ) - - # 提交增量异步任务 - task = process_files_incremental_async( - dataset_id=dataset_id, - files_to_add=request.files_to_add, - files_to_remove=request.files_to_remove, - system_prompt=request.system_prompt, - mcp_settings=request.mcp_settings, - task_id=task_id - ) - - return QueueTaskResponse( - success=True, - message=f"增量文件处理任务已提交到队列 - 添加 {total_add_files} 个文件,删除 {total_remove_files} 个文件,项目ID: {dataset_id}", - dataset_id=dataset_id, - task_id=task_id, - task_status="pending", - estimated_processing_time=estimated_time - ) - - except HTTPException: - raise - except Exception as e: - print(f"Error submitting incremental file processing task: {str(e)}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - - -@app.get("/api/v1/task/{task_id}/status") -async def get_task_status(task_id: str): - """获取任务状态 - 简单可靠""" - try: - status_data = task_status_store.get_status(task_id) - - if not status_data: - return { - "success": False, - "message": "任务不存在或已过期", - "task_id": task_id, - "status": "not_found" - } - - return { - "success": True, - "message": "任务状态获取成功", - "task_id": task_id, - "status": status_data["status"], - "unique_id": status_data["unique_id"], - "created_at": status_data["created_at"], - "updated_at": status_data["updated_at"], - "result": status_data.get("result"), - "error": status_data.get("error") - } - - except Exception as e: - print(f"Error getting task status: {str(e)}") - raise HTTPException(status_code=500, detail=f"获取任务状态失败: {str(e)}") - - -@app.delete("/api/v1/task/{task_id}") -async def delete_task(task_id: str): - """删除任务记录""" - try: - success = task_status_store.delete_status(task_id) - if success: - return { - "success": True, - "message": f"任务记录已删除: {task_id}", - "task_id": task_id - } - else: - return { - "success": False, - "message": f"任务记录不存在: {task_id}", - "task_id": task_id - } - except Exception as e: - print(f"Error deleting task: {str(e)}") - raise HTTPException(status_code=500, detail=f"删除任务记录失败: {str(e)}") - - -@app.get("/api/v1/tasks") -async def list_tasks(status: Optional[str] = None, dataset_id: Optional[str] = None, limit: int = 100): - """列出任务,支持筛选""" - try: - if status or dataset_id: - # 使用搜索功能 - tasks = task_status_store.search_tasks(status=status, unique_id=dataset_id, limit=limit) - else: - # 获取所有任务 - all_tasks = task_status_store.list_all() - tasks = list(all_tasks.values())[:limit] - - return { - "success": True, - "message": "任务列表获取成功", - "total_tasks": len(tasks), - "tasks": tasks, - "filters": { - "status": status, - "dataset_id": dataset_id, - "limit": limit - } - } - - except Exception as e: - print(f"Error listing tasks: {str(e)}") - raise HTTPException(status_code=500, detail=f"获取任务列表失败: {str(e)}") - - -@app.get("/api/v1/tasks/statistics") -async def get_task_statistics(): - """获取任务统计信息""" - try: - stats = task_status_store.get_statistics() - - return { - "success": True, - "message": "统计信息获取成功", - "statistics": stats - } - - except Exception as e: - print(f"Error getting statistics: {str(e)}") - raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}") - - -@app.post("/api/v1/tasks/cleanup") -async def cleanup_tasks(older_than_days: int = 7): - """清理旧任务记录""" - try: - deleted_count = task_status_store.cleanup_old_tasks(older_than_days=older_than_days) - - return { - "success": True, - "message": f"已清理 {deleted_count} 条旧任务记录", - "deleted_count": deleted_count, - "older_than_days": older_than_days - } - - except Exception as e: - print(f"Error cleaning up tasks: {str(e)}") - raise HTTPException(status_code=500, detail=f"清理任务记录失败: {str(e)}") - - -@app.get("/api/v1/projects") -async def list_all_projects(): - """获取所有项目列表""" - try: - # 获取机器人项目(projects/robot) - robot_dir = "projects/robot" - robot_projects = [] - - if os.path.exists(robot_dir): - for item in os.listdir(robot_dir): - item_path = os.path.join(robot_dir, item) - if os.path.isdir(item_path): - try: - # 读取机器人配置文件 - config_path = os.path.join(item_path, "robot_config.json") - config_data = {} - if os.path.exists(config_path): - with open(config_path, 'r', encoding='utf-8') as f: - config_data = json.load(f) - - # 统计文件数量 - file_count = 0 - if os.path.exists(os.path.join(item_path, "dataset")): - for root, dirs, files in os.walk(os.path.join(item_path, "dataset")): - file_count += len(files) - - robot_projects.append({ - "id": item, - "name": config_data.get("name", item), - "type": "robot", - "status": config_data.get("status", "active"), - "file_count": file_count, - "config": config_data, - "created_at": os.path.getctime(item_path), - "updated_at": os.path.getmtime(item_path) - }) - except Exception as e: - print(f"Error reading robot project {item}: {str(e)}") - robot_projects.append({ - "id": item, - "name": item, - "type": "robot", - "status": "unknown", - "file_count": 0, - "created_at": os.path.getctime(item_path), - "updated_at": os.path.getmtime(item_path) - }) - - # 获取数据集(projects/data) - data_dir = "projects/data" - datasets = [] - - if os.path.exists(data_dir): - for item in os.listdir(data_dir): - item_path = os.path.join(data_dir, item) - if os.path.isdir(item_path): - try: - # 读取处理日志 - log_path = os.path.join(item_path, "processing_log.json") - log_data = {} - if os.path.exists(log_path): - with open(log_path, 'r', encoding='utf-8') as f: - log_data = json.load(f) - - # 统计文件数量 - file_count = 0 - for root, dirs, files in os.walk(item_path): - file_count += len([f for f in files if not f.endswith('.pkl')]) - - # 获取状态 - status = "active" - if log_data.get("status"): - status = log_data["status"] - elif os.path.exists(os.path.join(item_path, "processed")): - status = "completed" - - datasets.append({ - "id": item, - "name": f"数据集 - {item[:8]}...", - "type": "dataset", - "status": status, - "file_count": file_count, - "log_data": log_data, - "created_at": os.path.getctime(item_path), - "updated_at": os.path.getmtime(item_path) - }) - except Exception as e: - print(f"Error reading dataset {item}: {str(e)}") - datasets.append({ - "id": item, - "name": f"数据集 - {item[:8]}...", - "type": "dataset", - "status": "unknown", - "file_count": 0, - "created_at": os.path.getctime(item_path), - "updated_at": os.path.getmtime(item_path) - }) - - all_projects = robot_projects + datasets - - return { - "success": True, - "message": "项目列表获取成功", - "total_projects": len(all_projects), - "robot_projects": robot_projects, - "datasets": datasets, - "projects": all_projects # 保持向后兼容 - } - - except Exception as e: - print(f"Error listing projects: {str(e)}") - raise HTTPException(status_code=500, detail=f"获取项目列表失败: {str(e)}") - - -@app.get("/api/v1/projects/robot") -async def list_robot_projects(): - """获取机器人项目列表""" - try: - response = await list_all_projects() - return { - "success": True, - "message": "机器人项目列表获取成功", - "total_projects": len(response["robot_projects"]), - "projects": response["robot_projects"] - } - except Exception as e: - print(f"Error listing robot projects: {str(e)}") - raise HTTPException(status_code=500, detail=f"获取机器人项目列表失败: {str(e)}") - - -@app.get("/api/v1/projects/datasets") -async def list_datasets(): - """获取数据集列表""" - try: - response = await list_all_projects() - return { - "success": True, - "message": "数据集列表获取成功", - "total_projects": len(response["datasets"]), - "projects": response["datasets"] - } - except Exception as e: - print(f"Error listing datasets: {str(e)}") - raise HTTPException(status_code=500, detail=f"获取数据集列表失败: {str(e)}") - - -@app.get("/api/v1/projects/{dataset_id}/tasks") -async def get_project_tasks(dataset_id: str): - """获取指定项目的所有任务""" - try: - tasks = task_status_store.get_by_unique_id(dataset_id) - - return { - "success": True, - "message": "项目任务获取成功", - "dataset_id": dataset_id, - "total_tasks": len(tasks), - "tasks": tasks - } - - except Exception as e: - print(f"Error getting project tasks: {str(e)}") - raise HTTPException(status_code=500, detail=f"获取项目任务失败: {str(e)}") - - -@app.post("/api/v1/files/{dataset_id}/cleanup/async") -async def cleanup_project_async_endpoint(dataset_id: str, remove_all: bool = False): - """异步清理项目文件""" - try: - task = cleanup_project_async(dataset_id=dataset_id, remove_all=remove_all) - - return { - "success": True, - "message": f"项目清理任务已提交到队列,项目ID: {dataset_id}", - "dataset_id": dataset_id, - "task_id": task.id, - "action": "remove_all" if remove_all else "cleanup_logs" - } - except Exception as e: - print(f"Error submitting cleanup task: {str(e)}") - raise HTTPException(status_code=500, detail=f"提交清理任务失败: {str(e)}") - - -@app.post("/api/v1/chat/completions") -async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)): - """ - Chat completions API similar to OpenAI, supports both streaming and non-streaming - - Args: - request: ChatRequest containing messages, model, optional dataset_ids list, required bot_id, system_prompt, mcp_settings, and files - authorization: Authorization header containing API key (Bearer ) - - Returns: - Union[ChatResponse, StreamingResponse]: Chat completion response or stream - - Notes: - - dataset_ids: 可选参数,当提供时必须是项目ID列表(单个项目也使用数组格式) - - bot_id: 必需参数,机器人ID - - 只有当 robot_type == "catalog_agent" 且 dataset_ids 为非空数组时才会创建机器人项目目录:projects/robot/{bot_id}/ - - robot_type 为其他值(包括默认的 "agent")时不创建任何目录 - - dataset_ids 为空数组 []、None 或未提供时不创建任何目录 - - 支持多知识库合并,自动处理文件夹重名冲突 - - Required Parameters: - - bot_id: str - 目标机器人ID - - messages: List[Message] - 对话消息列表 - Optional Parameters: - - dataset_ids: List[str] - 源知识库项目ID列表(单个项目也使用数组格式) - - robot_type: str - 机器人类型,默认为 "agent" - - Example: - {"bot_id": "my-bot-001", "messages": [{"role": "user", "content": "Hello"}]} - {"dataset_ids": ["project-123"], "bot_id": "my-bot-001", "messages": [{"role": "user", "content": "Hello"}]} - {"dataset_ids": ["project-123", "project-456"], "bot_id": "my-bot-002", "messages": [{"role": "user", "content": "Hello"}]} - {"dataset_ids": ["project-123"], "bot_id": "my-catalog-bot", "robot_type": "catalog_agent", "messages": [{"role": "user", "content": "Hello"}]} - """ - try: - # v1接口:从Authorization header中提取API key作为模型API密钥 - api_key = extract_api_key_from_auth(authorization) - - # 获取bot_id(必需参数) - bot_id = request.bot_id - if not bot_id: - raise HTTPException(status_code=400, detail="bot_id is required") - - # 创建项目目录(如果有dataset_ids且不是agent类型) - project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type) - - # 收集额外参数作为 generate_cfg - exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier'} - generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields} - - # 处理消息 - messages = process_messages(request.messages, request.language) - - # 调用公共的agent创建和响应生成逻辑 - return await create_agent_and_generate_response( - bot_id=bot_id, - api_key=api_key, - messages=messages, - stream=request.stream, - tool_response=True, - model_name=request.model, - model_server=request.model_server, - language=request.language, - system_prompt=request.system_prompt, - mcp_settings=request.mcp_settings, - robot_type=request.robot_type, - project_dir=project_dir, - generate_cfg=generate_cfg, - user_identifier=request.user_identifier - ) - - except Exception as e: - import traceback - error_details = traceback.format_exc() - print(f"Error in chat_completions: {str(e)}") - print(f"Full traceback: {error_details}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - - -async def fetch_bot_config(bot_id: str) -> Dict[str, Any]: - """获取机器人配置从后端API""" - try: - backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai") - url = f"{backend_host}/v1/agent_bot_config/{bot_id}" - - auth_token = generate_v2_auth_token(bot_id) - headers = { - "content-type": "application/json", - "authorization": f"Bearer {auth_token}" - } - # 使用异步HTTP请求 - async with aiohttp.ClientSession() as session: - async with session.get(url, headers=headers, timeout=30) as response: - if response.status != 200: - raise HTTPException( - status_code=400, - detail=f"Failed to fetch bot config: API returned status code {response.status}" - ) - - # 解析响应 - response_data = await response.json() - - if not response_data.get("success"): - raise HTTPException( - status_code=400, - detail=f"Failed to fetch bot config: {response_data.get('message', 'Unknown error')}" - ) - - return response_data.get("data", {}) - - except aiohttp.ClientError as e: - raise HTTPException( - status_code=500, - detail=f"Failed to connect to backend API: {str(e)}" - ) - except Exception as e: - if isinstance(e, HTTPException): - raise - raise HTTPException( - status_code=500, - detail=f"Failed to fetch bot config: {str(e)}" - ) - - -def process_messages(messages: List[Message], language: Optional[str] = None) -> List[Dict[str, str]]: - """处理消息列表,包括[TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER]分割和语言指令添加 - - 这是 get_content_from_messages 的逆运算,将包含 [TOOL_RESPONSE] 的消息重新组装回 - msg['role'] == 'function' 和 msg.get('function_call') 的格式。 - """ - processed_messages = [] - - # 收集所有ASSISTANT消息的索引 - assistant_indices = [i for i, msg in enumerate(messages) if msg.role == "assistant"] - total_assistant_messages = len(assistant_indices) - cutoff_point = max(0, total_assistant_messages - 5) - - # 处理每条消息 - for i, msg in enumerate(messages): - if msg.role == "assistant": - # 确定当前ASSISTANT消息在所有ASSISTANT消息中的位置(从0开始) - assistant_position = assistant_indices.index(i) - - # 使用正则表达式按照 [TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER] 进行切割 - parts = re.split(r'\[(TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg.content) - - # 重新组装内容,根据消息位置决定处理方式 - filtered_content = "" - current_tag = None - is_recent_message = assistant_position >= cutoff_point # 最近10条消息 - - for i in range(0, len(parts)): - if i % 2 == 0: # 文本内容 - text = parts[i].strip() - if not text: - continue - - if current_tag == "TOOL_RESPONSE": - if is_recent_message: - # 最近10条ASSISTANT消息:保留完整TOOL_RESPONSE信息(使用简略模式) - if len(text) <= 500: - filtered_content += f"[TOOL_RESPONSE]\n{text}\n" - else: - # 截取前中后3段内容,每段250字 - first_part = text[:250] - middle_start = len(text) // 2 - 125 - middle_part = text[middle_start:middle_start + 250] - last_part = text[-250:] - - # 计算省略的字数 - omitted_count = len(text) - 750 - omitted_text = f"...此处省略{omitted_count}字..." - - # 拼接内容 - truncated_text = f"{first_part}\n{omitted_text}\n{middle_part}\n{omitted_text}\n{last_part}" - filtered_content += f"[TOOL_RESPONSE]\n{truncated_text}\n" - # 10条以上的消息:不保留TOOL_RESPONSE数据(完全跳过) - elif current_tag == "TOOL_CALL": - if is_recent_message: - # 最近10条ASSISTANT消息:保留TOOL_CALL信息 - filtered_content += f"[TOOL_CALL]\n{text}\n" - # 10条以上的消息:不保留TOOL_CALL数据(完全跳过) - elif current_tag == "ANSWER": - # 所有ASSISTANT消息都保留ANSWER数据 - filtered_content += f"[ANSWER]\n{text}\n" - else: - # 第一个标签之前的内容 - filtered_content += text + "\n" - else: # 标签 - current_tag = parts[i] - - # 取最终处理后的内容,去除首尾空白 - final_content = filtered_content.strip() - if final_content: - processed_messages.append({"role": msg.role, "content": final_content}) - else: - # 如果处理后为空,使用原内容 - processed_messages.append({"role": msg.role, "content": msg.content}) - else: - processed_messages.append({"role": msg.role, "content": msg.content}) - - # 逆运算:将包含 [TOOL_RESPONSE] 的消息重新组装回 msg['role'] == 'function' 和 msg.get('function_call') - # 这是 get_content_from_messages 的逆运算 - final_messages = [] - for msg in processed_messages: - if msg["role"] == ASSISTANT and "[TOOL_RESPONSE]" in msg["content"]: - # 分割消息内容 - parts = re.split(r'\[(TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"]) - - current_tag = None - assistant_content = "" - function_calls = [] - tool_responses = [] - - for i in range(0, len(parts)): - if i % 2 == 0: # 文本内容 - text = parts[i].strip() - if not text: - continue - - if current_tag == "TOOL_RESPONSE": - # 解析 TOOL_RESPONSE 格式:[TOOL_RESPONSE] function_name\ncontent - lines = text.split('\n', 1) - function_name = lines[0].strip() if lines else "" - response_content = lines[1].strip() if len(lines) > 1 else "" - - tool_responses.append({ - "role": FUNCTION, - "name": function_name, - "content": response_content - }) - elif current_tag == "TOOL_CALL": - # 解析 TOOL_CALL 格式:[TOOL_CALL] function_name\narguments - lines = text.split('\n', 1) - function_name = lines[0].strip() if lines else "" - arguments = lines[1].strip() if len(lines) > 1 else "" - - function_calls.append({ - "name": function_name, - "arguments": arguments - }) - elif current_tag == "ANSWER": - assistant_content += text + "\n" - else: - # 第一个标签之前的内容也属于 assistant - assistant_content += text + "\n" - else: # 标签 - current_tag = parts[i] - - # 添加 assistant 消息(如果有内容) - if assistant_content.strip() or function_calls: - assistant_msg = {"role": ASSISTANT} - if assistant_content.strip(): - assistant_msg["content"] = assistant_content.strip() - if function_calls: - # 如果有多个 function_call,只取第一个(兼容原有逻辑) - assistant_msg["function_call"] = function_calls[0] - final_messages.append(assistant_msg) - - # 添加所有 tool_responses 作为 function 消息 - final_messages.extend(tool_responses) - else: - # 非 assistant 消息或不包含 [TOOL_RESPONSE] 的消息直接添加 - final_messages.append(msg) - - # 在最后一条消息的末尾追加回复语言 - if final_messages and language: - language_map = { - 'zh': '请用中文回复', - 'en': 'Please reply in English', - 'ja': '日本語で回答してください', - 'jp': '日本語で回答してください' - } - language_instruction = language_map.get(language.lower(), '') - if language_instruction: - # 在最后一条消息末尾追加语言指令 - final_messages[-1]['content'] = final_messages[-1]['content'] + f"\n\nlanguage:\n{language_instruction}。" - - return final_messages - - -def extract_guidelines_from_system_prompt(system_prompt: Optional[str]) -> tuple[str, str]: - """从system_prompt中提取```guideline内容并清理原提示词 - - Returns: - tuple[str, str]: (清理后的system_prompt, 提取的guidelines内容) - """ - if not system_prompt: - return "", "" - - # 使用正则表达式提取 ```guideline``` 包裹的内容 - pattern = r'```guideline\s*\n(.*?)\n```' - matches = re.findall(pattern, system_prompt, re.DOTALL) - - guidelines_text = "\n".join(matches).strip() - - # # 从原始system_prompt中删除 ```guideline``` 内容块 - # cleaned_prompt = re.sub(pattern, '', system_prompt, flags=re.DOTALL) - - # # 清理多余的空行 - # cleaned_prompt = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_prompt).strip() - - return guidelines_text - - -def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str: - """将messages格式化为纯文本聊天记录 - - Args: - messages: 消息列表 - - Returns: - str: 格式化的聊天记录 - """ - chat_history = [] - - for message in messages: - role = message.get('role', '') - content = message.get('content', '') - - if role == 'user': - chat_history.append(f"user: {content}") - elif role == 'assistant': - chat_history.append(f"assistant: {content}") - # 忽略其他角色(如function等) - - return "\n".join(chat_history) - - -async def call_guideline_llm(chat_history: str, guidelines_text: str, model_name: str, api_key: str, model_server: str) -> str: - """调用大语言模型处理guideline分析 - - Args: - chat_history: 聊天历史记录 - guidelines_text: 指导原则文本 - model_name: 模型名称 - api_key: API密钥 - model_server: 模型服务器地址 - - Returns: - str: 模型响应结果 - """ - # 读取guideline提示词模板 - try: - with open('./prompt/guideline_prompt.md', 'r', encoding='utf-8') as f: - guideline_template = f.read() - except Exception as e: - print(f"Error reading guideline prompt template: {e}") - return "" - - # 替换模板中的占位符 - system_prompt = guideline_template.replace('{chat_history}', chat_history).replace('{guidelines_text}', guidelines_text) - - # 配置LLM - llm_config = { - 'model': model_name, - 'api_key': api_key, - 'model_server': model_server, # 使用传入的model_server参数 - } - - # 创建LLM实例 - llm_instance = TextChatAtOAI(llm_config) - - # 调用模型 - messages = [{'role': 'user', 'content': system_prompt}] - - try: - # 设置stream=False来获取非流式响应 - response = llm_instance.chat(messages=messages, stream=False) - - # 处理响应 - if isinstance(response, list) and response: - # 如果返回的是Message列表,提取内容 - if hasattr(response[0], 'content'): - return response[0].content - elif isinstance(response[0], dict) and 'content' in response[0]: - return response[0]['content'] - - # 如果是字符串,直接返回 - if isinstance(response, str): - return response - - # 处理其他类型 - return str(response) if response else "" - - except Exception as e: - print(f"Error calling guideline LLM: {e}") - return "" - - -def _get_optimal_batch_size(guidelines_count: int) -> int: - """根据guidelines数量决定最优批次数量(并发数)""" - if guidelines_count <= 10: - return 1 - elif guidelines_count <= 20: - return 2 - elif guidelines_count <= 30: - return 3 - else: - return 5 - - -async def process_guideline_batch( - guidelines_batch: List[str], - chat_history: str, - model_name: str, - api_key: str, - model_server: str -) -> str: - """处理单个guideline批次""" - try: - # 调用LLM分析这批guidelines - batch_guidelines_text = "\n".join(guidelines_batch) - batch_analysis = await call_guideline_llm(chat_history, batch_guidelines_text, model_name, api_key, model_server) - - return batch_analysis - except Exception as e: - print(f"Error processing guideline batch: {e}") - return "" - - - -async def create_agent_and_generate_response( - bot_id: str, - api_key: str, - messages: List[Dict[str, str]], - stream: bool, - tool_response: bool, - model_name: str, - model_server: str, - language: str, - system_prompt: Optional[str], - mcp_settings: Optional[List[Dict]], - robot_type: str, - project_dir: Optional[str] = None, - generate_cfg: Optional[Dict] = None, - user_identifier: Optional[str] = None -) -> Union[ChatResponse, StreamingResponse]: - """创建agent并生成响应的公共逻辑""" - if generate_cfg is None: - generate_cfg = {} - - # 1. 从system_prompt提取guideline内容 - guidelines_text = extract_guidelines_from_system_prompt(system_prompt) - print(f"guidelines_text: {guidelines_text}") - - # 2. 如果有guideline内容,进行并发处理 - guideline_analysis = "" - if guidelines_text: - # 按换行符分割guidelines - guidelines_list = [g.strip() for g in guidelines_text.split('\n') if g.strip()] - guidelines_count = len(guidelines_list) - - if guidelines_count > 0: - # 获取最优批次数量(并发数) - batch_count = _get_optimal_batch_size(guidelines_count) - - # 计算每个批次应该包含多少条guideline - guidelines_per_batch = max(1, guidelines_count // batch_count) - - # 分批处理guidelines - batches = [] - for i in range(0, guidelines_count, guidelines_per_batch): - batch = guidelines_list[i:i + guidelines_per_batch] - batches.append(batch) - - # 确保批次数量不超过要求的并发数 - while len(batches) > batch_count: - # 将最后一个批次合并到倒数第二个批次 - batches[-2].extend(batches[-1]) - batches.pop() - - print(f"Processing {guidelines_count} guidelines in {len(batches)} batches with {batch_count} concurrent batches") - - # 准备chat_history - chat_history = format_messages_to_chat_history(messages) - - # 并发执行所有任务:guideline批次处理 + agent创建 - import asyncio - tasks = [] - - # 添加所有guideline批次任务 - for batch in batches: - task = process_guideline_batch( - guidelines_batch=batch, - chat_history=chat_history, - model_name=model_name, - api_key=api_key, - model_server=model_server - ) - tasks.append(task) - - # 添加agent创建任务 - agent_task = agent_manager.get_or_create_agent( - bot_id=bot_id, - project_dir=project_dir, - model_name=model_name, - api_key=api_key, - model_server=model_server, - generate_cfg=generate_cfg, - language=language, - system_prompt=system_prompt, - mcp_settings=mcp_settings, - robot_type=robot_type, - user_identifier=user_identifier - ) - tasks.append(agent_task) - - # 等待所有任务完成 - all_results = await asyncio.gather(*tasks, return_exceptions=True) - - # 处理结果:最后一个结果是agent,前面的是guideline批次结果 - agent = all_results[-1] # agent创建的结果 - batch_results = all_results[:-1] # guideline批次的结果 - - # 合并guideline分析结果 - valid_results = [] - for i, result in enumerate(batch_results): - if isinstance(result, Exception): - print(f"Guideline batch {i} failed: {result}") - continue - if result and result.strip(): - valid_results.append(result.strip()) - - if valid_results: - guideline_analysis = "\n\n".join(valid_results) - print(f"Merged guideline analysis result: {guideline_analysis}") - - # 将分析结果添加到最后一个消息的内容中 - if guideline_analysis and messages: - last_message = messages[-1] - if last_message.get('role') == 'user': - messages[-1]['content'] += f"\n\nActive Guidelines:\n{guideline_analysis}\nPlease follow these guidelines in your response." - else: - # 3. 从全局管理器获取或创建助手实例 - agent = await agent_manager.get_or_create_agent( - bot_id=bot_id, - project_dir=project_dir, - model_name=model_name, - api_key=api_key, - model_server=model_server, - generate_cfg=generate_cfg, - language=language, - system_prompt=system_prompt, - mcp_settings=mcp_settings, - robot_type=robot_type, - user_identifier=user_identifier - ) - - # 根据stream参数决定返回流式还是非流式响应 - if stream: - return StreamingResponse( - generate_stream_response(agent, messages, tool_response, model_name), - media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} - ) - else: - # 非流式响应 - final_responses = agent.run_nonstream(messages) - - if final_responses and len(final_responses) > 0: - # 使用 get_content_from_messages 处理响应,支持 tool_response 参数 - content = get_content_from_messages(final_responses, tool_response=tool_response) - - # 构造OpenAI格式的响应 - return ChatResponse( - choices=[{ - "index": 0, - "message": { - "role": "assistant", - "content": content - }, - "finish_reason": "stop" - }], - usage={ - "prompt_tokens": sum(len(msg.get("content", "")) for msg in messages), - "completion_tokens": len(content), - "total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(content) - } - ) - else: - raise HTTPException(status_code=500, detail="No response from agent") - - -def create_project_directory(dataset_ids: Optional[List[str]], bot_id: str, robot_type: str = "general_agent") -> Optional[str]: - """创建项目目录的公共逻辑""" - # 只有当 robot_type == "catalog_agent" 且 dataset_ids 不为空时才创建目录 - if robot_type != "catalog_agent" or not dataset_ids or len(dataset_ids) == 0: - return None - - try: - from utils.multi_project_manager import create_robot_project - return create_robot_project(dataset_ids, bot_id) - except Exception as e: - print(f"Error creating project directory: {e}") - return None - - -def extract_api_key_from_auth(authorization: Optional[str]) -> Optional[str]: - """从Authorization header中提取API key""" - if not authorization: - return None - - # 移除 "Bearer " 前缀 - if authorization.startswith("Bearer "): - return authorization[7:] - else: - return authorization - - -def generate_v2_auth_token(bot_id: str) -> str: - """生成v2接口的认证token""" - masterkey = os.getenv("MASTERKEY", "master") - token_input = f"{masterkey}:{bot_id}" - return hashlib.md5(token_input.encode()).hexdigest() - - -@app.post("/api/v2/chat/completions") -async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[str] = Header(None)): - """ - Chat completions API v2 with simplified parameters. - Only requires messages, stream, tool_response, bot_id, and language parameters. - Other parameters are fetched from the backend bot configuration API. - - Args: - request: ChatRequestV2 containing only essential parameters - authorization: Authorization header for authentication (different from v1) - - Returns: - Union[ChatResponse, StreamingResponse]: Chat completion response or stream - - Required Parameters: - - bot_id: str - 目标机器人ID - - messages: List[Message] - 对话消息列表 - - Optional Parameters: - - stream: bool - 是否流式输出,默认false - - tool_response: bool - 是否包含工具响应,默认false - - language: str - 回复语言,默认"ja" - - Authentication: - - Requires valid MD5 hash token: MD5(MASTERKEY:bot_id) - - Authorization header should contain: Bearer {token} - - Uses MD5 hash of MASTERKEY:bot_id for backend API authentication - - Optionally uses API key from bot config for model access - """ - try: - # 获取bot_id(必需参数) - bot_id = request.bot_id - if not bot_id: - raise HTTPException(status_code=400, detail="bot_id is required") - - # v2接口鉴权验证 - expected_token = generate_v2_auth_token(bot_id) - provided_token = extract_api_key_from_auth(authorization) - - if not provided_token: - raise HTTPException( - status_code=401, - detail="Authorization header is required for v2 API" - ) - - if provided_token != expected_token: - raise HTTPException( - status_code=403, - detail=f"Invalid authentication token. Expected: {expected_token[:8]}..., Provided: {provided_token[:8]}..." - ) - - # 从后端API获取机器人配置(使用v2的鉴权方式) - bot_config = await fetch_bot_config(bot_id) - - # v2接口:API密钥优先从后端配置获取,其次才从Authorization header获取 - # 注意:这里的Authorization header已经用于鉴权,不再作为API key使用 - api_key = bot_config.get("api_key") - - # 创建项目目录(从后端配置获取dataset_ids) - project_dir = create_project_directory( - bot_config.get("dataset_ids", []), - bot_id, - bot_config.get("robot_type", "general_agent") - ) - - # 处理消息 - messages = process_messages(request.messages, request.language) - - # 调用公共的agent创建和响应生成逻辑 - return await create_agent_and_generate_response( - bot_id=bot_id, - api_key=api_key, - messages=messages, - stream=request.stream, - tool_response=request.tool_response, - model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"), - model_server=bot_config.get("model_server", ""), - language=request.language or bot_config.get("language", "ja"), - system_prompt=bot_config.get("system_prompt"), - mcp_settings=bot_config.get("mcp_settings", []), - robot_type=bot_config.get("robot_type", "general_agent"), - project_dir=project_dir, - generate_cfg={}, # v2接口不传递额外的generate_cfg - user_identifier=request.user_identifier - ) - - except HTTPException: - raise - except Exception as e: - import traceback - error_details = traceback.format_exc() - print(f"Error in chat_completions_v2: {str(e)}") - print(f"Full traceback: {error_details}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - - -@app.post("/api/v1/upload") -async def upload_file(file: UploadFile = File(...), folder: Optional[str] = Form(None)): - """ - 文件上传API接口,上传文件到 ./projects/uploads/ 目录下 - - 可以指定自定义文件夹名,如果不指定则使用日期文件夹 - 指定文件夹时使用原始文件名并支持版本控制 - - Args: - file: 上传的文件 - folder: 可选的自定义文件夹名 - - Returns: - dict: 包含文件路径和文件夹信息的响应 - """ - try: - # 调试信息 - print(f"Received folder parameter: {folder}") - print(f"File received: {file.filename if file else 'None'}") - - # 确定上传文件夹 - if folder: - # 使用指定的自定义文件夹 - target_folder = folder - # 安全性检查:防止路径遍历攻击 - target_folder = os.path.basename(target_folder) - else: - # 获取当前日期并格式化为年月日 - current_date = datetime.now() - target_folder = current_date.strftime("%Y%m%d") - - # 创建上传目录 - upload_dir = os.path.join("projects", "uploads", target_folder) - os.makedirs(upload_dir, exist_ok=True) - - # 处理文件名 - if not file.filename: - raise HTTPException(status_code=400, detail="文件名不能为空") - - # 解析文件名和扩展名 - original_filename = file.filename - name_without_ext, file_extension = os.path.splitext(original_filename) - - # 根据是否指定文件夹决定命名策略 - if folder: - # 使用原始文件名,支持版本控制 - final_filename, version = get_versioned_filename(upload_dir, name_without_ext, file_extension) - file_path = os.path.join(upload_dir, final_filename) - - # 保存文件 - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - return { - "success": True, - "message": f"文件上传成功{' (版本: ' + str(version) + ')' if version > 1 else ''}", - "file_path": file_path, - "folder": target_folder, - "original_filename": original_filename, - "version": version - } - else: - # 使用UUID唯一文件名(原有逻辑) - unique_filename = f"{uuid.uuid4()}{file_extension}" - file_path = os.path.join(upload_dir, unique_filename) - - # 保存文件 - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - return { - "success": True, - "message": "文件上传成功", - "file_path": file_path, - "folder": target_folder, - "original_filename": original_filename - } - - except HTTPException: - raise - except Exception as e: - print(f"Error uploading file: {str(e)}") - raise HTTPException(status_code=500, detail=f"文件上传失败: {str(e)}") - - -@app.get("/api/health") -async def health_check(): - """Health check endpoint""" - return {"message": "Database Assistant API is running"} - - -@app.get("/api/v1/system/performance") -async def get_performance_stats(): - """获取系统性能统计信息""" - try: - # 获取agent管理器统计 - agent_stats = agent_manager.get_cache_stats() - - # 获取连接池统计(简化版) - pool_stats = { - "connection_pool": "active", - "max_connections_per_host": 100, - "max_connections_total": 500, - "keepalive_timeout": 30 - } - - # 获取文件缓存统计 - file_cache_stats = { - "cache_size": len(file_cache._cache) if hasattr(file_cache, '_cache') else 0, - "max_cache_size": file_cache.cache_size if hasattr(file_cache, 'cache_size') else 1000, - "ttl": file_cache.ttl if hasattr(file_cache, 'ttl') else 300 - } - - # 系统资源信息 - import psutil - system_stats = { - "cpu_count": multiprocessing.cpu_count(), - "memory_total_gb": round(psutil.virtual_memory().total / (1024**3), 2), - "memory_available_gb": round(psutil.virtual_memory().available / (1024**3), 2), - "memory_percent": psutil.virtual_memory().percent, - "disk_usage_percent": psutil.disk_usage('/').percent - } - - return { - "success": True, - "timestamp": int(time.time()), - "performance": { - "agent_manager": agent_stats, - "connection_pool": pool_stats, - "file_cache": file_cache_stats, - "system": system_stats - } - } - - except Exception as e: - print(f"Error getting performance stats: {str(e)}") - raise HTTPException(status_code=500, detail=f"获取性能统计失败: {str(e)}") - - -@app.post("/api/v1/system/optimize") -async def optimize_system(profile: str = "balanced"): - """应用系统优化配置""" - try: - # 应用优化配置 - config = apply_optimization_profile(profile) - - return { - "success": True, - "message": f"已应用 {profile} 优化配置", - "config": config - } - - except Exception as e: - print(f"Error applying optimization profile: {str(e)}") - raise HTTPException(status_code=500, detail=f"应用优化配置失败: {str(e)}") - - -@app.post("/api/v1/system/clear-cache") -async def clear_system_cache(cache_type: Optional[str] = None): - """清理系统缓存""" - try: - cleared_counts = {} - - if cache_type is None or cache_type == "agent": - # 清理agent缓存 - agent_count = agent_manager.clear_cache() - cleared_counts["agent_cache"] = agent_count - - if cache_type is None or cache_type == "file": - # 清理文件缓存 - if hasattr(file_cache, '_cache'): - file_count = len(file_cache._cache) - file_cache._cache.clear() - cleared_counts["file_cache"] = file_count - - return { - "success": True, - "message": f"已清理指定类型的缓存", - "cleared_counts": cleared_counts - } - - except Exception as e: - print(f"Error clearing cache: {str(e)}") - raise HTTPException(status_code=500, detail=f"清理缓存失败: {str(e)}") - - -@app.get("/api/v1/system/config") -async def get_system_config(): - """获取当前系统配置""" - try: - return { - "success": True, - "config": { - "max_cached_agents": max_cached_agents, - "shard_count": shard_count, - "tokenizer_parallelism": os.getenv("TOKENIZERS_PARALLELISM", "true"), - "max_connections_per_host": os.getenv("MAX_CONNECTIONS_PER_HOST", "100"), - "max_connections_total": os.getenv("MAX_CONNECTIONS_TOTAL", "500"), - "file_cache_size": os.getenv("FILE_CACHE_SIZE", "1000"), - "file_cache_ttl": os.getenv("FILE_CACHE_TTL", "300") - } - } - - except Exception as e: - print(f"Error getting system config: {str(e)}") - raise HTTPException(status_code=500, detail=f"获取系统配置失败: {str(e)}") - - -@app.post("/system/remove-project-cache") -async def remove_project_cache(dataset_id: str): - """移除特定项目的缓存""" - try: - removed_count = agent_manager.remove_cache_by_unique_id(dataset_id) - if removed_count > 0: - return {"message": f"项目缓存移除成功: {dataset_id}", "removed_count": removed_count} - else: - return {"message": f"未找到项目缓存: {dataset_id}", "removed_count": 0} - except Exception as e: - raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}") - - -@app.get("/api/v1/files/{dataset_id}/status") -async def get_files_processing_status(dataset_id: str): - """获取项目的文件处理状态""" - try: - # Load processed files log - processed_log = load_processed_files_log(dataset_id) - - # Get project directory info - project_dir = os.path.join("projects", "data", dataset_id) - project_exists = os.path.exists(project_dir) - - # Collect document.txt files - document_files = [] - if project_exists: - for root, dirs, files in os.walk(project_dir): - for file in files: - if file == "document.txt": - document_files.append(os.path.join(root, file)) - - return { - "dataset_id": dataset_id, - "project_exists": project_exists, - "processed_files_count": len(processed_log), - "processed_files": processed_log, - "document_files_count": len(document_files), - "document_files": document_files, - "log_file_exists": os.path.exists(os.path.join("projects", "data", dataset_id, "processed_files.json")) - } - except Exception as e: - raise HTTPException(status_code=500, detail=f"获取文件处理状态失败: {str(e)}") - - -@app.post("/api/v1/files/{dataset_id}/reset") -async def reset_files_processing(dataset_id: str): - """重置项目的文件处理状态,删除处理日志和所有文件""" - try: - project_dir = os.path.join("projects", "data", dataset_id) - log_file = os.path.join("projects", "data", dataset_id, "processed_files.json") - - # Load processed log to know what files to remove - processed_log = load_processed_files_log(dataset_id) - - removed_files = [] - # Remove all processed files and their dataset directories - for file_hash, file_info in processed_log.items(): - # Remove local file in files directory - if 'local_path' in file_info: - if remove_file_or_directory(file_info['local_path']): - removed_files.append(file_info['local_path']) - - # Handle new key-based structure first - if 'key' in file_info: - # Remove dataset directory by key - key = file_info['key'] - if remove_dataset_directory_by_key(dataset_id, key): - removed_files.append(f"dataset/{key}") - elif 'filename' in file_info: - # Fallback to old filename-based structure - filename_without_ext = os.path.splitext(file_info['filename'])[0] - dataset_dir = os.path.join("projects", "data", dataset_id, "dataset", filename_without_ext) - if remove_file_or_directory(dataset_dir): - removed_files.append(dataset_dir) - - # Also remove any specific dataset path if exists (fallback) - if 'dataset_path' in file_info: - if remove_file_or_directory(file_info['dataset_path']): - removed_files.append(file_info['dataset_path']) - - # Remove the log file - if remove_file_or_directory(log_file): - removed_files.append(log_file) - - # Remove the entire files directory - files_dir = os.path.join(project_dir, "files") - if remove_file_or_directory(files_dir): - removed_files.append(files_dir) - - # Also remove the entire dataset directory (clean up any remaining files) - dataset_dir = os.path.join(project_dir, "dataset") - if remove_file_or_directory(dataset_dir): - removed_files.append(dataset_dir) - - # Remove README.md if exists - readme_file = os.path.join(project_dir, "README.md") - if remove_file_or_directory(readme_file): - removed_files.append(readme_file) - - return { - "message": f"文件处理状态重置成功: {dataset_id}", - "removed_files_count": len(removed_files), - "removed_files": removed_files - } - except Exception as e: - raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}") - - -@app.post("/api/v1/embedding/encode", response_model=EncodeResponse) -async def encode_texts(request: EncodeRequest): - """ - 文本编码 API - - Args: - request: 包含 texts 和 batch_size 的编码请求 - - Returns: - 编码结果 - """ - try: - model_manager = get_model_manager() - - if not request.texts: - return EncodeResponse( - success=False, - embeddings=[], - shape=[0, 0], - processing_time=0.0, - total_texts=0, - error="texts 不能为空" - ) - - start_time = time.time() - - # 使用模型管理器编码文本 - embeddings = await model_manager.encode_texts( - request.texts, - batch_size=request.batch_size - ) - - processing_time = time.time() - start_time - - # 转换为列表格式 - embeddings_list = embeddings.tolist() - - return EncodeResponse( - success=True, - embeddings=embeddings_list, - shape=list(embeddings.shape), - processing_time=processing_time, - total_texts=len(request.texts) - ) - - except Exception as e: - logger.error(f"文本编码 API 错误: {e}") - return EncodeResponse( - success=False, - embeddings=[], - shape=[0, 0], - processing_time=0.0, - total_texts=len(request.texts) if request else 0, - error=f"编码失败: {str(e)}" - ) +# Include all route modules +app.include_router(chat.router) +app.include_router(files.router) +app.include_router(projects.router) +app.include_router(system.router) # 注册文件管理API路由 app.include_router(file_manager_router) diff --git a/routes/__init__.py b/routes/__init__.py new file mode 100644 index 0000000..8ca12ba --- /dev/null +++ b/routes/__init__.py @@ -0,0 +1 @@ +# Routes package initialization \ No newline at end of file diff --git a/routes/chat.py b/routes/chat.py new file mode 100644 index 0000000..7cee7c8 --- /dev/null +++ b/routes/chat.py @@ -0,0 +1,429 @@ +import json +import os +import asyncio +from typing import Union, Optional +from fastapi import APIRouter, HTTPException, Header +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +from utils import ( + Message, ChatRequest, ChatResponse, + get_global_agent_manager, init_global_sharded_agent_manager +) +from utils.api_models import ChatRequestV2 +from utils.fastapi_utils import ( + process_messages, extract_guidelines_from_system_prompt, format_messages_to_chat_history, + create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config, + call_guideline_llm, _get_optimal_batch_size, process_guideline_batch, get_content_from_messages +) + +router = APIRouter() + +# 初始化全局助手管理器 +agent_manager = init_global_sharded_agent_manager( + max_cached_agents=int(os.getenv("MAX_CACHED_AGENTS", "50")), + shard_count=int(os.getenv("SHARD_COUNT", "16")) +) + + +async def generate_stream_response(agent, messages, tool_response: bool, model: str): + """生成流式响应""" + accumulated_content = "" + chunk_id = 0 + try: + for response in agent.run(messages=messages): + previous_content = accumulated_content + accumulated_content = get_content_from_messages(response, tool_response=tool_response) + + # 计算新增的内容 + if accumulated_content.startswith(previous_content): + new_content = accumulated_content[len(previous_content):] + else: + new_content = accumulated_content + previous_content = "" + + # 只有当有新内容时才发送chunk + if new_content: + chunk_id += 1 + # 构造OpenAI格式的流式响应 + chunk_data = { + "id": f"chatcmpl-{chunk_id}", + "object": "chat.completion.chunk", + "created": int(__import__('time').time()), + "model": model, + "choices": [{ + "index": 0, + "delta": { + "content": new_content + }, + "finish_reason": None + }] + } + + yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n" + + # 发送最终完成标记 + final_chunk = { + "id": f"chatcmpl-{chunk_id + 1}", + "object": "chat.completion.chunk", + "created": int(__import__('time').time()), + "model": model, + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": "stop" + }] + } + yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n" + + # 发送结束标记 + yield "data: [DONE]\n\n" + except Exception as e: + import traceback + error_details = traceback.format_exc() + from utils.logger import logger + logger.error(f"Error in generate_stream_response: {str(e)}") + logger.error(f"Full traceback: {error_details}") + + error_data = { + "error": { + "message": f"Stream error: {str(e)}", + "type": "internal_error" + } + } + yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n" + + +async def create_agent_and_generate_response( + bot_id: str, + api_key: str, + messages: list, + stream: bool, + tool_response: bool, + model_name: str, + model_server: str, + language: str, + system_prompt: Optional[str], + mcp_settings: Optional[list], + robot_type: str, + project_dir: Optional[str] = None, + generate_cfg: Optional[dict] = None, + user_identifier: Optional[str] = None +) -> Union[ChatResponse, StreamingResponse]: + """创建agent并生成响应的公共逻辑""" + if generate_cfg is None: + generate_cfg = {} + + # 1. 从system_prompt提取guideline内容 + system_prompt, guidelines_text = extract_guidelines_from_system_prompt(system_prompt) + print(f"guidelines_text: {guidelines_text}") + + # 2. 如果有guideline内容,进行并发处理 + guideline_analysis = "" + if guidelines_text: + # 按换行符分割guidelines + guidelines_list = [g.strip() for g in guidelines_text.split('\n') if g.strip()] + guidelines_count = len(guidelines_list) + + if guidelines_count > 0: + # 获取最优批次数量(并发数) + batch_count = _get_optimal_batch_size(guidelines_count) + + # 计算每个批次应该包含多少条guideline + guidelines_per_batch = max(1, guidelines_count // batch_count) + + # 分批处理guidelines + batches = [] + for i in range(0, guidelines_count, guidelines_per_batch): + batch = guidelines_list[i:i + guidelines_per_batch] + batches.append(batch) + + # 确保批次数量不超过要求的并发数 + while len(batches) > batch_count: + # 将最后一个批次合并到倒数第二个批次 + batches[-2].extend(batches[-1]) + batches.pop() + + print(f"Processing {guidelines_count} guidelines in {len(batches)} batches with {batch_count} concurrent batches") + + # 准备chat_history + chat_history = format_messages_to_chat_history(messages) + + # 并发执行所有任务:guideline批次处理 + agent创建 + tasks = [] + + # 添加所有guideline批次任务 + for batch in batches: + task = process_guideline_batch( + guidelines_batch=batch, + chat_history=chat_history, + model_name=model_name, + api_key=api_key, + model_server=model_server + ) + tasks.append(task) + + # 添加agent创建任务 + agent_task = agent_manager.get_or_create_agent( + bot_id=bot_id, + project_dir=project_dir, + model_name=model_name, + api_key=api_key, + model_server=model_server, + generate_cfg=generate_cfg, + language=language, + system_prompt=system_prompt, + mcp_settings=mcp_settings, + robot_type=robot_type, + user_identifier=user_identifier + ) + tasks.append(agent_task) + + # 等待所有任务完成 + all_results = await asyncio.gather(*tasks, return_exceptions=True) + + # 处理结果:最后一个结果是agent,前面的是guideline批次结果 + agent = all_results[-1] # agent创建的结果 + batch_results = all_results[:-1] # guideline批次的结果 + + # 合并guideline分析结果 + valid_results = [] + for i, result in enumerate(batch_results): + if isinstance(result, Exception): + print(f"Guideline batch {i} failed: {result}") + continue + if result and result.strip(): + valid_results.append(result.strip()) + + if valid_results: + guideline_analysis = "\n\n".join(valid_results) + print(f"Merged guideline analysis result: {guideline_analysis}") + + # 将分析结果添加到最后一个消息的内容中 + if guideline_analysis and messages: + last_message = messages[-1] + if last_message.get('role') == 'user': + messages[-1]['content'] += f"\n\nActive Guidelines:\n{guideline_analysis}\nPlease follow these guidelines in your response." + else: + # 3. 从全局管理器获取或创建助手实例 + agent = await agent_manager.get_or_create_agent( + bot_id=bot_id, + project_dir=project_dir, + model_name=model_name, + api_key=api_key, + model_server=model_server, + generate_cfg=generate_cfg, + language=language, + system_prompt=system_prompt, + mcp_settings=mcp_settings, + robot_type=robot_type, + user_identifier=user_identifier + ) + + # 根据stream参数决定返回流式还是非流式响应 + if stream: + return StreamingResponse( + generate_stream_response(agent, messages, tool_response, model_name), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} + ) + else: + # 非流式响应 + final_responses = agent.run_nonstream(messages) + + if final_responses and len(final_responses) > 0: + # 使用 get_content_from_messages 处理响应,支持 tool_response 参数 + content = get_content_from_messages(final_responses, tool_response=tool_response) + + # 构造OpenAI格式的响应 + return ChatResponse( + choices=[{ + "index": 0, + "message": { + "role": "assistant", + "content": content + }, + "finish_reason": "stop" + }], + usage={ + "prompt_tokens": sum(len(msg.get("content", "")) for msg in messages), + "completion_tokens": len(content), + "total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(content) + } + ) + else: + raise HTTPException(status_code=500, detail="No response from agent") + + +@router.post("/api/v1/chat/completions") +async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)): + """ + Chat completions API similar to OpenAI, supports both streaming and non-streaming + + Args: + request: ChatRequest containing messages, model, optional dataset_ids list, required bot_id, system_prompt, mcp_settings, and files + authorization: Authorization header containing API key (Bearer ) + + Returns: + Union[ChatResponse, StreamingResponse]: Chat completion response or stream + + Notes: + - dataset_ids: 可选参数,当提供时必须是项目ID列表(单个项目也使用数组格式) + - bot_id: 必需参数,机器人ID + - 只有当 robot_type == "catalog_agent" 且 dataset_ids 为非空数组时才会创建机器人项目目录:projects/robot/{bot_id}/ + - robot_type 为其他值(包括默认的 "agent")时不创建任何目录 + - dataset_ids 为空数组 []、None 或未提供时不创建任何目录 + - 支持多知识库合并,自动处理文件夹重名冲突 + + Required Parameters: + - bot_id: str - 目标机器人ID + - messages: List[Message] - 对话消息列表 + Optional Parameters: + - dataset_ids: List[str] - 源知识库项目ID列表(单个项目也使用数组格式) + - robot_type: str - 机器人类型,默认为 "agent" + + Example: + {"bot_id": "my-bot-001", "messages": [{"role": "user", "content": "Hello"}]} + {"dataset_ids": ["project-123"], "bot_id": "my-bot-001", "messages": [{"role": "user", "content": "Hello"}]} + {"dataset_ids": ["project-123", "project-456"], "bot_id": "my-bot-002", "messages": [{"role": "user", "content": "Hello"}]} + {"dataset_ids": ["project-123"], "bot_id": "my-catalog-bot", "robot_type": "catalog_agent", "messages": [{"role": "user", "content": "Hello"}]} + """ + try: + # v1接口:从Authorization header中提取API key作为模型API密钥 + api_key = extract_api_key_from_auth(authorization) + + # 获取bot_id(必需参数) + bot_id = request.bot_id + if not bot_id: + raise HTTPException(status_code=400, detail="bot_id is required") + + # 创建项目目录(如果有dataset_ids且不是agent类型) + project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type) + + # 收集额外参数作为 generate_cfg + exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier'} + generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields} + + # 处理消息 + messages = process_messages(request.messages, request.language) + + # 调用公共的agent创建和响应生成逻辑 + return await create_agent_and_generate_response( + bot_id=bot_id, + api_key=api_key, + messages=messages, + stream=request.stream, + tool_response=True, + model_name=request.model, + model_server=request.model_server, + language=request.language, + system_prompt=request.system_prompt, + mcp_settings=request.mcp_settings, + robot_type=request.robot_type, + project_dir=project_dir, + generate_cfg=generate_cfg, + user_identifier=request.user_identifier + ) + + except Exception as e: + import traceback + error_details = traceback.format_exc() + print(f"Error in chat_completions: {str(e)}") + print(f"Full traceback: {error_details}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + + +@router.post("/api/v2/chat/completions") +async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[str] = Header(None)): + """ + Chat completions API v2 with simplified parameters. + Only requires messages, stream, tool_response, bot_id, and language parameters. + Other parameters are fetched from the backend bot configuration API. + + Args: + request: ChatRequestV2 containing only essential parameters + authorization: Authorization header for authentication (different from v1) + + Returns: + Union[ChatResponse, StreamingResponse]: Chat completion response or stream + + Required Parameters: + - bot_id: str - 目标机器人ID + - messages: List[Message] - 对话消息列表 + + Optional Parameters: + - stream: bool - 是否流式输出,默认false + - tool_response: bool - 是否包含工具响应,默认false + - language: str - 回复语言,默认"ja" + + Authentication: + - Requires valid MD5 hash token: MD5(MASTERKEY:bot_id) + - Authorization header should contain: Bearer {token} + - Uses MD5 hash of MASTERKEY:bot_id for backend API authentication + - Optionally uses API key from bot config for model access + """ + try: + # 获取bot_id(必需参数) + bot_id = request.bot_id + if not bot_id: + raise HTTPException(status_code=400, detail="bot_id is required") + + # v2接口鉴权验证 + expected_token = generate_v2_auth_token(bot_id) + provided_token = extract_api_key_from_auth(authorization) + + if not provided_token: + raise HTTPException( + status_code=401, + detail="Authorization header is required for v2 API" + ) + + if provided_token != expected_token: + raise HTTPException( + status_code=403, + detail=f"Invalid authentication token. Expected: {expected_token[:8]}..., Provided: {provided_token[:8]}..." + ) + + # 从后端API获取机器人配置(使用v2的鉴权方式) + bot_config = await fetch_bot_config(bot_id) + + # v2接口:API密钥优先从后端配置获取,其次才从Authorization header获取 + # 注意:这里的Authorization header已经用于鉴权,不再作为API key使用 + api_key = bot_config.get("api_key") + + # 创建项目目录(从后端配置获取dataset_ids) + project_dir = create_project_directory( + bot_config.get("dataset_ids", []), + bot_id, + bot_config.get("robot_type", "general_agent") + ) + + # 处理消息 + messages = process_messages(request.messages, request.language) + + # 调用公共的agent创建和响应生成逻辑 + return await create_agent_and_generate_response( + bot_id=bot_id, + api_key=api_key, + messages=messages, + stream=request.stream, + tool_response=request.tool_response, + model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"), + model_server=bot_config.get("model_server", ""), + language=request.language or bot_config.get("language", "ja"), + system_prompt=bot_config.get("system_prompt"), + mcp_settings=bot_config.get("mcp_settings", []), + robot_type=bot_config.get("robot_type", "general_agent"), + project_dir=project_dir, + generate_cfg={}, # v2接口不传递额外的generate_cfg + user_identifier=request.user_identifier + ) + + except HTTPException: + raise + except Exception as e: + import traceback + error_details = traceback.format_exc() + print(f"Error in chat_completions_v2: {str(e)}") + print(f"Full traceback: {error_details}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") diff --git a/routes/files.py b/routes/files.py new file mode 100644 index 0000000..b65f6c3 --- /dev/null +++ b/routes/files.py @@ -0,0 +1,467 @@ +import os +import uuid +import shutil +from datetime import datetime +from typing import Optional, List +from fastapi import APIRouter, HTTPException, Header, UploadFile, File, Form +from pydantic import BaseModel + +from utils import ( + DatasetRequest, QueueTaskRequest, IncrementalTaskRequest, QueueTaskResponse, + load_processed_files_log, remove_file_or_directory, remove_dataset_directory_by_key +) +from utils.fastapi_utils import get_versioned_filename +from task_queue.manager import queue_manager +from task_queue.integration_tasks import process_files_async, process_files_incremental_async, cleanup_project_async +from task_queue.task_status import task_status_store + +router = APIRouter() + + +@router.post("/api/v1/files/process/async") +async def process_files_async_endpoint(request: QueueTaskRequest, authorization: Optional[str] = Header(None)): + """ + 异步处理文件的队列版本API + 与 /api/v1/files/process 功能相同,但使用队列异步处理 + + Args: + request: QueueTaskRequest containing dataset_id, files, system_prompt, mcp_settings, and queue options + authorization: Authorization header containing API key (Bearer ) + + Returns: + QueueTaskResponse: Processing result with task ID for tracking + """ + try: + dataset_id = request.dataset_id + if not dataset_id: + raise HTTPException(status_code=400, detail="dataset_id is required") + + # 估算处理时间(基于文件数量) + estimated_time = 0 + if request.upload_folder: + # 对于upload_folder,无法预先估算文件数量,使用默认时间 + estimated_time = 120 # 默认2分钟 + elif request.files: + total_files = sum(len(file_list) for file_list in request.files.values()) + estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒 + + # 创建任务状态记录 + import uuid + task_id = str(uuid.uuid4()) + task_status_store.set_status( + task_id=task_id, + unique_id=dataset_id, + status="pending" + ) + + # 提交异步任务 + task = process_files_async( + dataset_id=dataset_id, + files=request.files, + upload_folder=request.upload_folder, + task_id=task_id + ) + + # 构建更详细的消息 + message = f"文件处理任务已提交到队列,项目ID: {dataset_id}" + if request.upload_folder: + group_count = len(request.upload_folder) + message += f",将从 {group_count} 个上传文件夹自动扫描文件" + elif request.files: + total_files = sum(len(file_list) for file_list in request.files.values()) + message += f",包含 {total_files} 个文件" + + return QueueTaskResponse( + success=True, + message=message, + dataset_id=dataset_id, + task_id=task_id, # 使用我们自己的task_id + task_status="pending", + estimated_processing_time=estimated_time + ) + + except HTTPException: + raise + except Exception as e: + print(f"Error submitting async file processing task: {str(e)}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + + +@router.post("/api/v1/files/process/incremental") +async def process_files_incremental_endpoint(request: IncrementalTaskRequest, authorization: Optional[str] = Header(None)): + """ + 增量处理文件的队列版本API - 支持添加和删除文件 + + Args: + request: IncrementalTaskRequest containing dataset_id, files_to_add, files_to_remove, system_prompt, mcp_settings, and queue options + authorization: Authorization header containing API key (Bearer ) + + Returns: + QueueTaskResponse: Processing result with task ID for tracking + """ + try: + dataset_id = request.dataset_id + if not dataset_id: + raise HTTPException(status_code=400, detail="dataset_id is required") + + # 验证至少有添加或删除操作 + if not request.files_to_add and not request.files_to_remove: + raise HTTPException(status_code=400, detail="At least one of files_to_add or files_to_remove must be provided") + + # 估算处理时间(基于文件数量) + estimated_time = 0 + total_add_files = sum(len(file_list) for file_list in (request.files_to_add or {}).values()) + total_remove_files = sum(len(file_list) for file_list in (request.files_to_remove or {}).values()) + total_files = total_add_files + total_remove_files + estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒 + + # 创建任务状态记录 + import uuid + task_id = str(uuid.uuid4()) + task_status_store.set_status( + task_id=task_id, + unique_id=dataset_id, + status="pending" + ) + + # 提交增量异步任务 + task = process_files_incremental_async( + dataset_id=dataset_id, + files_to_add=request.files_to_add, + files_to_remove=request.files_to_remove, + system_prompt=request.system_prompt, + mcp_settings=request.mcp_settings, + task_id=task_id + ) + + return QueueTaskResponse( + success=True, + message=f"增量文件处理任务已提交到队列 - 添加 {total_add_files} 个文件,删除 {total_remove_files} 个文件,项目ID: {dataset_id}", + dataset_id=dataset_id, + task_id=task_id, + task_status="pending", + estimated_processing_time=estimated_time + ) + + except HTTPException: + raise + except Exception as e: + print(f"Error submitting incremental file processing task: {str(e)}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + + +@router.get("/api/v1/files/{dataset_id}/status") +async def get_files_processing_status(dataset_id: str): + """获取项目的文件处理状态""" + try: + # Load processed files log + processed_log = load_processed_files_log(dataset_id) + + # Get project directory info + project_dir = os.path.join("projects", "data", dataset_id) + project_exists = os.path.exists(project_dir) + + # Collect document.txt files + document_files = [] + if project_exists: + for root, dirs, files in os.walk(project_dir): + for file in files: + if file == "document.txt": + document_files.append(os.path.join(root, file)) + + return { + "dataset_id": dataset_id, + "project_exists": project_exists, + "processed_files_count": len(processed_log), + "processed_files": processed_log, + "document_files_count": len(document_files), + "document_files": document_files, + "log_file_exists": os.path.exists(os.path.join("projects", "data", dataset_id, "processed_files.json")) + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"获取文件处理状态失败: {str(e)}") + + +@router.post("/api/v1/files/{dataset_id}/reset") +async def reset_files_processing(dataset_id: str): + """重置项目的文件处理状态,删除处理日志和所有文件""" + try: + project_dir = os.path.join("projects", "data", dataset_id) + log_file = os.path.join("projects", "data", dataset_id, "processed_files.json") + + # Load processed log to know what files to remove + processed_log = load_processed_files_log(dataset_id) + + removed_files = [] + # Remove all processed files and their dataset directories + for file_hash, file_info in processed_log.items(): + # Remove local file in files directory + if 'local_path' in file_info: + if remove_file_or_directory(file_info['local_path']): + removed_files.append(file_info['local_path']) + + # Handle new key-based structure first + if 'key' in file_info: + # Remove dataset directory by key + key = file_info['key'] + if remove_dataset_directory_by_key(dataset_id, key): + removed_files.append(f"dataset/{key}") + elif 'filename' in file_info: + # Fallback to old filename-based structure + filename_without_ext = os.path.splitext(file_info['filename'])[0] + dataset_dir = os.path.join("projects", "data", dataset_id, "dataset", filename_without_ext) + if remove_file_or_directory(dataset_dir): + removed_files.append(dataset_dir) + + # Also remove any specific dataset path if exists (fallback) + if 'dataset_path' in file_info: + if remove_file_or_directory(file_info['dataset_path']): + removed_files.append(file_info['dataset_path']) + + # Remove the log file + if remove_file_or_directory(log_file): + removed_files.append(log_file) + + # Remove the entire files directory + files_dir = os.path.join(project_dir, "files") + if remove_file_or_directory(files_dir): + removed_files.append(files_dir) + + # Also remove the entire dataset directory (clean up any remaining files) + dataset_dir = os.path.join(project_dir, "dataset") + if remove_file_or_directory(dataset_dir): + removed_files.append(dataset_dir) + + # Remove README.md if exists + readme_file = os.path.join(project_dir, "README.md") + if remove_file_or_directory(readme_file): + removed_files.append(readme_file) + + return { + "message": f"文件处理状态重置成功: {dataset_id}", + "removed_files_count": len(removed_files), + "removed_files": removed_files + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}") + + +@router.post("/api/v1/files/{dataset_id}/cleanup/async") +async def cleanup_project_async_endpoint(dataset_id: str, remove_all: bool = False): + """异步清理项目文件""" + try: + task = cleanup_project_async(dataset_id=dataset_id, remove_all=remove_all) + + return { + "success": True, + "message": f"项目清理任务已提交到队列,项目ID: {dataset_id}", + "dataset_id": dataset_id, + "task_id": task.id, + "action": "remove_all" if remove_all else "cleanup_logs" + } + except Exception as e: + print(f"Error submitting cleanup task: {str(e)}") + raise HTTPException(status_code=500, detail=f"提交清理任务失败: {str(e)}") + + +@router.post("/api/v1/upload") +async def upload_file(file: UploadFile = File(...), folder: Optional[str] = Form(None)): + """ + 文件上传API接口,上传文件到 ./projects/uploads/ 目录下 + + 可以指定自定义文件夹名,如果不指定则使用日期文件夹 + 指定文件夹时使用原始文件名并支持版本控制 + + Args: + file: 上传的文件 + folder: 可选的自定义文件夹名 + + Returns: + dict: 包含文件路径和文件夹信息的响应 + """ + try: + # 调试信息 + print(f"Received folder parameter: {folder}") + print(f"File received: {file.filename if file else 'None'}") + + # 确定上传文件夹 + if folder: + # 使用指定的自定义文件夹 + target_folder = folder + # 安全性检查:防止路径遍历攻击 + target_folder = os.path.basename(target_folder) + else: + # 获取当前日期并格式化为年月日 + current_date = datetime.now() + target_folder = current_date.strftime("%Y%m%d") + + # 创建上传目录 + upload_dir = os.path.join("projects", "uploads", target_folder) + os.makedirs(upload_dir, exist_ok=True) + + # 处理文件名 + if not file.filename: + raise HTTPException(status_code=400, detail="文件名不能为空") + + # 解析文件名和扩展名 + original_filename = file.filename + name_without_ext, file_extension = os.path.splitext(original_filename) + + # 根据是否指定文件夹决定命名策略 + if folder: + # 使用原始文件名,支持版本控制 + final_filename, version = get_versioned_filename(upload_dir, name_without_ext, file_extension) + file_path = os.path.join(upload_dir, final_filename) + + # 保存文件 + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + return { + "success": True, + "message": f"文件上传成功{' (版本: ' + str(version) + ')' if version > 1 else ''}", + "file_path": file_path, + "folder": target_folder, + "original_filename": original_filename, + "version": version + } + else: + # 使用UUID唯一文件名(原有逻辑) + unique_filename = f"{uuid.uuid4()}{file_extension}" + file_path = os.path.join(upload_dir, unique_filename) + + # 保存文件 + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + return { + "success": True, + "message": "文件上传成功", + "file_path": file_path, + "folder": target_folder, + "original_filename": original_filename + } + + except HTTPException: + raise + except Exception as e: + print(f"Error uploading file: {str(e)}") + raise HTTPException(status_code=500, detail=f"文件上传失败: {str(e)}") + + +# Task management routes that are related to file processing +@router.get("/api/v1/task/{task_id}/status") +async def get_task_status(task_id: str): + """获取任务状态 - 简单可靠""" + try: + status_data = task_status_store.get_status(task_id) + + if not status_data: + return { + "success": False, + "message": "任务不存在或已过期", + "task_id": task_id, + "status": "not_found" + } + + return { + "success": True, + "message": "任务状态获取成功", + "task_id": task_id, + "status": status_data["status"], + "unique_id": status_data["unique_id"], + "created_at": status_data["created_at"], + "updated_at": status_data["updated_at"], + "result": status_data.get("result"), + "error": status_data.get("error") + } + + except Exception as e: + print(f"Error getting task status: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取任务状态失败: {str(e)}") + + +@router.delete("/api/v1/task/{task_id}") +async def delete_task(task_id: str): + """删除任务记录""" + try: + success = task_status_store.delete_status(task_id) + if success: + return { + "success": True, + "message": f"任务记录已删除: {task_id}", + "task_id": task_id + } + else: + return { + "success": False, + "message": f"任务记录不存在: {task_id}", + "task_id": task_id + } + except Exception as e: + print(f"Error deleting task: {str(e)}") + raise HTTPException(status_code=500, detail=f"删除任务记录失败: {str(e)}") + + +@router.get("/api/v1/tasks") +async def list_tasks(status: Optional[str] = None, dataset_id: Optional[str] = None, limit: int = 100): + """列出任务,支持筛选""" + try: + if status or dataset_id: + # 使用搜索功能 + tasks = task_status_store.search_tasks(status=status, unique_id=dataset_id, limit=limit) + else: + # 获取所有任务 + all_tasks = task_status_store.list_all() + tasks = list(all_tasks.values())[:limit] + + return { + "success": True, + "message": "任务列表获取成功", + "total_tasks": len(tasks), + "tasks": tasks, + "filters": { + "status": status, + "dataset_id": dataset_id, + "limit": limit + } + } + + except Exception as e: + print(f"Error listing tasks: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取任务列表失败: {str(e)}") + + +@router.get("/api/v1/tasks/statistics") +async def get_task_statistics(): + """获取任务统计信息""" + try: + stats = task_status_store.get_statistics() + + return { + "success": True, + "message": "统计信息获取成功", + "statistics": stats + } + + except Exception as e: + print(f"Error getting statistics: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}") + + +@router.post("/api/v1/tasks/cleanup") +async def cleanup_tasks(older_than_days: int = 7): + """清理旧任务记录""" + try: + deleted_count = task_status_store.cleanup_old_tasks(older_than_days=older_than_days) + + return { + "success": True, + "message": f"已清理 {deleted_count} 条旧任务记录", + "deleted_count": deleted_count, + "older_than_days": older_than_days + } + + except Exception as e: + print(f"Error cleaning up tasks: {str(e)}") + raise HTTPException(status_code=500, detail=f"清理任务记录失败: {str(e)}") \ No newline at end of file diff --git a/routes/projects.py b/routes/projects.py new file mode 100644 index 0000000..c068133 --- /dev/null +++ b/routes/projects.py @@ -0,0 +1,173 @@ +import os +import json +from typing import Optional +from fastapi import APIRouter, HTTPException + +from task_queue.task_status import task_status_store + +router = APIRouter() + + +@router.get("/api/v1/projects") +async def list_all_projects(): + """获取所有项目列表""" + try: + # 获取机器人项目(projects/robot) + robot_dir = "projects/robot" + robot_projects = [] + + if os.path.exists(robot_dir): + for item in os.listdir(robot_dir): + item_path = os.path.join(robot_dir, item) + if os.path.isdir(item_path): + try: + # 读取机器人配置文件 + config_path = os.path.join(item_path, "robot_config.json") + config_data = {} + if os.path.exists(config_path): + with open(config_path, 'r', encoding='utf-8') as f: + config_data = json.load(f) + + # 统计文件数量 + file_count = 0 + if os.path.exists(os.path.join(item_path, "dataset")): + for root, dirs, files in os.walk(os.path.join(item_path, "dataset")): + file_count += len(files) + + robot_projects.append({ + "id": item, + "name": config_data.get("name", item), + "type": "robot", + "status": config_data.get("status", "active"), + "file_count": file_count, + "config": config_data, + "created_at": os.path.getctime(item_path), + "updated_at": os.path.getmtime(item_path) + }) + except Exception as e: + print(f"Error reading robot project {item}: {str(e)}") + robot_projects.append({ + "id": item, + "name": item, + "type": "robot", + "status": "unknown", + "file_count": 0, + "created_at": os.path.getctime(item_path), + "updated_at": os.path.getmtime(item_path) + }) + + # 获取数据集(projects/data) + data_dir = "projects/data" + datasets = [] + + if os.path.exists(data_dir): + for item in os.listdir(data_dir): + item_path = os.path.join(data_dir, item) + if os.path.isdir(item_path): + try: + # 读取处理日志 + log_path = os.path.join(item_path, "processing_log.json") + log_data = {} + if os.path.exists(log_path): + with open(log_path, 'r', encoding='utf-8') as f: + log_data = json.load(f) + + # 统计文件数量 + file_count = 0 + for root, dirs, files in os.walk(item_path): + file_count += len([f for f in files if not f.endswith('.pkl')]) + + # 获取状态 + status = "active" + if log_data.get("status"): + status = log_data["status"] + elif os.path.exists(os.path.join(item_path, "processed")): + status = "completed" + + datasets.append({ + "id": item, + "name": f"数据集 - {item[:8]}...", + "type": "dataset", + "status": status, + "file_count": file_count, + "log_data": log_data, + "created_at": os.path.getctime(item_path), + "updated_at": os.path.getmtime(item_path) + }) + except Exception as e: + print(f"Error reading dataset {item}: {str(e)}") + datasets.append({ + "id": item, + "name": f"数据集 - {item[:8]}...", + "type": "dataset", + "status": "unknown", + "file_count": 0, + "created_at": os.path.getctime(item_path), + "updated_at": os.path.getmtime(item_path) + }) + + all_projects = robot_projects + datasets + + return { + "success": True, + "message": "项目列表获取成功", + "total_projects": len(all_projects), + "robot_projects": robot_projects, + "datasets": datasets, + "projects": all_projects # 保持向后兼容 + } + + except Exception as e: + print(f"Error listing projects: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取项目列表失败: {str(e)}") + + +@router.get("/api/v1/projects/robot") +async def list_robot_projects(): + """获取机器人项目列表""" + try: + response = await list_all_projects() + return { + "success": True, + "message": "机器人项目列表获取成功", + "total_projects": len(response["robot_projects"]), + "projects": response["robot_projects"] + } + except Exception as e: + print(f"Error listing robot projects: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取机器人项目列表失败: {str(e)}") + + +@router.get("/api/v1/projects/datasets") +async def list_datasets(): + """获取数据集列表""" + try: + response = await list_all_projects() + return { + "success": True, + "message": "数据集列表获取成功", + "total_projects": len(response["datasets"]), + "projects": response["datasets"] + } + except Exception as e: + print(f"Error listing datasets: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取数据集列表失败: {str(e)}") + + +@router.get("/api/v1/projects/{dataset_id}/tasks") +async def get_project_tasks(dataset_id: str): + """获取指定项目的所有任务""" + try: + tasks = task_status_store.get_by_unique_id(dataset_id) + + return { + "success": True, + "message": "项目任务获取成功", + "dataset_id": dataset_id, + "total_tasks": len(tasks), + "tasks": tasks + } + + except Exception as e: + print(f"Error getting project tasks: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取项目任务失败: {str(e)}") \ No newline at end of file diff --git a/routes/system.py b/routes/system.py new file mode 100644 index 0000000..13d3ac0 --- /dev/null +++ b/routes/system.py @@ -0,0 +1,272 @@ +import os +import time +import multiprocessing +from typing import Optional +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from utils import ( + get_global_agent_manager, init_global_sharded_agent_manager, + get_global_connection_pool, init_global_connection_pool, + get_global_file_cache, init_global_file_cache, + setup_system_optimizations +) +try: + from utils.system_optimizer import apply_optimization_profile +except ImportError: + def apply_optimization_profile(profile): + return {"profile": profile, "status": "system_optimizer not available"} +from utils.fastapi_utils import get_content_from_messages +from embedding import get_model_manager +from pydantic import BaseModel + +router = APIRouter() + + +class EncodeRequest(BaseModel): + texts: list[str] + batch_size: int = 32 + + +class EncodeResponse(BaseModel): + success: bool + embeddings: list[list[float]] + shape: list[int] + processing_time: float + total_texts: int + error: Optional[str] = None + + +# 系统优化设置初始化 +print("正在初始化系统优化...") +system_optimizer = setup_system_optimizations() + +# 全局助手管理器配置(使用优化后的配置) +max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "50")) # 增加缓存大小 +shard_count = int(os.getenv("SHARD_COUNT", "16")) # 分片数量 + +# 初始化优化的全局助手管理器 +agent_manager = init_global_sharded_agent_manager( + max_cached_agents=max_cached_agents, + shard_count=shard_count +) + +# 初始化连接池 +connection_pool = init_global_connection_pool( + max_connections_per_host=int(os.getenv("MAX_CONNECTIONS_PER_HOST", "100")), + max_connections_total=int(os.getenv("MAX_CONNECTIONS_TOTAL", "500")), + keepalive_timeout=int(os.getenv("KEEPALIVE_TIMEOUT", "30")), + connect_timeout=int(os.getenv("CONNECT_TIMEOUT", "10")), + total_timeout=int(os.getenv("TOTAL_TIMEOUT", "60")) +) + +# 初始化文件缓存 +file_cache = init_global_file_cache( + cache_size=int(os.getenv("FILE_CACHE_SIZE", "1000")), + ttl=int(os.getenv("FILE_CACHE_TTL", "300")) +) + +print("系统优化初始化完成") +print(f"- 分片Agent管理器: {shard_count} 个分片,最多缓存 {max_cached_agents} 个agent") +print(f"- 连接池: 每主机100连接,总计500连接") +print(f"- 文件缓存: 1000个文件,TTL 300秒") + + +@router.get("/api/health") +async def health_check(): + """Health check endpoint""" + return {"message": "Database Assistant API is running"} + + +@router.get("/api/v1/system/performance") +async def get_performance_stats(): + """获取系统性能统计信息""" + try: + # 获取agent管理器统计 + agent_stats = agent_manager.get_cache_stats() + + # 获取连接池统计(简化版) + pool_stats = { + "connection_pool": "active", + "max_connections_per_host": 100, + "max_connections_total": 500, + "keepalive_timeout": 30 + } + + # 获取文件缓存统计 + file_cache_stats = { + "cache_size": len(file_cache._cache) if hasattr(file_cache, '_cache') else 0, + "max_cache_size": file_cache.cache_size if hasattr(file_cache, 'cache_size') else 1000, + "ttl": file_cache.ttl if hasattr(file_cache, 'ttl') else 300 + } + + # 系统资源信息 + try: + import psutil + system_stats = { + "cpu_count": multiprocessing.cpu_count(), + "memory_total_gb": round(psutil.virtual_memory().total / (1024**3), 2), + "memory_available_gb": round(psutil.virtual_memory().available / (1024**3), 2), + "memory_percent": psutil.virtual_memory().percent, + "disk_usage_percent": psutil.disk_usage('/').percent + } + except ImportError: + system_stats = { + "cpu_count": multiprocessing.cpu_count(), + "memory_info": "psutil not available" + } + + return { + "success": True, + "timestamp": int(time.time()), + "performance": { + "agent_manager": agent_stats, + "connection_pool": pool_stats, + "file_cache": file_cache_stats, + "system": system_stats + } + } + + except Exception as e: + print(f"Error getting performance stats: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取性能统计失败: {str(e)}") + + +@router.post("/api/v1/system/optimize") +async def optimize_system(profile: str = "balanced"): + """应用系统优化配置""" + try: + # 应用优化配置 + config = apply_optimization_profile(profile) + + return { + "success": True, + "message": f"已应用 {profile} 优化配置", + "config": config + } + + except Exception as e: + print(f"Error applying optimization profile: {str(e)}") + raise HTTPException(status_code=500, detail=f"应用优化配置失败: {str(e)}") + + +@router.post("/api/v1/system/clear-cache") +async def clear_system_cache(cache_type: Optional[str] = None): + """清理系统缓存""" + try: + cleared_counts = {} + + if cache_type is None or cache_type == "agent": + # 清理agent缓存 + agent_count = agent_manager.clear_cache() + cleared_counts["agent_cache"] = agent_count + + if cache_type is None or cache_type == "file": + # 清理文件缓存 + if hasattr(file_cache, '_cache'): + file_count = len(file_cache._cache) + file_cache._cache.clear() + cleared_counts["file_cache"] = file_count + + return { + "success": True, + "message": f"已清理指定类型的缓存", + "cleared_counts": cleared_counts + } + + except Exception as e: + print(f"Error clearing cache: {str(e)}") + raise HTTPException(status_code=500, detail=f"清理缓存失败: {str(e)}") + + +@router.get("/api/v1/system/config") +async def get_system_config(): + """获取当前系统配置""" + try: + return { + "success": True, + "config": { + "max_cached_agents": max_cached_agents, + "shard_count": shard_count, + "tokenizer_parallelism": os.getenv("TOKENIZERS_PARALLELISM", "true"), + "max_connections_per_host": os.getenv("MAX_CONNECTIONS_PER_HOST", "100"), + "max_connections_total": os.getenv("MAX_CONNECTIONS_TOTAL", "500"), + "file_cache_size": os.getenv("FILE_CACHE_SIZE", "1000"), + "file_cache_ttl": os.getenv("FILE_CACHE_TTL", "300") + } + } + + except Exception as e: + print(f"Error getting system config: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取系统配置失败: {str(e)}") + + +@router.post("/system/remove-project-cache") +async def remove_project_cache(dataset_id: str): + """移除特定项目的缓存""" + try: + removed_count = agent_manager.remove_cache_by_unique_id(dataset_id) + if removed_count > 0: + return {"message": f"项目缓存移除成功: {dataset_id}", "removed_count": removed_count} + else: + return {"message": f"未找到项目缓存: {dataset_id}", "removed_count": 0} + except Exception as e: + raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}") + + +@router.post("/api/v1/embedding/encode", response_model=EncodeResponse) +async def encode_texts(request: EncodeRequest): + """ + 文本编码 API + + Args: + request: 包含 texts 和 batch_size 的编码请求 + + Returns: + 编码结果 + """ + try: + model_manager = get_model_manager() + + if not request.texts: + return EncodeResponse( + success=False, + embeddings=[], + shape=[0, 0], + processing_time=0.0, + total_texts=0, + error="texts 不能为空" + ) + + start_time = time.time() + + # 使用模型管理器编码文本 + embeddings = await model_manager.encode_texts( + request.texts, + batch_size=request.batch_size + ) + + processing_time = time.time() - start_time + + # 转换为列表格式 + embeddings_list = embeddings.tolist() + + return EncodeResponse( + success=True, + embeddings=embeddings_list, + shape=list(embeddings.shape), + processing_time=processing_time, + total_texts=len(request.texts) + ) + + except Exception as e: + from utils.logger import logger + logger.error(f"文本编码 API 错误: {e}") + return EncodeResponse( + success=False, + embeddings=[], + shape=[0, 0], + processing_time=0.0, + total_texts=len(request.texts) if request else 0, + error=f"编码失败: {str(e)}" + ) \ No newline at end of file diff --git a/utils/fastapi_utils.py b/utils/fastapi_utils.py new file mode 100644 index 0000000..3f2094c --- /dev/null +++ b/utils/fastapi_utils.py @@ -0,0 +1,488 @@ +import os +import re +import hashlib +import asyncio +from typing import List, Dict, Optional, Union, Any +import aiohttp +from qwen_agent.llm.schema import ASSISTANT, FUNCTION +from qwen_agent.llm.oai import TextChatAtOAI +from fastapi import HTTPException +from utils.logger import logger + + +def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extension: str) -> tuple[str, int]: + """ + 获取带版本号的文件名,自动处理文件删除和版本递增 + + Args: + upload_dir: 上传目录路径 + name_without_ext: 不含扩展名的文件名 + file_extension: 文件扩展名(包含点号) + + Returns: + tuple[str, int]: (最终文件名, 版本号) + """ + # 检查原始文件是否存在 + original_file = os.path.join(upload_dir, name_without_ext + file_extension) + original_exists = os.path.exists(original_file) + + # 查找所有相关的版本化文件 + pattern = re.compile(re.escape(name_without_ext) + r'_(\d+)' + re.escape(file_extension) + r'$') + existing_versions = [] + files_to_delete = [] + + for filename in os.listdir(upload_dir): + # 检查是否是原始文件 + if filename == name_without_ext + file_extension: + files_to_delete.append(filename) + continue + + # 检查是否是版本化文件 + match = pattern.match(filename) + if match: + version_num = int(match.group(1)) + existing_versions.append(version_num) + files_to_delete.append(filename) + + # 如果没有任何相关文件存在,使用原始文件名(版本1) + if not original_exists and not existing_versions: + return name_without_ext + file_extension, 1 + + # 删除所有现有文件(原始文件和版本化文件) + for filename in files_to_delete: + file_to_delete = os.path.join(upload_dir, filename) + try: + os.remove(file_to_delete) + print(f"已删除文件: {file_to_delete}") + except OSError as e: + print(f"删除文件失败 {file_to_delete}: {e}") + + # 确定下一个版本号 + if existing_versions: + next_version = max(existing_versions) + 1 + else: + next_version = 2 + + # 生成带版本号的文件名 + versioned_filename = f"{name_without_ext}_{next_version}{file_extension}" + + return versioned_filename, next_version + + +def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str: + """Extract content from qwen-agent messages with special formatting""" + full_text = '' + content = [] + TOOL_CALL_S = '[TOOL_CALL]' + TOOL_RESULT_S = '[TOOL_RESPONSE]' + THOUGHT_S = '[THINK]' + ANSWER_S = '[ANSWER]' + + for msg in messages: + + if msg['role'] == ASSISTANT: + if msg.get('reasoning_content'): + assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages' + content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}') + if msg.get('content'): + assert isinstance(msg['content'], str), 'Now only supports text messages' + # 过滤掉流式输出中的不完整 tool_call 文本 + content_text = msg["content"] + + # 使用正则表达式替换不完整的 tool_call 模式为空字符串 + + # 匹配并替换不完整的 tool_call 模式 + content_text = re.sub(r' List[Dict[str, str]]: + """处理消息列表,包括[TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER]分割和语言指令添加 + + 这是 get_content_from_messages 的逆运算,将包含 [TOOL_RESPONSE] 的消息重新组装回 + msg['role'] == 'function' 和 msg.get('function_call') 的格式。 + """ + processed_messages = [] + + # 收集所有ASSISTANT消息的索引 + assistant_indices = [i for i, msg in enumerate(messages) if msg.role == "assistant"] + total_assistant_messages = len(assistant_indices) + cutoff_point = max(0, total_assistant_messages - 5) + + # 处理每条消息 + for i, msg in enumerate(messages): + if msg.role == "assistant": + # 确定当前ASSISTANT消息在所有ASSISTANT消息中的位置(从0开始) + assistant_position = assistant_indices.index(i) + + # 使用正则表达式按照 [TOOL_CALL]|[TOOL_RESPONSE]|[ANSWER] 进行切割 + parts = re.split(r'\[(TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg.content) + + # 重新组装内容,根据消息位置决定处理方式 + filtered_content = "" + current_tag = None + is_recent_message = assistant_position >= cutoff_point # 最近10条消息 + + for i in range(0, len(parts)): + if i % 2 == 0: # 文本内容 + text = parts[i].strip() + if not text: + continue + + if current_tag == "TOOL_RESPONSE": + if is_recent_message: + # 最近10条ASSISTANT消息:保留完整TOOL_RESPONSE信息(使用简略模式) + if len(text) <= 500: + filtered_content += f"[TOOL_RESPONSE]\n{text}\n" + else: + # 截取前中后3段内容,每段250字 + first_part = text[:250] + middle_start = len(text) // 2 - 125 + middle_part = text[middle_start:middle_start + 250] + last_part = text[-250:] + + # 计算省略的字数 + omitted_count = len(text) - 750 + omitted_text = f"...此处省略{omitted_count}字..." + + # 拼接内容 + truncated_text = f"{first_part}\n{omitted_text}\n{middle_part}\n{omitted_text}\n{last_part}" + filtered_content += f"[TOOL_RESPONSE]\n{truncated_text}\n" + # 10条以上的消息:不保留TOOL_RESPONSE数据(完全跳过) + elif current_tag == "TOOL_CALL": + if is_recent_message: + # 最近10条ASSISTANT消息:保留TOOL_CALL信息 + filtered_content += f"[TOOL_CALL]\n{text}\n" + # 10条以上的消息:不保留TOOL_CALL数据(完全跳过) + elif current_tag == "ANSWER": + # 所有ASSISTANT消息都保留ANSWER数据 + filtered_content += f"[ANSWER]\n{text}\n" + else: + # 第一个标签之前的内容 + filtered_content += text + "\n" + else: # 标签 + current_tag = parts[i] + + # 取最终处理后的内容,去除首尾空白 + final_content = filtered_content.strip() + if final_content: + processed_messages.append({"role": msg.role, "content": final_content}) + else: + # 如果处理后为空,使用原内容 + processed_messages.append({"role": msg.role, "content": msg.content}) + else: + processed_messages.append({"role": msg.role, "content": msg.content}) + + # 逆运算:将包含 [TOOL_RESPONSE] 的消息重新组装回 msg['role'] == 'function' 和 msg.get('function_call') + # 这是 get_content_from_messages 的逆运算 + final_messages = [] + for msg in processed_messages: + if msg["role"] == ASSISTANT and "[TOOL_RESPONSE]" in msg["content"]: + # 分割消息内容 + parts = re.split(r'\[(TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"]) + + current_tag = None + assistant_content = "" + function_calls = [] + tool_responses = [] + + for i in range(0, len(parts)): + if i % 2 == 0: # 文本内容 + text = parts[i].strip() + if not text: + continue + + if current_tag == "TOOL_RESPONSE": + # 解析 TOOL_RESPONSE 格式:[TOOL_RESPONSE] function_name\ncontent + lines = text.split('\n', 1) + function_name = lines[0].strip() if lines else "" + response_content = lines[1].strip() if len(lines) > 1 else "" + + tool_responses.append({ + "role": FUNCTION, + "name": function_name, + "content": response_content + }) + elif current_tag == "TOOL_CALL": + # 解析 TOOL_CALL 格式:[TOOL_CALL] function_name\narguments + lines = text.split('\n', 1) + function_name = lines[0].strip() if lines else "" + arguments = lines[1].strip() if len(lines) > 1 else "" + + function_calls.append({ + "name": function_name, + "arguments": arguments + }) + elif current_tag == "ANSWER": + assistant_content += text + "\n" + else: + # 第一个标签之前的内容也属于 assistant + assistant_content += text + "\n" + else: # 标签 + current_tag = parts[i] + + # 添加 assistant 消息(如果有内容) + if assistant_content.strip() or function_calls: + assistant_msg = {"role": ASSISTANT} + if assistant_content.strip(): + assistant_msg["content"] = assistant_content.strip() + if function_calls: + # 如果有多个 function_call,只取第一个(兼容原有逻辑) + assistant_msg["function_call"] = function_calls[0] + final_messages.append(assistant_msg) + + # 添加所有 tool_responses 作为 function 消息 + final_messages.extend(tool_responses) + else: + # 非 assistant 消息或不包含 [TOOL_RESPONSE] 的消息直接添加 + final_messages.append(msg) + + # 在最后一条消息的末尾追加回复语言 + if final_messages and language: + language_map = { + 'zh': '请用中文回复', + 'en': 'Please reply in English', + 'ja': '日本語で回答してください', + 'jp': '日本語で回答してください' + } + language_instruction = language_map.get(language.lower(), '') + if language_instruction: + # 在最后一条消息末尾追加语言指令 + final_messages[-1]['content'] = final_messages[-1]['content'] + f"\n\nlanguage:\n{language_instruction}。" + + return final_messages + + +def extract_guidelines_from_system_prompt(system_prompt: Optional[str]) -> tuple[str, str]: + """从system_prompt中提取```guideline内容并清理原提示词 + + Returns: + tuple[str, str]: (清理后的system_prompt, 提取的guidelines内容) + """ + if not system_prompt: + return "", "" + + # 使用正则表达式提取 ```guideline``` 包裹的内容 + pattern = r'```guideline\s*\n(.*?)\n```' + matches = re.findall(pattern, system_prompt, re.DOTALL) + + guidelines_text = "\n".join(matches).strip() + + # # 从原始system_prompt中删除 ```guideline``` 内容块 + # cleaned_prompt = re.sub(pattern, '', system_prompt, flags=re.DOTALL) + + # # 清理多余的空行 + # cleaned_prompt = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_prompt).strip() + + return system_prompt, guidelines_text + + +def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str: + """将messages格式化为纯文本聊天记录 + + Args: + messages: 消息列表 + + Returns: + str: 格式化的聊天记录 + """ + chat_history = [] + + for message in messages: + role = message.get('role', '') + content = message.get('content', '') + + if role == 'user': + chat_history.append(f"user: {content}") + elif role == 'assistant': + chat_history.append(f"assistant: {content}") + # 忽略其他角色(如function等) + + return "\n".join(chat_history) + + +def create_project_directory(dataset_ids: Optional[List[str]], bot_id: str, robot_type: str = "general_agent") -> Optional[str]: + """创建项目目录的公共逻辑""" + # 只有当 robot_type == "catalog_agent" 且 dataset_ids 不为空时才创建目录 + if robot_type != "catalog_agent" or not dataset_ids or len(dataset_ids) == 0: + return None + + try: + from utils.multi_project_manager import create_robot_project + return create_robot_project(dataset_ids, bot_id) + except Exception as e: + print(f"Error creating project directory: {e}") + return None + + +def extract_api_key_from_auth(authorization: Optional[str]) -> Optional[str]: + """从Authorization header中提取API key""" + if not authorization: + return None + + # 移除 "Bearer " 前缀 + if authorization.startswith("Bearer "): + return authorization[7:] + else: + return authorization + + +def generate_v2_auth_token(bot_id: str) -> str: + """生成v2接口的认证token""" + masterkey = os.getenv("MASTERKEY", "master") + token_input = f"{masterkey}:{bot_id}" + return hashlib.md5(token_input.encode()).hexdigest() + + +async def fetch_bot_config(bot_id: str) -> Dict[str, Any]: + """获取机器人配置从后端API""" + try: + backend_host = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai") + url = f"{backend_host}/v1/agent_bot_config/{bot_id}" + + auth_token = generate_v2_auth_token(bot_id) + headers = { + "content-type": "application/json", + "authorization": f"Bearer {auth_token}" + } + # 使用异步HTTP请求 + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers, timeout=30) as response: + if response.status != 200: + raise HTTPException( + status_code=400, + detail=f"Failed to fetch bot config: API returned status code {response.status}" + ) + + # 解析响应 + response_data = await response.json() + + if not response_data.get("success"): + raise HTTPException( + status_code=400, + detail=f"Failed to fetch bot config: {response_data.get('message', 'Unknown error')}" + ) + + return response_data.get("data", {}) + + except aiohttp.ClientError as e: + raise HTTPException( + status_code=500, + detail=f"Failed to connect to backend API: {str(e)}" + ) + except Exception as e: + if isinstance(e, HTTPException): + raise + raise HTTPException( + status_code=500, + detail=f"Failed to fetch bot config: {str(e)}" + ) + + +async def call_guideline_llm(chat_history: str, guidelines_text: str, model_name: str, api_key: str, model_server: str) -> str: + """调用大语言模型处理guideline分析 + + Args: + chat_history: 聊天历史记录 + guidelines_text: 指导原则文本 + model_name: 模型名称 + api_key: API密钥 + model_server: 模型服务器地址 + + Returns: + str: 模型响应结果 + """ + # 读取guideline提示词模板 + try: + with open('./prompt/guideline_prompt.md', 'r', encoding='utf-8') as f: + guideline_template = f.read() + except Exception as e: + print(f"Error reading guideline prompt template: {e}") + return "" + + # 替换模板中的占位符 + system_prompt = guideline_template.replace('{chat_history}', chat_history).replace('{guidelines_text}', guidelines_text) + + # 配置LLM + llm_config = { + 'model': model_name, + 'api_key': api_key, + 'model_server': model_server, # 使用传入的model_server参数 + } + + # 创建LLM实例 + llm_instance = TextChatAtOAI(llm_config) + + # 调用模型 + messages = [{'role': 'user', 'content': system_prompt}] + + try: + # 设置stream=False来获取非流式响应 + response = llm_instance.chat(messages=messages, stream=False) + + # 处理响应 + if isinstance(response, list) and response: + # 如果返回的是Message列表,提取内容 + if hasattr(response[0], 'content'): + return response[0].content + elif isinstance(response[0], dict) and 'content' in response[0]: + return response[0]['content'] + + # 如果是字符串,直接返回 + if isinstance(response, str): + return response + + # 处理其他类型 + return str(response) if response else "" + + except Exception as e: + print(f"Error calling guideline LLM: {e}") + return "" + + +def _get_optimal_batch_size(guidelines_count: int) -> int: + """根据guidelines数量决定最优批次数量(并发数)""" + if guidelines_count <= 10: + return 1 + elif guidelines_count <= 20: + return 2 + elif guidelines_count <= 30: + return 3 + else: + return 5 + + +async def process_guideline_batch( + guidelines_batch: List[str], + chat_history: str, + model_name: str, + api_key: str, + model_server: str +) -> str: + """处理单个guideline批次""" + try: + # 调用LLM分析这批guidelines + batch_guidelines_text = "\n".join(guidelines_batch) + batch_analysis = await call_guideline_llm(chat_history, batch_guidelines_text, model_name, api_key, model_server) + + return batch_analysis + except Exception as e: + print(f"Error processing guideline batch: {e}") + return ""