mcp dataset_dir support

This commit is contained in:
朱潮 2025-10-22 22:25:59 +08:00
parent 42a14088f8
commit 839f3c4b36
5 changed files with 24 additions and 61 deletions

View File

@ -15,27 +15,13 @@ from typing import Any, Dict, List, Optional, Union
import pandas as pd import pandas as pd
def validate_file_path(file_path: str, allowed_dir: str) -> str:
"""验证文件路径是否在允许的目录内"""
# 转换为绝对路径
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
allowed_dir = os.path.abspath(allowed_dir)
# 检查路径是否在允许的目录内
if not file_path.startswith(allowed_dir):
raise ValueError(f"Access denied: path {file_path} is not within allowed directory {allowed_dir}")
# 检查路径遍历攻击
if ".." in file_path:
raise ValueError(f"Access denied: path traversal attack detected")
return file_path
def get_allowed_directory(): def get_allowed_directory():
"""获取允许访问的目录""" """获取允许访问的目录"""
# 优先使用命令行参数传入的dataset_dir
if len(sys.argv) > 1:
dataset_dir = sys.argv[1]
return os.path.abspath(dataset_dir)
# 从环境变量读取项目数据目录 # 从环境变量读取项目数据目录
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects") project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
return os.path.abspath(project_dir) return os.path.abspath(project_dir)
@ -700,4 +686,4 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@ -34,6 +34,11 @@ def validate_file_path(file_path: str, allowed_dir: str) -> str:
def get_allowed_directory(): def get_allowed_directory():
"""获取允许访问的目录""" """获取允许访问的目录"""
# 优先使用命令行参数传入的dataset_dir
if len(sys.argv) > 1:
dataset_dir = sys.argv[1]
return os.path.abspath(dataset_dir)
# 从环境变量读取项目数据目录 # 从环境变量读取项目数据目录
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects") project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
return os.path.abspath(project_dir) return os.path.abspath(project_dir)

View File

@ -13,27 +13,13 @@ import re
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
def validate_file_path(file_path: str, allowed_dir: str) -> str:
"""验证文件路径是否在允许的目录内"""
# 转换为绝对路径
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
allowed_dir = os.path.abspath(allowed_dir)
# 检查路径是否在允许的目录内
if not file_path.startswith(allowed_dir):
raise ValueError(f"Access denied: path {file_path} is not within allowed directory {allowed_dir}")
# 检查路径遍历攻击
if ".." in file_path:
raise ValueError(f"Access denied: path traversal attack detected")
return file_path
def get_allowed_directory(): def get_allowed_directory():
"""获取允许访问的目录""" """获取允许访问的目录"""
# 优先使用命令行参数传入的dataset_dir
if len(sys.argv) > 1:
dataset_dir = sys.argv[1]
return os.path.abspath(dataset_dir)
# 从环境变量读取项目数据目录 # 从环境变量读取项目数据目录
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects") project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
return os.path.abspath(project_dir) return os.path.abspath(project_dir)
@ -1155,4 +1141,4 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@ -48,27 +48,15 @@ def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-
return embedder return embedder
def validate_file_path(file_path: str, allowed_dir: str) -> str:
"""验证文件路径是否在允许的目录内"""
# 转换为绝对路径
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
allowed_dir = os.path.abspath(allowed_dir)
# 检查路径是否在允许的目录内
if not file_path.startswith(allowed_dir):
raise ValueError(f"Access denied: path {file_path} is not within allowed directory {allowed_dir}")
# 检查路径遍历攻击
if ".." in file_path:
raise ValueError(f"Access denied: path traversal attack detected")
return file_path
def get_allowed_directory(): def get_allowed_directory():
"""获取允许访问的目录""" """获取允许访问的目录"""
# 优先使用命令行参数传入的dataset_dir
if len(sys.argv) > 1:
dataset_dir = sys.argv[1]
return os.path.abspath(dataset_dir)
# 从环境变量读取项目数据目录 # 从环境变量读取项目数据目录
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects") project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
return os.path.abspath(project_dir) return os.path.abspath(project_dir)

View File

@ -71,9 +71,7 @@ def generate_directory_tree(project_dir: str, unique_id: str, max_depth: int = 3
tree_lines = [] tree_lines = []
if not os.path.exists(dataset_dir): if not os.path.exists(dataset_dir):
return "dataset/\n└── [No dataset directory found]" return "└── [No dataset directory found]"
tree_lines.append("dataset/")
try: try:
entries = sorted(os.listdir(dataset_dir)) entries = sorted(os.listdir(dataset_dir))
@ -347,4 +345,4 @@ def get_project_stats(unique_id: str) -> Dict:
stats["embedding_files_count"] = len(embedding_files) stats["embedding_files_count"] = len(embedding_files)
stats["embedding_files_detail"] = embedding_files stats["embedding_files_detail"] = embedding_files
return stats return stats