新增agent文件夹,修改import引用,增加custom_mcp_manager

This commit is contained in:
朱潮 2025-11-26 17:23:02 +08:00
parent 7002019229
commit 58ac6e3024
11 changed files with 536 additions and 723 deletions

0
agent/__init__.py Normal file
View File

474
agent/custom_mcp_manager.py Normal file
View File

@ -0,0 +1,474 @@
# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import atexit
import datetime
import json
import threading
import time
import uuid
from contextlib import AsyncExitStack
from typing import Dict, Optional, Union
from dotenv import load_dotenv
from qwen_agent.log import logger
from qwen_agent.tools.base import BaseTool
class CustomMCPManager:
_instance = None # Private class variable to store the unique instance
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(CustomMCPManager, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self):
if not hasattr(self, 'clients'): # The singleton should only be inited once
"""Set a new event loop in a separate thread"""
try:
import mcp # noqa
except ImportError as e:
raise ImportError('Could not import mcp. Please install mcp with `pip install -U mcp`.') from e
load_dotenv() # Load environment variables from .env file
self.clients: dict = {}
self.loop = asyncio.new_event_loop()
self.loop_thread = threading.Thread(target=self.start_loop, daemon=True)
self.loop_thread.start()
# A fallback way to terminate MCP tool processes after Qwen-Agent exits
self.processes = []
self.monkey_patch_mcp_create_platform_compatible_process()
def monkey_patch_mcp_create_platform_compatible_process(self):
try:
import mcp.client.stdio
target = mcp.client.stdio._create_platform_compatible_process
except (ModuleNotFoundError, AttributeError) as e:
raise ImportError('Qwen-Agent needs to monkey patch MCP for process cleanup. '
'Please upgrade MCP to a higher version with `pip install -U mcp`.') from e
async def _monkey_patched_create_platform_compatible_process(*args, **kwargs):
process = await target(*args, **kwargs)
self.processes.append(process)
return process
mcp.client.stdio._create_platform_compatible_process = _monkey_patched_create_platform_compatible_process
def start_loop(self):
asyncio.set_event_loop(self.loop)
# Set a global exception handler to silently handle cross-task exceptions from MCP SSE connections
def exception_handler(loop, context):
exception = context.get('exception')
if exception:
# Silently handle cross-task exceptions from MCP SSE connections
if (isinstance(exception, RuntimeError) and
'Attempted to exit cancel scope in a different task' in str(exception)):
return # Silently ignore this type of exception
if (isinstance(exception, BaseExceptionGroup) and # noqa
'Attempted to exit cancel scope in a different task' in str(exception)): # noqa
return # Silently ignore this type of exception
# Other exceptions are handled normally
loop.default_exception_handler(context)
self.loop.set_exception_handler(exception_handler)
self.loop.run_forever()
def is_valid_mcp_servers(self, config: dict):
"""Example of mcp servers configuration:
{
"mcpServers": {
"memory": {
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-memory"]
},
"filesystem": {
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/allowed/files"]
}
}
}
"""
# Check if the top-level key "mcpServers" exists and its value is a dictionary
if not isinstance(config, dict) or 'mcpServers' not in config or not isinstance(config['mcpServers'], dict):
return False
mcp_servers = config['mcpServers']
# Check each sub-item under "mcpServers"
for key in mcp_servers:
server = mcp_servers[key]
# Each sub-item must be a dictionary
if not isinstance(server, dict):
return False
if 'command' in server:
# "command" must be a string
if not isinstance(server['command'], str):
return False
# "args" must be a list
if 'args' not in server or not isinstance(server['args'], list):
return False
if 'url' in server:
# "url" must be a string
if not isinstance(server['url'], str):
return False
# "headers" must be a dictionary
if 'headers' in server and not isinstance(server['headers'], dict):
return False
# If the "env" key exists, it must be a dictionary
if 'env' in server and not isinstance(server['env'], dict):
return False
return True
def initConfig(self, config: Dict):
if not self.is_valid_mcp_servers(config):
raise ValueError('Config of mcpservers is not valid')
logger.info(f'Initializing MCP tools from mcp servers: {list(config["mcpServers"].keys())}')
# Submit coroutine to the event loop and wait for the result
future = asyncio.run_coroutine_threadsafe(self.init_config_async(config), self.loop)
try:
result = future.result() # You can specify a timeout if desired
return result
except Exception as e:
logger.info(f'Failed in initializing MCP tools: {e}')
raise e
async def init_config_async(self, config: Dict):
tools: list = []
mcp_servers = config['mcpServers']
for server_name in mcp_servers:
client = CustomMCPClient()
server = mcp_servers[server_name]
await client.connection_server(mcp_server_name=server_name,
mcp_server=server) # Attempt to connect to the server
client_id = server_name + '_' + str(
uuid.uuid4()) # To allow the same server name be used across different running agents
client.client_id = client_id # Ensure client_id is set on the client instance
self.clients[client_id] = client # Add to clients dict after successful connection
for tool in client.tools:
"""MCP tool example:
{
"name": "read_query",
"description": "Execute a SELECT query on the SQLite database",
"inputSchema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SELECT SQL query to execute"
}
},
"required": ["query"]
}
"""
parameters = tool.inputSchema
# The required field in inputSchema may be empty and needs to be initialized.
if 'required' not in parameters:
parameters['required'] = []
# Remove keys from parameters that do not conform to the standard OpenAI schema
# Check if the required fields exist
required_fields = {'type', 'properties', 'required'}
missing_fields = required_fields - parameters.keys()
if missing_fields:
raise ValueError(f'Missing required fields in schema: {missing_fields}')
# Keep only the necessary fields
cleaned_parameters = {
'type': parameters['type'],
'properties': parameters['properties'],
'required': parameters['required']
}
register_name = server_name + '-' + tool.name
agent_tool = self.create_tool_class(register_name=register_name,
register_client_id=client_id,
tool_name=tool.name,
tool_desc=tool.description,
tool_parameters=cleaned_parameters)
tools.append(agent_tool)
if client.resources:
"""MCP resource example:
{
uri: string; // Unique identifier for the resource
name: string; // Human-readable name
description?: string; // Optional description
mimeType?: string; // Optional MIME type
}
"""
# List resources
list_resources_tool_name = server_name + '-' + 'list_resources'
list_resources_params = {'type': 'object', 'properties': {}, 'required': []}
list_resources_agent_tool = self.create_tool_class(
register_name=list_resources_tool_name,
register_client_id=client_id,
tool_name='list_resources',
tool_desc='Servers expose a list of concrete resources through this tool. '
'By invoking it, you can discover the available resources and obtain resource templates, which help clients understand how to construct valid URIs. '
'These URI formats will be used as input parameters for the read_resource function. ',
tool_parameters=list_resources_params)
tools.append(list_resources_agent_tool)
# Read resource
resources_template_str = '' # Check if there are resource templates
try:
list_resource_templates = await client.session.list_resource_templates(
) # Check if the server has resources tesmplate
if list_resource_templates.resourceTemplates:
resources_template_str = '\n'.join(
str(template) for template in list_resource_templates.resourceTemplates)
except Exception as e:
logger.info(f'Failed in listing MCP resource templates: {e}')
read_resource_tool_name = server_name + '-' + 'read_resource'
read_resource_params = {
'type': 'object',
'properties': {
'uri': {
'type': 'string',
'description': 'The URI identifying the specific resource to access'
}
},
'required': ['uri']
}
original_tool_desc = 'Request to access a resource provided by a connected MCP server. Resources represent data sources that can be used as context, such as files, API responses, or system information.'
if resources_template_str:
tool_desc = original_tool_desc + '\nResource Templates:\n' + resources_template_str
else:
tool_desc = original_tool_desc
read_resource_agent_tool = self.create_tool_class(register_name=read_resource_tool_name,
register_client_id=client_id,
tool_name='read_resource',
tool_desc=tool_desc,
tool_parameters=read_resource_params)
tools.append(read_resource_agent_tool)
return tools
def create_tool_class(self, register_name, register_client_id, tool_name, tool_desc, tool_parameters):
class ToolClass(BaseTool):
name = register_name
description = tool_desc
parameters = tool_parameters
client_id = register_client_id
def call(self, params: Union[str, dict], **kwargs) -> str:
tool_args = json.loads(params)
# Submit coroutine to the event loop and wait for the result
manager = CustomMCPManager()
client = manager.clients[self.client_id]
future = asyncio.run_coroutine_threadsafe(client.execute_function(tool_name, tool_args), manager.loop)
try:
result = future.result()
return result
except Exception as e:
logger.info(f'Failed in executing MCP tool: {e}')
raise e
ToolClass.__name__ = f'{register_name}_Class'
return ToolClass()
def shutdown(self):
futures = []
for client_id in list(self.clients.keys()):
client: CustomMCPClient = self.clients[client_id]
future = asyncio.run_coroutine_threadsafe(client.cleanup(), self.loop)
futures.append(future)
del self.clients[client_id]
time.sleep(1) # Wait for the graceful cleanups, otherwise fall back
# fallback
if asyncio.all_tasks(self.loop):
logger.info(
'There are still tasks in `CustomMCPManager().loop`, force terminating the MCP tool processes. There may be some exceptions.'
)
for process in self.processes:
try:
process.terminate()
except ProcessLookupError:
pass # it's ok, the process may exit earlier
self.loop.call_soon_threadsafe(self.loop.stop)
self.loop_thread.join()
class CustomMCPClient:
def __init__(self):
from mcp import ClientSession
self.session: Optional[ClientSession] = None
self.tools: list = None
self.exit_stack = AsyncExitStack()
self.resources: bool = False
self._last_mcp_server_name = None
self._last_mcp_server = None
self.client_id = None # For replacing in MCPManager.clients
async def connection_server(self, mcp_server_name, mcp_server):
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
"""Connect to an MCP server and retrieve the available tools."""
# Save parameters
self._last_mcp_server_name = mcp_server_name
self._last_mcp_server = mcp_server
try:
if 'url' in mcp_server:
url = mcp_server.get('url')
sse_read_timeout = mcp_server.get('sse_read_timeout', 300)
logger.info(f'{mcp_server_name} sse_read_timeout: {sse_read_timeout}s')
if mcp_server.get('type', 'sse') == 'streamable-http':
# streamable-http mode
"""streamable-http mode mcp example:
{"mcpServers": {
"streamable-mcp-server": {
"type": "streamable-http",
"url":"http://0.0.0.0:8000/mcp"
}
}
}
"""
headers = mcp_server.get('headers', {})
self._streams_context = streamablehttp_client(
url=url, headers=headers, sse_read_timeout=datetime.timedelta(seconds=sse_read_timeout))
read_stream, write_stream, get_session_id = await self.exit_stack.enter_async_context(
self._streams_context)
self._session_context = ClientSession(read_stream, write_stream)
self.session = await self.exit_stack.enter_async_context(self._session_context)
else:
# sse mode
headers = mcp_server.get('headers', {'Accept': 'text/event-stream'})
self._streams_context = sse_client(url, headers, sse_read_timeout=sse_read_timeout)
streams = await self.exit_stack.enter_async_context(self._streams_context)
self._session_context = ClientSession(*streams)
self.session = await self.exit_stack.enter_async_context(self._session_context)
else:
server_params = StdioServerParameters(command=mcp_server['command'],
args=mcp_server['args'],
env=mcp_server.get('env', None))
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
logger.info(
f'Initializing a MCP stdio_client, if this takes forever, please check the config of this mcp server: {mcp_server_name}'
)
await self.session.initialize()
list_tools = await self.session.list_tools()
self.tools = list_tools.tools
try:
list_resources = await self.session.list_resources() # Check if the server has resources
if list_resources.resources:
self.resources = True
except Exception:
# logger.info(f"No list resources: {e}")
pass
except Exception as e:
logger.warning(f'Failed in connecting to MCP server: {e}')
raise e
async def reconnect(self):
# Create a new MCPClient and connect
if self.client_id is None:
raise RuntimeError(
'Cannot reconnect: client_id is None. This usually means the client was not properly registered in MCPManager.'
)
new_client = CustomMCPClient()
new_client.client_id = self.client_id
await new_client.connection_server(self._last_mcp_server_name, self._last_mcp_server)
return new_client
async def execute_function(self, tool_name, tool_args: dict):
from mcp.types import TextResourceContents
# Check if session is alive
try:
await self.session.send_ping()
except Exception as e:
logger.info(f"Session is not alive, please increase 'sse_read_timeout' in the config, try reconnect: {e}")
# Auto reconnect
try:
manager = CustomMCPManager()
if self.client_id is not None:
manager.clients[self.client_id] = await self.reconnect()
return await manager.clients[self.client_id].execute_function(tool_name, tool_args)
else:
logger.info('Reconnect failed: client_id is None')
return 'Session reconnect (client creation) exception: client_id is None'
except Exception as e3:
logger.info(f'Reconnect (client creation) exception type: {type(e3)}, value: {repr(e3)}')
return f'Session reconnect (client creation) exception: {e3}'
if tool_name == 'list_resources':
try:
list_resources = await self.session.list_resources()
if list_resources.resources:
resources_str = '\n\n'.join(str(resource) for resource in list_resources.resources)
else:
resources_str = 'No resources found'
return resources_str
except Exception as e:
logger.info(f'No list resources: {e}')
return f'Error: {e}'
elif tool_name == 'read_resource':
try:
uri = tool_args.get('uri')
if not uri:
raise ValueError('URI is required for read_resource')
read_resource = await self.session.read_resource(uri)
texts = []
for resource in read_resource.contents:
if isinstance(resource, TextResourceContents):
texts.append(resource.text)
# if isinstance(resource, BlobResourceContents):
# texts.append(resource.blob)
if texts:
return '\n\n'.join(texts)
else:
return 'Failed to read resource'
except Exception as e:
logger.info(f'Failed to read resource: {e}')
return f'Error: {e}'
else:
response = await self.session.call_tool(tool_name, tool_args)
texts = []
for content in response.content:
if content.type == 'text':
texts.append(content.text)
if texts:
return '\n\n'.join(texts)
else:
return 'execute error'
async def cleanup(self):
await self.exit_stack.aclose()
def _cleanup_mcp(_sig_num=None, _frame=None):
if CustomMCPManager._instance is None:
return
manager = CustomMCPManager()
manager.shutdown()
# Make sure all subprocesses are terminated even if killed abnormally
if threading.current_thread() is threading.main_thread():
atexit.register(_cleanup_mcp)

