mcp dataset_dir support
This commit is contained in:
parent
42a14088f8
commit
839f3c4b36
@ -15,27 +15,13 @@ from typing import Any, Dict, List, Optional, Union
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def validate_file_path(file_path: str, allowed_dir: str) -> str:
|
||||
"""验证文件路径是否在允许的目录内"""
|
||||
# 转换为绝对路径
|
||||
if not os.path.isabs(file_path):
|
||||
file_path = os.path.abspath(file_path)
|
||||
|
||||
allowed_dir = os.path.abspath(allowed_dir)
|
||||
|
||||
# 检查路径是否在允许的目录内
|
||||
if not file_path.startswith(allowed_dir):
|
||||
raise ValueError(f"Access denied: path {file_path} is not within allowed directory {allowed_dir}")
|
||||
|
||||
# 检查路径遍历攻击
|
||||
if ".." in file_path:
|
||||
raise ValueError(f"Access denied: path traversal attack detected")
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def get_allowed_directory():
|
||||
"""获取允许访问的目录"""
|
||||
# 优先使用命令行参数传入的dataset_dir
|
||||
if len(sys.argv) > 1:
|
||||
dataset_dir = sys.argv[1]
|
||||
return os.path.abspath(dataset_dir)
|
||||
|
||||
# 从环境变量读取项目数据目录
|
||||
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
|
||||
return os.path.abspath(project_dir)
|
||||
@ -700,4 +686,4 @@ async def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@ -34,6 +34,11 @@ def validate_file_path(file_path: str, allowed_dir: str) -> str:
|
||||
|
||||
def get_allowed_directory():
|
||||
"""获取允许访问的目录"""
|
||||
# 优先使用命令行参数传入的dataset_dir
|
||||
if len(sys.argv) > 1:
|
||||
dataset_dir = sys.argv[1]
|
||||
return os.path.abspath(dataset_dir)
|
||||
|
||||
# 从环境变量读取项目数据目录
|
||||
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
|
||||
return os.path.abspath(project_dir)
|
||||
|
||||
@ -13,27 +13,13 @@ import re
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
def validate_file_path(file_path: str, allowed_dir: str) -> str:
|
||||
"""验证文件路径是否在允许的目录内"""
|
||||
# 转换为绝对路径
|
||||
if not os.path.isabs(file_path):
|
||||
file_path = os.path.abspath(file_path)
|
||||
|
||||
allowed_dir = os.path.abspath(allowed_dir)
|
||||
|
||||
# 检查路径是否在允许的目录内
|
||||
if not file_path.startswith(allowed_dir):
|
||||
raise ValueError(f"Access denied: path {file_path} is not within allowed directory {allowed_dir}")
|
||||
|
||||
# 检查路径遍历攻击
|
||||
if ".." in file_path:
|
||||
raise ValueError(f"Access denied: path traversal attack detected")
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def get_allowed_directory():
|
||||
"""获取允许访问的目录"""
|
||||
# 优先使用命令行参数传入的dataset_dir
|
||||
if len(sys.argv) > 1:
|
||||
dataset_dir = sys.argv[1]
|
||||
return os.path.abspath(dataset_dir)
|
||||
|
||||
# 从环境变量读取项目数据目录
|
||||
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
|
||||
return os.path.abspath(project_dir)
|
||||
@ -1155,4 +1141,4 @@ async def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@ -48,27 +48,15 @@ def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-
|
||||
return embedder
|
||||
|
||||
|
||||
def validate_file_path(file_path: str, allowed_dir: str) -> str:
|
||||
"""验证文件路径是否在允许的目录内"""
|
||||
# 转换为绝对路径
|
||||
if not os.path.isabs(file_path):
|
||||
file_path = os.path.abspath(file_path)
|
||||
|
||||
allowed_dir = os.path.abspath(allowed_dir)
|
||||
|
||||
# 检查路径是否在允许的目录内
|
||||
if not file_path.startswith(allowed_dir):
|
||||
raise ValueError(f"Access denied: path {file_path} is not within allowed directory {allowed_dir}")
|
||||
|
||||
# 检查路径遍历攻击
|
||||
if ".." in file_path:
|
||||
raise ValueError(f"Access denied: path traversal attack detected")
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def get_allowed_directory():
|
||||
"""获取允许访问的目录"""
|
||||
# 优先使用命令行参数传入的dataset_dir
|
||||
if len(sys.argv) > 1:
|
||||
dataset_dir = sys.argv[1]
|
||||
return os.path.abspath(dataset_dir)
|
||||
|
||||
# 从环境变量读取项目数据目录
|
||||
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
|
||||
return os.path.abspath(project_dir)
|
||||
|
||||
@ -71,9 +71,7 @@ def generate_directory_tree(project_dir: str, unique_id: str, max_depth: int = 3
|
||||
tree_lines = []
|
||||
|
||||
if not os.path.exists(dataset_dir):
|
||||
return "dataset/\n└── [No dataset directory found]"
|
||||
|
||||
tree_lines.append("dataset/")
|
||||
return "└── [No dataset directory found]"
|
||||
|
||||
try:
|
||||
entries = sorted(os.listdir(dataset_dir))
|
||||
@ -347,4 +345,4 @@ def get_project_stats(unique_id: str) -> Dict:
|
||||
stats["embedding_files_count"] = len(embedding_files)
|
||||
stats["embedding_files_detail"] = embedding_files
|
||||
|
||||
return stats
|
||||
return stats
|
||||
|
||||
Loading…
Reference in New Issue
Block a user