新增agent文件夹,修改import引用,增加custom_mcp_manager
This commit is contained in:
parent
7002019229
commit
58ac6e3024
0
agent/__init__.py
Normal file
0
agent/__init__.py
Normal file
474
agent/custom_mcp_manager.py
Normal file
474
agent/custom_mcp_manager.py
Normal file
@ -0,0 +1,474 @@
|
|||||||
|
# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import atexit
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from qwen_agent.log import logger
|
||||||
|
from qwen_agent.tools.base import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class CustomMCPManager:
|
||||||
|
_instance = None # Private class variable to store the unique instance
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(CustomMCPManager, cls).__new__(cls, *args, **kwargs)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not hasattr(self, 'clients'): # The singleton should only be inited once
|
||||||
|
"""Set a new event loop in a separate thread"""
|
||||||
|
try:
|
||||||
|
import mcp # noqa
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError('Could not import mcp. Please install mcp with `pip install -U mcp`.') from e
|
||||||
|
|
||||||
|
load_dotenv() # Load environment variables from .env file
|
||||||
|
self.clients: dict = {}
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
self.loop_thread = threading.Thread(target=self.start_loop, daemon=True)
|
||||||
|
self.loop_thread.start()
|
||||||
|
|
||||||
|
# A fallback way to terminate MCP tool processes after Qwen-Agent exits
|
||||||
|
self.processes = []
|
||||||
|
self.monkey_patch_mcp_create_platform_compatible_process()
|
||||||
|
|
||||||
|
def monkey_patch_mcp_create_platform_compatible_process(self):
|
||||||
|
try:
|
||||||
|
import mcp.client.stdio
|
||||||
|
target = mcp.client.stdio._create_platform_compatible_process
|
||||||
|
except (ModuleNotFoundError, AttributeError) as e:
|
||||||
|
raise ImportError('Qwen-Agent needs to monkey patch MCP for process cleanup. '
|
||||||
|
'Please upgrade MCP to a higher version with `pip install -U mcp`.') from e
|
||||||
|
|
||||||
|
async def _monkey_patched_create_platform_compatible_process(*args, **kwargs):
|
||||||
|
process = await target(*args, **kwargs)
|
||||||
|
self.processes.append(process)
|
||||||
|
return process
|
||||||
|
|
||||||
|
mcp.client.stdio._create_platform_compatible_process = _monkey_patched_create_platform_compatible_process
|
||||||
|
|
||||||
|
def start_loop(self):
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
|
||||||
|
# Set a global exception handler to silently handle cross-task exceptions from MCP SSE connections
|
||||||
|
def exception_handler(loop, context):
|
||||||
|
exception = context.get('exception')
|
||||||
|
if exception:
|
||||||
|
# Silently handle cross-task exceptions from MCP SSE connections
|
||||||
|
if (isinstance(exception, RuntimeError) and
|
||||||
|
'Attempted to exit cancel scope in a different task' in str(exception)):
|
||||||
|
return # Silently ignore this type of exception
|
||||||
|
if (isinstance(exception, BaseExceptionGroup) and # noqa
|
||||||
|
'Attempted to exit cancel scope in a different task' in str(exception)): # noqa
|
||||||
|
return # Silently ignore this type of exception
|
||||||
|
|
||||||
|
# Other exceptions are handled normally
|
||||||
|
loop.default_exception_handler(context)
|
||||||
|
|
||||||
|
self.loop.set_exception_handler(exception_handler)
|
||||||
|
self.loop.run_forever()
|
||||||
|
|
||||||
|
def is_valid_mcp_servers(self, config: dict):
|
||||||
|
"""Example of mcp servers configuration:
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"memory": {
|
||||||
|
"command": "npx",
|
||||||
|
"args": ["-y", "@modelcontextprotocol/server-memory"]
|
||||||
|
},
|
||||||
|
"filesystem": {
|
||||||
|
"command": "npx",
|
||||||
|
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/allowed/files"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Check if the top-level key "mcpServers" exists and its value is a dictionary
|
||||||
|
if not isinstance(config, dict) or 'mcpServers' not in config or not isinstance(config['mcpServers'], dict):
|
||||||
|
return False
|
||||||
|
mcp_servers = config['mcpServers']
|
||||||
|
# Check each sub-item under "mcpServers"
|
||||||
|
for key in mcp_servers:
|
||||||
|
server = mcp_servers[key]
|
||||||
|
# Each sub-item must be a dictionary
|
||||||
|
if not isinstance(server, dict):
|
||||||
|
return False
|
||||||
|
if 'command' in server:
|
||||||
|
# "command" must be a string
|
||||||
|
if not isinstance(server['command'], str):
|
||||||
|
return False
|
||||||
|
# "args" must be a list
|
||||||
|
if 'args' not in server or not isinstance(server['args'], list):
|
||||||
|
return False
|
||||||
|
if 'url' in server:
|
||||||
|
# "url" must be a string
|
||||||
|
if not isinstance(server['url'], str):
|
||||||
|
return False
|
||||||
|
# "headers" must be a dictionary
|
||||||
|
if 'headers' in server and not isinstance(server['headers'], dict):
|
||||||
|
return False
|
||||||
|
# If the "env" key exists, it must be a dictionary
|
||||||
|
if 'env' in server and not isinstance(server['env'], dict):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def initConfig(self, config: Dict):
|
||||||
|
if not self.is_valid_mcp_servers(config):
|
||||||
|
raise ValueError('Config of mcpservers is not valid')
|
||||||
|
logger.info(f'Initializing MCP tools from mcp servers: {list(config["mcpServers"].keys())}')
|
||||||
|
# Submit coroutine to the event loop and wait for the result
|
||||||
|
future = asyncio.run_coroutine_threadsafe(self.init_config_async(config), self.loop)
|
||||||
|
try:
|
||||||
|
result = future.result() # You can specify a timeout if desired
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f'Failed in initializing MCP tools: {e}')
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def init_config_async(self, config: Dict):
|
||||||
|
tools: list = []
|
||||||
|
mcp_servers = config['mcpServers']
|
||||||
|
for server_name in mcp_servers:
|
||||||
|
client = CustomMCPClient()
|
||||||
|
server = mcp_servers[server_name]
|
||||||
|
await client.connection_server(mcp_server_name=server_name,
|
||||||
|
mcp_server=server) # Attempt to connect to the server
|
||||||
|
|
||||||
|
client_id = server_name + '_' + str(
|
||||||
|
uuid.uuid4()) # To allow the same server name be used across different running agents
|
||||||
|
client.client_id = client_id # Ensure client_id is set on the client instance
|
||||||
|
self.clients[client_id] = client # Add to clients dict after successful connection
|
||||||
|
for tool in client.tools:
|
||||||
|
"""MCP tool example:
|
||||||
|
{
|
||||||
|
"name": "read_query",
|
||||||
|
"description": "Execute a SELECT query on the SQLite database",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "SELECT SQL query to execute"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
parameters = tool.inputSchema
|
||||||
|
# The required field in inputSchema may be empty and needs to be initialized.
|
||||||
|
if 'required' not in parameters:
|
||||||
|
parameters['required'] = []
|
||||||
|
# Remove keys from parameters that do not conform to the standard OpenAI schema
|
||||||
|
# Check if the required fields exist
|
||||||
|
required_fields = {'type', 'properties', 'required'}
|
||||||
|
missing_fields = required_fields - parameters.keys()
|
||||||
|
if missing_fields:
|
||||||
|
raise ValueError(f'Missing required fields in schema: {missing_fields}')
|
||||||
|
|
||||||
|
# Keep only the necessary fields
|
||||||
|
cleaned_parameters = {
|
||||||
|
'type': parameters['type'],
|
||||||
|
'properties': parameters['properties'],
|
||||||
|
'required': parameters['required']
|
||||||
|
}
|
||||||
|
register_name = server_name + '-' + tool.name
|
||||||
|
agent_tool = self.create_tool_class(register_name=register_name,
|
||||||
|
register_client_id=client_id,
|
||||||
|
tool_name=tool.name,
|
||||||
|
tool_desc=tool.description,
|
||||||
|
tool_parameters=cleaned_parameters)
|
||||||
|
tools.append(agent_tool)
|
||||||
|
|
||||||
|
if client.resources:
|
||||||
|
"""MCP resource example:
|
||||||
|
{
|
||||||
|
uri: string; // Unique identifier for the resource
|
||||||
|
name: string; // Human-readable name
|
||||||
|
description?: string; // Optional description
|
||||||
|
mimeType?: string; // Optional MIME type
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# List resources
|
||||||
|
list_resources_tool_name = server_name + '-' + 'list_resources'
|
||||||
|
list_resources_params = {'type': 'object', 'properties': {}, 'required': []}
|
||||||
|
list_resources_agent_tool = self.create_tool_class(
|
||||||
|
register_name=list_resources_tool_name,
|
||||||
|
register_client_id=client_id,
|
||||||
|
tool_name='list_resources',
|
||||||
|
tool_desc='Servers expose a list of concrete resources through this tool. '
|
||||||
|
'By invoking it, you can discover the available resources and obtain resource templates, which help clients understand how to construct valid URIs. '
|
||||||
|
'These URI formats will be used as input parameters for the read_resource function. ',
|
||||||
|
tool_parameters=list_resources_params)
|
||||||
|
tools.append(list_resources_agent_tool)
|
||||||
|
|
||||||
|
# Read resource
|
||||||
|
resources_template_str = '' # Check if there are resource templates
|
||||||
|
try:
|
||||||
|
list_resource_templates = await client.session.list_resource_templates(
|
||||||
|
) # Check if the server has resources tesmplate
|
||||||
|
if list_resource_templates.resourceTemplates:
|
||||||
|
resources_template_str = '\n'.join(
|
||||||
|
str(template) for template in list_resource_templates.resourceTemplates)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f'Failed in listing MCP resource templates: {e}')
|
||||||
|
|
||||||
|
read_resource_tool_name = server_name + '-' + 'read_resource'
|
||||||
|
read_resource_params = {
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {
|
||||||
|
'uri': {
|
||||||
|
'type': 'string',
|
||||||
|
'description': 'The URI identifying the specific resource to access'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'required': ['uri']
|
||||||
|
}
|
||||||
|
original_tool_desc = 'Request to access a resource provided by a connected MCP server. Resources represent data sources that can be used as context, such as files, API responses, or system information.'
|
||||||
|
if resources_template_str:
|
||||||
|
tool_desc = original_tool_desc + '\nResource Templates:\n' + resources_template_str
|
||||||
|
else:
|
||||||
|
tool_desc = original_tool_desc
|
||||||
|
read_resource_agent_tool = self.create_tool_class(register_name=read_resource_tool_name,
|
||||||
|
register_client_id=client_id,
|
||||||
|
tool_name='read_resource',
|
||||||
|
tool_desc=tool_desc,
|
||||||
|
tool_parameters=read_resource_params)
|
||||||
|
tools.append(read_resource_agent_tool)
|
||||||
|
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def create_tool_class(self, register_name, register_client_id, tool_name, tool_desc, tool_parameters):
|
||||||
|
|
||||||
|
class ToolClass(BaseTool):
|
||||||
|
name = register_name
|
||||||
|
description = tool_desc
|
||||||
|
parameters = tool_parameters
|
||||||
|
client_id = register_client_id
|
||||||
|
|
||||||
|
def call(self, params: Union[str, dict], **kwargs) -> str:
|
||||||
|
tool_args = json.loads(params)
|
||||||
|
# Submit coroutine to the event loop and wait for the result
|
||||||
|
manager = CustomMCPManager()
|
||||||
|
client = manager.clients[self.client_id]
|
||||||
|
future = asyncio.run_coroutine_threadsafe(client.execute_function(tool_name, tool_args), manager.loop)
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f'Failed in executing MCP tool: {e}')
|
||||||
|
raise e
|
||||||
|
|
||||||
|
ToolClass.__name__ = f'{register_name}_Class'
|
||||||
|
return ToolClass()
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
futures = []
|
||||||
|
for client_id in list(self.clients.keys()):
|
||||||
|
client: CustomMCPClient = self.clients[client_id]
|
||||||
|
future = asyncio.run_coroutine_threadsafe(client.cleanup(), self.loop)
|
||||||
|
futures.append(future)
|
||||||
|
del self.clients[client_id]
|
||||||
|
time.sleep(1) # Wait for the graceful cleanups, otherwise fall back
|
||||||
|
|
||||||
|
# fallback
|
||||||
|
if asyncio.all_tasks(self.loop):
|
||||||
|
logger.info(
|
||||||
|
'There are still tasks in `CustomMCPManager().loop`, force terminating the MCP tool processes. There may be some exceptions.'
|
||||||
|
)
|
||||||
|
for process in self.processes:
|
||||||
|
try:
|
||||||
|
process.terminate()
|
||||||
|
except ProcessLookupError:
|
||||||
|
pass # it's ok, the process may exit earlier
|
||||||
|
|
||||||
|
self.loop.call_soon_threadsafe(self.loop.stop)
|
||||||
|
self.loop_thread.join()
|
||||||
|
|
||||||
|
|
||||||
|
class CustomMCPClient:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
from mcp import ClientSession
|
||||||
|
self.session: Optional[ClientSession] = None
|
||||||
|
self.tools: list = None
|
||||||
|
self.exit_stack = AsyncExitStack()
|
||||||
|
self.resources: bool = False
|
||||||
|
self._last_mcp_server_name = None
|
||||||
|
self._last_mcp_server = None
|
||||||
|
self.client_id = None # For replacing in MCPManager.clients
|
||||||
|
|
||||||
|
async def connection_server(self, mcp_server_name, mcp_server):
|
||||||
|
from mcp import ClientSession, StdioServerParameters
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
from mcp.client.stdio import stdio_client
|
||||||
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
|
"""Connect to an MCP server and retrieve the available tools."""
|
||||||
|
# Save parameters
|
||||||
|
self._last_mcp_server_name = mcp_server_name
|
||||||
|
self._last_mcp_server = mcp_server
|
||||||
|
|
||||||
|
try:
|
||||||
|
if 'url' in mcp_server:
|
||||||
|
url = mcp_server.get('url')
|
||||||
|
sse_read_timeout = mcp_server.get('sse_read_timeout', 300)
|
||||||
|
logger.info(f'{mcp_server_name} sse_read_timeout: {sse_read_timeout}s')
|
||||||
|
if mcp_server.get('type', 'sse') == 'streamable-http':
|
||||||
|
# streamable-http mode
|
||||||
|
"""streamable-http mode mcp example:
|
||||||
|
{"mcpServers": {
|
||||||
|
"streamable-mcp-server": {
|
||||||
|
"type": "streamable-http",
|
||||||
|
"url":"http://0.0.0.0:8000/mcp"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
headers = mcp_server.get('headers', {})
|
||||||
|
self._streams_context = streamablehttp_client(
|
||||||
|
url=url, headers=headers, sse_read_timeout=datetime.timedelta(seconds=sse_read_timeout))
|
||||||
|
read_stream, write_stream, get_session_id = await self.exit_stack.enter_async_context(
|
||||||
|
self._streams_context)
|
||||||
|
self._session_context = ClientSession(read_stream, write_stream)
|
||||||
|
self.session = await self.exit_stack.enter_async_context(self._session_context)
|
||||||
|
else:
|
||||||
|
# sse mode
|
||||||
|
headers = mcp_server.get('headers', {'Accept': 'text/event-stream'})
|
||||||
|
self._streams_context = sse_client(url, headers, sse_read_timeout=sse_read_timeout)
|
||||||
|
streams = await self.exit_stack.enter_async_context(self._streams_context)
|
||||||
|
self._session_context = ClientSession(*streams)
|
||||||
|
self.session = await self.exit_stack.enter_async_context(self._session_context)
|
||||||
|
else:
|
||||||
|
server_params = StdioServerParameters(command=mcp_server['command'],
|
||||||
|
args=mcp_server['args'],
|
||||||
|
env=mcp_server.get('env', None))
|
||||||
|
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||||||
|
self.stdio, self.write = stdio_transport
|
||||||
|
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
|
||||||
|
logger.info(
|
||||||
|
f'Initializing a MCP stdio_client, if this takes forever, please check the config of this mcp server: {mcp_server_name}'
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.session.initialize()
|
||||||
|
list_tools = await self.session.list_tools()
|
||||||
|
self.tools = list_tools.tools
|
||||||
|
try:
|
||||||
|
list_resources = await self.session.list_resources() # Check if the server has resources
|
||||||
|
if list_resources.resources:
|
||||||
|
self.resources = True
|
||||||
|
except Exception:
|
||||||
|
# logger.info(f"No list resources: {e}")
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f'Failed in connecting to MCP server: {e}')
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def reconnect(self):
|
||||||
|
# Create a new MCPClient and connect
|
||||||
|
if self.client_id is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
'Cannot reconnect: client_id is None. This usually means the client was not properly registered in MCPManager.'
|
||||||
|
)
|
||||||
|
new_client = CustomMCPClient()
|
||||||
|
new_client.client_id = self.client_id
|
||||||
|
await new_client.connection_server(self._last_mcp_server_name, self._last_mcp_server)
|
||||||
|
return new_client
|
||||||
|
|
||||||
|
async def execute_function(self, tool_name, tool_args: dict):
|
||||||
|
from mcp.types import TextResourceContents
|
||||||
|
|
||||||
|
# Check if session is alive
|
||||||
|
try:
|
||||||
|
await self.session.send_ping()
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"Session is not alive, please increase 'sse_read_timeout' in the config, try reconnect: {e}")
|
||||||
|
# Auto reconnect
|
||||||
|
try:
|
||||||
|
manager = CustomMCPManager()
|
||||||
|
if self.client_id is not None:
|
||||||
|
manager.clients[self.client_id] = await self.reconnect()
|
||||||
|
return await manager.clients[self.client_id].execute_function(tool_name, tool_args)
|
||||||
|
else:
|
||||||
|
logger.info('Reconnect failed: client_id is None')
|
||||||
|
return 'Session reconnect (client creation) exception: client_id is None'
|
||||||
|
except Exception as e3:
|
||||||
|
logger.info(f'Reconnect (client creation) exception type: {type(e3)}, value: {repr(e3)}')
|
||||||
|
return f'Session reconnect (client creation) exception: {e3}'
|
||||||
|
if tool_name == 'list_resources':
|
||||||
|
try:
|
||||||
|
list_resources = await self.session.list_resources()
|
||||||
|
if list_resources.resources:
|
||||||
|
resources_str = '\n\n'.join(str(resource) for resource in list_resources.resources)
|
||||||
|
else:
|
||||||
|
resources_str = 'No resources found'
|
||||||
|
return resources_str
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f'No list resources: {e}')
|
||||||
|
return f'Error: {e}'
|
||||||
|
elif tool_name == 'read_resource':
|
||||||
|
try:
|
||||||
|
uri = tool_args.get('uri')
|
||||||
|
if not uri:
|
||||||
|
raise ValueError('URI is required for read_resource')
|
||||||
|
read_resource = await self.session.read_resource(uri)
|
||||||
|
texts = []
|
||||||
|
for resource in read_resource.contents:
|
||||||
|
if isinstance(resource, TextResourceContents):
|
||||||
|
texts.append(resource.text)
|
||||||
|
# if isinstance(resource, BlobResourceContents):
|
||||||
|
# texts.append(resource.blob)
|
||||||
|
if texts:
|
||||||
|
return '\n\n'.join(texts)
|
||||||
|
else:
|
||||||
|
return 'Failed to read resource'
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f'Failed to read resource: {e}')
|
||||||
|
return f'Error: {e}'
|
||||||
|
else:
|
||||||
|
response = await self.session.call_tool(tool_name, tool_args)
|
||||||
|
texts = []
|
||||||
|
for content in response.content:
|
||||||
|
if content.type == 'text':
|
||||||
|
texts.append(content.text)
|
||||||
|
if texts:
|
||||||
|
return '\n\n'.join(texts)
|
||||||
|
else:
|
||||||
|
return 'execute error'
|
||||||
|
|
||||||
|
async def cleanup(self):
|
||||||
|
await self.exit_stack.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_mcp(_sig_num=None, _frame=None):
|
||||||
|
if CustomMCPManager._instance is None:
|
||||||
|
return
|
||||||
|
manager = CustomMCPManager()
|
||||||
|
manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
# Make sure all subprocesses are terminated even if killed abnormally
|
||||||
|
if threading.current_thread() is threading.main_thread():
|
||||||
|
atexit.register(_cleanup_mcp)
|
||||||
@ -23,8 +23,8 @@ from typing import Dict, List, Optional
|
|||||||
from qwen_agent.agents import Assistant
|
from qwen_agent.agents import Assistant
|
||||||
from qwen_agent.log import logger
|
from qwen_agent.log import logger
|
||||||
|
|
||||||
from modified_assistant import init_modified_agent_service_with_files, update_agent_llm
|
from agent.modified_assistant import init_modified_agent_service_with_files, update_agent_llm
|
||||||
from .prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||||||
|
|
||||||
|
|
||||||
class FileLoadedAgentManager:
|
class FileLoadedAgentManager:
|
||||||
@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Iterator, List, Literal, Optional, Union
|
from typing import Dict, Iterator, List, Literal, Optional, Union
|
||||||
@ -21,7 +22,11 @@ from typing import Dict, Iterator, List, Literal, Optional, Union
|
|||||||
from qwen_agent.agents import Assistant
|
from qwen_agent.agents import Assistant
|
||||||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION, Message
|
from qwen_agent.llm.schema import ASSISTANT, FUNCTION, Message
|
||||||
from qwen_agent.llm.oai import TextChatAtOAI
|
from qwen_agent.llm.oai import TextChatAtOAI
|
||||||
from utils.logger import tool_logger
|
from qwen_agent.tools import BaseTool
|
||||||
|
from agent.custom_mcp_manager import CustomMCPManager
|
||||||
|
|
||||||
|
# 设置工具日志记录器
|
||||||
|
tool_logger = logging.getLogger('tool_logger')
|
||||||
|
|
||||||
class ModifiedAssistant(Assistant):
|
class ModifiedAssistant(Assistant):
|
||||||
"""
|
"""
|
||||||
@ -94,6 +99,26 @@ class ModifiedAssistant(Assistant):
|
|||||||
error_message = f'An error occurred when calling tool {tool_name}: {type(ex).__name__}: {str(ex)}'
|
error_message = f'An error occurred when calling tool {tool_name}: {type(ex).__name__}: {str(ex)}'
|
||||||
return error_message
|
return error_message
|
||||||
|
|
||||||
|
def _init_tool(self, tool: Union[str, Dict, BaseTool]):
|
||||||
|
"""重写工具初始化方法,使用CustomMCPManager处理MCP服务器配置"""
|
||||||
|
if isinstance(tool, BaseTool):
|
||||||
|
# 处理BaseTool实例
|
||||||
|
tool_name = tool.name
|
||||||
|
if tool_name in self.function_map:
|
||||||
|
tool_logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
|
||||||
|
self.function_map[tool_name] = tool
|
||||||
|
elif isinstance(tool, dict) and 'mcpServers' in tool:
|
||||||
|
# 使用CustomMCPManager处理MCP服务器配置,支持headers
|
||||||
|
tools = CustomMCPManager().initConfig(tool)
|
||||||
|
for tool in tools:
|
||||||
|
tool_name = tool.name
|
||||||
|
if tool_name in self.function_map:
|
||||||
|
tool_logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
|
||||||
|
self.function_map[tool_name] = tool
|
||||||
|
else:
|
||||||
|
# 调用父类的处理方法
|
||||||
|
super()._init_tool(tool)
|
||||||
|
|
||||||
def _call_llm_with_retry(self, messages: List[Message], functions=None, extra_generate_cfg=None, max_retries: int = 5) -> Iterator:
|
def _call_llm_with_retry(self, messages: List[Message], functions=None, extra_generate_cfg=None, max_retries: int = 5) -> Iterator:
|
||||||
"""带重试机制的LLM调用
|
"""带重试机制的LLM调用
|
||||||
|
|
||||||
@ -109,7 +109,7 @@ async def load_system_prompt_async(project_dir: str, language: str = None, syste
|
|||||||
Returns:
|
Returns:
|
||||||
str: 加载到的系统提示词内容
|
str: 加载到的系统提示词内容
|
||||||
"""
|
"""
|
||||||
from .config_cache import config_cache
|
from agent.config_cache import config_cache
|
||||||
|
|
||||||
# 获取语言显示名称
|
# 获取语言显示名称
|
||||||
language_display_map = {
|
language_display_map = {
|
||||||
@ -235,7 +235,7 @@ async def load_mcp_settings_async(project_dir: str, mcp_settings: list=None, bot
|
|||||||
支持在 mcp_settings.json 的 args 中使用 {dataset_dir} 占位符,
|
支持在 mcp_settings.json 的 args 中使用 {dataset_dir} 占位符,
|
||||||
会在 init_modified_agent_service_with_files 中被替换为实际的路径。
|
会在 init_modified_agent_service_with_files 中被替换为实际的路径。
|
||||||
"""
|
"""
|
||||||
from .config_cache import config_cache
|
from agent.config_cache import config_cache
|
||||||
|
|
||||||
# 1. 首先读取默认MCP设置
|
# 1. 首先读取默认MCP设置
|
||||||
default_mcp_settings = []
|
default_mcp_settings = []
|
||||||
@ -26,8 +26,8 @@ from collections import defaultdict
|
|||||||
from qwen_agent.agents import Assistant
|
from qwen_agent.agents import Assistant
|
||||||
from qwen_agent.log import logger
|
from qwen_agent.log import logger
|
||||||
|
|
||||||
from modified_assistant import init_modified_agent_service_with_files, update_agent_llm
|
from agent.modified_assistant import init_modified_agent_service_with_files, update_agent_llm
|
||||||
from .prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||||||
|
|
||||||
|
|
||||||
class ShardedAgentManager:
|
class ShardedAgentManager:
|
||||||
@ -1,691 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
文件管理API - WebDAV的HTTP API替代方案
|
|
||||||
提供RESTful接口来管理projects文件夹
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional, Dict, Any
|
|
||||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
|
|
||||||
from fastapi.responses import FileResponse, StreamingResponse
|
|
||||||
import mimetypes
|
|
||||||
import json
|
|
||||||
import zipfile
|
|
||||||
import tempfile
|
|
||||||
import io
|
|
||||||
import math
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/files", tags=["file_management"])
|
|
||||||
|
|
||||||
PROJECTS_DIR = Path("projects")
|
|
||||||
PROJECTS_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/list")
|
|
||||||
async def list_files(path: str = "", recursive: bool = False):
|
|
||||||
"""
|
|
||||||
列出目录内容
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: 相对路径,空字符串表示根目录
|
|
||||||
recursive: 是否递归列出所有子目录
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
target_path = PROJECTS_DIR / path
|
|
||||||
|
|
||||||
if not target_path.exists():
|
|
||||||
raise HTTPException(status_code=404, detail="路径不存在")
|
|
||||||
|
|
||||||
if not target_path.is_dir():
|
|
||||||
raise HTTPException(status_code=400, detail="路径不是目录")
|
|
||||||
|
|
||||||
def scan_directory(directory: Path, base_path: Path = PROJECTS_DIR) -> List[Dict[str, Any]]:
|
|
||||||
items = []
|
|
||||||
try:
|
|
||||||
for item in directory.iterdir():
|
|
||||||
# 跳过隐藏文件
|
|
||||||
if item.name.startswith('.'):
|
|
||||||
continue
|
|
||||||
|
|
||||||
relative_path = item.relative_to(base_path)
|
|
||||||
stat = item.stat()
|
|
||||||
|
|
||||||
item_info = {
|
|
||||||
"name": item.name,
|
|
||||||
"path": str(relative_path),
|
|
||||||
"type": "directory" if item.is_dir() else "file",
|
|
||||||
"size": stat.st_size if item.is_file() else 0,
|
|
||||||
"modified": stat.st_mtime,
|
|
||||||
"created": stat.st_ctime
|
|
||||||
}
|
|
||||||
|
|
||||||
items.append(item_info)
|
|
||||||
|
|
||||||
# 递归扫描子目录
|
|
||||||
if recursive and item.is_dir():
|
|
||||||
items.extend(scan_directory(item, base_path))
|
|
||||||
|
|
||||||
except PermissionError:
|
|
||||||
pass
|
|
||||||
return items
|
|
||||||
|
|
||||||
items = scan_directory(target_path)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"path": path,
|
|
||||||
"items": items,
|
|
||||||
"total": len(items)
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"列出目录失败: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/upload")
|
|
||||||
async def upload_file(file: UploadFile = File(...), path: str = Form("")):
|
|
||||||
"""
|
|
||||||
上传文件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file: 上传的文件
|
|
||||||
path: 目标路径(相对于projects目录)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
target_dir = PROJECTS_DIR / path
|
|
||||||
target_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
file_path = target_dir / file.filename
|
|
||||||
|
|
||||||
# 如果文件已存在,检查是否覆盖
|
|
||||||
if file_path.exists():
|
|
||||||
# 可以添加版本控制或重命名逻辑
|
|
||||||
pass
|
|
||||||
|
|
||||||
with open(file_path, "wb") as buffer:
|
|
||||||
content = await file.read()
|
|
||||||
buffer.write(content)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message": "文件上传成功",
|
|
||||||
"filename": file.filename,
|
|
||||||
"path": str(Path(path) / file.filename),
|
|
||||||
"size": len(content)
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"文件上传失败: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/download/{file_path:path}")
|
|
||||||
async def download_file(file_path: str):
|
|
||||||
"""
|
|
||||||
下载文件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: 文件相对路径
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
target_path = PROJECTS_DIR / file_path
|
|
||||||
|
|
||||||
if not target_path.exists():
|
|
||||||
raise HTTPException(status_code=404, detail="文件不存在")
|
|
||||||
|
|
||||||
if not target_path.is_file():
|
|
||||||
raise HTTPException(status_code=400, detail="不是文件")
|
|
||||||
|
|
||||||
# 猜测MIME类型
|
|
||||||
mime_type, _ = mimetypes.guess_type(str(target_path))
|
|
||||||
if mime_type is None:
|
|
||||||
mime_type = "application/octet-stream"
|
|
||||||
|
|
||||||
return FileResponse(
|
|
||||||
path=str(target_path),
|
|
||||||
filename=target_path.name,
|
|
||||||
media_type=mime_type
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"文件下载失败: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/delete")
|
|
||||||
async def delete_item(path: str):
|
|
||||||
"""
|
|
||||||
删除文件或目录
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: 要删除的路径
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
target_path = PROJECTS_DIR / path
|
|
||||||
|
|
||||||
if not target_path.exists():
|
|
||||||
raise HTTPException(status_code=404, detail="路径不存在")
|
|
||||||
|
|
||||||
if target_path.is_file():
|
|
||||||
target_path.unlink()
|
|
||||||
elif target_path.is_dir():
|
|
||||||
shutil.rmtree(target_path)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message": f"{'文件' if target_path.is_file() else '目录'}删除成功",
|
|
||||||
"path": path
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/create-folder")
|
|
||||||
async def create_folder(path: str, name: str):
|
|
||||||
"""
|
|
||||||
创建文件夹
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: 父目录路径
|
|
||||||
name: 新文件夹名称
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
parent_path = PROJECTS_DIR / path
|
|
||||||
parent_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
new_folder = parent_path / name
|
|
||||||
|
|
||||||
if new_folder.exists():
|
|
||||||
raise HTTPException(status_code=400, detail="文件夹已存在")
|
|
||||||
|
|
||||||
new_folder.mkdir()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message": "文件夹创建成功",
|
|
||||||
"path": str(Path(path) / name)
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"创建文件夹失败: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/rename")
|
|
||||||
async def rename_item(old_path: str, new_name: str):
|
|
||||||
"""
|
|
||||||
重命名文件或文件夹
|
|
||||||
|
|
||||||
Args:
|
|
||||||
old_path: 原路径
|
|
||||||
new_name: 新名称
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
old_full_path = PROJECTS_DIR / old_path
|
|
||||||
|
|
||||||
if not old_full_path.exists():
|
|
||||||
raise HTTPException(status_code=404, detail="文件或目录不存在")
|
|
||||||
|
|
||||||
new_full_path = old_full_path.parent / new_name
|
|
||||||
|
|
||||||
if new_full_path.exists():
|
|
||||||
raise HTTPException(status_code=400, detail="目标名称已存在")
|
|
||||||
|
|
||||||
old_full_path.rename(new_full_path)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message": "重命名成功",
|
|
||||||
"old_path": old_path,
|
|
||||||
"new_path": str(new_full_path.relative_to(PROJECTS_DIR))
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"重命名失败: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/move")
|
|
||||||
async def move_item(source_path: str, target_path: str):
|
|
||||||
"""
|
|
||||||
移动文件或文件夹
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_path: 源路径
|
|
||||||
target_path: 目标路径
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
source_full_path = PROJECTS_DIR / source_path
|
|
||||||
target_full_path = PROJECTS_DIR / target_path
|
|
||||||
|
|
||||||
if not source_full_path.exists():
|
|
||||||
raise HTTPException(status_code=404, detail="源文件或目录不存在")
|
|
||||||
|
|
||||||
# 确保目标目录存在
|
|
||||||
target_full_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
shutil.move(str(source_full_path), str(target_full_path))
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message": "移动成功",
|
|
||||||
"source_path": source_path,
|
|
||||||
"target_path": target_path
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"移动失败: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/copy")
|
|
||||||
async def copy_item(source_path: str, target_path: str):
|
|
||||||
"""
|
|
||||||
复制文件或文件夹
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_path: 源路径
|
|
||||||
target_path: 目标路径
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
source_full_path = PROJECTS_DIR / source_path
|
|
||||||
target_full_path = PROJECTS_DIR / target_path
|
|
||||||
|
|
||||||
if not source_full_path.exists():
|
|
||||||
raise HTTPException(status_code=404, detail="源文件或目录不存在")
|
|
||||||
|
|
||||||
# 确保目标目录存在
|
|
||||||
target_full_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
if source_full_path.is_file():
|
|
||||||
shutil.copy2(str(source_full_path), str(target_full_path))
|
|
||||||
else:
|
|
||||||
shutil.copytree(str(source_full_path), str(target_full_path))
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message": "复制成功",
|
|
||||||
"source_path": source_path,
|
|
||||||
"target_path": target_path
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"复制失败: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/search")
|
|
||||||
async def search_files(query: str, path: str = "", file_type: Optional[str] = None):
|
|
||||||
"""
|
|
||||||
搜索文件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 搜索关键词
|
|
||||||
path: 搜索路径
|
|
||||||
file_type: 文件类型过滤
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
search_path = PROJECTS_DIR / path
|
|
||||||
|
|
||||||
if not search_path.exists():
|
|
||||||
raise HTTPException(status_code=404, detail="搜索路径不存在")
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
def scan_for_files(directory: Path):
|
|
||||||
try:
|
|
||||||
for item in directory.iterdir():
|
|
||||||
if item.name.startswith('.'):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查文件名是否包含关键词
|
|
||||||
if query.lower() in item.name.lower():
|
|
||||||
# 检查文件类型过滤
|
|
||||||
if file_type:
|
|
||||||
if item.suffix.lower() == file_type.lower():
|
|
||||||
results.append({
|
|
||||||
"name": item.name,
|
|
||||||
"path": str(item.relative_to(PROJECTS_DIR)),
|
|
||||||
"type": "directory" if item.is_dir() else "file"
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
results.append({
|
|
||||||
"name": item.name,
|
|
||||||
"path": str(item.relative_to(PROJECTS_DIR)),
|
|
||||||
"type": "directory" if item.is_dir() else "file"
|
|
||||||
})
|
|
||||||
|
|
||||||
# 递归搜索子目录
|
|
||||||
if item.is_dir():
|
|
||||||
scan_for_files(item)
|
|
||||||
|
|
||||||
except PermissionError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
scan_for_files(search_path)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"query": query,
|
|
||||||
"path": path,
|
|
||||||
"results": results,
|
|
||||||
"total": len(results)
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"搜索失败: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/info/{file_path:path}")
|
|
||||||
async def get_file_info(file_path: str):
|
|
||||||
"""
|
|
||||||
获取文件或文件夹详细信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: 文件路径
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
target_path = PROJECTS_DIR / file_path
|
|
||||||
|
|
||||||
if not target_path.exists():
|
|
||||||
raise HTTPException(status_code=404, detail="路径不存在")
|
|
||||||
|
|
||||||
stat = target_path.stat()
|
|
||||||
|
|
||||||
info = {
|
|
||||||
"name": target_path.name,
|
|
||||||
"path": file_path,
|
|
||||||
"type": "directory" if target_path.is_dir() else "file",
|
|
||||||
"size": stat.st_size if target_path.is_file() else 0,
|
|
||||||
"modified": stat.st_mtime,
|
|
||||||
"created": stat.st_ctime,
|
|
||||||
"permissions": oct(stat.st_mode)[-3:]
|
|
||||||
}
|
|
||||||
|
|
||||||
# 如果是文件,添加额外信息
|
|
||||||
if target_path.is_file():
|
|
||||||
mime_type, _ = mimetypes.guess_type(str(target_path))
|
|
||||||
info["mime_type"] = mime_type or "unknown"
|
|
||||||
|
|
||||||
# 读取文件内容预览(仅对小文件)
|
|
||||||
if stat.st_size < 1024 * 1024: # 小于1MB
|
|
||||||
try:
|
|
||||||
with open(target_path, 'r', encoding='utf-8') as f:
|
|
||||||
content = f.read(1000) # 读取前1000字符
|
|
||||||
info["preview"] = content
|
|
||||||
except:
|
|
||||||
info["preview"] = "[二进制文件,无法预览]"
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"info": info
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"获取文件信息失败: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/download-folder-zip")
|
|
||||||
async def download_folder_as_zip(request: Dict[str, str]):
|
|
||||||
"""
|
|
||||||
将文件夹压缩为ZIP并下载
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: 包含path字段的JSON对象
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
folder_path = request.get("path", "")
|
|
||||||
|
|
||||||
if not folder_path:
|
|
||||||
raise HTTPException(status_code=400, detail="路径不能为空")
|
|
||||||
|
|
||||||
target_path = PROJECTS_DIR / folder_path
|
|
||||||
|
|
||||||
if not target_path.exists():
|
|
||||||
raise HTTPException(status_code=404, detail="文件夹不存在")
|
|
||||||
|
|
||||||
if not target_path.is_dir():
|
|
||||||
raise HTTPException(status_code=400, detail="路径不是文件夹")
|
|
||||||
|
|
||||||
# 计算文件夹大小,检查是否过大
|
|
||||||
total_size = 0
|
|
||||||
file_count = 0
|
|
||||||
for file_path in target_path.rglob('*'):
|
|
||||||
if file_path.is_file():
|
|
||||||
total_size += file_path.stat().st_size
|
|
||||||
file_count += 1
|
|
||||||
|
|
||||||
# 限制最大500MB
|
|
||||||
max_size = 500 * 1024 * 1024
|
|
||||||
if total_size > max_size:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=413,
|
|
||||||
detail=f"文件夹过大 ({formatFileSize(total_size)}),最大支持 {formatFileSize(max_size)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 限制文件数量
|
|
||||||
max_files = 10000
|
|
||||||
if file_count > max_files:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=413,
|
|
||||||
detail=f"文件数量过多 ({file_count}),最大支持 {max_files} 个文件"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建ZIP文件名
|
|
||||||
folder_name = target_path.name
|
|
||||||
zip_filename = f"{folder_name}.zip"
|
|
||||||
|
|
||||||
# 创建ZIP文件
|
|
||||||
zip_buffer = io.BytesIO()
|
|
||||||
|
|
||||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED, compresslevel=6) as zipf:
|
|
||||||
# 添加文件夹中的所有文件
|
|
||||||
for file_path in target_path.rglob('*'):
|
|
||||||
if file_path.is_file():
|
|
||||||
try:
|
|
||||||
# 计算相对路径,保持文件夹结构
|
|
||||||
arcname = file_path.relative_to(target_path)
|
|
||||||
zipf.write(file_path, arcname)
|
|
||||||
except (OSError, IOError) as e:
|
|
||||||
# 跳过无法读取的文件,但记录警告
|
|
||||||
print(f"Warning: Skipping file {file_path}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
zip_buffer.seek(0)
|
|
||||||
|
|
||||||
# 设置响应头
|
|
||||||
headers = {
|
|
||||||
'Content-Disposition': f'attachment; filename="{zip_filename}"',
|
|
||||||
'Content-Type': 'application/zip',
|
|
||||||
'Content-Length': str(len(zip_buffer.getvalue())),
|
|
||||||
'Cache-Control': 'no-cache'
|
|
||||||
}
|
|
||||||
|
|
||||||
# 创建流式响应
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
|
|
||||||
async def generate_zip():
|
|
||||||
zip_buffer.seek(0)
|
|
||||||
yield zip_buffer.getvalue()
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
generate_zip(),
|
|
||||||
media_type="application/zip",
|
|
||||||
headers=headers
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"压缩文件夹失败: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/download-multiple-zip")
|
|
||||||
async def download_multiple_items_as_zip(request: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
将多个文件和文件夹压缩为ZIP并下载
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: 包含paths和filename字段的JSON对象
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
paths = request.get("paths", [])
|
|
||||||
filename = request.get("filename", "batch_download.zip")
|
|
||||||
|
|
||||||
if not paths:
|
|
||||||
raise HTTPException(status_code=400, detail="请选择要下载的文件")
|
|
||||||
|
|
||||||
# 验证所有路径
|
|
||||||
valid_paths = []
|
|
||||||
total_size = 0
|
|
||||||
file_count = 0
|
|
||||||
|
|
||||||
for path in paths:
|
|
||||||
target_path = PROJECTS_DIR / path
|
|
||||||
|
|
||||||
if not target_path.exists():
|
|
||||||
continue # 跳过不存在的文件
|
|
||||||
|
|
||||||
if target_path.is_file():
|
|
||||||
total_size += target_path.stat().st_size
|
|
||||||
file_count += 1
|
|
||||||
valid_paths.append(path)
|
|
||||||
elif target_path.is_dir():
|
|
||||||
# 计算文件夹大小
|
|
||||||
for file_path in target_path.rglob('*'):
|
|
||||||
if file_path.is_file():
|
|
||||||
total_size += file_path.stat().st_size
|
|
||||||
file_count += 1
|
|
||||||
valid_paths.append(path)
|
|
||||||
|
|
||||||
if not valid_paths:
|
|
||||||
raise HTTPException(status_code=404, detail="没有找到有效的文件")
|
|
||||||
|
|
||||||
# 限制大小
|
|
||||||
max_size = 500 * 1024 * 1024
|
|
||||||
if total_size > max_size:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=413,
|
|
||||||
detail=f"选中文件过大 ({formatFileSize(total_size)}),最大支持 {formatFileSize(max_size)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 限制文件数量
|
|
||||||
max_files = 10000
|
|
||||||
if file_count > max_files:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=413,
|
|
||||||
detail=f"文件数量过多 ({file_count}),最大支持 {max_files} 个文件"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建ZIP文件
|
|
||||||
zip_buffer = io.BytesIO()
|
|
||||||
|
|
||||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED, compresslevel=6) as zipf:
|
|
||||||
for path in valid_paths:
|
|
||||||
target_path = PROJECTS_DIR / path
|
|
||||||
|
|
||||||
if target_path.is_file():
|
|
||||||
# 单个文件
|
|
||||||
try:
|
|
||||||
zipf.write(target_path, target_path.name)
|
|
||||||
except (OSError, IOError) as e:
|
|
||||||
print(f"Warning: Skipping file {target_path}: {e}")
|
|
||||||
continue
|
|
||||||
elif target_path.is_dir():
|
|
||||||
# 文件夹
|
|
||||||
for file_path in target_path.rglob('*'):
|
|
||||||
if file_path.is_file():
|
|
||||||
try:
|
|
||||||
# 保持相对路径结构
|
|
||||||
arcname = f"{target_path.name}/{file_path.relative_to(target_path)}"
|
|
||||||
zipf.write(file_path, arcname)
|
|
||||||
except (OSError, IOError) as e:
|
|
||||||
print(f"Warning: Skipping file {file_path}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
zip_buffer.seek(0)
|
|
||||||
|
|
||||||
# 设置响应头
|
|
||||||
headers = {
|
|
||||||
'Content-Disposition': f'attachment; filename="{filename}"',
|
|
||||||
'Content-Type': 'application/zip',
|
|
||||||
'Content-Length': str(len(zip_buffer.getvalue())),
|
|
||||||
'Cache-Control': 'no-cache'
|
|
||||||
}
|
|
||||||
|
|
||||||
# 创建流式响应
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
|
|
||||||
async def generate_zip():
|
|
||||||
zip_buffer.seek(0)
|
|
||||||
yield zip_buffer.getvalue()
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
generate_zip(),
|
|
||||||
media_type="application/zip",
|
|
||||||
headers=headers
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"批量压缩失败: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
def formatFileSize(bytes_size: int) -> str:
|
|
||||||
"""格式化文件大小"""
|
|
||||||
if bytes_size == 0:
|
|
||||||
return "0 B"
|
|
||||||
|
|
||||||
k = 1024
|
|
||||||
sizes = ["B", "KB", "MB", "GB", "TB"]
|
|
||||||
i = int(math.floor(math.log(bytes_size, k)))
|
|
||||||
|
|
||||||
if i >= len(sizes):
|
|
||||||
i = len(sizes) - 1
|
|
||||||
|
|
||||||
return f"{bytes_size / math.pow(k, i):.1f} {sizes[i]}"
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/batch-operation")
|
|
||||||
async def batch_operation(operations: List[Dict[str, Any]]):
|
|
||||||
"""
|
|
||||||
批量操作多个文件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operations: 操作列表,每个操作包含type和相应的参数
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for op in operations:
|
|
||||||
op_type = op.get("type")
|
|
||||||
op_result = {"type": op_type, "success": False}
|
|
||||||
|
|
||||||
try:
|
|
||||||
if op_type == "delete":
|
|
||||||
await delete_item(op["path"])
|
|
||||||
op_result["success"] = True
|
|
||||||
op_result["message"] = "删除成功"
|
|
||||||
elif op_type == "move":
|
|
||||||
await move_item(op["source_path"], op["target_path"])
|
|
||||||
op_result["success"] = True
|
|
||||||
op_result["message"] = "移动成功"
|
|
||||||
elif op_type == "copy":
|
|
||||||
await copy_item(op["source_path"], op["target_path"])
|
|
||||||
op_result["success"] = True
|
|
||||||
op_result["message"] = "复制成功"
|
|
||||||
else:
|
|
||||||
op_result["error"] = f"不支持的操作类型: {op_type}"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
op_result["error"] = str(e)
|
|
||||||
|
|
||||||
results.append(op_result)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"results": results,
|
|
||||||
"total": len(operations),
|
|
||||||
"successful": sum(1 for r in results if r["success"])
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"批量操作失败: {str(e)}")
|
|
||||||
@ -7,9 +7,9 @@ from fastapi.responses import StreamingResponse
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from utils import (
|
from utils import (
|
||||||
Message, ChatRequest, ChatResponse,
|
Message, ChatRequest, ChatResponse
|
||||||
get_global_agent_manager, init_global_sharded_agent_manager
|
|
||||||
)
|
)
|
||||||
|
from agent.sharded_agent_manager import init_global_sharded_agent_manager
|
||||||
from utils.api_models import ChatRequestV2
|
from utils.api_models import ChatRequestV2
|
||||||
from utils.fastapi_utils import (
|
from utils.fastapi_utils import (
|
||||||
process_messages, extract_guidelines_from_system_prompt, format_messages_to_chat_history,
|
process_messages, extract_guidelines_from_system_prompt, format_messages_to_chat_history,
|
||||||
@ -205,6 +205,7 @@ async def create_agent_and_generate_response(
|
|||||||
# 处理结果:最后一个结果是agent,前面的是guideline批次结果
|
# 处理结果:最后一个结果是agent,前面的是guideline批次结果
|
||||||
agent = all_results[-1] # agent创建的结果
|
agent = all_results[-1] # agent创建的结果
|
||||||
batch_results = all_results[:-1] # guideline批次的结果
|
batch_results = all_results[:-1] # guideline批次的结果
|
||||||
|
print(f"batch_results:{batch_results}")
|
||||||
|
|
||||||
# 合并guideline分析结果,使用JSON格式的checks数组
|
# 合并guideline分析结果,使用JSON格式的checks数组
|
||||||
all_checks = []
|
all_checks = []
|
||||||
|
|||||||
@ -6,11 +6,11 @@ from fastapi import APIRouter, HTTPException
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from utils import (
|
from utils import (
|
||||||
get_global_agent_manager, init_global_sharded_agent_manager,
|
|
||||||
get_global_connection_pool, init_global_connection_pool,
|
get_global_connection_pool, init_global_connection_pool,
|
||||||
get_global_file_cache, init_global_file_cache,
|
get_global_file_cache, init_global_file_cache,
|
||||||
setup_system_optimizations
|
setup_system_optimizations
|
||||||
)
|
)
|
||||||
|
from agent.sharded_agent_manager import init_global_sharded_agent_manager
|
||||||
try:
|
try:
|
||||||
from utils.system_optimizer import apply_optimization_profile
|
from utils.system_optimizer import apply_optimization_profile
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|||||||
@ -32,17 +32,19 @@ from .project_manager import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Import agent management modules
|
# Import agent management modules
|
||||||
from .file_loaded_agent_manager import (
|
# Note: These have been moved to agent package
|
||||||
get_global_agent_manager,
|
# from .file_loaded_agent_manager import (
|
||||||
init_global_agent_manager
|
# get_global_agent_manager,
|
||||||
)
|
# init_global_agent_manager
|
||||||
|
# )
|
||||||
|
|
||||||
# Import optimized modules
|
# Import optimized modules
|
||||||
from .sharded_agent_manager import (
|
# Note: These have been moved to agent package
|
||||||
ShardedAgentManager,
|
# from .sharded_agent_manager import (
|
||||||
get_global_sharded_agent_manager,
|
# ShardedAgentManager,
|
||||||
init_global_sharded_agent_manager
|
# get_global_sharded_agent_manager,
|
||||||
)
|
# init_global_sharded_agent_manager
|
||||||
|
# )
|
||||||
|
|
||||||
from .connection_pool import (
|
from .connection_pool import (
|
||||||
HTTPConnectionPool,
|
HTTPConnectionPool,
|
||||||
@ -78,10 +80,11 @@ from .system_optimizer import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Import config cache module
|
# Import config cache module
|
||||||
from .config_cache import (
|
# Note: This has been moved to agent package
|
||||||
config_cache,
|
# from .config_cache import (
|
||||||
ConfigFileCache
|
# config_cache,
|
||||||
)
|
# ConfigFileCache
|
||||||
|
# )
|
||||||
|
|
||||||
from .agent_pool import (
|
from .agent_pool import (
|
||||||
AgentPool,
|
AgentPool,
|
||||||
@ -123,9 +126,10 @@ from .api_models import (
|
|||||||
create_chat_response
|
create_chat_response
|
||||||
)
|
)
|
||||||
|
|
||||||
from .prompt_loader import (
|
# Note: This has been moved to agent package
|
||||||
load_system_prompt,
|
# from .prompt_loader import (
|
||||||
)
|
# load_system_prompt,
|
||||||
|
# )
|
||||||
|
|
||||||
from .multi_project_manager import (
|
from .multi_project_manager import (
|
||||||
create_robot_project,
|
create_robot_project,
|
||||||
@ -162,9 +166,9 @@ __all__ = [
|
|||||||
'list_projects',
|
'list_projects',
|
||||||
'get_project_stats',
|
'get_project_stats',
|
||||||
|
|
||||||
# file_loaded_agent_manager
|
# file_loaded_agent_manager (moved to agent package)
|
||||||
'get_global_agent_manager',
|
# 'get_global_agent_manager',
|
||||||
'init_global_agent_manager',
|
# 'init_global_agent_manager',
|
||||||
|
|
||||||
# agent_pool
|
# agent_pool
|
||||||
'AgentPool',
|
'AgentPool',
|
||||||
@ -202,8 +206,8 @@ __all__ = [
|
|||||||
'create_error_response',
|
'create_error_response',
|
||||||
'create_chat_response',
|
'create_chat_response',
|
||||||
|
|
||||||
# prompt_loader
|
# prompt_loader (moved to agent package)
|
||||||
'load_system_prompt',
|
# 'load_system_prompt',
|
||||||
|
|
||||||
# multi_project_manager
|
# multi_project_manager
|
||||||
'create_robot_project',
|
'create_robot_project',
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user