View File

@ -23,8 +23,8 @@ from typing import Dict, List, Optional
from qwen_agent.agents import Assistant
from qwen_agent.log import logger
from modified_assistant import init_modified_agent_service_with_files, update_agent_llm
from .prompt_loader import load_system_prompt_async, load_mcp_settings_async
from agent.modified_assistant import init_modified_agent_service_with_files, update_agent_llm
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
class FileLoadedAgentManager:

View File

@ -14,6 +14,7 @@
import copy
import json
import logging
import os
import time
from typing import Dict, Iterator, List, Literal, Optional, Union
@ -21,7 +22,11 @@ from typing import Dict, Iterator, List, Literal, Optional, Union
from qwen_agent.agents import Assistant
from qwen_agent.llm.schema import ASSISTANT, FUNCTION, Message
from qwen_agent.llm.oai import TextChatAtOAI
from utils.logger import tool_logger
from qwen_agent.tools import BaseTool
from agent.custom_mcp_manager import CustomMCPManager
# 设置工具日志记录器
tool_logger = logging.getLogger('tool_logger')
class ModifiedAssistant(Assistant):
"""
@ -94,6 +99,26 @@ class ModifiedAssistant(Assistant):
error_message = f'An error occurred when calling tool {tool_name}: {type(ex).__name__}: {str(ex)}'
return error_message
def _init_tool(self, tool: Union[str, Dict, BaseTool]):
"""重写工具初始化方法使用CustomMCPManager处理MCP服务器配置"""
if isinstance(tool, BaseTool):
# 处理BaseTool实例
tool_name = tool.name
if tool_name in self.function_map:
tool_logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
self.function_map[tool_name] = tool
elif isinstance(tool, dict) and 'mcpServers' in tool:
# 使用CustomMCPManager处理MCP服务器配置支持headers
tools = CustomMCPManager().initConfig(tool)
for tool in tools:
tool_name = tool.name
if tool_name in self.function_map:
tool_logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
self.function_map[tool_name] = tool
else:
# 调用父类的处理方法
super()._init_tool(tool)
def _call_llm_with_retry(self, messages: List[Message], functions=None, extra_generate_cfg=None, max_retries: int = 5) -> Iterator:
"""带重试机制的LLM调用

