qwen_agent/mcp/mcp_wrapper.py
2025-10-07 12:25:41 +08:00

308 lines
9.0 KiB
Python

#!/usr/bin/env python3
"""
通用MCP工具包装器
为所有MCP工具提供目录访问控制和安全限制
"""
import os
import sys
import json
import asyncio
import subprocess
from typing import Dict, Any, Optional
from pathlib import Path
class MCPSecurityWrapper:
"""MCP安全包装器"""
def __init__(self, allowed_directory: str):
self.allowed_directory = os.path.abspath(allowed_directory)
self.project_id = os.getenv("PROJECT_ID", "default")
def validate_path(self, path: str) -> str:
"""验证路径是否在允许的目录内"""
if not os.path.isabs(path):
path = os.path.abspath(path)
# 规范化路径
path = os.path.normpath(path)
# 检查路径遍历
if ".." in path.split(os.sep):
raise ValueError(f"路径遍历攻击被阻止: {path}")
# 检查是否在允许的目录内
if not path.startswith(self.allowed_directory):
raise ValueError(f"访问被拒绝: {path} 不在允许的目录 {self.allowed_directory}")
return path
def safe_execute_command(self, command: list, cwd: Optional[str] = None) -> Dict[str, Any]:
"""安全执行命令,限制工作目录"""
if cwd is None:
cwd = self.allowed_directory
else:
cwd = self.validate_path(cwd)
# 设置环境变量限制
env = os.environ.copy()
env["PWD"] = cwd
env["PROJECT_DATA_DIR"] = self.allowed_directory
env["PROJECT_ID"] = self.project_id
try:
result = subprocess.run(
command,
cwd=cwd,
env=env,
capture_output=True,
text=True,
timeout=30 # 30秒超时
)
return {
"success": result.returncode == 0,
"stdout": result.stdout,
"stderr": result.stderr,
"returncode": result.returncode
}
except subprocess.TimeoutExpired:
return {
"success": False,
"stdout": "",
"stderr": "命令执行超时",
"returncode": -1
}
except Exception as e:
return {
"success": False,
"stdout": "",
"stderr": str(e),
"returncode": -1
}
async def handle_wrapped_request(request: Dict[str, Any], tool_name: str) -> Dict[str, Any]:
"""处理包装后的MCP请求"""
try:
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
allowed_dir = get_allowed_directory()
wrapper = MCPSecurityWrapper(allowed_dir)
if method == "initialize":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": f"{tool_name}-wrapper",
"version": "1.0.0"
}
}
}
elif method == "ping":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"pong": True
}
}
elif method == "tools/list":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"tools": get_tool_definitions(tool_name)
}
}
elif method == "tools/call":
return await execute_tool_call(wrapper, tool_name, params, request_id)
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32601,
"message": f"Unknown method: {method}"
}
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
def get_allowed_directory():
"""获取允许访问的目录"""
project_dir = os.getenv("PROJECT_DATA_DIR", "./data")
return os.path.abspath(project_dir)
def get_tool_definitions(tool_name: str) -> list:
"""根据工具名称返回工具定义"""
if tool_name == "ripgrep-wrapper":
return [
{
"name": "ripgrep_search",
"description": "在项目目录内搜索文本",
"inputSchema": {
"type": "object",
"properties": {
"pattern": {"type": "string", "description": "搜索模式"},
"path": {"type": "string", "description": "搜索路径(相对于项目目录)"},
"maxResults": {"type": "integer", "default": 100}
},
"required": ["pattern"]
}
}
]
elif tool_name == "directory-tree-wrapper":
return [
{
"name": "get_directory_tree",
"description": "获取项目目录结构",
"inputSchema": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "目录路径(相对于项目目录)"},
"max_depth": {"type": "integer", "default": 3}
}
}
}
]
else:
return []
async def execute_tool_call(wrapper: MCPSecurityWrapper, tool_name: str, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
"""执行工具调用"""
try:
if tool_name == "ripgrep-wrapper":
return await execute_ripgrep_search(wrapper, params, request_id)
elif tool_name == "directory-tree-wrapper":
return await execute_directory_tree(wrapper, params, request_id)
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32601,
"message": f"Unknown tool: {tool_name}"
}
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32603,
"message": str(e)
}
}
async def execute_ripgrep_search(wrapper: MCPSecurityWrapper, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
"""执行ripgrep搜索"""
pattern = params.get("pattern", "")
path = params.get("path", ".")
max_results = params.get("maxResults", 100)
# 验证和构建搜索路径
search_path = os.path.join(wrapper.allowed_directory, path)
search_path = wrapper.validate_path(search_path)
# 构建ripgrep命令
command = [
"rg",
"--json",
"--max-count", str(max_results),
pattern,
search_path
]
result = wrapper.safe_execute_command(command)
if result["success"]:
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"content": [
{
"type": "text",
"text": result["stdout"]
}
]
}
}
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32603,
"message": f"搜索失败: {result['stderr']}"
}
}
async def execute_directory_tree(wrapper: MCPSecurityWrapper, params: Dict[str, Any], request_id: Any) -> Dict[str, Any]:
"""执行目录树获取"""
path = params.get("path", ".")
max_depth = params.get("max_depth", 3)
# 验证和构建目录路径
dir_path = os.path.join(wrapper.allowed_directory, path)
dir_path = wrapper.validate_path(dir_path)
# 构建目录树命令
command = [
"find",
dir_path,
"-type", "d",
"-maxdepth", str(max_depth)
]
result = wrapper.safe_execute_command(command)
if result["success"]:
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"content": [
{
"type": "text",
"text": result["stdout"]
}
]
}
}
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32603,
"message": f"获取目录树失败: {result['stderr']}"
}
}