Convert all Chinese comments, docstrings, logger/print output, HTTPException detail messages, and API response messages to English across the entire codebase. Functional zh/ja localized strings (e.g. prompt templates, timezone display names, date formats) are preserved as-is. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
251 lines
8.3 KiB
Python
251 lines
8.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Common utility functions for MCP servers
|
|
Provide shared utilities for path handling, file validation, and request processing
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import asyncio
|
|
from typing import Any, Dict, List, Optional, Union
|
|
import re
|
|
|
|
def get_allowed_directory():
|
|
"""Get the allowed directory"""
|
|
# Prefer dataset_dir passed in through command-line arguments
|
|
if len(sys.argv) > 1:
|
|
dataset_dir = sys.argv[1]
|
|
return os.path.abspath(dataset_dir)
|
|
|
|
# Read the project data directory from the environment variable
|
|
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects/data")
|
|
return os.path.abspath(project_dir)
|
|
|
|
|
|
def resolve_file_path(file_path: str, default_subfolder: str = "default") -> str:
|
|
"""
|
|
Resolve file paths, supporting both folder/document.txt and document.txt formats
|
|
|
|
Args:
|
|
file_path: input file path
|
|
default_subfolder: default subfolder name used when only a file name is provided
|
|
|
|
Returns:
|
|
resolved full file path
|
|
"""
|
|
# If the path contains a folder separator, use it directly
|
|
if '/' in file_path or '\\' in file_path:
|
|
clean_path = file_path.replace('\\', '/')
|
|
|
|
# Remove the projects/ prefix if present
|
|
if clean_path.startswith('projects/'):
|
|
clean_path = clean_path[9:] # Remove the 'projects/' prefix
|
|
elif clean_path.startswith('./projects/'):
|
|
clean_path = clean_path[11:] # Remove the './projects/' prefix
|
|
else:
|
|
# If only the file name is provided, prepend the default subfolder
|
|
clean_path = f"{default_subfolder}/{file_path}"
|
|
|
|
# Get the allowed directory
|
|
project_data_dir = get_allowed_directory()
|
|
|
|
# Try to locate the file in the project directory
|
|
full_path = os.path.join(project_data_dir, clean_path.lstrip('./'))
|
|
if os.path.exists(full_path):
|
|
return full_path
|
|
|
|
# If the direct path does not exist, try a recursive search
|
|
found = find_file_in_project(clean_path, project_data_dir)
|
|
if found:
|
|
return found
|
|
|
|
# If this is a bare file name and it is not found in the default subfolder, try the root directory
|
|
if '/' not in file_path and '\\' not in file_path:
|
|
root_path = os.path.join(project_data_dir, file_path)
|
|
if os.path.exists(root_path):
|
|
return root_path
|
|
|
|
raise FileNotFoundError(f"File not found: {file_path} (searched in {project_data_dir})")
|
|
|
|
|
|
def find_file_in_project(filename: str, project_dir: str) -> Optional[str]:
|
|
"""Recursively search for a file in the project directory"""
|
|
# If filename includes a path, search only that path
|
|
if '/' in filename:
|
|
parts = filename.split('/')
|
|
target_file = parts[-1]
|
|
search_dir = os.path.join(project_dir, *parts[:-1])
|
|
|
|
if os.path.exists(search_dir):
|
|
target_path = os.path.join(search_dir, target_file)
|
|
if os.path.exists(target_path):
|
|
return target_path
|
|
else:
|
|
# If this is a bare file name, recursively search the entire project directory
|
|
for root, dirs, files in os.walk(project_dir):
|
|
if filename in files:
|
|
return os.path.join(root, filename)
|
|
return None
|
|
|
|
|
|
def load_tools_from_json(tools_file_name: str) -> List[Dict[str, Any]]:
|
|
"""Load tool definitions from a JSON file"""
|
|
try:
|
|
tools_file = os.path.join(os.path.dirname(__file__), "tools", tools_file_name)
|
|
if os.path.exists(tools_file):
|
|
with open(tools_file, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
else:
|
|
# If the JSON file does not exist, use the default definition
|
|
return []
|
|
except Exception as e:
|
|
print(f"Warning: Unable to load tool definition JSON file: {str(e)}")
|
|
return []
|
|
|
|
|
|
def create_error_response(request_id: Any, code: int, message: str) -> Dict[str, Any]:
|
|
"""Create a standardized error response"""
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"id": request_id,
|
|
"error": {
|
|
"code": code,
|
|
"message": message
|
|
}
|
|
}
|
|
|
|
|
|
def create_success_response(request_id: Any, result: Any) -> Dict[str, Any]:
|
|
"""Create a standardized success response"""
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"id": request_id,
|
|
"result": result
|
|
}
|
|
|
|
|
|
def create_initialize_response(request_id: Any, server_name: str, server_version: str = "1.0.0") -> Dict[str, Any]:
|
|
"""Create a standardized initialize response"""
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"id": request_id,
|
|
"result": {
|
|
"protocolVersion": "2024-11-05",
|
|
"capabilities": {
|
|
"tools": {}
|
|
},
|
|
"serverInfo": {
|
|
"name": server_name,
|
|
"version": server_version
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
def create_ping_response(request_id: Any) -> Dict[str, Any]:
|
|
"""Create a standardized ping response"""
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"id": request_id,
|
|
"result": {
|
|
"pong": True
|
|
}
|
|
}
|
|
|
|
|
|
def create_tools_list_response(request_id: Any, tools: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
"""Create a standardized tool list response"""
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"id": request_id,
|
|
"result": {
|
|
"tools": tools
|
|
}
|
|
}
|
|
|
|
|
|
def is_regex_pattern(pattern: str) -> bool:
|
|
"""Detect whether a string is a regular expression pattern"""
|
|
# Check /pattern/ format
|
|
if pattern.startswith('/') and pattern.endswith('/') and len(pattern) > 2:
|
|
return True
|
|
|
|
# Check r"pattern" or r'pattern' format
|
|
if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")) and len(pattern) > 3:
|
|
return True
|
|
|
|
# Check whether it contains regex metacharacters
|
|
regex_chars = {'*', '+', '?', '|', '(', ')', '[', ']', '{', '}', '^', '$', '\\', '.'}
|
|
return any(char in pattern for char in regex_chars)
|
|
|
|
|
|
def compile_pattern(pattern: str) -> Union[re.Pattern, str, None]:
|
|
"""Compile a regex pattern, or return the original string if it is not regex"""
|
|
if not is_regex_pattern(pattern):
|
|
return pattern
|
|
|
|
try:
|
|
# Handle /pattern/ format
|
|
if pattern.startswith('/') and pattern.endswith('/'):
|
|
regex_body = pattern[1:-1]
|
|
return re.compile(regex_body)
|
|
|
|
# Handle r"pattern" or r'pattern' format
|
|
if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")):
|
|
regex_body = pattern[2:-1]
|
|
return re.compile(regex_body)
|
|
|
|
# Directly compile strings containing regex characters
|
|
return re.compile(pattern)
|
|
except re.error as e:
|
|
# If compilation fails, return None to indicate an invalid regex
|
|
print(f"Warning: Regular expression '{pattern}' compilation failed: {e}")
|
|
return None
|
|
|
|
|
|
async def handle_mcp_streaming(request_handler):
|
|
"""Handle the standard main loop for MCP requests"""
|
|
try:
|
|
while True:
|
|
# Read from stdin
|
|
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
|
|
if not line:
|
|
break
|
|
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
try:
|
|
request = json.loads(line)
|
|
response = await request_handler(request)
|
|
|
|
# Write to stdout
|
|
sys.stdout.write(json.dumps(response, ensure_ascii=False) + "\n")
|
|
sys.stdout.flush()
|
|
|
|
except json.JSONDecodeError:
|
|
error_response = {
|
|
"jsonrpc": "2.0",
|
|
"error": {
|
|
"code": -32700,
|
|
"message": "Parse error"
|
|
}
|
|
}
|
|
sys.stdout.write(json.dumps(error_response, ensure_ascii=False) + "\n")
|
|
sys.stdout.flush()
|
|
|
|
except Exception as e:
|
|
error_response = {
|
|
"jsonrpc": "2.0",
|
|
"error": {
|
|
"code": -32603,
|
|
"message": f"Internal error: {str(e)}"
|
|
}
|
|
}
|
|
sys.stdout.write(json.dumps(error_response, ensure_ascii=False) + "\n")
|
|
sys.stdout.flush()
|
|
|
|
except KeyboardInterrupt:
|
|
pass |