新增agent文件夹,修改import引用,增加custom_mcp_manager
This commit is contained in:
parent
7002019229
commit
58ac6e3024
0
agent/__init__.py
Normal file
0
agent/__init__.py
Normal file
474
agent/custom_mcp_manager.py
Normal file
474
agent/custom_mcp_manager.py
Normal file
@ -0,0 +1,474 @@
|
||||
# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import datetime
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from qwen_agent.log import logger
|
||||
from qwen_agent.tools.base import BaseTool
|
||||
|
||||
|
||||
class CustomMCPManager:
|
||||
_instance = None # Private class variable to store the unique instance
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(CustomMCPManager, cls).__new__(cls, *args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, 'clients'): # The singleton should only be inited once
|
||||
"""Set a new event loop in a separate thread"""
|
||||
try:
|
||||
import mcp # noqa
|
||||
except ImportError as e:
|
||||
raise ImportError('Could not import mcp. Please install mcp with `pip install -U mcp`.') from e
|
||||
|
||||
load_dotenv() # Load environment variables from .env file
|
||||
self.clients: dict = {}
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.loop_thread = threading.Thread(target=self.start_loop, daemon=True)
|
||||
self.loop_thread.start()
|
||||
|
||||
# A fallback way to terminate MCP tool processes after Qwen-Agent exits
|
||||
self.processes = []
|
||||
self.monkey_patch_mcp_create_platform_compatible_process()
|
||||
|
||||
def monkey_patch_mcp_create_platform_compatible_process(self):
|
||||
try:
|
||||
import mcp.client.stdio
|
||||
target = mcp.client.stdio._create_platform_compatible_process
|
||||
except (ModuleNotFoundError, AttributeError) as e:
|
||||
raise ImportError('Qwen-Agent needs to monkey patch MCP for process cleanup. '
|
||||
'Please upgrade MCP to a higher version with `pip install -U mcp`.') from e
|
||||
|
||||
async def _monkey_patched_create_platform_compatible_process(*args, **kwargs):
|
||||
process = await target(*args, **kwargs)
|
||||
self.processes.append(process)
|
||||
return process
|
||||
|
||||
mcp.client.stdio._create_platform_compatible_process = _monkey_patched_create_platform_compatible_process
|
||||
|
||||
def start_loop(self):
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
# Set a global exception handler to silently handle cross-task exceptions from MCP SSE connections
|
||||
def exception_handler(loop, context):
|
||||
exception = context.get('exception')
|
||||
if exception:
|
||||
# Silently handle cross-task exceptions from MCP SSE connections
|
||||
if (isinstance(exception, RuntimeError) and
|
||||
'Attempted to exit cancel scope in a different task' in str(exception)):
|
||||
return # Silently ignore this type of exception
|
||||
if (isinstance(exception, BaseExceptionGroup) and # noqa
|
||||
'Attempted to exit cancel scope in a different task' in str(exception)): # noqa
|
||||
return # Silently ignore this type of exception
|
||||
|
||||
# Other exceptions are handled normally
|
||||
loop.default_exception_handler(context)
|
||||
|
||||
self.loop.set_exception_handler(exception_handler)
|
||||
self.loop.run_forever()
|
||||
|
||||
def is_valid_mcp_servers(self, config: dict):
|
||||
"""Example of mcp servers configuration:
|
||||
{
|
||||
"mcpServers": {
|
||||
"memory": {
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-memory"]
|
||||
},
|
||||
"filesystem": {
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/allowed/files"]
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# Check if the top-level key "mcpServers" exists and its value is a dictionary
|
||||
if not isinstance(config, dict) or 'mcpServers' not in config or not isinstance(config['mcpServers'], dict):
|
||||
return False
|
||||
mcp_servers = config['mcpServers']
|
||||
# Check each sub-item under "mcpServers"
|
||||
for key in mcp_servers:
|
||||
server = mcp_servers[key]
|
||||
# Each sub-item must be a dictionary
|
||||
if not isinstance(server, dict):
|
||||
return False
|
||||
if 'command' in server:
|
||||
# "command" must be a string
|
||||
if not isinstance(server['command'], str):
|
||||
return False
|
||||
# "args" must be a list
|
||||
if 'args' not in server or not isinstance(server['args'], list):
|
||||
return False
|
||||
if 'url' in server:
|
||||
# "url" must be a string
|
||||
if not isinstance(server['url'], str):
|
||||
return False
|
||||
# "headers" must be a dictionary
|
||||
if 'headers' in server and not isinstance(server['headers'], dict):
|
||||
return False
|
||||
# If the "env" key exists, it must be a dictionary
|
||||
if 'env' in server and not isinstance(server['env'], dict):
|
||||
return False
|
||||
return True
|
||||
|
||||
def initConfig(self, config: Dict):
|
||||
if not self.is_valid_mcp_servers(config):
|
||||
raise ValueError('Config of mcpservers is not valid')
|
||||
logger.info(f'Initializing MCP tools from mcp servers: {list(config["mcpServers"].keys())}')
|
||||
# Submit coroutine to the event loop and wait for the result
|
||||
future = asyncio.run_coroutine_threadsafe(self.init_config_async(config), self.loop)
|
||||
try:
|
||||
result = future.result() # You can specify a timeout if desired
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.info(f'Failed in initializing MCP tools: {e}')
|
||||
raise e
|
||||
|
||||
async def init_config_async(self, config: Dict):
|
||||
tools: list = []
|
||||
mcp_servers = config['mcpServers']
|
||||
for server_name in mcp_servers:
|
||||
client = CustomMCPClient()
|
||||
server = mcp_servers[server_name]
|
||||
await client.connection_server(mcp_server_name=server_name,
|
||||
mcp_server=server) # Attempt to connect to the server
|
||||
|
||||
client_id = server_name + '_' + str(
|
||||
uuid.uuid4()) # To allow the same server name be used across different running agents
|
||||
client.client_id = client_id # Ensure client_id is set on the client instance
|
||||
self.clients[client_id] = client # Add to clients dict after successful connection
|
||||
for tool in client.tools:
|
||||
"""MCP tool example:
|
||||
{
|
||||
"name": "read_query",
|
||||
"description": "Execute a SELECT query on the SQLite database",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "SELECT SQL query to execute"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
"""
|
||||
parameters = tool.inputSchema
|
||||
# The required field in inputSchema may be empty and needs to be initialized.
|
||||
if 'required' not in parameters:
|
||||
parameters['required'] = []
|
||||
# Remove keys from parameters that do not conform to the standard OpenAI schema
|
||||
# Check if the required fields exist
|
||||
required_fields = {'type', 'properties', 'required'}
|
||||
missing_fields = required_fields - parameters.keys()
|
||||
if missing_fields:
|
||||
raise ValueError(f'Missing required fields in schema: {missing_fields}')
|
||||
|
||||
# Keep only the necessary fields
|
||||
cleaned_parameters = {
|
||||
'type': parameters['type'],
|
||||
'properties': parameters['properties'],
|
||||
'required': parameters['required']
|
||||
}
|
||||
register_name = server_name + '-' + tool.name
|
||||
agent_tool = self.create_tool_class(register_name=register_name,
|
||||
register_client_id=client_id,
|
||||
tool_name=tool.name,
|
||||
tool_desc=tool.description,
|
||||
tool_parameters=cleaned_parameters)
|
||||
tools.append(agent_tool)
|
||||
|
||||
if client.resources:
|
||||
"""MCP resource example:
|
||||
{
|
||||
uri: string; // Unique identifier for the resource
|
||||
name: string; // Human-readable name
|
||||
description?: string; // Optional description
|
||||
mimeType?: string; // Optional MIME type
|
||||
}
|
||||
"""
|
||||
# List resources
|
||||
list_resources_tool_name = server_name + '-' + 'list_resources'
|
||||
list_resources_params = {'type': 'object', 'properties': {}, 'required': []}
|
||||
list_resources_agent_tool = self.create_tool_class(
|
||||
register_name=list_resources_tool_name,
|
||||
register_client_id=client_id,
|
||||
tool_name='list_resources',
|
||||
tool_desc='Servers expose a list of concrete resources through this tool. '
|
||||
'By invoking it, you can discover the available resources and obtain resource templates, which help clients understand how to construct valid URIs. '
|
||||
'These URI formats will be used as input parameters for the read_resource function. ',
|
||||
tool_parameters=list_resources_params)
|
||||
tools.append(list_resources_agent_tool)
|
||||
|
||||
# Read resource
|
||||
resources_template_str = '' # Check if there are resource templates
|
||||
try:
|
||||
list_resource_templates = await client.session.list_resource_templates(
|
||||
) # Check if the server has resources tesmplate
|
||||
if list_resource_templates.resourceTemplates:
|
||||
resources_template_str = '\n'.join(
|
||||
str(template) for template in list_resource_templates.resourceTemplates)
|
||||
|
||||
except Exception as e:
|
||||
logger.info(f'Failed in listing MCP resource templates: {e}')
|
||||
|
||||
read_resource_tool_name = server_name + '-' + 'read_resource'
|
||||
read_resource_params = {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'uri': {
|
||||
'type': 'string',
|
||||
'description': 'The URI identifying the specific resource to access'
|
||||
}
|
||||
},
|
||||
'required': ['uri']
|
||||
}
|
||||
original_tool_desc = 'Request to access a resource provided by a connected MCP server. Resources represent data sources that can be used as context, such as files, API responses, or system information.'
|
||||
if resources_template_str:
|
||||
tool_desc = original_tool_desc + '\nResource Templates:\n' + resources_template_str
|
||||
else:
|
||||
tool_desc = original_tool_desc
|
||||
read_resource_agent_tool = self.create_tool_class(register_name=read_resource_tool_name,
|
||||
register_client_id=client_id,
|
||||
tool_name='read_resource',
|
||||
tool_desc=tool_desc,
|
||||
tool_parameters=read_resource_params)
|
||||
tools.append(read_resource_agent_tool)
|
||||
|
||||
return tools
|
||||
|
||||
def create_tool_class(self, register_name, register_client_id, tool_name, tool_desc, tool_parameters):
|
||||
|
||||
class ToolClass(BaseTool):
|
||||
name = register_name
|
||||
description = tool_desc
|
||||
parameters = tool_parameters
|
||||
client_id = register_client_id
|
||||
|
||||
def call(self, params: Union[str, dict], **kwargs) -> str:
|
||||
tool_args = json.loads(params)
|
||||
# Submit coroutine to the event loop and wait for the result
|
||||
manager = CustomMCPManager()
|
||||
client = manager.clients[self.client_id]
|
||||
future = asyncio.run_coroutine_threadsafe(client.execute_function(tool_name, tool_args), manager.loop)
|
||||
try:
|
||||
result = future.result()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.info(f'Failed in executing MCP tool: {e}')
|
||||
raise e
|
||||
|
||||
ToolClass.__name__ = f'{register_name}_Class'
|
||||
return ToolClass()
|
||||
|
||||
def shutdown(self):
|
||||
futures = []
|
||||
for client_id in list(self.clients.keys()):
|
||||
client: CustomMCPClient = self.clients[client_id]
|
||||
future = asyncio.run_coroutine_threadsafe(client.cleanup(), self.loop)
|
||||
futures.append(future)
|
||||
del self.clients[client_id]
|
||||
time.sleep(1) # Wait for the graceful cleanups, otherwise fall back
|
||||
|
||||
# fallback
|
||||
if asyncio.all_tasks(self.loop):
|
||||
logger.info(
|
||||
'There are still tasks in `CustomMCPManager().loop`, force terminating the MCP tool processes. There may be some exceptions.'
|
||||
)
|
||||
for process in self.processes:
|
||||
try:
|
||||
process.terminate()
|
||||
except ProcessLookupError:
|
||||
pass # it's ok, the process may exit earlier
|
||||
|
||||
self.loop.call_soon_threadsafe(self.loop.stop)
|
||||
self.loop_thread.join()
|
||||
|
||||
|
||||
class CustomMCPClient:
|
||||
|
||||
def __init__(self):
|
||||
from mcp import ClientSession
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.tools: list = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self.resources: bool = False
|
||||
self._last_mcp_server_name = None
|
||||
self._last_mcp_server = None
|
||||
self.client_id = None # For replacing in MCPManager.clients
|
||||
|
||||
async def connection_server(self, mcp_server_name, mcp_server):
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
"""Connect to an MCP server and retrieve the available tools."""
|
||||
# Save parameters
|
||||
self._last_mcp_server_name = mcp_server_name
|
||||
self._last_mcp_server = mcp_server
|
||||
|
||||
try:
|
||||
if 'url' in mcp_server:
|
||||
url = mcp_server.get('url')
|
||||
sse_read_timeout = mcp_server.get('sse_read_timeout', 300)
|
||||
logger.info(f'{mcp_server_name} sse_read_timeout: {sse_read_timeout}s')
|
||||
if mcp_server.get('type', 'sse') == 'streamable-http':
|
||||
# streamable-http mode
|
||||
"""streamable-http mode mcp example:
|
||||
{"mcpServers": {
|
||||
"streamable-mcp-server": {
|
||||
"type": "streamable-http",
|
||||
"url":"http://0.0.0.0:8000/mcp"
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
headers = mcp_server.get('headers', {})
|
||||
self._streams_context = streamablehttp_client(
|
||||
url=url, headers=headers, sse_read_timeout=datetime.timedelta(seconds=sse_read_timeout))
|
||||
read_stream, write_stream, get_session_id = await self.exit_stack.enter_async_context(
|
||||
self._streams_context)
|
||||
self._session_context = ClientSession(read_stream, write_stream)
|
||||
self.session = await self.exit_stack.enter_async_context(self._session_context)
|
||||
else:
|
||||
# sse mode
|
||||
headers = mcp_server.get('headers', {'Accept': 'text/event-stream'})
|
||||
self._streams_context = sse_client(url, headers, sse_read_timeout=sse_read_timeout)
|
||||
streams = await self.exit_stack.enter_async_context(self._streams_context)
|
||||
self._session_context = ClientSession(*streams)
|
||||
self.session = await self.exit_stack.enter_async_context(self._session_context)
|
||||
else:
|
||||
server_params = StdioServerParameters(command=mcp_server['command'],
|
||||
args=mcp_server['args'],
|
||||
env=mcp_server.get('env', None))
|
||||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||||
self.stdio, self.write = stdio_transport
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
|
||||
logger.info(
|
||||
f'Initializing a MCP stdio_client, if this takes forever, please check the config of this mcp server: {mcp_server_name}'
|
||||
)
|
||||
|
||||
await self.session.initialize()
|
||||
list_tools = await self.session.list_tools()
|
||||
self.tools = list_tools.tools
|
||||
try:
|
||||
list_resources = await self.session.list_resources() # Check if the server has resources
|
||||
if list_resources.resources:
|
||||
self.resources = True
|
||||
except Exception:
|
||||
# logger.info(f"No list resources: {e}")
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f'Failed in connecting to MCP server: {e}')
|
||||
raise e
|
||||
|
||||
async def reconnect(self):
|
||||
# Create a new MCPClient and connect
|
||||
if self.client_id is None:
|
||||
raise RuntimeError(
|
||||
'Cannot reconnect: client_id is None. This usually means the client was not properly registered in MCPManager.'
|
||||
)
|
||||
new_client = CustomMCPClient()
|
||||
new_client.client_id = self.client_id
|
||||
await new_client.connection_server(self._last_mcp_server_name, self._last_mcp_server)
|
||||
return new_client
|
||||
|
||||
async def execute_function(self, tool_name, tool_args: dict):
|
||||
from mcp.types import TextResourceContents
|
||||
|
||||
# Check if session is alive
|
||||
try:
|
||||
await self.session.send_ping()
|
||||
except Exception as e:
|
||||
logger.info(f"Session is not alive, please increase 'sse_read_timeout' in the config, try reconnect: {e}")
|
||||
# Auto reconnect
|
||||
try:
|
||||
manager = CustomMCPManager()
|
||||
if self.client_id is not None:
|
||||
manager.clients[self.client_id] = await self.reconnect()
|
||||
return await manager.clients[self.client_id].execute_function(tool_name, tool_args)
|
||||
else:
|
||||
logger.info('Reconnect failed: client_id is None')
|
||||
return 'Session reconnect (client creation) exception: client_id is None'
|
||||
except Exception as e3:
|
||||
logger.info(f'Reconnect (client creation) exception type: {type(e3)}, value: {repr(e3)}')
|
||||
return f'Session reconnect (client creation) exception: {e3}'
|
||||
if tool_name == 'list_resources':
|
||||
try:
|
||||
list_resources = await self.session.list_resources()
|
||||
if list_resources.resources:
|
||||
resources_str = '\n\n'.join(str(resource) for resource in list_resources.resources)
|
||||
else:
|
||||
resources_str = 'No resources found'
|
||||
return resources_str
|
||||
except Exception as e:
|
||||
logger.info(f'No list resources: {e}')
|
||||
return f'Error: {e}'
|
||||
elif tool_name == 'read_resource':
|
||||
try:
|
||||
uri = tool_args.get('uri')
|
||||
if not uri:
|
||||
raise ValueError('URI is required for read_resource')
|
||||
read_resource = await self.session.read_resource(uri)
|
||||
texts = []
|
||||
for resource in read_resource.contents:
|
||||
if isinstance(resource, TextResourceContents):
|
||||
texts.append(resource.text)
|
||||
# if isinstance(resource, BlobResourceContents):
|
||||
# texts.append(resource.blob)
|
||||
if texts:
|
||||
return '\n\n'.join(texts)
|
||||
else:
|
||||
return 'Failed to read resource'
|
||||
except Exception as e:
|
||||
logger.info(f'Failed to read resource: {e}')
|
||||
return f'Error: {e}'
|
||||
else:
|
||||
response = await self.session.call_tool(tool_name, tool_args)
|
||||
texts = []
|
||||
for content in response.content:
|
||||
if content.type == 'text':
|
||||
texts.append(content.text)
|
||||
if texts:
|
||||
return '\n\n'.join(texts)
|
||||
else:
|
||||
return 'execute error'
|
||||
|
||||
async def cleanup(self):
|
||||
await self.exit_stack.aclose()
|
||||
|
||||
|
||||
def _cleanup_mcp(_sig_num=None, _frame=None):
|
||||
if CustomMCPManager._instance is None:
|
||||
return
|
||||
manager = CustomMCPManager()
|
||||
manager.shutdown()
|
||||
|
||||
|
||||
# Make sure all subprocesses are terminated even if killed abnormally
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
atexit.register(_cleanup_mcp)
|
||||
@ -23,8 +23,8 @@ from typing import Dict, List, Optional
|
||||
from qwen_agent.agents import Assistant
|
||||
from qwen_agent.log import logger
|
||||
|
||||
from modified_assistant import init_modified_agent_service_with_files, update_agent_llm
|
||||
from .prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||||
from agent.modified_assistant import init_modified_agent_service_with_files, update_agent_llm
|
||||
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||||
|
||||
|
||||
class FileLoadedAgentManager:
|
||||
@ -14,6 +14,7 @@
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Iterator, List, Literal, Optional, Union
|
||||
@ -21,7 +22,11 @@ from typing import Dict, Iterator, List, Literal, Optional, Union
|
||||
from qwen_agent.agents import Assistant
|
||||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION, Message
|
||||
from qwen_agent.llm.oai import TextChatAtOAI
|
||||
from utils.logger import tool_logger
|
||||
from qwen_agent.tools import BaseTool
|
||||
from agent.custom_mcp_manager import CustomMCPManager
|
||||
|
||||
# 设置工具日志记录器
|
||||
tool_logger = logging.getLogger('tool_logger')
|
||||
|
||||
class ModifiedAssistant(Assistant):
|
||||
"""
|
||||
@ -94,6 +99,26 @@ class ModifiedAssistant(Assistant):
|
||||
error_message = f'An error occurred when calling tool {tool_name}: {type(ex).__name__}: {str(ex)}'
|
||||
return error_message
|
||||
|
||||
def _init_tool(self, tool: Union[str, Dict, BaseTool]):
|
||||
"""重写工具初始化方法,使用CustomMCPManager处理MCP服务器配置"""
|
||||
if isinstance(tool, BaseTool):
|
||||
# 处理BaseTool实例
|
||||
tool_name = tool.name
|
||||
if tool_name in self.function_map:
|
||||
tool_logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
|
||||
self.function_map[tool_name] = tool
|
||||
elif isinstance(tool, dict) and 'mcpServers' in tool:
|
||||
# 使用CustomMCPManager处理MCP服务器配置,支持headers
|
||||
tools = CustomMCPManager().initConfig(tool)
|
||||
for tool in tools:
|
||||
tool_name = tool.name
|
||||
if tool_name in self.function_map:
|
||||
tool_logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
|
||||
self.function_map[tool_name] = tool
|
||||
else:
|
||||
# 调用父类的处理方法
|
||||
super()._init_tool(tool)
|
||||
|
||||
def _call_llm_with_retry(self, messages: List[Message], functions=None, extra_generate_cfg=None, max_retries: int = 5) -> Iterator:
|
||||
"""带重试机制的LLM调用
|
||||
|
||||
@ -109,7 +109,7 @@ async def load_system_prompt_async(project_dir: str, language: str = None, syste
|
||||
Returns:
|
||||
str: 加载到的系统提示词内容
|
||||
"""
|
||||
from .config_cache import config_cache
|
||||
from agent.config_cache import config_cache
|
||||
|
||||
# 获取语言显示名称
|
||||
language_display_map = {
|
||||
@ -235,7 +235,7 @@ async def load_mcp_settings_async(project_dir: str, mcp_settings: list=None, bot
|
||||
支持在 mcp_settings.json 的 args 中使用 {dataset_dir} 占位符,
|
||||
会在 init_modified_agent_service_with_files 中被替换为实际的路径。
|
||||
"""
|
||||
from .config_cache import config_cache
|
||||
from agent.config_cache import config_cache
|
||||
|
||||
# 1. 首先读取默认MCP设置
|
||||
default_mcp_settings = []
|
||||
@ -26,8 +26,8 @@ from collections import defaultdict
|
||||
from qwen_agent.agents import Assistant
|
||||
from qwen_agent.log import logger
|
||||
|
||||
from modified_assistant import init_modified_agent_service_with_files, update_agent_llm
|
||||
from .prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||||
from agent.modified_assistant import init_modified_agent_service_with_files, update_agent_llm
|
||||
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||||
|
||||
|
||||
class ShardedAgentManager:
|
||||
@ -1,691 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
文件管理API - WebDAV的HTTP API替代方案
|
||||
提供RESTful接口来管理projects文件夹
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
import mimetypes
|
||||
import json
|
||||
import zipfile
|
||||
import tempfile
|
||||
import io
|
||||
import math
|
||||
|
||||
router = APIRouter(prefix="/api/v1/files", tags=["file_management"])
|
||||
|
||||
PROJECTS_DIR = Path("projects")
|
||||
PROJECTS_DIR.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
async def list_files(path: str = "", recursive: bool = False):
|
||||
"""
|
||||
列出目录内容
|
||||
|
||||
Args:
|
||||
path: 相对路径,空字符串表示根目录
|
||||
recursive: 是否递归列出所有子目录
|
||||
"""
|
||||
try:
|
||||
target_path = PROJECTS_DIR / path
|
||||
|
||||
if not target_path.exists():
|
||||
raise HTTPException(status_code=404, detail="路径不存在")
|
||||
|
||||
if not target_path.is_dir():
|
||||
raise HTTPException(status_code=400, detail="路径不是目录")
|
||||
|
||||
def scan_directory(directory: Path, base_path: Path = PROJECTS_DIR) -> List[Dict[str, Any]]:
|
||||
items = []
|
||||
try:
|
||||
for item in directory.iterdir():
|
||||
# 跳过隐藏文件
|
||||
if item.name.startswith('.'):
|
||||
continue
|
||||
|
||||
relative_path = item.relative_to(base_path)
|
||||
stat = item.stat()
|
||||
|
||||
item_info = {
|
||||
"name": item.name,
|
||||
"path": str(relative_path),
|
||||
"type": "directory" if item.is_dir() else "file",
|
||||
"size": stat.st_size if item.is_file() else 0,
|
||||
"modified": stat.st_mtime,
|
||||
"created": stat.st_ctime
|
||||
}
|
||||
|
||||
items.append(item_info)
|
||||
|
||||
# 递归扫描子目录
|
||||
if recursive and item.is_dir():
|
||||
items.extend(scan_directory(item, base_path))
|
||||
|
||||
except PermissionError:
|
||||
pass
|
||||
return items
|
||||
|
||||
items = scan_directory(target_path)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"path": path,
|
||||
"items": items,
|
||||
"total": len(items)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"列出目录失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...), path: str = Form("")):
|
||||
"""
|
||||
上传文件
|
||||
|
||||
Args:
|
||||
file: 上传的文件
|
||||
path: 目标路径(相对于projects目录)
|
||||
"""
|
||||
try:
|
||||
target_dir = PROJECTS_DIR / path
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_path = target_dir / file.filename
|
||||
|
||||
# 如果文件已存在,检查是否覆盖
|
||||
if file_path.exists():
|
||||
# 可以添加版本控制或重命名逻辑
|
||||
pass
|
||||
|
||||
with open(file_path, "wb") as buffer:
|
||||
content = await file.read()
|
||||
buffer.write(content)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "文件上传成功",
|
||||
"filename": file.filename,
|
||||
"path": str(Path(path) / file.filename),
|
||||
"size": len(content)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"文件上传失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/download/{file_path:path}")
|
||||
async def download_file(file_path: str):
|
||||
"""
|
||||
下载文件
|
||||
|
||||
Args:
|
||||
file_path: 文件相对路径
|
||||
"""
|
||||
try:
|
||||
target_path = PROJECTS_DIR / file_path
|
||||
|
||||
if not target_path.exists():
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
|
||||
if not target_path.is_file():
|
||||
raise HTTPException(status_code=400, detail="不是文件")
|
||||
|
||||
# 猜测MIME类型
|
||||
mime_type, _ = mimetypes.guess_type(str(target_path))
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
return FileResponse(
|
||||
path=str(target_path),
|
||||
filename=target_path.name,
|
||||
media_type=mime_type
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"文件下载失败: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/delete")
|
||||
async def delete_item(path: str):
|
||||
"""
|
||||
删除文件或目录
|
||||
|
||||
Args:
|
||||
path: 要删除的路径
|
||||
"""
|
||||
try:
|
||||
target_path = PROJECTS_DIR / path
|
||||
|
||||
if not target_path.exists():
|
||||
raise HTTPException(status_code=404, detail="路径不存在")
|
||||
|
||||
if target_path.is_file():
|
||||
target_path.unlink()
|
||||
elif target_path.is_dir():
|
||||
shutil.rmtree(target_path)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"{'文件' if target_path.is_file() else '目录'}删除成功",
|
||||
"path": path
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/create-folder")
|
||||
async def create_folder(path: str, name: str):
|
||||
"""
|
||||
创建文件夹
|
||||
|
||||
Args:
|
||||
path: 父目录路径
|
||||
name: 新文件夹名称
|
||||
"""
|
||||
try:
|
||||
parent_path = PROJECTS_DIR / path
|
||||
parent_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
new_folder = parent_path / name
|
||||
|
||||
if new_folder.exists():
|
||||
raise HTTPException(status_code=400, detail="文件夹已存在")
|
||||
|
||||
new_folder.mkdir()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "文件夹创建成功",
|
||||
"path": str(Path(path) / name)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"创建文件夹失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/rename")
|
||||
async def rename_item(old_path: str, new_name: str):
|
||||
"""
|
||||
重命名文件或文件夹
|
||||
|
||||
Args:
|
||||
old_path: 原路径
|
||||
new_name: 新名称
|
||||
"""
|
||||
try:
|
||||
old_full_path = PROJECTS_DIR / old_path
|
||||
|
||||
if not old_full_path.exists():
|
||||
raise HTTPException(status_code=404, detail="文件或目录不存在")
|
||||
|
||||
new_full_path = old_full_path.parent / new_name
|
||||
|
||||
if new_full_path.exists():
|
||||
raise HTTPException(status_code=400, detail="目标名称已存在")
|
||||
|
||||
old_full_path.rename(new_full_path)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "重命名成功",
|
||||
"old_path": old_path,
|
||||
"new_path": str(new_full_path.relative_to(PROJECTS_DIR))
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"重命名失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/move")
|
||||
async def move_item(source_path: str, target_path: str):
|
||||
"""
|
||||
移动文件或文件夹
|
||||
|
||||
Args:
|
||||
source_path: 源路径
|
||||
target_path: 目标路径
|
||||
"""
|
||||
try:
|
||||
source_full_path = PROJECTS_DIR / source_path
|
||||
target_full_path = PROJECTS_DIR / target_path
|
||||
|
||||
if not source_full_path.exists():
|
||||
raise HTTPException(status_code=404, detail="源文件或目录不存在")
|
||||
|
||||
# 确保目标目录存在
|
||||
target_full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shutil.move(str(source_full_path), str(target_full_path))
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "移动成功",
|
||||
"source_path": source_path,
|
||||
"target_path": target_path
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"移动失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/copy")
|
||||
async def copy_item(source_path: str, target_path: str):
|
||||
"""
|
||||
复制文件或文件夹
|
||||
|
||||
Args:
|
||||
source_path: 源路径
|
||||
target_path: 目标路径
|
||||
"""
|
||||
try:
|
||||
source_full_path = PROJECTS_DIR / source_path
|
||||
target_full_path = PROJECTS_DIR / target_path
|
||||
|
||||
if not source_full_path.exists():
|
||||
raise HTTPException(status_code=404, detail="源文件或目录不存在")
|
||||
|
||||
# 确保目标目录存在
|
||||
target_full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if source_full_path.is_file():
|
||||
shutil.copy2(str(source_full_path), str(target_full_path))
|
||||
else:
|
||||
shutil.copytree(str(source_full_path), str(target_full_path))
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "复制成功",
|
||||
"source_path": source_path,
|
||||
"target_path": target_path
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"复制失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
async def search_files(query: str, path: str = "", file_type: Optional[str] = None):
|
||||
"""
|
||||
搜索文件
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
path: 搜索路径
|
||||
file_type: 文件类型过滤
|
||||
"""
|
||||
try:
|
||||
search_path = PROJECTS_DIR / path
|
||||
|
||||
if not search_path.exists():
|
||||
raise HTTPException(status_code=404, detail="搜索路径不存在")
|
||||
|
||||
results = []
|
||||
|
||||
def scan_for_files(directory: Path):
|
||||
try:
|
||||
for item in directory.iterdir():
|
||||
if item.name.startswith('.'):
|
||||
continue
|
||||
|
||||
# 检查文件名是否包含关键词
|
||||
if query.lower() in item.name.lower():
|
||||
# 检查文件类型过滤
|
||||
if file_type:
|
||||
if item.suffix.lower() == file_type.lower():
|
||||
results.append({
|
||||
"name": item.name,
|
||||
"path": str(item.relative_to(PROJECTS_DIR)),
|
||||
"type": "directory" if item.is_dir() else "file"
|
||||
})
|
||||
else:
|
||||
results.append({
|
||||
"name": item.name,
|
||||
"path": str(item.relative_to(PROJECTS_DIR)),
|
||||
"type": "directory" if item.is_dir() else "file"
|
||||
})
|
||||
|
||||
# 递归搜索子目录
|
||||
if item.is_dir():
|
||||
scan_for_files(item)
|
||||
|
||||
except PermissionError:
|
||||
pass
|
||||
|
||||
scan_for_files(search_path)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"query": query,
|
||||
"path": path,
|
||||
"results": results,
|
||||
"total": len(results)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"搜索失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/info/{file_path:path}")
|
||||
async def get_file_info(file_path: str):
|
||||
"""
|
||||
获取文件或文件夹详细信息
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
"""
|
||||
try:
|
||||
target_path = PROJECTS_DIR / file_path
|
||||
|
||||
if not target_path.exists():
|
||||
raise HTTPException(status_code=404, detail="路径不存在")
|
||||
|
||||
stat = target_path.stat()
|
||||
|
||||
info = {
|
||||
"name": target_path.name,
|
||||
"path": file_path,
|
||||
"type": "directory" if target_path.is_dir() else "file",
|
||||
"size": stat.st_size if target_path.is_file() else 0,
|
||||
"modified": stat.st_mtime,
|
||||
"created": stat.st_ctime,
|
||||
"permissions": oct(stat.st_mode)[-3:]
|
||||
}
|
||||
|
||||
# 如果是文件,添加额外信息
|
||||
if target_path.is_file():
|
||||
mime_type, _ = mimetypes.guess_type(str(target_path))
|
||||
info["mime_type"] = mime_type or "unknown"
|
||||
|
||||
# 读取文件内容预览(仅对小文件)
|
||||
if stat.st_size < 1024 * 1024: # 小于1MB
|
||||
try:
|
||||
with open(target_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read(1000) # 读取前1000字符
|
||||
info["preview"] = content
|
||||
except:
|
||||
info["preview"] = "[二进制文件,无法预览]"
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"info": info
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"获取文件信息失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/download-folder-zip")
|
||||
async def download_folder_as_zip(request: Dict[str, str]):
|
||||
"""
|
||||
将文件夹压缩为ZIP并下载
|
||||
|
||||
Args:
|
||||
request: 包含path字段的JSON对象
|
||||
"""
|
||||
try:
|
||||
folder_path = request.get("path", "")
|
||||
|
||||
if not folder_path:
|
||||
raise HTTPException(status_code=400, detail="路径不能为空")
|
||||
|
||||
target_path = PROJECTS_DIR / folder_path
|
||||
|
||||
if not target_path.exists():
|
||||
raise HTTPException(status_code=404, detail="文件夹不存在")
|
||||
|
||||
if not target_path.is_dir():
|
||||
raise HTTPException(status_code=400, detail="路径不是文件夹")
|
||||
|
||||
# 计算文件夹大小,检查是否过大
|
||||
total_size = 0
|
||||
file_count = 0
|
||||
for file_path in target_path.rglob('*'):
|
||||
if file_path.is_file():
|
||||
total_size += file_path.stat().st_size
|
||||
file_count += 1
|
||||
|
||||
# 限制最大500MB
|
||||
max_size = 500 * 1024 * 1024
|
||||
if total_size > max_size:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"文件夹过大 ({formatFileSize(total_size)}),最大支持 {formatFileSize(max_size)}"
|
||||
)
|
||||
|
||||
# 限制文件数量
|
||||
max_files = 10000
|
||||
if file_count > max_files:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"文件数量过多 ({file_count}),最大支持 {max_files} 个文件"
|
||||
)
|
||||
|
||||
# 创建ZIP文件名
|
||||
folder_name = target_path.name
|
||||
zip_filename = f"{folder_name}.zip"
|
||||
|
||||
# 创建ZIP文件
|
||||
zip_buffer = io.BytesIO()
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED, compresslevel=6) as zipf:
|
||||
# 添加文件夹中的所有文件
|
||||
for file_path in target_path.rglob('*'):
|
||||
if file_path.is_file():
|
||||
try:
|
||||
# 计算相对路径,保持文件夹结构
|
||||
arcname = file_path.relative_to(target_path)
|
||||
zipf.write(file_path, arcname)
|
||||
except (OSError, IOError) as e:
|
||||
# 跳过无法读取的文件,但记录警告
|
||||
print(f"Warning: Skipping file {file_path}: {e}")
|
||||
continue
|
||||
|
||||
zip_buffer.seek(0)
|
||||
|
||||
# 设置响应头
|
||||
headers = {
|
||||
'Content-Disposition': f'attachment; filename="{zip_filename}"',
|
||||
'Content-Type': 'application/zip',
|
||||
'Content-Length': str(len(zip_buffer.getvalue())),
|
||||
'Cache-Control': 'no-cache'
|
||||
}
|
||||
|
||||
# 创建流式响应
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
async def generate_zip():
|
||||
zip_buffer.seek(0)
|
||||
yield zip_buffer.getvalue()
|
||||
|
||||
return StreamingResponse(
|
||||
generate_zip(),
|
||||
media_type="application/zip",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"压缩文件夹失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/download-multiple-zip")
|
||||
async def download_multiple_items_as_zip(request: Dict[str, Any]):
|
||||
"""
|
||||
将多个文件和文件夹压缩为ZIP并下载
|
||||
|
||||
Args:
|
||||
request: 包含paths和filename字段的JSON对象
|
||||
"""
|
||||
try:
|
||||
paths = request.get("paths", [])
|
||||
filename = request.get("filename", "batch_download.zip")
|
||||
|
||||
if not paths:
|
||||
raise HTTPException(status_code=400, detail="请选择要下载的文件")
|
||||
|
||||
# 验证所有路径
|
||||
valid_paths = []
|
||||
total_size = 0
|
||||
file_count = 0
|
||||
|
||||
for path in paths:
|
||||
target_path = PROJECTS_DIR / path
|
||||
|
||||
if not target_path.exists():
|
||||
continue # 跳过不存在的文件
|
||||
|
||||
if target_path.is_file():
|
||||
total_size += target_path.stat().st_size
|
||||
file_count += 1
|
||||
valid_paths.append(path)
|
||||
elif target_path.is_dir():
|
||||
# 计算文件夹大小
|
||||
for file_path in target_path.rglob('*'):
|
||||
if file_path.is_file():
|
||||
total_size += file_path.stat().st_size
|
||||
file_count += 1
|
||||
valid_paths.append(path)
|
||||
|
||||
if not valid_paths:
|
||||
raise HTTPException(status_code=404, detail="没有找到有效的文件")
|
||||
|
||||
# 限制大小
|
||||
max_size = 500 * 1024 * 1024
|
||||
if total_size > max_size:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"选中文件过大 ({formatFileSize(total_size)}),最大支持 {formatFileSize(max_size)}"
|
||||
)
|
||||
|
||||
# 限制文件数量
|
||||
max_files = 10000
|
||||
if file_count > max_files:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"文件数量过多 ({file_count}),最大支持 {max_files} 个文件"
|
||||
)
|
||||
|
||||
# 创建ZIP文件
|
||||
zip_buffer = io.BytesIO()
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED, compresslevel=6) as zipf:
|
||||
for path in valid_paths:
|
||||
target_path = PROJECTS_DIR / path
|
||||
|
||||
if target_path.is_file():
|
||||
# 单个文件
|
||||
try:
|
||||
zipf.write(target_path, target_path.name)
|
||||
except (OSError, IOError) as e:
|
||||
print(f"Warning: Skipping file {target_path}: {e}")
|
||||
continue
|
||||
elif target_path.is_dir():
|
||||
# 文件夹
|
||||
for file_path in target_path.rglob('*'):
|
||||
if file_path.is_file():
|
||||
try:
|
||||
# 保持相对路径结构
|
||||
arcname = f"{target_path.name}/{file_path.relative_to(target_path)}"
|
||||
zipf.write(file_path, arcname)
|
||||
except (OSError, IOError) as e:
|
||||
print(f"Warning: Skipping file {file_path}: {e}")
|
||||
continue
|
||||
|
||||
zip_buffer.seek(0)
|
||||
|
||||
# 设置响应头
|
||||
headers = {
|
||||
'Content-Disposition': f'attachment; filename="{filename}"',
|
||||
'Content-Type': 'application/zip',
|
||||
'Content-Length': str(len(zip_buffer.getvalue())),
|
||||
'Cache-Control': 'no-cache'
|
||||
}
|
||||
|
||||
# 创建流式响应
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
async def generate_zip():
|
||||
zip_buffer.seek(0)
|
||||
yield zip_buffer.getvalue()
|
||||
|
||||
return StreamingResponse(
|
||||
generate_zip(),
|
||||
media_type="application/zip",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"批量压缩失败: {str(e)}")
|
||||
|
||||
|
||||
def formatFileSize(bytes_size: int) -> str:
|
||||
"""格式化文件大小"""
|
||||
if bytes_size == 0:
|
||||
return "0 B"
|
||||
|
||||
k = 1024
|
||||
sizes = ["B", "KB", "MB", "GB", "TB"]
|
||||
i = int(math.floor(math.log(bytes_size, k)))
|
||||
|
||||
if i >= len(sizes):
|
||||
i = len(sizes) - 1
|
||||
|
||||
return f"{bytes_size / math.pow(k, i):.1f} {sizes[i]}"
|
||||
|
||||
|
||||
@router.post("/batch-operation")
|
||||
async def batch_operation(operations: List[Dict[str, Any]]):
|
||||
"""
|
||||
批量操作多个文件
|
||||
|
||||
Args:
|
||||
operations: 操作列表,每个操作包含type和相应的参数
|
||||
"""
|
||||
try:
|
||||
results = []
|
||||
|
||||
for op in operations:
|
||||
op_type = op.get("type")
|
||||
op_result = {"type": op_type, "success": False}
|
||||
|
||||
try:
|
||||
if op_type == "delete":
|
||||
await delete_item(op["path"])
|
||||
op_result["success"] = True
|
||||
op_result["message"] = "删除成功"
|
||||
elif op_type == "move":
|
||||
await move_item(op["source_path"], op["target_path"])
|
||||
op_result["success"] = True
|
||||
op_result["message"] = "移动成功"
|
||||
elif op_type == "copy":
|
||||
await copy_item(op["source_path"], op["target_path"])
|
||||
op_result["success"] = True
|
||||
op_result["message"] = "复制成功"
|
||||
else:
|
||||
op_result["error"] = f"不支持的操作类型: {op_type}"
|
||||
|
||||
except Exception as e:
|
||||
op_result["error"] = str(e)
|
||||
|
||||
results.append(op_result)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"results": results,
|
||||
"total": len(operations),
|
||||
"successful": sum(1 for r in results if r["success"])
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"批量操作失败: {str(e)}")
|
||||
@ -7,9 +7,9 @@ from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from utils import (
|
||||
Message, ChatRequest, ChatResponse,
|
||||
get_global_agent_manager, init_global_sharded_agent_manager
|
||||
Message, ChatRequest, ChatResponse
|
||||
)
|
||||
from agent.sharded_agent_manager import init_global_sharded_agent_manager
|
||||
from utils.api_models import ChatRequestV2
|
||||
from utils.fastapi_utils import (
|
||||
process_messages, extract_guidelines_from_system_prompt, format_messages_to_chat_history,
|
||||
@ -205,6 +205,7 @@ async def create_agent_and_generate_response(
|
||||
# 处理结果:最后一个结果是agent,前面的是guideline批次结果
|
||||
agent = all_results[-1] # agent创建的结果
|
||||
batch_results = all_results[:-1] # guideline批次的结果
|
||||
print(f"batch_results:{batch_results}")
|
||||
|
||||
# 合并guideline分析结果,使用JSON格式的checks数组
|
||||
all_checks = []
|
||||
|
||||
@ -6,11 +6,11 @@ from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from utils import (
|
||||
get_global_agent_manager, init_global_sharded_agent_manager,
|
||||
get_global_connection_pool, init_global_connection_pool,
|
||||
get_global_file_cache, init_global_file_cache,
|
||||
setup_system_optimizations
|
||||
)
|
||||
from agent.sharded_agent_manager import init_global_sharded_agent_manager
|
||||
try:
|
||||
from utils.system_optimizer import apply_optimization_profile
|
||||
except ImportError:
|
||||
|
||||
@ -32,17 +32,19 @@ from .project_manager import (
|
||||
)
|
||||
|
||||
# Import agent management modules
|
||||
from .file_loaded_agent_manager import (
|
||||
get_global_agent_manager,
|
||||
init_global_agent_manager
|
||||
)
|
||||
# Note: These have been moved to agent package
|
||||
# from .file_loaded_agent_manager import (
|
||||
# get_global_agent_manager,
|
||||
# init_global_agent_manager
|
||||
# )
|
||||
|
||||
# Import optimized modules
|
||||
from .sharded_agent_manager import (
|
||||
ShardedAgentManager,
|
||||
get_global_sharded_agent_manager,
|
||||
init_global_sharded_agent_manager
|
||||
)
|
||||
# Import optimized modules
|
||||
# Note: These have been moved to agent package
|
||||
# from .sharded_agent_manager import (
|
||||
# ShardedAgentManager,
|
||||
# get_global_sharded_agent_manager,
|
||||
# init_global_sharded_agent_manager
|
||||
# )
|
||||
|
||||
from .connection_pool import (
|
||||
HTTPConnectionPool,
|
||||
@ -78,10 +80,11 @@ from .system_optimizer import (
|
||||
)
|
||||
|
||||
# Import config cache module
|
||||
from .config_cache import (
|
||||
config_cache,
|
||||
ConfigFileCache
|
||||
)
|
||||
# Note: This has been moved to agent package
|
||||
# from .config_cache import (
|
||||
# config_cache,
|
||||
# ConfigFileCache
|
||||
# )
|
||||
|
||||
from .agent_pool import (
|
||||
AgentPool,
|
||||
@ -123,9 +126,10 @@ from .api_models import (
|
||||
create_chat_response
|
||||
)
|
||||
|
||||
from .prompt_loader import (
|
||||
load_system_prompt,
|
||||
)
|
||||
# Note: This has been moved to agent package
|
||||
# from .prompt_loader import (
|
||||
# load_system_prompt,
|
||||
# )
|
||||
|
||||
from .multi_project_manager import (
|
||||
create_robot_project,
|
||||
@ -162,9 +166,9 @@ __all__ = [
|
||||
'list_projects',
|
||||
'get_project_stats',
|
||||
|
||||
# file_loaded_agent_manager
|
||||
'get_global_agent_manager',
|
||||
'init_global_agent_manager',
|
||||
# file_loaded_agent_manager (moved to agent package)
|
||||
# 'get_global_agent_manager',
|
||||
# 'init_global_agent_manager',
|
||||
|
||||
# agent_pool
|
||||
'AgentPool',
|
||||
@ -202,8 +206,8 @@ __all__ = [
|
||||
'create_error_response',
|
||||
'create_chat_response',
|
||||
|
||||
# prompt_loader
|
||||
'load_system_prompt',
|
||||
# prompt_loader (moved to agent package)
|
||||
# 'load_system_prompt',
|
||||
|
||||
# multi_project_manager
|
||||
'create_robot_project',
|
||||
|
||||
Loading…
Reference in New Issue
Block a user