View File

@ -109,7 +109,7 @@ async def load_system_prompt_async(project_dir: str, language: str = None, syste
Returns:
str: 加载到的系统提示词内容
"""
from .config_cache import config_cache
from agent.config_cache import config_cache
# 获取语言显示名称
language_display_map = {
@ -235,7 +235,7 @@ async def load_mcp_settings_async(project_dir: str, mcp_settings: list=None, bot
支持在 mcp_settings.json args 中使用 {dataset_dir} 占位符
会在 init_modified_agent_service_with_files 中被替换为实际的路径
"""
from .config_cache import config_cache
from agent.config_cache import config_cache
# 1. 首先读取默认MCP设置
default_mcp_settings = []

View File

@ -26,8 +26,8 @@ from collections import defaultdict
from qwen_agent.agents import Assistant
from qwen_agent.log import logger
from modified_assistant import init_modified_agent_service_with_files, update_agent_llm
from .prompt_loader import load_system_prompt_async, load_mcp_settings_async
from agent.modified_assistant import init_modified_agent_service_with_files, update_agent_llm
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
class ShardedAgentManager:

View File

@ -1,691 +0,0 @@
#!/usr/bin/env python3
"""
文件管理API - WebDAV的HTTP API替代方案
提供RESTful接口来管理projects文件夹
"""
import os
import shutil
from pathlib import Path
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from fastapi.responses import FileResponse, StreamingResponse
import mimetypes
import json
import zipfile
import tempfile
import io
import math
router = APIRouter(prefix="/api/v1/files", tags=["file_management"])
PROJECTS_DIR = Path("projects")
PROJECTS_DIR.mkdir(exist_ok=True)
@router.get("/list")
async def list_files(path: str = "", recursive: bool = False):
"""
列出目录内容
Args:
path: 相对路径空字符串表示根目录
recursive: 是否递归列出所有子目录
"""
try:
target_path = PROJECTS_DIR / path
if not target_path.exists():
raise HTTPException(status_code=404, detail="路径不存在")
if not target_path.is_dir():
raise HTTPException(status_code=400, detail="路径不是目录")
def scan_directory(directory: Path, base_path: Path = PROJECTS_DIR) -> List[Dict[str, Any]]:
items = []
try:
for item in directory.iterdir():
# 跳过隐藏文件
if item.name.startswith('.'):
continue
relative_path = item.relative_to(base_path)
stat = item.stat()
item_info = {
"name": item.name,
"path": str(relative_path),
"type": "directory" if item.is_dir() else "file",
"size": stat.st_size if item.is_file() else 0,
"modified": stat.st_mtime,
"created": stat.st_ctime
}
items.append(item_info)
# 递归扫描子目录
if recursive and item.is_dir():
items.extend(scan_directory(item, base_path))
except PermissionError:
pass
return items
items = scan_directory(target_path)
return {
"success": True,
"path": path,
"items": items,
"total": len(items)
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"列出目录失败: {str(e)}")
@router.post("/upload")
async def upload_file(file: UploadFile = File(...), path: str = Form("")):
"""
上传文件
Args:
file: 上传的文件
path: 目标路径相对于projects目录
"""
try:
target_dir = PROJECTS_DIR / path
target_dir.mkdir(parents=True, exist_ok=True)
file_path = target_dir / file.filename
# 如果文件已存在,检查是否覆盖
if file_path.exists():
# 可以添加版本控制或重命名逻辑
pass
with open(file_path, "wb") as buffer:
content = await file.read()
buffer.write(content)
return {
"success": True,
"message": "文件上传成功",
"filename": file.filename,
"path": str(Path(path) / file.filename),
"size": len(content)
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"文件上传失败: {str(e)}")
@router.get("/download/{file_path:path}")
async def download_file(file_path: str):
"""
下载文件
Args:
file_path: 文件相对路径
"""
try:
target_path = PROJECTS_DIR / file_path
if not target_path.exists():
raise HTTPException(status_code=404, detail="文件不存在")
if not target_path.is_file():
raise HTTPException(status_code=400, detail="不是文件")
# 猜测MIME类型
mime_type, _ = mimetypes.guess_type(str(target_path))
if mime_type is None:
mime_type = "application/octet-stream"
return FileResponse(
path=str(target_path),
filename=target_path.name,
media_type=mime_type
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"文件下载失败: {str(e)}")
@router.delete("/delete")
async def delete_item(path: str):
"""
删除文件或目录
Args:
path: 要删除的路径
"""
try:
target_path = PROJECTS_DIR / path
if not target_path.exists():
raise HTTPException(status_code=404, detail="路径不存在")
if target_path.is_file():
target_path.unlink()
elif target_path.is_dir():
shutil.rmtree(target_path)
return {
"success": True,
"message": f"{'文件' if target_path.is_file() else '目录'}删除成功",
"path": path
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}")
@router.post("/create-folder")
async def create_folder(path: str, name: str):
"""
创建文件夹
Args:
path: 父目录路径
name: 新文件夹名称
"""
try:
parent_path = PROJECTS_DIR / path
parent_path.mkdir(parents=True, exist_ok=True)
new_folder = parent_path / name
if new_folder.exists():
raise HTTPException(status_code=400, detail="文件夹已存在")
new_folder.mkdir()
return {
"success": True,
"message": "文件夹创建成功",
"path": str(Path(path) / name)
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"创建文件夹失败: {str(e)}")
@router.post("/rename")
async def rename_item(old_path: str, new_name: str):
"""
重命名文件或文件夹
Args:
old_path: 原路径
new_name: 新名称
"""
try:
old_full_path = PROJECTS_DIR / old_path
if not old_full_path.exists():
raise HTTPException(status_code=404, detail="文件或目录不存在")
new_full_path = old_full_path.parent / new_name
if new_full_path.exists():
raise HTTPException(status_code=400, detail="目标名称已存在")
old_full_path.rename(new_full_path)
return {
"success": True,
"message": "重命名成功",
"old_path": old_path,
"new_path": str(new_full_path.relative_to(PROJECTS_DIR))
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"重命名失败: {str(e)}")
@router.post("/move")
async def move_item(source_path: str, target_path: str):
"""
移动文件或文件夹
Args:
source_path: 源路径
target_path: 目标路径
"""
try:
source_full_path = PROJECTS_DIR / source_path
target_full_path = PROJECTS_DIR / target_path
if not source_full_path.exists():
raise HTTPException(status_code=404, detail="源文件或目录不存在")
# 确保目标目录存在
target_full_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(source_full_path), str(target_full_path))
return {
"success": True,
"message": "移动成功",
"source_path": source_path,
"target_path": target_path
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"移动失败: {str(e)}")
@router.post("/copy")
async def copy_item(source_path: str, target_path: str):
"""
复制文件或文件夹
Args:
source_path: 源路径
target_path: 目标路径
"""
try:
source_full_path = PROJECTS_DIR / source_path
target_full_path = PROJECTS_DIR / target_path
if not source_full_path.exists():
raise HTTPException(status_code=404, detail="源文件或目录不存在")
# 确保目标目录存在
target_full_path.parent.mkdir(parents=True, exist_ok=True)
if source_full_path.is_file():
shutil.copy2(str(source_full_path), str(target_full_path))
else:
shutil.copytree(str(source_full_path), str(target_full_path))
return {
"success": True,
"message": "复制成功",
"source_path": source_path,
"target_path": target_path
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"复制失败: {str(e)}")
@router.get("/search")
async def search_files(query: str, path: str = "", file_type: Optional[str] = None):
"""
搜索文件
Args:
query: 搜索关键词
path: 搜索路径
file_type: 文件类型过滤
"""
try:
search_path = PROJECTS_DIR / path
if not search_path.exists():
raise HTTPException(status_code=404, detail="搜索路径不存在")
results = []
def scan_for_files(directory: Path):
try:
for item in directory.iterdir():
if item.name.startswith('.'):
continue
# 检查文件名是否包含关键词
if query.lower() in item.name.lower():
# 检查文件类型过滤
if file_type:
if item.suffix.lower() == file_type.lower():
results.append({
"name": item.name,
"path": str(item.relative_to(PROJECTS_DIR)),
"type": "directory" if item.is_dir() else "file"
})
else:
results.append({
"name": item.name,
"path": str(item.relative_to(PROJECTS_DIR)),
"type": "directory" if item.is_dir() else "file"
})
# 递归搜索子目录
if item.is_dir():
scan_for_files(item)
except PermissionError:
pass
scan_for_files(search_path)
return {
"success": True,
"query": query,
"path": path,
"results": results,
"total": len(results)
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"搜索失败: {str(e)}")
@router.get("/info/{file_path:path}")
async def get_file_info(file_path: str):
"""
获取文件或文件夹详细信息
Args:
file_path: 文件路径
"""
try:
target_path = PROJECTS_DIR / file_path
if not target_path.exists():
raise HTTPException(status_code=404, detail="路径不存在")
stat = target_path.stat()
info = {
"name": target_path.name,
"path": file_path,
"type": "directory" if target_path.is_dir() else "file",
"size": stat.st_size if target_path.is_file() else 0,
"modified": stat.st_mtime,
"created": stat.st_ctime,
"permissions": oct(stat.st_mode)[-3:]
}
# 如果是文件,添加额外信息
if target_path.is_file():
mime_type, _ = mimetypes.guess_type(str(target_path))
info["mime_type"] = mime_type or "unknown"
# 读取文件内容预览(仅对小文件)
if stat.st_size < 1024 * 1024: # 小于1MB
try:
with open(target_path, 'r', encoding='utf-8') as f:
content = f.read(1000) # 读取前1000字符
info["preview"] = content
except:
info["preview"] = "[二进制文件,无法预览]"
return {
"success": True,
"info": info
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取文件信息失败: {str(e)}")
@router.post("/download-folder-zip")
async def download_folder_as_zip(request: Dict[str, str]):
"""
将文件夹压缩为ZIP并下载
Args:
request: 包含path字段的JSON对象
"""
try:
folder_path = request.get("path", "")
if not folder_path:
raise HTTPException(status_code=400, detail="路径不能为空")
target_path = PROJECTS_DIR / folder_path
if not target_path.exists():
raise HTTPException(status_code=404, detail="文件夹不存在")
if not target_path.is_dir():
raise HTTPException(status_code=400, detail="路径不是文件夹")
# 计算文件夹大小,检查是否过大
total_size = 0
file_count = 0
for file_path in target_path.rglob('*'):
if file_path.is_file():
total_size += file_path.stat().st_size
file_count += 1
# 限制最大500MB
max_size = 500 * 1024 * 1024
if total_size > max_size:
raise HTTPException(
status_code=413,
detail=f"文件夹过大 ({formatFileSize(total_size)}),最大支持 {formatFileSize(max_size)}"
)
# 限制文件数量
max_files = 10000
if file_count > max_files:
raise HTTPException(
status_code=413,
detail=f"文件数量过多 ({file_count}),最大支持 {max_files} 个文件"
)
# 创建ZIP文件名
folder_name = target_path.name
zip_filename = f"{folder_name}.zip"
# 创建ZIP文件
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED, compresslevel=6) as zipf:
# 添加文件夹中的所有文件
for file_path in target_path.rglob('*'):
if file_path.is_file():
try:
# 计算相对路径,保持文件夹结构
arcname = file_path.relative_to(target_path)
zipf.write(file_path, arcname)
except (OSError, IOError) as e:
# 跳过无法读取的文件,但记录警告
print(f"Warning: Skipping file {file_path}: {e}")
continue
zip_buffer.seek(0)
# 设置响应头
headers = {
'Content-Disposition': f'attachment; filename="{zip_filename}"',
'Content-Type': 'application/zip',
'Content-Length': str(len(zip_buffer.getvalue())),
'Cache-Control': 'no-cache'
}
# 创建流式响应
from fastapi.responses import StreamingResponse
async def generate_zip():
zip_buffer.seek(0)
yield zip_buffer.getvalue()
return StreamingResponse(
generate_zip(),
media_type="application/zip",
headers=headers
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"压缩文件夹失败: {str(e)}")
@router.post("/download-multiple-zip")
async def download_multiple_items_as_zip(request: Dict[str, Any]):
"""
将多个文件和文件夹压缩为ZIP并下载
Args:
request: 包含paths和filename字段的JSON对象
"""
try:
paths = request.get("paths", [])
filename = request.get("filename", "batch_download.zip")
if not paths:
raise HTTPException(status_code=400, detail="请选择要下载的文件")
# 验证所有路径
valid_paths = []
total_size = 0
file_count = 0
for path in paths:
target_path = PROJECTS_DIR / path
if not target_path.exists():
continue # 跳过不存在的文件
if target_path.is_file():
total_size += target_path.stat().st_size
file_count += 1
valid_paths.append(path)
elif target_path.is_dir():
# 计算文件夹大小
for file_path in target_path.rglob('*'):
if file_path.is_file():
total_size += file_path.stat().st_size
file_count += 1
valid_paths.append(path)
if not valid_paths:
raise HTTPException(status_code=404, detail="没有找到有效的文件")
# 限制大小
max_size = 500 * 1024 * 1024
if total_size > max_size:
raise HTTPException(
status_code=413,
detail=f"选中文件过大 ({formatFileSize(total_size)}),最大支持 {formatFileSize(max_size)}"
)
# 限制文件数量
max_files = 10000
if file_count > max_files:
raise HTTPException(
status_code=413,
detail=f"文件数量过多 ({file_count}),最大支持 {max_files} 个文件"
)
# 创建ZIP文件
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED, compresslevel=6) as zipf:
for path in valid_paths:
target_path = PROJECTS_DIR / path
if target_path.is_file():
# 单个文件
try:
zipf.write(target_path, target_path.name)
except (OSError, IOError) as e:
print(f"Warning: Skipping file {target_path}: {e}")
continue
elif target_path.is_dir():
# 文件夹
for file_path in target_path.rglob('*'):
if file_path.is_file():
try:
# 保持相对路径结构
arcname = f"{target_path.name}/{file_path.relative_to(target_path)}"
zipf.write(file_path, arcname)
except (OSError, IOError) as e:
print(f"Warning: Skipping file {file_path}: {e}")
continue
zip_buffer.seek(0)
# 设置响应头
headers = {
'Content-Disposition': f'attachment; filename="{filename}"',
'Content-Type': 'application/zip',
'Content-Length': str(len(zip_buffer.getvalue())),
'Cache-Control': 'no-cache'
}
# 创建流式响应
from fastapi.responses import StreamingResponse
async def generate_zip():
zip_buffer.seek(0)
yield zip_buffer.getvalue()
return StreamingResponse(
generate_zip(),
media_type="application/zip",
headers=headers
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"批量压缩失败: {str(e)}")
def formatFileSize(bytes_size: int) -> str:
"""格式化文件大小"""
if bytes_size == 0:
return "0 B"
k = 1024
sizes = ["B", "KB", "MB", "GB", "TB"]
i = int(math.floor(math.log(bytes_size, k)))
if i >= len(sizes):
i = len(sizes) - 1
return f"{bytes_size / math.pow(k, i):.1f} {sizes[i]}"
@router.post("/batch-operation")
async def batch_operation(operations: List[Dict[str, Any]]):
"""
批量操作多个文件
Args:
operations: 操作列表每个操作包含type和相应的参数
"""
try:
results = []
for op in operations:
op_type = op.get("type")
op_result = {"type": op_type, "success": False}
try:
if op_type == "delete":
await delete_item(op["path"])
op_result["success"] = True
op_result["message"] = "删除成功"
elif op_type == "move":
await move_item(op["source_path"], op["target_path"])
op_result["success"] = True
op_result["message"] = "移动成功"
elif op_type == "copy":
await copy_item(op["source_path"], op["target_path"])
op_result["success"] = True
op_result["message"] = "复制成功"
else:
op_result["error"] = f"不支持的操作类型: {op_type}"
except Exception as e:
op_result["error"] = str(e)
results.append(op_result)
return {
"success": True,
"results": results,
"total": len(operations),
"successful": sum(1 for r in results if r["success"])
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"批量操作失败: {str(e)}")

