#!/usr/bin/env python3 """ MCP服务器通用工具函数 提供路径处理、文件验证、请求处理等公共功能 """ import json import os import sys import asyncio from typing import Any, Dict, List, Optional, Union import re def get_allowed_directory(): """获取允许访问的目录""" # 优先使用命令行参数传入的dataset_dir if len(sys.argv) > 1: dataset_dir = sys.argv[1] return os.path.abspath(dataset_dir) # 从环境变量读取项目数据目录 project_dir = os.getenv("PROJECT_DATA_DIR", "./projects") return os.path.abspath(project_dir) def resolve_file_path(file_path: str, default_subfolder: str = "default") -> str: """ 解析文件路径,支持 folder/document.txt 和 document.txt 两种格式 Args: file_path: 输入的文件路径 default_subfolder: 当只传入文件名时使用的默认子文件夹名称 Returns: 解析后的完整文件路径 """ # 如果路径包含文件夹分隔符,直接使用 if '/' in file_path or '\\' in file_path: clean_path = file_path.replace('\\', '/') # 移除 projects/ 前缀(如果存在) if clean_path.startswith('projects/'): clean_path = clean_path[9:] # 移除 'projects/' 前缀 elif clean_path.startswith('./projects/'): clean_path = clean_path[11:] # 移除 './projects/' 前缀 else: # 如果只有文件名,添加默认子文件夹 clean_path = f"{default_subfolder}/{file_path}" # 获取允许的目录 project_data_dir = get_allowed_directory() # 尝试在项目目录中查找文件 full_path = os.path.join(project_data_dir, clean_path.lstrip('./')) if os.path.exists(full_path): return full_path # 如果直接路径不存在,尝试递归查找 found = find_file_in_project(clean_path, project_data_dir) if found: return found # 如果是纯文件名且在default子文件夹中不存在,尝试在根目录查找 if '/' not in file_path and '\\' not in file_path: root_path = os.path.join(project_data_dir, file_path) if os.path.exists(root_path): return root_path raise FileNotFoundError(f"File not found: {file_path} (searched in {project_data_dir})") def find_file_in_project(filename: str, project_dir: str) -> Optional[str]: """在项目目录中递归查找文件""" # 如果filename包含路径,只搜索指定的路径 if '/' in filename: parts = filename.split('/') target_file = parts[-1] search_dir = os.path.join(project_dir, *parts[:-1]) if os.path.exists(search_dir): target_path = os.path.join(search_dir, target_file) if os.path.exists(target_path): return target_path else: # 纯文件名,递归搜索整个项目目录 for root, dirs, files in os.walk(project_dir): if filename in files: return os.path.join(root, filename) return None def load_tools_from_json(tools_file_name: str) -> List[Dict[str, Any]]: """从 JSON 文件加载工具定义""" try: tools_file = os.path.join(os.path.dirname(__file__), "tools", tools_file_name) if os.path.exists(tools_file): with open(tools_file, 'r', encoding='utf-8') as f: return json.load(f) else: # 如果 JSON 文件不存在,使用默认定义 return [] except Exception as e: print(f"Warning: Unable to load tool definition JSON file: {str(e)}") return [] def create_error_response(request_id: Any, code: int, message: str) -> Dict[str, Any]: """创建标准化的错误响应""" return { "jsonrpc": "2.0", "id": request_id, "error": { "code": code, "message": message } } def create_success_response(request_id: Any, result: Any) -> Dict[str, Any]: """创建标准化的成功响应""" return { "jsonrpc": "2.0", "id": request_id, "result": result } def create_initialize_response(request_id: Any, server_name: str, server_version: str = "1.0.0") -> Dict[str, Any]: """创建标准化的初始化响应""" return { "jsonrpc": "2.0", "id": request_id, "result": { "protocolVersion": "2024-11-05", "capabilities": { "tools": {} }, "serverInfo": { "name": server_name, "version": server_version } } } def create_ping_response(request_id: Any) -> Dict[str, Any]: """创建标准化的ping响应""" return { "jsonrpc": "2.0", "id": request_id, "result": { "pong": True } } def create_tools_list_response(request_id: Any, tools: List[Dict[str, Any]]) -> Dict[str, Any]: """创建标准化的工具列表响应""" return { "jsonrpc": "2.0", "id": request_id, "result": { "tools": tools } } def is_regex_pattern(pattern: str) -> bool: """检测字符串是否为正则表达式模式""" # 检查 /pattern/ 格式 if pattern.startswith('/') and pattern.endswith('/') and len(pattern) > 2: return True # 检查 r"pattern" 或 r'pattern' 格式 if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")) and len(pattern) > 3: return True # 检查是否包含正则特殊字符 regex_chars = {'*', '+', '?', '|', '(', ')', '[', ']', '{', '}', '^', '$', '\\', '.'} return any(char in pattern for char in regex_chars) def compile_pattern(pattern: str) -> Union[re.Pattern, str, None]: """编译正则表达式模式,如果不是正则则返回原字符串""" if not is_regex_pattern(pattern): return pattern try: # 处理 /pattern/ 格式 if pattern.startswith('/') and pattern.endswith('/'): regex_body = pattern[1:-1] return re.compile(regex_body) # 处理 r"pattern" 或 r'pattern' 格式 if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")): regex_body = pattern[2:-1] return re.compile(regex_body) # 直接编译包含正则字符的字符串 return re.compile(pattern) except re.error as e: # 如果编译失败,返回None表示无效的正则 print(f"Warning: Regular expression '{pattern}' compilation failed: {e}") return None async def handle_mcp_streaming(request_handler): """处理MCP请求的标准主循环""" try: while True: # Read from stdin line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline) if not line: break line = line.strip() if not line: continue try: request = json.loads(line) response = await request_handler(request) # Write to stdout sys.stdout.write(json.dumps(response, ensure_ascii=False) + "\n") sys.stdout.flush() except json.JSONDecodeError: error_response = { "jsonrpc": "2.0", "error": { "code": -32700, "message": "Parse error" } } sys.stdout.write(json.dumps(error_response, ensure_ascii=False) + "\n") sys.stdout.flush() except Exception as e: error_response = { "jsonrpc": "2.0", "error": { "code": -32603, "message": f"Internal error: {str(e)}" } } sys.stdout.write(json.dumps(error_response, ensure_ascii=False) + "\n") sys.stdout.flush() except KeyboardInterrupt: pass