501 lines
23 KiB
Python
501 lines
23 KiB
Python
# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
import asyncio
|
||
import atexit
|
||
import datetime
|
||
import json
|
||
import threading
|
||
import time
|
||
import uuid
|
||
from contextlib import AsyncExitStack
|
||
from typing import Dict, Optional, Union
|
||
|
||
from dotenv import load_dotenv
|
||
|
||
from qwen_agent.log import logger
|
||
from qwen_agent.tools.base import BaseTool
|
||
|
||
|
||
class CustomMCPManager:
|
||
_instance = None # Private class variable to store the unique instance
|
||
|
||
def __new__(cls, *args, **kwargs):
|
||
if cls._instance is None:
|
||
cls._instance = super(CustomMCPManager, cls).__new__(cls, *args, **kwargs)
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if not hasattr(self, 'clients'): # The singleton should only be inited once
|
||
"""Set a new event loop in a separate thread"""
|
||
try:
|
||
import mcp # noqa
|
||
except ImportError as e:
|
||
raise ImportError('Could not import mcp. Please install mcp with `pip install -U mcp`.') from e
|
||
|
||
load_dotenv() # Load environment variables from .env file
|
||
self.clients: dict = {}
|
||
self.loop = asyncio.new_event_loop()
|
||
self.loop_thread = threading.Thread(target=self.start_loop, daemon=True)
|
||
self.loop_thread.start()
|
||
|
||
# A fallback way to terminate MCP tool processes after Qwen-Agent exits
|
||
self.processes = []
|
||
self.monkey_patch_mcp_create_platform_compatible_process()
|
||
|
||
def monkey_patch_mcp_create_platform_compatible_process(self):
|
||
try:
|
||
import mcp.client.stdio
|
||
target = mcp.client.stdio._create_platform_compatible_process
|
||
except (ModuleNotFoundError, AttributeError) as e:
|
||
raise ImportError('Qwen-Agent needs to monkey patch MCP for process cleanup. '
|
||
'Please upgrade MCP to a higher version with `pip install -U mcp`.') from e
|
||
|
||
async def _monkey_patched_create_platform_compatible_process(*args, **kwargs):
|
||
process = await target(*args, **kwargs)
|
||
self.processes.append(process)
|
||
return process
|
||
|
||
mcp.client.stdio._create_platform_compatible_process = _monkey_patched_create_platform_compatible_process
|
||
|
||
def start_loop(self):
|
||
asyncio.set_event_loop(self.loop)
|
||
|
||
# Set a global exception handler to silently handle cross-task exceptions from MCP SSE connections
|
||
def exception_handler(loop, context):
|
||
exception = context.get('exception')
|
||
if exception:
|
||
# Silently handle cross-task exceptions from MCP SSE connections
|
||
if (isinstance(exception, RuntimeError) and
|
||
'Attempted to exit cancel scope in a different task' in str(exception)):
|
||
return # Silently ignore this type of exception
|
||
if (isinstance(exception, BaseExceptionGroup) and # noqa
|
||
'Attempted to exit cancel scope in a different task' in str(exception)): # noqa
|
||
return # Silently ignore this type of exception
|
||
|
||
# Other exceptions are handled normally
|
||
loop.default_exception_handler(context)
|
||
|
||
self.loop.set_exception_handler(exception_handler)
|
||
self.loop.run_forever()
|
||
|
||
def is_valid_mcp_servers(self, config: dict):
|
||
"""Example of mcp servers configuration:
|
||
{
|
||
"mcpServers": {
|
||
"memory": {
|
||
"command": "npx",
|
||
"args": ["-y", "@modelcontextprotocol/server-memory"]
|
||
},
|
||
"filesystem": {
|
||
"command": "npx",
|
||
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/allowed/files"]
|
||
}
|
||
}
|
||
}
|
||
"""
|
||
|
||
# Check if the top-level key "mcpServers" exists and its value is a dictionary
|
||
if not isinstance(config, dict) or 'mcpServers' not in config or not isinstance(config['mcpServers'], dict):
|
||
return False
|
||
mcp_servers = config['mcpServers']
|
||
# Check each sub-item under "mcpServers"
|
||
for key in mcp_servers:
|
||
server = mcp_servers[key]
|
||
# Each sub-item must be a dictionary
|
||
if not isinstance(server, dict):
|
||
return False
|
||
if 'command' in server:
|
||
# "command" must be a string
|
||
if not isinstance(server['command'], str):
|
||
return False
|
||
# "args" must be a list
|
||
if 'args' not in server or not isinstance(server['args'], list):
|
||
return False
|
||
if 'url' in server:
|
||
# "url" must be a string
|
||
if not isinstance(server['url'], str):
|
||
return False
|
||
# "headers" must be a dictionary
|
||
if 'headers' in server and not isinstance(server['headers'], dict):
|
||
return False
|
||
# If the "env" key exists, it must be a dictionary
|
||
if 'env' in server and not isinstance(server['env'], dict):
|
||
return False
|
||
return True
|
||
|
||
def initConfig(self, config: Dict):
|
||
if not self.is_valid_mcp_servers(config):
|
||
raise ValueError('Config of mcpservers is not valid')
|
||
logger.info(f'Initializing MCP tools from mcp servers: {list(config["mcpServers"].keys())}')
|
||
# Submit coroutine to the event loop and wait for the result
|
||
future = asyncio.run_coroutine_threadsafe(self.init_config_async(config), self.loop)
|
||
try:
|
||
result = future.result() # You can specify a timeout if desired
|
||
return result
|
||
except Exception as e:
|
||
logger.info(f'Failed in initializing MCP tools: {e}')
|
||
raise e
|
||
|
||
async def init_config_async(self, config: Dict):
|
||
tools: list = []
|
||
mcp_servers = config['mcpServers']
|
||
|
||
# 并发连接所有MCP服务器
|
||
connection_tasks = []
|
||
for server_name in mcp_servers:
|
||
client = CustomMCPClient()
|
||
server = mcp_servers[server_name]
|
||
# 创建连接任务
|
||
task = self._connect_and_store_client(client, server_name, server)
|
||
connection_tasks.append(task)
|
||
|
||
# 并发执行所有连接任务
|
||
connected_clients = await asyncio.gather(*connection_tasks, return_exceptions=True)
|
||
|
||
# 处理连接结果并为每个客户端创建工具
|
||
for result in connected_clients:
|
||
if isinstance(result, Exception):
|
||
logger.error(f'Failed to connect MCP server: {result}')
|
||
continue
|
||
|
||
client, server_name = result
|
||
client_id = client.client_id
|
||
for tool in client.tools:
|
||
"""MCP tool example:
|
||
{
|
||
"name": "read_query",
|
||
"description": "Execute a SELECT query on the SQLite database",
|
||
"inputSchema": {
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {
|
||
"type": "string",
|
||
"description": "SELECT SQL query to execute"
|
||
}
|
||
},
|
||
"required": ["query"]
|
||
}
|
||
"""
|
||
parameters = tool.inputSchema
|
||
# The required field in inputSchema may be empty and needs to be initialized.
|
||
if 'required' not in parameters:
|
||
parameters['required'] = []
|
||
# Remove keys from parameters that do not conform to the standard OpenAI schema
|
||
# Check if the required fields exist
|
||
required_fields = {'type', 'properties', 'required'}
|
||
missing_fields = required_fields - parameters.keys()
|
||
if missing_fields:
|
||
raise ValueError(f'Missing required fields in schema: {missing_fields}')
|
||
|
||
# Keep only the necessary fields
|
||
cleaned_parameters = {
|
||
'type': parameters['type'],
|
||
'properties': parameters['properties'],
|
||
'required': parameters['required']
|
||
}
|
||
register_name = server_name + '-' + tool.name
|
||
agent_tool = self.create_tool_class(register_name=register_name,
|
||
register_client_id=client_id,
|
||
tool_name=tool.name,
|
||
tool_desc=tool.description,
|
||
tool_parameters=cleaned_parameters)
|
||
tools.append(agent_tool)
|
||
|
||
if client.resources:
|
||
"""MCP resource example:
|
||
{
|
||
uri: string; // Unique identifier for the resource
|
||
name: string; // Human-readable name
|
||
description?: string; // Optional description
|
||
mimeType?: string; // Optional MIME type
|
||
}
|
||
"""
|
||
# List resources
|
||
list_resources_tool_name = server_name + '-' + 'list_resources'
|
||
list_resources_params = {'type': 'object', 'properties': {}, 'required': []}
|
||
list_resources_agent_tool = self.create_tool_class(
|
||
register_name=list_resources_tool_name,
|
||
register_client_id=client_id,
|
||
tool_name='list_resources',
|
||
tool_desc='Servers expose a list of concrete resources through this tool. '
|
||
'By invoking it, you can discover the available resources and obtain resource templates, which help clients understand how to construct valid URIs. '
|
||
'These URI formats will be used as input parameters for the read_resource function. ',
|
||
tool_parameters=list_resources_params)
|
||
tools.append(list_resources_agent_tool)
|
||
|
||
# Read resource
|
||
resources_template_str = '' # Check if there are resource templates
|
||
try:
|
||
list_resource_templates = await client.session.list_resource_templates(
|
||
) # Check if the server has resources tesmplate
|
||
if list_resource_templates.resourceTemplates:
|
||
resources_template_str = '\n'.join(
|
||
str(template) for template in list_resource_templates.resourceTemplates)
|
||
|
||
except Exception as e:
|
||
logger.info(f'Failed in listing MCP resource templates: {e}')
|
||
|
||
read_resource_tool_name = server_name + '-' + 'read_resource'
|
||
read_resource_params = {
|
||
'type': 'object',
|
||
'properties': {
|
||
'uri': {
|
||
'type': 'string',
|
||
'description': 'The URI identifying the specific resource to access'
|
||
}
|
||
},
|
||
'required': ['uri']
|
||
}
|
||
original_tool_desc = 'Request to access a resource provided by a connected MCP server. Resources represent data sources that can be used as context, such as files, API responses, or system information.'
|
||
if resources_template_str:
|
||
tool_desc = original_tool_desc + '\nResource Templates:\n' + resources_template_str
|
||
else:
|
||
tool_desc = original_tool_desc
|
||
read_resource_agent_tool = self.create_tool_class(register_name=read_resource_tool_name,
|
||
register_client_id=client_id,
|
||
tool_name='read_resource',
|
||
tool_desc=tool_desc,
|
||
tool_parameters=read_resource_params)
|
||
tools.append(read_resource_agent_tool)
|
||
|
||
return tools
|
||
|
||
async def _connect_and_store_client(self, client, server_name, server):
|
||
"""辅助方法:连接MCP服务器并存储客户端"""
|
||
try:
|
||
await client.connection_server(mcp_server_name=server_name,
|
||
mcp_server=server) # Attempt to connect to the server
|
||
|
||
client_id = server_name + '_' + str(
|
||
uuid.uuid4()) # To allow the same server name be used across different running agents
|
||
client.client_id = client_id # Ensure client_id is set on the client instance
|
||
self.clients[client_id] = client # Add to clients dict after successful connection
|
||
return client, server_name
|
||
except Exception as e:
|
||
logger.error(f'Failed to connect MCP server {server_name}: {e}')
|
||
raise e
|
||
|
||
def create_tool_class(self, register_name, register_client_id, tool_name, tool_desc, tool_parameters):
|
||
|
||
class ToolClass(BaseTool):
|
||
name = register_name
|
||
description = tool_desc
|
||
parameters = tool_parameters
|
||
client_id = register_client_id
|
||
|
||
def call(self, params: Union[str, dict], **kwargs) -> str:
|
||
tool_args = json.loads(params)
|
||
# Submit coroutine to the event loop and wait for the result
|
||
manager = CustomMCPManager()
|
||
client = manager.clients[self.client_id]
|
||
future = asyncio.run_coroutine_threadsafe(client.execute_function(tool_name, tool_args), manager.loop)
|
||
try:
|
||
result = future.result()
|
||
return result
|
||
except Exception as e:
|
||
logger.info(f'Failed in executing MCP tool: {e}')
|
||
raise e
|
||
|
||
ToolClass.__name__ = f'{register_name}_Class'
|
||
return ToolClass()
|
||
|
||
def shutdown(self):
|
||
futures = []
|
||
for client_id in list(self.clients.keys()):
|
||
client: CustomMCPClient = self.clients[client_id]
|
||
future = asyncio.run_coroutine_threadsafe(client.cleanup(), self.loop)
|
||
futures.append(future)
|
||
del self.clients[client_id]
|
||
time.sleep(1) # Wait for the graceful cleanups, otherwise fall back
|
||
|
||
# fallback
|
||
if asyncio.all_tasks(self.loop):
|
||
logger.info(
|
||
'There are still tasks in `CustomMCPManager().loop`, force terminating the MCP tool processes. There may be some exceptions.'
|
||
)
|
||
for process in self.processes:
|
||
try:
|
||
process.terminate()
|
||
except ProcessLookupError:
|
||
pass # it's ok, the process may exit earlier
|
||
|
||
self.loop.call_soon_threadsafe(self.loop.stop)
|
||
self.loop_thread.join()
|
||
|
||
|
||
class CustomMCPClient:
|
||
|
||
def __init__(self):
|
||
from mcp import ClientSession
|
||
self.session: Optional[ClientSession] = None
|
||
self.tools: list = None
|
||
self.exit_stack = AsyncExitStack()
|
||
self.resources: bool = False
|
||
self._last_mcp_server_name = None
|
||
self._last_mcp_server = None
|
||
self.client_id = None # For replacing in MCPManager.clients
|
||
|
||
async def connection_server(self, mcp_server_name, mcp_server):
|
||
from mcp import ClientSession, StdioServerParameters
|
||
from mcp.client.sse import sse_client
|
||
from mcp.client.stdio import stdio_client
|
||
from mcp.client.streamable_http import streamablehttp_client
|
||
"""Connect to an MCP server and retrieve the available tools."""
|
||
# Save parameters
|
||
self._last_mcp_server_name = mcp_server_name
|
||
self._last_mcp_server = mcp_server
|
||
|
||
try:
|
||
if 'url' in mcp_server:
|
||
url = mcp_server.get('url')
|
||
sse_read_timeout = mcp_server.get('sse_read_timeout', 300)
|
||
logger.info(f'{mcp_server_name} sse_read_timeout: {sse_read_timeout}s')
|
||
if mcp_server.get('type', 'sse') == 'streamable-http':
|
||
# streamable-http mode
|
||
"""streamable-http mode mcp example:
|
||
{"mcpServers": {
|
||
"streamable-mcp-server": {
|
||
"type": "streamable-http",
|
||
"url":"http://0.0.0.0:8000/mcp"
|
||
}
|
||
}
|
||
}
|
||
"""
|
||
headers = mcp_server.get('headers', {})
|
||
self._streams_context = streamablehttp_client(
|
||
url=url, headers=headers, sse_read_timeout=datetime.timedelta(seconds=sse_read_timeout))
|
||
read_stream, write_stream, get_session_id = await self.exit_stack.enter_async_context(
|
||
self._streams_context)
|
||
self._session_context = ClientSession(read_stream, write_stream)
|
||
self.session = await self.exit_stack.enter_async_context(self._session_context)
|
||
else:
|
||
# sse mode
|
||
headers = mcp_server.get('headers', {'Accept': 'text/event-stream'})
|
||
self._streams_context = sse_client(url, headers, sse_read_timeout=sse_read_timeout)
|
||
streams = await self.exit_stack.enter_async_context(self._streams_context)
|
||
self._session_context = ClientSession(*streams)
|
||
self.session = await self.exit_stack.enter_async_context(self._session_context)
|
||
else:
|
||
server_params = StdioServerParameters(command=mcp_server['command'],
|
||
args=mcp_server['args'],
|
||
env=mcp_server.get('env', None))
|
||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||
self.stdio, self.write = stdio_transport
|
||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
|
||
logger.info(
|
||
f'Initializing a MCP stdio_client, if this takes forever, please check the config of this mcp server: {mcp_server_name}'
|
||
)
|
||
|
||
await self.session.initialize()
|
||
list_tools = await self.session.list_tools()
|
||
self.tools = list_tools.tools
|
||
try:
|
||
list_resources = await self.session.list_resources() # Check if the server has resources
|
||
if list_resources.resources:
|
||
self.resources = True
|
||
except Exception:
|
||
# logger.info(f"No list resources: {e}")
|
||
pass
|
||
except Exception as e:
|
||
logger.warning(f'Failed in connecting to MCP server: {e}')
|
||
raise e
|
||
|
||
async def reconnect(self):
|
||
# Create a new MCPClient and connect
|
||
if self.client_id is None:
|
||
raise RuntimeError(
|
||
'Cannot reconnect: client_id is None. This usually means the client was not properly registered in MCPManager.'
|
||
)
|
||
new_client = CustomMCPClient()
|
||
new_client.client_id = self.client_id
|
||
await new_client.connection_server(self._last_mcp_server_name, self._last_mcp_server)
|
||
return new_client
|
||
|
||
async def execute_function(self, tool_name, tool_args: dict):
|
||
from mcp.types import TextResourceContents
|
||
|
||
# Check if session is alive
|
||
try:
|
||
await self.session.send_ping()
|
||
except Exception as e:
|
||
logger.info(f"Session is not alive, please increase 'sse_read_timeout' in the config, try reconnect: {e}")
|
||
# Auto reconnect
|
||
try:
|
||
manager = CustomMCPManager()
|
||
if self.client_id is not None:
|
||
manager.clients[self.client_id] = await self.reconnect()
|
||
return await manager.clients[self.client_id].execute_function(tool_name, tool_args)
|
||
else:
|
||
logger.info('Reconnect failed: client_id is None')
|
||
return 'Session reconnect (client creation) exception: client_id is None'
|
||
except Exception as e3:
|
||
logger.info(f'Reconnect (client creation) exception type: {type(e3)}, value: {repr(e3)}')
|
||
return f'Session reconnect (client creation) exception: {e3}'
|
||
if tool_name == 'list_resources':
|
||
try:
|
||
list_resources = await self.session.list_resources()
|
||
if list_resources.resources:
|
||
resources_str = '\n\n'.join(str(resource) for resource in list_resources.resources)
|
||
else:
|
||
resources_str = 'No resources found'
|
||
return resources_str
|
||
except Exception as e:
|
||
logger.info(f'No list resources: {e}')
|
||
return f'Error: {e}'
|
||
elif tool_name == 'read_resource':
|
||
try:
|
||
uri = tool_args.get('uri')
|
||
if not uri:
|
||
raise ValueError('URI is required for read_resource')
|
||
read_resource = await self.session.read_resource(uri)
|
||
texts = []
|
||
for resource in read_resource.contents:
|
||
if isinstance(resource, TextResourceContents):
|
||
texts.append(resource.text)
|
||
# if isinstance(resource, BlobResourceContents):
|
||
# texts.append(resource.blob)
|
||
if texts:
|
||
return '\n\n'.join(texts)
|
||
else:
|
||
return 'Failed to read resource'
|
||
except Exception as e:
|
||
logger.info(f'Failed to read resource: {e}')
|
||
return f'Error: {e}'
|
||
else:
|
||
response = await self.session.call_tool(tool_name, tool_args)
|
||
texts = []
|
||
for content in response.content:
|
||
if content.type == 'text':
|
||
texts.append(content.text)
|
||
if texts:
|
||
return '\n\n'.join(texts)
|
||
else:
|
||
return 'execute error'
|
||
|
||
async def cleanup(self):
|
||
await self.exit_stack.aclose()
|
||
|
||
|
||
def _cleanup_mcp(_sig_num=None, _frame=None):
|
||
if CustomMCPManager._instance is None:
|
||
return
|
||
manager = CustomMCPManager()
|
||
manager.shutdown()
|
||
|
||
|
||
# Make sure all subprocesses are terminated even if killed abnormally
|
||
if threading.current_thread() is threading.main_thread():
|
||
atexit.register(_cleanup_mcp)
|