252 lines
8.1 KiB
Python
252 lines
8.1 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
MCP服务器通用工具函数
|
||
提供路径处理、文件验证、请求处理等公共功能
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import sys
|
||
import asyncio
|
||
from typing import Any, Dict, List, Optional, Union
|
||
import re
|
||
|
||
|
||
def get_allowed_directory():
|
||
"""获取允许访问的目录"""
|
||
# 优先使用命令行参数传入的dataset_dir
|
||
if len(sys.argv) > 1:
|
||
dataset_dir = sys.argv[1]
|
||
return os.path.abspath(dataset_dir)
|
||
|
||
# 从环境变量读取项目数据目录
|
||
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects/data")
|
||
return os.path.abspath(project_dir)
|
||
|
||
|
||
def resolve_file_path(file_path: str, default_subfolder: str = "default") -> str:
|
||
"""
|
||
解析文件路径,支持 folder/document.txt 和 document.txt 两种格式
|
||
|
||
Args:
|
||
file_path: 输入的文件路径
|
||
default_subfolder: 当只传入文件名时使用的默认子文件夹名称
|
||
|
||
Returns:
|
||
解析后的完整文件路径
|
||
"""
|
||
# 如果路径包含文件夹分隔符,直接使用
|
||
if '/' in file_path or '\\' in file_path:
|
||
clean_path = file_path.replace('\\', '/')
|
||
|
||
# 移除 projects/ 前缀(如果存在)
|
||
if clean_path.startswith('projects/'):
|
||
clean_path = clean_path[9:] # 移除 'projects/' 前缀
|
||
elif clean_path.startswith('./projects/'):
|
||
clean_path = clean_path[11:] # 移除 './projects/' 前缀
|
||
else:
|
||
# 如果只有文件名,添加默认子文件夹
|
||
clean_path = f"{default_subfolder}/{file_path}"
|
||
|
||
# 获取允许的目录
|
||
project_data_dir = get_allowed_directory()
|
||
|
||
# 尝试在项目目录中查找文件
|
||
full_path = os.path.join(project_data_dir, clean_path.lstrip('./'))
|
||
if os.path.exists(full_path):
|
||
return full_path
|
||
|
||
# 如果直接路径不存在,尝试递归查找
|
||
found = find_file_in_project(clean_path, project_data_dir)
|
||
if found:
|
||
return found
|
||
|
||
# 如果是纯文件名且在default子文件夹中不存在,尝试在根目录查找
|
||
if '/' not in file_path and '\\' not in file_path:
|
||
root_path = os.path.join(project_data_dir, file_path)
|
||
if os.path.exists(root_path):
|
||
return root_path
|
||
|
||
raise FileNotFoundError(f"File not found: {file_path} (searched in {project_data_dir})")
|
||
|
||
|
||
def find_file_in_project(filename: str, project_dir: str) -> Optional[str]:
|
||
"""在项目目录中递归查找文件"""
|
||
# 如果filename包含路径,只搜索指定的路径
|
||
if '/' in filename:
|
||
parts = filename.split('/')
|
||
target_file = parts[-1]
|
||
search_dir = os.path.join(project_dir, *parts[:-1])
|
||
|
||
if os.path.exists(search_dir):
|
||
target_path = os.path.join(search_dir, target_file)
|
||
if os.path.exists(target_path):
|
||
return target_path
|
||
else:
|
||
# 纯文件名,递归搜索整个项目目录
|
||
for root, dirs, files in os.walk(project_dir):
|
||
if filename in files:
|
||
return os.path.join(root, filename)
|
||
return None
|
||
|
||
|
||
def load_tools_from_json(tools_file_name: str) -> List[Dict[str, Any]]:
|
||
"""从 JSON 文件加载工具定义"""
|
||
try:
|
||
tools_file = os.path.join(os.path.dirname(__file__), "tools", tools_file_name)
|
||
if os.path.exists(tools_file):
|
||
with open(tools_file, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
else:
|
||
# 如果 JSON 文件不存在,使用默认定义
|
||
return []
|
||
except Exception as e:
|
||
print(f"Warning: Unable to load tool definition JSON file: {str(e)}")
|
||
return []
|
||
|
||
|
||
def create_error_response(request_id: Any, code: int, message: str) -> Dict[str, Any]:
|
||
"""创建标准化的错误响应"""
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": request_id,
|
||
"error": {
|
||
"code": code,
|
||
"message": message
|
||
}
|
||
}
|
||
|
||
|
||
def create_success_response(request_id: Any, result: Any) -> Dict[str, Any]:
|
||
"""创建标准化的成功响应"""
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": request_id,
|
||
"result": result
|
||
}
|
||
|
||
|
||
def create_initialize_response(request_id: Any, server_name: str, server_version: str = "1.0.0") -> Dict[str, Any]:
|
||
"""创建标准化的初始化响应"""
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": request_id,
|
||
"result": {
|
||
"protocolVersion": "2024-11-05",
|
||
"capabilities": {
|
||
"tools": {}
|
||
},
|
||
"serverInfo": {
|
||
"name": server_name,
|
||
"version": server_version
|
||
}
|
||
}
|
||
}
|
||
|
||
|
||
def create_ping_response(request_id: Any) -> Dict[str, Any]:
|
||
"""创建标准化的ping响应"""
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": request_id,
|
||
"result": {
|
||
"pong": True
|
||
}
|
||
}
|
||
|
||
|
||
def create_tools_list_response(request_id: Any, tools: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||
"""创建标准化的工具列表响应"""
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": request_id,
|
||
"result": {
|
||
"tools": tools
|
||
}
|
||
}
|
||
|
||
|
||
def is_regex_pattern(pattern: str) -> bool:
|
||
"""检测字符串是否为正则表达式模式"""
|
||
# 检查 /pattern/ 格式
|
||
if pattern.startswith('/') and pattern.endswith('/') and len(pattern) > 2:
|
||
return True
|
||
|
||
# 检查 r"pattern" 或 r'pattern' 格式
|
||
if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")) and len(pattern) > 3:
|
||
return True
|
||
|
||
# 检查是否包含正则特殊字符
|
||
regex_chars = {'*', '+', '?', '|', '(', ')', '[', ']', '{', '}', '^', '$', '\\', '.'}
|
||
return any(char in pattern for char in regex_chars)
|
||
|
||
|
||
def compile_pattern(pattern: str) -> Union[re.Pattern, str, None]:
|
||
"""编译正则表达式模式,如果不是正则则返回原字符串"""
|
||
if not is_regex_pattern(pattern):
|
||
return pattern
|
||
|
||
try:
|
||
# 处理 /pattern/ 格式
|
||
if pattern.startswith('/') and pattern.endswith('/'):
|
||
regex_body = pattern[1:-1]
|
||
return re.compile(regex_body)
|
||
|
||
# 处理 r"pattern" 或 r'pattern' 格式
|
||
if pattern.startswith(('r"', "r'")) and pattern.endswith(('"', "'")):
|
||
regex_body = pattern[2:-1]
|
||
return re.compile(regex_body)
|
||
|
||
# 直接编译包含正则字符的字符串
|
||
return re.compile(pattern)
|
||
except re.error as e:
|
||
# 如果编译失败,返回None表示无效的正则
|
||
print(f"Warning: Regular expression '{pattern}' compilation failed: {e}")
|
||
return None
|
||
|
||
|
||
async def handle_mcp_streaming(request_handler):
|
||
"""处理MCP请求的标准主循环"""
|
||
try:
|
||
while True:
|
||
# Read from stdin
|
||
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
|
||
if not line:
|
||
break
|
||
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
try:
|
||
request = json.loads(line)
|
||
response = await request_handler(request)
|
||
|
||
# Write to stdout
|
||
sys.stdout.write(json.dumps(response, ensure_ascii=False) + "\n")
|
||
sys.stdout.flush()
|
||
|
||
except json.JSONDecodeError:
|
||
error_response = {
|
||
"jsonrpc": "2.0",
|
||
"error": {
|
||
"code": -32700,
|
||
"message": "Parse error"
|
||
}
|
||
}
|
||
sys.stdout.write(json.dumps(error_response, ensure_ascii=False) + "\n")
|
||
sys.stdout.flush()
|
||
|
||
except Exception as e:
|
||
error_response = {
|
||
"jsonrpc": "2.0",
|
||
"error": {
|
||
"code": -32603,
|
||
"message": f"Internal error: {str(e)}"
|
||
}
|
||
}
|
||
sys.stdout.write(json.dumps(error_response, ensure_ascii=False) + "\n")
|
||
sys.stdout.flush()
|
||
|
||
except KeyboardInterrupt:
|
||
pass |