From 839f3c4b36887811e6efae5b4d347cec781afb1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Wed, 22 Oct 2025 22:25:59 +0800 Subject: [PATCH] mcp dataset_dir support --- mcp/excel_csv_operator_server.py | 26 ++++++-------------------- mcp/json_reader_server.py | 5 +++++ mcp/multi_keyword_search_server.py | 26 ++++++-------------------- mcp/semantic_search_server.py | 22 +++++----------------- utils/project_manager.py | 6 ++---- 5 files changed, 24 insertions(+), 61 deletions(-) diff --git a/mcp/excel_csv_operator_server.py b/mcp/excel_csv_operator_server.py index 667817c..36d676b 100644 --- a/mcp/excel_csv_operator_server.py +++ b/mcp/excel_csv_operator_server.py @@ -15,27 +15,13 @@ from typing import Any, Dict, List, Optional, Union import pandas as pd -def validate_file_path(file_path: str, allowed_dir: str) -> str: - """验证文件路径是否在允许的目录内""" - # 转换为绝对路径 - if not os.path.isabs(file_path): - file_path = os.path.abspath(file_path) - - allowed_dir = os.path.abspath(allowed_dir) - - # 检查路径是否在允许的目录内 - if not file_path.startswith(allowed_dir): - raise ValueError(f"Access denied: path {file_path} is not within allowed directory {allowed_dir}") - - # 检查路径遍历攻击 - if ".." in file_path: - raise ValueError(f"Access denied: path traversal attack detected") - - return file_path - - def get_allowed_directory(): """获取允许访问的目录""" + # 优先使用命令行参数传入的dataset_dir + if len(sys.argv) > 1: + dataset_dir = sys.argv[1] + return os.path.abspath(dataset_dir) + # 从环境变量读取项目数据目录 project_dir = os.getenv("PROJECT_DATA_DIR", "./projects") return os.path.abspath(project_dir) @@ -700,4 +686,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/mcp/json_reader_server.py b/mcp/json_reader_server.py index 52dddb1..3dce1b9 100644 --- a/mcp/json_reader_server.py +++ b/mcp/json_reader_server.py @@ -34,6 +34,11 @@ def validate_file_path(file_path: str, allowed_dir: str) -> str: def get_allowed_directory(): """获取允许访问的目录""" + # 优先使用命令行参数传入的dataset_dir + if len(sys.argv) > 1: + dataset_dir = sys.argv[1] + return os.path.abspath(dataset_dir) + # 从环境变量读取项目数据目录 project_dir = os.getenv("PROJECT_DATA_DIR", "./projects") return os.path.abspath(project_dir) diff --git a/mcp/multi_keyword_search_server.py b/mcp/multi_keyword_search_server.py index 78aa70c..e477576 100644 --- a/mcp/multi_keyword_search_server.py +++ b/mcp/multi_keyword_search_server.py @@ -13,27 +13,13 @@ import re from typing import Any, Dict, List, Optional, Union -def validate_file_path(file_path: str, allowed_dir: str) -> str: - """验证文件路径是否在允许的目录内""" - # 转换为绝对路径 - if not os.path.isabs(file_path): - file_path = os.path.abspath(file_path) - - allowed_dir = os.path.abspath(allowed_dir) - - # 检查路径是否在允许的目录内 - if not file_path.startswith(allowed_dir): - raise ValueError(f"Access denied: path {file_path} is not within allowed directory {allowed_dir}") - - # 检查路径遍历攻击 - if ".." in file_path: - raise ValueError(f"Access denied: path traversal attack detected") - - return file_path - - def get_allowed_directory(): """获取允许访问的目录""" + # 优先使用命令行参数传入的dataset_dir + if len(sys.argv) > 1: + dataset_dir = sys.argv[1] + return os.path.abspath(dataset_dir) + # 从环境变量读取项目数据目录 project_dir = os.getenv("PROJECT_DATA_DIR", "./projects") return os.path.abspath(project_dir) @@ -1155,4 +1141,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/mcp/semantic_search_server.py b/mcp/semantic_search_server.py index c989942..efe6fbb 100644 --- a/mcp/semantic_search_server.py +++ b/mcp/semantic_search_server.py @@ -48,27 +48,15 @@ def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual- return embedder -def validate_file_path(file_path: str, allowed_dir: str) -> str: - """验证文件路径是否在允许的目录内""" - # 转换为绝对路径 - if not os.path.isabs(file_path): - file_path = os.path.abspath(file_path) - - allowed_dir = os.path.abspath(allowed_dir) - - # 检查路径是否在允许的目录内 - if not file_path.startswith(allowed_dir): - raise ValueError(f"Access denied: path {file_path} is not within allowed directory {allowed_dir}") - - # 检查路径遍历攻击 - if ".." in file_path: - raise ValueError(f"Access denied: path traversal attack detected") - - return file_path def get_allowed_directory(): """获取允许访问的目录""" + # 优先使用命令行参数传入的dataset_dir + if len(sys.argv) > 1: + dataset_dir = sys.argv[1] + return os.path.abspath(dataset_dir) + # 从环境变量读取项目数据目录 project_dir = os.getenv("PROJECT_DATA_DIR", "./projects") return os.path.abspath(project_dir) diff --git a/utils/project_manager.py b/utils/project_manager.py index bd9a820..b3786de 100644 --- a/utils/project_manager.py +++ b/utils/project_manager.py @@ -71,9 +71,7 @@ def generate_directory_tree(project_dir: str, unique_id: str, max_depth: int = 3 tree_lines = [] if not os.path.exists(dataset_dir): - return "dataset/\n└── [No dataset directory found]" - - tree_lines.append("dataset/") + return "└── [No dataset directory found]" try: entries = sorted(os.listdir(dataset_dir)) @@ -347,4 +345,4 @@ def get_project_stats(unique_id: str) -> Dict: stats["embedding_files_count"] = len(embedding_files) stats["embedding_files_detail"] = embedding_files - return stats \ No newline at end of file + return stats