- Fix mem0 connection pool exhausted error with proper pooling - Convert memory operations to async tasks - Optimize docker-compose configuration - Add skill upload functionality - Reduce cache size for better performance - Update dependencies Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
632 lines
23 KiB
Python
632 lines
23 KiB
Python
import os
|
||
import re
|
||
import shutil
|
||
import zipfile
|
||
import logging
|
||
import asyncio
|
||
from typing import List, Optional
|
||
from fastapi import APIRouter, HTTPException, Query, UploadFile, File, Form
|
||
from pydantic import BaseModel
|
||
from utils.settings import SKILLS_DIR
|
||
import aiofiles
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
class SkillItem(BaseModel):
|
||
name: str
|
||
description: str
|
||
user_skill: bool = False
|
||
|
||
|
||
class SkillListResponse(BaseModel):
|
||
skills: List[SkillItem]
|
||
total: int
|
||
|
||
|
||
# ============ 安全常量 ============
|
||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB 最大上传文件大小
|
||
MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500MB 解压后最大大小
|
||
MAX_COMPRESSION_RATIO = 100 # 最大压缩比例 100:1
|
||
MAX_ZIP_ENTRIES = 1000 # zip 文件中最多文件数量
|
||
|
||
|
||
def validate_bot_id(bot_id: str) -> str:
|
||
"""验证 bot_id 格式,防止路径遍历攻击"""
|
||
if not bot_id:
|
||
raise HTTPException(status_code=400, detail="bot_id 不能为空")
|
||
|
||
# 检查路径遍历字符
|
||
if '..' in bot_id or '/' in bot_id or '\\' in bot_id:
|
||
raise HTTPException(status_code=400, detail="bot_id 包含非法字符")
|
||
|
||
# 验证 UUID 格式(可选,根据实际需求)
|
||
uuid_pattern = r'^[a-fA-F0-9-]{36}$'
|
||
if not re.match(uuid_pattern, bot_id):
|
||
logger.warning(f"bot_id 格式可能无效: {bot_id}")
|
||
|
||
return bot_id
|
||
|
||
|
||
def validate_skill_name(skill_name: str) -> str:
|
||
"""验证 skill_name 格式,防止路径遍历攻击"""
|
||
if not skill_name:
|
||
raise HTTPException(status_code=400, detail="skill_name 不能为空")
|
||
|
||
# 检查路径遍历字符
|
||
if '..' in skill_name or '/' in skill_name or '\\' in skill_name:
|
||
raise HTTPException(status_code=400, detail="skill_name 包含非法字符")
|
||
|
||
return skill_name
|
||
|
||
|
||
async def validate_upload_file_size(file: UploadFile) -> int:
|
||
"""验证上传文件大小,返回实际文件大小"""
|
||
file_size = 0
|
||
chunk_size = 8192
|
||
|
||
# 保存当前位置以便后续重置
|
||
await file.seek(0)
|
||
|
||
while chunk := await file.read(chunk_size):
|
||
file_size += len(chunk)
|
||
if file_size > MAX_FILE_SIZE:
|
||
await file.seek(0) # 重置文件指针
|
||
raise HTTPException(
|
||
status_code=413,
|
||
detail=f"文件过大,最大允许 {MAX_FILE_SIZE // (1024*1024)}MB"
|
||
)
|
||
|
||
await file.seek(0) # 重置文件指针供后续使用
|
||
return file_size
|
||
|
||
|
||
def detect_zip_has_top_level_dirs(zip_path: str) -> bool:
|
||
"""检测 zip 文件是否包含顶级目录(而非直接包含文件)
|
||
|
||
Args:
|
||
zip_path: zip 文件路径
|
||
|
||
Returns:
|
||
bool: 如果 zip 包含顶级目录则返回 True
|
||
"""
|
||
try:
|
||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||
# 获取所有顶级路径(第一层目录/文件)
|
||
top_level_paths = set()
|
||
for name in zip_ref.namelist():
|
||
# 跳过空目录项(以 / 结尾的空路径)
|
||
if not name or name == '/':
|
||
continue
|
||
# 提取顶级路径(第一层)
|
||
parts = name.split('/')
|
||
if parts[0]: # 忽略空字符串
|
||
top_level_paths.add(parts[0])
|
||
|
||
logger.info(f"Zip top-level paths: {top_level_paths}")
|
||
|
||
# 检查是否有目录(目录项以 / 结尾,或路径中包含 /)
|
||
for path in top_level_paths:
|
||
# 如果路径中包含 /,说明是目录
|
||
# 或者检查 namelist 中是否有以该路径/ 开头的项
|
||
for full_name in zip_ref.namelist():
|
||
if full_name.startswith(f"{path}/"):
|
||
return True
|
||
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Error detecting zip structure: {e}")
|
||
return False
|
||
|
||
|
||
async def safe_extract_zip(zip_path: str, extract_dir: str) -> None:
|
||
"""安全地解压 zip 文件,防止 ZipSlip 和 zip 炸弹攻击
|
||
|
||
Args:
|
||
zip_path: zip 文件路径
|
||
extract_dir: 解压目标目录
|
||
|
||
Raises:
|
||
HTTPException: 如果检测到恶意文件
|
||
"""
|
||
try:
|
||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||
# 检查文件数量
|
||
file_list = zip_ref.infolist()
|
||
if len(file_list) > MAX_ZIP_ENTRIES:
|
||
raise zipfile.BadZipFile(f"zip 文件包含过多文件: {len(file_list)}")
|
||
|
||
# 检查压缩比例和总大小
|
||
compressed_size = sum(z.file_size for z in file_list)
|
||
uncompressed_size = sum(z.compress_size for z in file_list)
|
||
|
||
if uncompressed_size > MAX_UNCOMPRESSED_SIZE:
|
||
raise zipfile.BadZipFile(
|
||
f"解压后大小 {uncompressed_size // (1024*1024)}MB 超过限制 "
|
||
f"{MAX_UNCOMPRESSED_SIZE // (1024*1024)}MB"
|
||
)
|
||
|
||
# 检查压缩比例(防止 zip 炸弹)
|
||
if compressed_size > 0:
|
||
ratio = uncompressed_size / compressed_size
|
||
if ratio > MAX_COMPRESSION_RATIO:
|
||
raise zipfile.BadZipFile(
|
||
f"压缩比例 {ratio:.1f}:1 超过限制 {MAX_COMPRESSION_RATIO}:1,"
|
||
f"可能是 zip 炸弹攻击"
|
||
)
|
||
|
||
# 规范化目标目录路径
|
||
extract_dir_real = os.path.realpath(extract_dir)
|
||
|
||
# 安全地解压每个文件
|
||
for zip_info in file_list:
|
||
# 检查路径遍历攻击
|
||
if '..' in zip_info.filename or zip_info.filename.startswith('/'):
|
||
raise zipfile.BadZipFile(
|
||
f"检测到路径遍历攻击: {zip_info.filename}"
|
||
)
|
||
|
||
# 构建完整的目标路径
|
||
target_path = os.path.realpath(os.path.join(extract_dir, zip_info.filename))
|
||
|
||
# 确保目标路径在解压目录内
|
||
if not target_path.startswith(extract_dir_real + os.sep):
|
||
if target_path != extract_dir_real: # 允许目录本身
|
||
raise zipfile.BadZipFile(
|
||
f"文件将被解压到目标目录之外: {zip_info.filename}"
|
||
)
|
||
|
||
# 检查符号链接(兼容 Python 3.8+)
|
||
# is_symlink() 方法在 Python 3.9+ 才有,使用 hasattr 兼容旧版本
|
||
is_symlink = (
|
||
hasattr(zip_info, 'is_symlink') and zip_info.is_symlink()
|
||
) or (
|
||
# 通过 external_attr 检查(兼容所有版本)
|
||
(zip_info.external_attr >> 16) & 0o170000 == 0o120000
|
||
)
|
||
if is_symlink:
|
||
raise zipfile.BadZipFile(
|
||
f"不允许符号链接: {zip_info.filename}"
|
||
)
|
||
|
||
# 解压文件(使用线程池避免阻塞)
|
||
await asyncio.to_thread(zip_ref.extract, zip_info, extract_dir)
|
||
|
||
except zipfile.BadZipFile as e:
|
||
raise HTTPException(status_code=400, detail=f"无效的 zip 文件: {str(e)}")
|
||
|
||
|
||
async def validate_and_rename_skill_folder(
|
||
extract_dir: str,
|
||
has_top_level_dirs: bool
|
||
) -> str:
|
||
"""验证并重命名解压后的 skill 文件夹
|
||
|
||
检查解压后文件夹名称是否与 SKILL.md 中的 name 匹配,
|
||
如果不匹配则重命名文件夹。
|
||
|
||
Args:
|
||
extract_dir: 解压目标目录
|
||
has_top_level_dirs: zip 是否包含顶级目录
|
||
|
||
Returns:
|
||
str: 最终的解压路径(可能因为重命名而改变)
|
||
"""
|
||
try:
|
||
if has_top_level_dirs:
|
||
# zip 包含目录,检查每个目录
|
||
for folder_name in os.listdir(extract_dir):
|
||
folder_path = os.path.join(extract_dir, folder_name)
|
||
if os.path.isdir(folder_path):
|
||
skill_md_path = os.path.join(folder_path, 'SKILL.md')
|
||
if os.path.exists(skill_md_path):
|
||
metadata = await asyncio.to_thread(
|
||
parse_skill_frontmatter, skill_md_path
|
||
)
|
||
if metadata and 'name' in metadata:
|
||
expected_name = metadata['name']
|
||
if folder_name != expected_name:
|
||
new_folder_path = os.path.join(extract_dir, expected_name)
|
||
await asyncio.to_thread(
|
||
shutil.move, folder_path, new_folder_path
|
||
)
|
||
logger.info(
|
||
f"Renamed skill folder: {folder_name} -> {expected_name}"
|
||
)
|
||
return extract_dir
|
||
else:
|
||
# zip 直接包含文件,检查当前目录的 SKILL.md
|
||
skill_md_path = os.path.join(extract_dir, 'SKILL.md')
|
||
if os.path.exists(skill_md_path):
|
||
metadata = await asyncio.to_thread(
|
||
parse_skill_frontmatter, skill_md_path
|
||
)
|
||
if metadata and 'name' in metadata:
|
||
expected_name = metadata['name']
|
||
# 获取当前文件夹名称
|
||
current_name = os.path.basename(extract_dir)
|
||
if current_name != expected_name:
|
||
parent_dir = os.path.dirname(extract_dir)
|
||
new_folder_path = os.path.join(parent_dir, expected_name)
|
||
await asyncio.to_thread(
|
||
shutil.move, extract_dir, new_folder_path
|
||
)
|
||
logger.info(
|
||
f"Renamed skill folder: {current_name} -> {expected_name}"
|
||
)
|
||
return new_folder_path
|
||
return extract_dir
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Failed to validate/rename skill folder: {e}")
|
||
# 不抛出异常,允许上传继续
|
||
return extract_dir
|
||
|
||
|
||
async def save_upload_file_async(file: UploadFile, destination: str) -> None:
|
||
"""异步保存上传文件到目标路径"""
|
||
async with aiofiles.open(destination, 'wb') as f:
|
||
chunk_size = 8192
|
||
while chunk := await file.read(chunk_size):
|
||
await f.write(chunk)
|
||
|
||
|
||
def parse_skill_frontmatter(skill_md_path: str) -> Optional[dict]:
|
||
"""Parse the YAML frontmatter from SKILL.md file
|
||
|
||
Args:
|
||
skill_md_path: Path to the SKILL.md file
|
||
|
||
Returns:
|
||
dict with 'name' and 'description' if found, None otherwise
|
||
"""
|
||
try:
|
||
with open(skill_md_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# Match YAML frontmatter between --- delimiters
|
||
frontmatter_match = re.match(r'^---\s*\n(.*?)\n---', content, re.DOTALL)
|
||
if not frontmatter_match:
|
||
logger.warning(f"No frontmatter found in {skill_md_path}")
|
||
return None
|
||
|
||
frontmatter = frontmatter_match.group(1)
|
||
metadata = {}
|
||
|
||
# Parse key: value pairs from frontmatter
|
||
for line in frontmatter.split('\n'):
|
||
line = line.strip()
|
||
if ':' in line:
|
||
key, value = line.split(':', 1)
|
||
metadata[key.strip()] = value.strip()
|
||
|
||
# Return name and description if both exist
|
||
if 'name' in metadata and 'description' in metadata:
|
||
return {
|
||
'name': metadata['name'],
|
||
'description': metadata['description']
|
||
}
|
||
|
||
logger.warning(f"Missing name or description in {skill_md_path}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error parsing {skill_md_path}: {e}")
|
||
return None
|
||
|
||
|
||
def get_official_skills(base_dir: str) -> List[SkillItem]:
|
||
"""Get all official skills from the skills directory
|
||
|
||
Args:
|
||
base_dir: Base directory of the project
|
||
|
||
Returns:
|
||
List of SkillItem objects
|
||
"""
|
||
skills = []
|
||
# Use SKILLS_DIR from settings, relative to base_dir
|
||
if os.path.isabs(SKILLS_DIR):
|
||
official_skills_dir = SKILLS_DIR
|
||
else:
|
||
official_skills_dir = os.path.join(base_dir, SKILLS_DIR)
|
||
|
||
if not os.path.exists(official_skills_dir):
|
||
logger.warning(f"Official skills directory not found: {official_skills_dir}")
|
||
return skills
|
||
|
||
for skill_name in os.listdir(official_skills_dir):
|
||
skill_path = os.path.join(official_skills_dir, skill_name)
|
||
if os.path.isdir(skill_path):
|
||
skill_md_path = os.path.join(skill_path, 'SKILL.md')
|
||
if os.path.exists(skill_md_path):
|
||
metadata = parse_skill_frontmatter(skill_md_path)
|
||
if metadata:
|
||
skills.append(SkillItem(
|
||
name=metadata['name'],
|
||
description=metadata['description'],
|
||
user_skill=False
|
||
))
|
||
logger.debug(f"Found official skill: {metadata['name']}")
|
||
|
||
return skills
|
||
|
||
|
||
def get_user_skills(base_dir: str, bot_id: str) -> List[SkillItem]:
|
||
"""Get all user uploaded skills for a specific bot
|
||
|
||
Args:
|
||
base_dir: Base directory of the project
|
||
bot_id: Bot ID to look up user skills for
|
||
|
||
Returns:
|
||
List of SkillItem objects
|
||
"""
|
||
skills = []
|
||
user_skills_dir = os.path.join(base_dir, 'projects', 'uploads', bot_id, 'skills')
|
||
|
||
if not os.path.exists(user_skills_dir):
|
||
logger.info(f"No user skills directory found for bot {bot_id}: {user_skills_dir}")
|
||
return skills
|
||
|
||
for skill_name in os.listdir(user_skills_dir):
|
||
skill_path = os.path.join(user_skills_dir, skill_name)
|
||
if os.path.isdir(skill_path):
|
||
skill_md_path = os.path.join(skill_path, 'SKILL.md')
|
||
if os.path.exists(skill_md_path):
|
||
metadata = parse_skill_frontmatter(skill_md_path)
|
||
if metadata:
|
||
skills.append(SkillItem(
|
||
name=metadata['name'],
|
||
description=metadata['description'],
|
||
user_skill=True
|
||
))
|
||
logger.debug(f"Found user skill: {metadata['name']}")
|
||
|
||
return skills
|
||
|
||
|
||
@router.get("/api/v1/skill/list", response_model=SkillListResponse)
|
||
async def list_skills(
|
||
bot_id: str = Query(..., description="Bot ID to fetch user skills for")
|
||
):
|
||
"""
|
||
Get list of all available skills (official + user uploaded)
|
||
|
||
Args:
|
||
bot_id: Bot ID to fetch user uploaded skills for
|
||
|
||
Returns:
|
||
SkillListResponse containing all skills
|
||
|
||
Notes:
|
||
- Official skills are read from the /skills directory
|
||
- User skills are read from /projects/uploads/{bot_id}/skills directory
|
||
- User skills are marked with user_skill: true
|
||
"""
|
||
try:
|
||
# Get the project base directory
|
||
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
# Get official skills
|
||
official_skills = get_official_skills(base_dir)
|
||
|
||
# Get user skills for the specific bot
|
||
user_skills = get_user_skills(base_dir, bot_id)
|
||
|
||
# Combine both lists (user skills first)
|
||
all_skills = user_skills + official_skills
|
||
|
||
logger.info(f"Found {len(official_skills)} official skills and {len(user_skills)} user skills for bot {bot_id}")
|
||
|
||
return SkillListResponse(
|
||
skills=all_skills,
|
||
total=len(all_skills)
|
||
)
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
logger.error(f"Error in list_skills: {str(e)}")
|
||
logger.error(f"Full traceback: {error_details}")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||
|
||
|
||
@router.post("/api/v1/skill/upload")
|
||
async def upload_skill(file: UploadFile = File(...), bot_id: Optional[str] = Form(None)):
|
||
"""
|
||
Skill文件上传API接口,上传zip文件到 ./projects/uploads/ 目录下并自动解压
|
||
|
||
安全改进:
|
||
- P1-001: ZipSlip 路径遍历防护 - 检查每个文件的解压路径
|
||
- P1-004: 文件大小限制 - 最大 50MB
|
||
- P1-005: Zip 炸弹防护 - 检查压缩比例(最大 100:1)和解压后大小(最大 500MB)
|
||
- P1-008: 异步 I/O - 使用 aiofiles 和 asyncio.to_thread
|
||
|
||
Args:
|
||
file: 上传的zip文件
|
||
bot_id: Bot ID,用于创建用户专属的skills目录
|
||
|
||
Returns:
|
||
dict: 包含文件路径、解压信息的响应
|
||
|
||
Notes:
|
||
- 仅支持.zip格式的skill文件
|
||
- 上传后会自动解压到 projects/uploads/{bot_id}/skills/{skill_name}/ 目录
|
||
- 文件大小限制: 50MB
|
||
- 解压后大小限制: 500MB
|
||
"""
|
||
file_path = None # 初始化以便在异常处理中使用
|
||
|
||
try:
|
||
# 验证 bot_id (P1-006 路径遍历防护)
|
||
if not bot_id:
|
||
raise HTTPException(status_code=400, detail="bot_id 不能为空")
|
||
bot_id = validate_bot_id(bot_id)
|
||
|
||
# 验证文件名
|
||
if not file.filename:
|
||
raise HTTPException(status_code=400, detail="文件名不能为空")
|
||
|
||
logger.info(f"Skill upload - bot_id: {bot_id}, filename: {file.filename}")
|
||
|
||
# 验证是否为zip文件
|
||
original_filename = file.filename
|
||
name_without_ext, file_extension = os.path.splitext(original_filename)
|
||
|
||
if file_extension.lower() != '.zip':
|
||
raise HTTPException(status_code=400, detail="仅支持上传.zip格式的skill文件")
|
||
|
||
# P1-004: 验证文件大小(异步读取,不阻塞事件循环)
|
||
file_size = await validate_upload_file_size(file)
|
||
logger.info(f"File size: {file_size // 1024}KB")
|
||
|
||
folder_name = name_without_ext
|
||
|
||
# 创建上传目录(先保存 zip 文件)
|
||
upload_dir = os.path.join("projects", "uploads", bot_id, "skill_zip")
|
||
await asyncio.to_thread(os.makedirs, upload_dir, exist_ok=True)
|
||
|
||
# 保存zip文件路径
|
||
file_path = os.path.join(upload_dir, original_filename)
|
||
|
||
# P1-008: 异步保存文件(使用 aiofiles,不阻塞事件循环)
|
||
await save_upload_file_async(file, file_path)
|
||
logger.info(f"Saved zip file: {file_path}")
|
||
|
||
# 检测 zip 文件结构:是否包含顶级目录
|
||
has_top_level_dirs = await asyncio.to_thread(
|
||
detect_zip_has_top_level_dirs, file_path
|
||
)
|
||
logger.info(f"Zip contains top-level directories: {has_top_level_dirs}")
|
||
|
||
# 根据检测结果决定解压目标目录
|
||
if has_top_level_dirs:
|
||
# zip 包含目录(如 a-skill/, b-skill/),解压到 skills/ 目录
|
||
extract_target = os.path.join("projects", "uploads", bot_id, "skills")
|
||
logger.info(f"Detected directories in zip, extracting to: {extract_target}")
|
||
else:
|
||
# zip 直接包含文件,解压到 skills/{folder_name}/ 目录
|
||
extract_target = os.path.join("projects", "uploads", bot_id, "skills", folder_name)
|
||
logger.info(f"No directories in zip, extracting to: {extract_target}")
|
||
|
||
# 使用线程池避免阻塞
|
||
await asyncio.to_thread(os.makedirs, extract_target, exist_ok=True)
|
||
|
||
# P1-001, P1-005: 安全解压(防止 ZipSlip 和 zip 炸弹)
|
||
await safe_extract_zip(file_path, extract_target)
|
||
logger.info(f"Extracted to: {extract_target}")
|
||
|
||
# 验证并重命名文件夹以匹配 SKILL.md 中的 name
|
||
final_extract_path = await validate_and_rename_skill_folder(
|
||
extract_target, has_top_level_dirs
|
||
)
|
||
|
||
# 获取最终的 skill 名称
|
||
if has_top_level_dirs:
|
||
final_skill_name = folder_name
|
||
else:
|
||
final_skill_name = os.path.basename(final_extract_path)
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"Skill文件上传并解压成功",
|
||
"file_path": file_path,
|
||
"extract_path": final_extract_path,
|
||
"original_filename": original_filename,
|
||
"skill_name": final_skill_name
|
||
}
|
||
|
||
except HTTPException:
|
||
# 清理已上传的文件
|
||
if file_path and os.path.exists(file_path):
|
||
try:
|
||
await asyncio.to_thread(os.remove, file_path)
|
||
logger.info(f"Cleaned up file: {file_path}")
|
||
except Exception as cleanup_error:
|
||
logger.error(f"Failed to cleanup file: {cleanup_error}")
|
||
raise
|
||
|
||
except Exception as e:
|
||
# 清理已上传的文件
|
||
if file_path and os.path.exists(file_path):
|
||
try:
|
||
await asyncio.to_thread(os.remove, file_path)
|
||
logger.info(f"Cleaned up file: {file_path}")
|
||
except Exception as cleanup_error:
|
||
logger.error(f"Failed to cleanup file: {cleanup_error}")
|
||
|
||
logger.error(f"Error uploading skill file: {str(e)}")
|
||
# 不暴露详细错误信息给客户端(安全考虑)
|
||
raise HTTPException(status_code=500, detail="Skill文件上传失败")
|
||
|
||
|
||
@router.delete("/api/v1/skill/remove")
|
||
async def remove_skill(
|
||
bot_id: str = Query(..., description="Bot ID"),
|
||
skill_name: str = Query(..., description="Skill name to remove")
|
||
):
|
||
"""
|
||
删除用户上传的 skill
|
||
|
||
Args:
|
||
bot_id: Bot ID
|
||
skill_name: 要删除的 skill 名称
|
||
|
||
Returns:
|
||
dict: 删除结果
|
||
|
||
Notes:
|
||
- 只能删除用户上传的 skills,不能删除官方 skills
|
||
- 删除路径: projects/uploads/{bot_id}/skills/{skill_name}
|
||
"""
|
||
try:
|
||
# 验证参数(防止路径遍历攻击)
|
||
bot_id = validate_bot_id(bot_id)
|
||
skill_name = validate_skill_name(skill_name)
|
||
|
||
logger.info(f"Skill remove - bot_id: {bot_id}, skill_name: {skill_name}")
|
||
|
||
# 构建删除目录路径
|
||
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
skill_dir = os.path.join(base_dir, "projects", "uploads", bot_id, "skills", skill_name)
|
||
|
||
# 规范化路径并确保在允许的目录内
|
||
skill_dir_real = os.path.realpath(skill_dir)
|
||
allowed_base = os.path.realpath(os.path.join(base_dir, "projects", "uploads", bot_id, "skills"))
|
||
|
||
if not skill_dir_real.startswith(allowed_base + os.sep):
|
||
raise HTTPException(status_code=403, detail="非法的删除路径")
|
||
|
||
# 检查目录是否存在
|
||
if not os.path.exists(skill_dir_real):
|
||
raise HTTPException(status_code=404, detail="Skill 不存在")
|
||
|
||
# 检查是否为目录
|
||
if not os.path.isdir(skill_dir_real):
|
||
raise HTTPException(status_code=400, detail="目标路径不是目录")
|
||
|
||
# 使用线程池删除目录(避免阻塞事件循环)
|
||
await asyncio.to_thread(shutil.rmtree, skill_dir_real)
|
||
|
||
logger.info(f"Successfully removed skill directory: {skill_dir_real}")
|
||
|
||
return {
|
||
"success": True,
|
||
"message": f"Skill '{skill_name}' 删除成功",
|
||
"bot_id": bot_id,
|
||
"skill_name": skill_name
|
||
}
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
logger.error(f"Error removing skill: {str(e)}")
|
||
logger.error(f"Full traceback: {error_details}")
|
||
raise HTTPException(status_code=500, detail="删除 Skill 失败")
|