#!/usr/bin/env python3 """ Common utility functions for MCP servers Provide shared utilities for path handling, file validation, and request processing """ import json import os import sys import asyncio from typing import Any, Dict, List, Optional, Union import re def get_allowed_directory(): """Get the allowed directory""" # Prefer dataset_dir passed in through command-line arguments if len(sys.argv) > 1: dataset_dir = sys.argv[1] return os.path.abspath(dataset_dir) # Read the project data directory from the environment variable project_dir = os.getenv("PROJECT_DATA_DIR", "./projects/data") return os.path.abspath(project_dir) def resolve_file_path(file_path: str, default_subfolder: str = "default") -> str: """ Resolve file paths, supporting both folder/document.txt and document.txt formats Args: file_path: input file path default_subfolder: default subfolder name used when only a file name is provided Returns: resolved full file path """ # If the path contains a folder separator, use it directly if '/' in file_path or '\\' in file_path: clean_path = file_path.replace('\\', '/') # Remove the projects/ prefix if present if clean_path.startswith('projects/'): clean_path = clean_path[9:] # Remove the 'projects/' prefix elif clean_path.startswith('./projects/'): clean_path = clean_path[11:] # Remove the './projects/' prefix else: # If only the file name is provided, prepend the default subfolder clean_path = f"{default_subfolder}/{file_path}" # Get the allowed directory project_data_dir = get_allowed_directory() # Try to locate the file in the project directory full_path = os.path.join(project_data_dir, clean_path.lstrip('./')) if os.path.exists(full_path): return full_path # If the direct path does not exist, try a recursive search found = find_file_in_project(clean_path, project_data_dir) if found: return found # If this is a bare file name and it is not found in the default subfolder, try the root directory if '/' not in file_path and '\\' not in file_path: root_path = os.path.join(project_data_dir, file_path) if os.path.exists(root_path): return root_path raise FileNotFoundError(f"File not found: {file_path} (searched in {project_data_dir})") def find_file_in_project(filename: str, project_dir: str) -> Optional[str]: """Recursively search for a file in the project directory""" # If filename includes a path, search only that path if '/' in filename: parts = filename.split('/') target_file = parts[-1] search_dir = os.path.join(project_dir, *parts[:-1]) if os.path.exists(search_dir): target_path = os.path.join(search_dir, target_file) if os.path.exists(target_path): return target_path else: # If this is a bare file name, recursively search the entire project directory for root, dirs, files in os.walk(project_dir): if filename in files: return os.path.join(root, filename) return None def load_tools_from_json(tools_file_name: str) -> List[Dict[str, Any]]: """Load tool definitions from a JSON file""" try: tools_file = os.path.join(os.path.dirname(__file__), "tools", tools_file_name) if os.path.exists(tools_file): with open(tools_file, 'r', encoding='utf-8') as f: return json.load(f) else: # If the JSON file does not exist, use the default definition return [] except Exception as e: print(f"Warning: Unable to load tool definition JSON file: {str(e)}") return [] def create_error_response(request_id: Any, code: int, message: str) -> Dict[str, Any]: """Create a standardized error response""" return { "jsonrpc": "2.0", "id": request_id, "error": { "code": code, "message": message } } def create_success_response(request_id: Any, result: Any) -> Dict[str, Any]: """Create a standardized success response""" return { "jsonrpc": "2.0", "id": request_id, "result": result } def create_initialize_response(request_id: Any, server_name: str, server_version: str = "1.0.0") -> Dict[str, Any]: """Create a standardized initialize response""" return { "jsonrpc": "2.0", "id": request_id, "result": { "protocolVersion": "2024-11-05", "capabilities": { "tools": {} }, "serverInfo": { "name": server_name, "version": server_version } } } def create_ping_response(request_id: Any) -> Dict[str, Any]: """Create a standardized ping response""" return { "jsonrpc": "2.0", "id": request_id, "result": { "pong": True } } def create_tools_list_response(request_id: Any, tools: List[Dict[str, Any]]) -> Dict[str, Any]: """Create a standardized tool list response""" return { "jsonrpc": "2.0", "id": request_id, "result": { "tools": tools } } def is_regex_pattern(pattern: str) -> bool: """Detect whether a string is a regular expression pattern""" # Check /pattern/ format if pattern.startswith('/') and pattern.endswith('/') and len(pattern) > 2: return True # Check r"pattern" or r'pattern' format if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")) and len(pattern) > 3: return True # Check whether it contains regex metacharacters regex_chars = {'*', '+', '?', '|', '(', ')', '[', ']', '{', '}', '^', '$', '\\', '.'} return any(char in pattern for char in regex_chars) def compile_pattern(pattern: str) -> Union[re.Pattern, str, None]: """Compile a regex pattern, or return the original string if it is not regex""" if not is_regex_pattern(pattern): return pattern try: # Handle /pattern/ format if pattern.startswith('/') and pattern.endswith('/'): regex_body = pattern[1:-1] return re.compile(regex_body) # Handle r"pattern" or r'pattern' format if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")): regex_body = pattern[2:-1] return re.compile(regex_body) # Directly compile strings containing regex characters return re.compile(pattern) except re.error as e: # If compilation fails, return None to indicate an invalid regex print(f"Warning: Regular expression '{pattern}' compilation failed: {e}") return None async def handle_mcp_streaming(request_handler): """Handle the standard main loop for MCP requests""" try: while True: # Read from stdin line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline) if not line: break line = line.strip() if not line: continue try: request = json.loads(line) response = await request_handler(request) # Write to stdout sys.stdout.write(json.dumps(response, ensure_ascii=False) + "\n") sys.stdout.flush() except json.JSONDecodeError: error_response = { "jsonrpc": "2.0", "error": { "code": -32700, "message": "Parse error" } } sys.stdout.write(json.dumps(error_response, ensure_ascii=False) + "\n") sys.stdout.flush() except Exception as e: error_response = { "jsonrpc": "2.0", "error": { "code": -32603, "message": f"Internal error: {str(e)}" } } sys.stdout.write(json.dumps(error_response, ensure_ascii=False) + "\n") sys.stdout.flush() except KeyboardInterrupt: pass