View File

@ -7,9 +7,9 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from utils import (
Message, ChatRequest, ChatResponse,
get_global_agent_manager, init_global_sharded_agent_manager
Message, ChatRequest, ChatResponse
)
from agent.sharded_agent_manager import init_global_sharded_agent_manager
from utils.api_models import ChatRequestV2
from utils.fastapi_utils import (
process_messages, extract_guidelines_from_system_prompt, format_messages_to_chat_history,
@ -205,6 +205,7 @@ async def create_agent_and_generate_response(
# 处理结果最后一个结果是agent前面的是guideline批次结果
agent = all_results[-1] # agent创建的结果
batch_results = all_results[:-1] # guideline批次的结果
print(f"batch_results:{batch_results}")
# 合并guideline分析结果使用JSON格式的checks数组
all_checks = []

View File

@ -6,11 +6,11 @@ from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from utils import (
get_global_agent_manager, init_global_sharded_agent_manager,
get_global_connection_pool, init_global_connection_pool,
get_global_file_cache, init_global_file_cache,
setup_system_optimizations
)
from agent.sharded_agent_manager import init_global_sharded_agent_manager
try:
from utils.system_optimizer import apply_optimization_profile
except ImportError:

View File

@ -32,17 +32,19 @@ from .project_manager import (
)
# Import agent management modules
from .file_loaded_agent_manager import (
get_global_agent_manager,
init_global_agent_manager
)
# Note: These have been moved to agent package
# from .file_loaded_agent_manager import (
# get_global_agent_manager,
# init_global_agent_manager
# )
# Import optimized modules
from .sharded_agent_manager import (
ShardedAgentManager,
get_global_sharded_agent_manager,
init_global_sharded_agent_manager
)
# Note: These have been moved to agent package
# from .sharded_agent_manager import (
# ShardedAgentManager,
# get_global_sharded_agent_manager,
# init_global_sharded_agent_manager
# )
from .connection_pool import (
HTTPConnectionPool,
@ -78,10 +80,11 @@ from .system_optimizer import (
)
# Import config cache module
from .config_cache import (
config_cache,
ConfigFileCache
)
# Note: This has been moved to agent package
# from .config_cache import (
# config_cache,
# ConfigFileCache
# )
from .agent_pool import (
AgentPool,
@ -123,9 +126,10 @@ from .api_models import (
create_chat_response
)
from .prompt_loader import (
load_system_prompt,
)
# Note: This has been moved to agent package
# from .prompt_loader import (
# load_system_prompt,
# )
from .multi_project_manager import (
create_robot_project,
@ -162,9 +166,9 @@ __all__ = [
'list_projects',
'get_project_stats',
# file_loaded_agent_manager
'get_global_agent_manager',
'init_global_agent_manager',
# file_loaded_agent_manager (moved to agent package)
# 'get_global_agent_manager',
# 'init_global_agent_manager',
# agent_pool
'AgentPool',
@ -202,8 +206,8 @@ __all__ = [
'create_error_response',
'create_chat_response',
# prompt_loader
'load_system_prompt',
# prompt_loader (moved to agent package)
# 'load_system_prompt',
# multi_project_manager
'create_robot_project',