qwen_agent/agent/custom_mcp_manager.py

475 lines
22 KiB
Python

# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import atexit
import datetime
import json
import threading
import time
import uuid
from contextlib import AsyncExitStack
from typing import Dict, Optional, Union
from dotenv import load_dotenv
from qwen_agent.log import logger
from qwen_agent.tools.base import BaseTool
class CustomMCPManager:
_instance = None # Private class variable to store the unique instance
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(CustomMCPManager, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self):
if not hasattr(self, 'clients'): # The singleton should only be inited once
"""Set a new event loop in a separate thread"""
try:
import mcp # noqa
except ImportError as e:
raise ImportError('Could not import mcp. Please install mcp with `pip install -U mcp`.') from e
load_dotenv() # Load environment variables from .env file
self.clients: dict = {}
self.loop = asyncio.new_event_loop()
self.loop_thread = threading.Thread(target=self.start_loop, daemon=True)
self.loop_thread.start()
# A fallback way to terminate MCP tool processes after Qwen-Agent exits
self.processes = []
self.monkey_patch_mcp_create_platform_compatible_process()
def monkey_patch_mcp_create_platform_compatible_process(self):
try:
import mcp.client.stdio
target = mcp.client.stdio._create_platform_compatible_process
except (ModuleNotFoundError, AttributeError) as e:
raise ImportError('Qwen-Agent needs to monkey patch MCP for process cleanup. '
'Please upgrade MCP to a higher version with `pip install -U mcp`.') from e
async def _monkey_patched_create_platform_compatible_process(*args, **kwargs):
process = await target(*args, **kwargs)
self.processes.append(process)
return process
mcp.client.stdio._create_platform_compatible_process = _monkey_patched_create_platform_compatible_process
def start_loop(self):
asyncio.set_event_loop(self.loop)
# Set a global exception handler to silently handle cross-task exceptions from MCP SSE connections
def exception_handler(loop, context):
exception = context.get('exception')
if exception:
# Silently handle cross-task exceptions from MCP SSE connections
if (isinstance(exception, RuntimeError) and
'Attempted to exit cancel scope in a different task' in str(exception)):
return # Silently ignore this type of exception
if (isinstance(exception, BaseExceptionGroup) and # noqa
'Attempted to exit cancel scope in a different task' in str(exception)): # noqa
return # Silently ignore this type of exception
# Other exceptions are handled normally
loop.default_exception_handler(context)
self.loop.set_exception_handler(exception_handler)
self.loop.run_forever()
def is_valid_mcp_servers(self, config: dict):
"""Example of mcp servers configuration:
{
"mcpServers": {
"memory": {
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-memory"]
},
"filesystem": {
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/allowed/files"]
}
}
}
"""
# Check if the top-level key "mcpServers" exists and its value is a dictionary
if not isinstance(config, dict) or 'mcpServers' not in config or not isinstance(config['mcpServers'], dict):
return False
mcp_servers = config['mcpServers']
# Check each sub-item under "mcpServers"
for key in mcp_servers:
server = mcp_servers[key]
# Each sub-item must be a dictionary
if not isinstance(server, dict):
return False
if 'command' in server:
# "command" must be a string
if not isinstance(server['command'], str):
return False
# "args" must be a list
if 'args' not in server or not isinstance(server['args'], list):
return False
if 'url' in server:
# "url" must be a string
if not isinstance(server['url'], str):
return False
# "headers" must be a dictionary
if 'headers' in server and not isinstance(server['headers'], dict):
return False
# If the "env" key exists, it must be a dictionary
if 'env' in server and not isinstance(server['env'], dict):
return False
return True
def initConfig(self, config: Dict):
if not self.is_valid_mcp_servers(config):
raise ValueError('Config of mcpservers is not valid')
logger.info(f'Initializing MCP tools from mcp servers: {list(config["mcpServers"].keys())}')
# Submit coroutine to the event loop and wait for the result
future = asyncio.run_coroutine_threadsafe(self.init_config_async(config), self.loop)
try:
result = future.result() # You can specify a timeout if desired
return result
except Exception as e:
logger.info(f'Failed in initializing MCP tools: {e}')
raise e
async def init_config_async(self, config: Dict):
tools: list = []
mcp_servers = config['mcpServers']
for server_name in mcp_servers:
client = CustomMCPClient()
server = mcp_servers[server_name]
await client.connection_server(mcp_server_name=server_name,
mcp_server=server) # Attempt to connect to the server
client_id = server_name + '_' + str(
uuid.uuid4()) # To allow the same server name be used across different running agents
client.client_id = client_id # Ensure client_id is set on the client instance
self.clients[client_id] = client # Add to clients dict after successful connection
for tool in client.tools:
"""MCP tool example:
{
"name": "read_query",
"description": "Execute a SELECT query on the SQLite database",
"inputSchema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SELECT SQL query to execute"
}
},
"required": ["query"]
}
"""
parameters = tool.inputSchema
# The required field in inputSchema may be empty and needs to be initialized.
if 'required' not in parameters:
parameters['required'] = []
# Remove keys from parameters that do not conform to the standard OpenAI schema
# Check if the required fields exist
required_fields = {'type', 'properties', 'required'}
missing_fields = required_fields - parameters.keys()
if missing_fields:
raise ValueError(f'Missing required fields in schema: {missing_fields}')
# Keep only the necessary fields
cleaned_parameters = {
'type': parameters['type'],
'properties': parameters['properties'],
'required': parameters['required']
}
register_name = server_name + '-' + tool.name
agent_tool = self.create_tool_class(register_name=register_name,
register_client_id=client_id,
tool_name=tool.name,
tool_desc=tool.description,
tool_parameters=cleaned_parameters)
tools.append(agent_tool)
if client.resources:
"""MCP resource example:
{
uri: string; // Unique identifier for the resource
name: string; // Human-readable name
description?: string; // Optional description
mimeType?: string; // Optional MIME type
}
"""
# List resources
list_resources_tool_name = server_name + '-' + 'list_resources'
list_resources_params = {'type': 'object', 'properties': {}, 'required': []}
list_resources_agent_tool = self.create_tool_class(
register_name=list_resources_tool_name,
register_client_id=client_id,
tool_name='list_resources',
tool_desc='Servers expose a list of concrete resources through this tool. '
'By invoking it, you can discover the available resources and obtain resource templates, which help clients understand how to construct valid URIs. '
'These URI formats will be used as input parameters for the read_resource function. ',
tool_parameters=list_resources_params)
tools.append(list_resources_agent_tool)
# Read resource
resources_template_str = '' # Check if there are resource templates
try:
list_resource_templates = await client.session.list_resource_templates(
) # Check if the server has resources tesmplate
if list_resource_templates.resourceTemplates:
resources_template_str = '\n'.join(
str(template) for template in list_resource_templates.resourceTemplates)
except Exception as e:
logger.info(f'Failed in listing MCP resource templates: {e}')
read_resource_tool_name = server_name + '-' + 'read_resource'
read_resource_params = {
'type': 'object',
'properties': {
'uri': {
'type': 'string',
'description': 'The URI identifying the specific resource to access'
}
},
'required': ['uri']
}
original_tool_desc = 'Request to access a resource provided by a connected MCP server. Resources represent data sources that can be used as context, such as files, API responses, or system information.'
if resources_template_str:
tool_desc = original_tool_desc + '\nResource Templates:\n' + resources_template_str
else:
tool_desc = original_tool_desc
read_resource_agent_tool = self.create_tool_class(register_name=read_resource_tool_name,
register_client_id=client_id,
tool_name='read_resource',
tool_desc=tool_desc,
tool_parameters=read_resource_params)
tools.append(read_resource_agent_tool)
return tools
def create_tool_class(self, register_name, register_client_id, tool_name, tool_desc, tool_parameters):
class ToolClass(BaseTool):
name = register_name
description = tool_desc
parameters = tool_parameters
client_id = register_client_id
def call(self, params: Union[str, dict], **kwargs) -> str:
tool_args = json.loads(params)
# Submit coroutine to the event loop and wait for the result
manager = CustomMCPManager()
client = manager.clients[self.client_id]
future = asyncio.run_coroutine_threadsafe(client.execute_function(tool_name, tool_args), manager.loop)
try:
result = future.result()
return result
except Exception as e:
logger.info(f'Failed in executing MCP tool: {e}')
raise e
ToolClass.__name__ = f'{register_name}_Class'
return ToolClass()
def shutdown(self):
futures = []
for client_id in list(self.clients.keys()):
client: CustomMCPClient = self.clients[client_id]
future = asyncio.run_coroutine_threadsafe(client.cleanup(), self.loop)
futures.append(future)
del self.clients[client_id]
time.sleep(1) # Wait for the graceful cleanups, otherwise fall back
# fallback
if asyncio.all_tasks(self.loop):
logger.info(
'There are still tasks in `CustomMCPManager().loop`, force terminating the MCP tool processes. There may be some exceptions.'
)
for process in self.processes:
try:
process.terminate()
except ProcessLookupError:
pass # it's ok, the process may exit earlier
self.loop.call_soon_threadsafe(self.loop.stop)
self.loop_thread.join()
class CustomMCPClient:
def __init__(self):
from mcp import ClientSession
self.session: Optional[ClientSession] = None
self.tools: list = None
self.exit_stack = AsyncExitStack()
self.resources: bool = False
self._last_mcp_server_name = None
self._last_mcp_server = None
self.client_id = None # For replacing in MCPManager.clients
async def connection_server(self, mcp_server_name, mcp_server):
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
"""Connect to an MCP server and retrieve the available tools."""
# Save parameters
self._last_mcp_server_name = mcp_server_name
self._last_mcp_server = mcp_server
try:
if 'url' in mcp_server:
url = mcp_server.get('url')
sse_read_timeout = mcp_server.get('sse_read_timeout', 300)
logger.info(f'{mcp_server_name} sse_read_timeout: {sse_read_timeout}s')
if mcp_server.get('type', 'sse') == 'streamable-http':
# streamable-http mode
"""streamable-http mode mcp example:
{"mcpServers": {
"streamable-mcp-server": {
"type": "streamable-http",
"url":"http://0.0.0.0:8000/mcp"
}
}
}
"""
headers = mcp_server.get('headers', {})
self._streams_context = streamablehttp_client(
url=url, headers=headers, sse_read_timeout=datetime.timedelta(seconds=sse_read_timeout))
read_stream, write_stream, get_session_id = await self.exit_stack.enter_async_context(
self._streams_context)
self._session_context = ClientSession(read_stream, write_stream)
self.session = await self.exit_stack.enter_async_context(self._session_context)
else:
# sse mode
headers = mcp_server.get('headers', {'Accept': 'text/event-stream'})
self._streams_context = sse_client(url, headers, sse_read_timeout=sse_read_timeout)
streams = await self.exit_stack.enter_async_context(self._streams_context)
self._session_context = ClientSession(*streams)
self.session = await self.exit_stack.enter_async_context(self._session_context)
else:
server_params = StdioServerParameters(command=mcp_server['command'],
args=mcp_server['args'],
env=mcp_server.get('env', None))
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
logger.info(
f'Initializing a MCP stdio_client, if this takes forever, please check the config of this mcp server: {mcp_server_name}'
)
await self.session.initialize()
list_tools = await self.session.list_tools()
self.tools = list_tools.tools
try:
list_resources = await self.session.list_resources() # Check if the server has resources
if list_resources.resources:
self.resources = True
except Exception:
# logger.info(f"No list resources: {e}")
pass
except Exception as e:
logger.warning(f'Failed in connecting to MCP server: {e}')
raise e
async def reconnect(self):
# Create a new MCPClient and connect
if self.client_id is None:
raise RuntimeError(
'Cannot reconnect: client_id is None. This usually means the client was not properly registered in MCPManager.'
)
new_client = CustomMCPClient()
new_client.client_id = self.client_id
await new_client.connection_server(self._last_mcp_server_name, self._last_mcp_server)
return new_client
async def execute_function(self, tool_name, tool_args: dict):
from mcp.types import TextResourceContents
# Check if session is alive
try:
await self.session.send_ping()
except Exception as e:
logger.info(f"Session is not alive, please increase 'sse_read_timeout' in the config, try reconnect: {e}")
# Auto reconnect
try:
manager = CustomMCPManager()
if self.client_id is not None:
manager.clients[self.client_id] = await self.reconnect()
return await manager.clients[self.client_id].execute_function(tool_name, tool_args)
else:
logger.info('Reconnect failed: client_id is None')
return 'Session reconnect (client creation) exception: client_id is None'
except Exception as e3:
logger.info(f'Reconnect (client creation) exception type: {type(e3)}, value: {repr(e3)}')
return f'Session reconnect (client creation) exception: {e3}'
if tool_name == 'list_resources':
try:
list_resources = await self.session.list_resources()
if list_resources.resources:
resources_str = '\n\n'.join(str(resource) for resource in list_resources.resources)
else:
resources_str = 'No resources found'
return resources_str
except Exception as e:
logger.info(f'No list resources: {e}')
return f'Error: {e}'
elif tool_name == 'read_resource':
try:
uri = tool_args.get('uri')
if not uri:
raise ValueError('URI is required for read_resource')
read_resource = await self.session.read_resource(uri)
texts = []
for resource in read_resource.contents:
if isinstance(resource, TextResourceContents):
texts.append(resource.text)
# if isinstance(resource, BlobResourceContents):
# texts.append(resource.blob)
if texts:
return '\n\n'.join(texts)
else:
return 'Failed to read resource'
except Exception as e:
logger.info(f'Failed to read resource: {e}')
return f'Error: {e}'
else:
response = await self.session.call_tool(tool_name, tool_args)
texts = []
for content in response.content:
if content.type == 'text':
texts.append(content.text)
if texts:
return '\n\n'.join(texts)
else:
return 'execute error'
async def cleanup(self):
await self.exit_stack.aclose()
def _cleanup_mcp(_sig_num=None, _frame=None):
if CustomMCPManager._instance is None:
return
manager = CustomMCPManager()
manager.shutdown()
# Make sure all subprocesses are terminated even if killed abnormally
if threading.current_thread() is threading.main_thread():
atexit.register(_cleanup_mcp)