180 lines
6.7 KiB
Python
180 lines
6.7 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
System prompt and MCP settings loader utilities
|
||
"""
|
||
import os
|
||
import json
|
||
from typing import List, Dict, Optional
|
||
|
||
|
||
def load_system_prompt(project_dir: str, language: str = None, system_prompt: str=None) -> str:
|
||
|
||
|
||
"""
|
||
优先使用项目目录的system_prompt,没有才使用默认的system_prompt_default.md
|
||
|
||
Args:
|
||
project_dir: 项目目录路径
|
||
language: 语言代码,如 'zh', 'en', 'jp' 等(此参数将被忽略)
|
||
|
||
Returns:
|
||
str: 加载到的系统提示词内容,如果都未找到则返回空字符串
|
||
"""
|
||
|
||
# 1. 优先读取项目目录中的system_prompt
|
||
if not system_prompt:
|
||
system_prompt_file = os.path.join(project_dir, "system_prompt.md")
|
||
if os.path.exists(system_prompt_file):
|
||
try:
|
||
with open(system_prompt_file, 'r', encoding='utf-8') as f:
|
||
system_prompt = f.read()
|
||
print(f"Using project-specific system prompt")
|
||
except Exception as e:
|
||
print(f"Failed to load project system prompt: {str(e)}")
|
||
system_prompt = None
|
||
|
||
# 2. 如果项目目录没有,使用默认提示词
|
||
if not system_prompt:
|
||
try:
|
||
default_prompt_file = os.path.join("prompt", "system_prompt_default.md")
|
||
with open(default_prompt_file, 'r', encoding='utf-8') as f:
|
||
system_prompt = f.read()
|
||
print(f"Using default system prompt from prompt folder")
|
||
except Exception as e:
|
||
print(f"Failed to load default system prompt: {str(e)}")
|
||
system_prompt = None
|
||
|
||
readme = ""
|
||
readme_path = os.path.join(project_dir, "README.md")
|
||
if os.path.exists(readme_path):
|
||
with open(readme_path, "r", encoding="utf-8") as f:
|
||
readme = f.read().strip()
|
||
|
||
# 获取语言显示名称
|
||
language_display_map = {
|
||
'zh': '中文',
|
||
'en': 'English',
|
||
'ja': '日本語',
|
||
'jp': '日本語'
|
||
}
|
||
language_display = language_display_map.get(language, language if language else 'English')
|
||
return system_prompt.replace("{readme}", str(readme)).replace("{language}", language_display) or ""
|
||
|
||
|
||
def get_available_prompt_languages() -> list:
|
||
"""
|
||
获取可用的提示词语言列表
|
||
|
||
Returns:
|
||
list: 可用语言代码列表,如 ['zh', 'en', 'jp']
|
||
"""
|
||
prompt_dir = "prompt"
|
||
available_languages = []
|
||
|
||
if os.path.exists(prompt_dir):
|
||
for filename in os.listdir(prompt_dir):
|
||
if filename.startswith("system_prompt_") and filename.endswith(".md"):
|
||
# 提取语言代码,如从 "system_prompt_zh.md" 中提取 "zh"
|
||
language = filename[len("system_prompt_"):-len(".md")]
|
||
available_languages.append(language)
|
||
|
||
return available_languages
|
||
|
||
|
||
def replace_mcp_placeholders(mcp_settings: List[Dict], dataset_dir: str) -> List[Dict]:
|
||
"""
|
||
替换 MCP 配置中的占位符
|
||
"""
|
||
if not mcp_settings or not isinstance(mcp_settings, list):
|
||
return mcp_settings
|
||
|
||
def replace_placeholders_in_obj(obj):
|
||
"""递归替换对象中的占位符"""
|
||
if isinstance(obj, dict):
|
||
for key, value in obj.items():
|
||
if key == 'args' and isinstance(value, list):
|
||
# 特别处理 args 列表
|
||
obj[key] = [item.replace('{dataset_dir}', dataset_dir) if isinstance(item, str) else item
|
||
for item in value]
|
||
elif isinstance(value, (dict, list)):
|
||
obj[key] = replace_placeholders_in_obj(value)
|
||
elif isinstance(value, str):
|
||
obj[key] = value.replace('{dataset_dir}', dataset_dir)
|
||
elif isinstance(obj, list):
|
||
return [replace_placeholders_in_obj(item) if isinstance(item, (dict, list)) else
|
||
item.replace('{dataset_dir}', dataset_dir) if isinstance(item, str) else item
|
||
for item in obj]
|
||
return obj
|
||
|
||
return replace_placeholders_in_obj(mcp_settings)
|
||
|
||
def load_mcp_settings(project_dir: str, mcp_settings: list=None) -> List[Dict]:
|
||
|
||
"""
|
||
优先使用项目目录的mcp_settings.json,没有才使用默认的mcp/mcp_settings.json
|
||
|
||
Args:
|
||
project_dir: 项目目录路径
|
||
|
||
Returns:
|
||
List[Dict]: 加载到的MCP设置列表,如果都未找到则返回空列表
|
||
|
||
Note:
|
||
支持在 mcp_settings.json 的 args 中使用 {dataset_dir} 占位符,
|
||
会在 init_modified_agent_service_with_files 中被替换为实际的路径。
|
||
"""
|
||
# 1. 优先读取项目目录中的mcp_settings.json
|
||
if mcp_settings is None:
|
||
mcp_settings_file = os.path.join(project_dir, "mcp_settings.json")
|
||
if os.path.exists(mcp_settings_file):
|
||
try:
|
||
with open(mcp_settings_file, 'r', encoding='utf-8') as f:
|
||
mcp_settings = json.load(f)
|
||
print(f"Using project-specific mcp_settings")
|
||
except Exception as e:
|
||
print(f"Failed to load project mcp_settings: {str(e)}")
|
||
mcp_settings = None
|
||
|
||
# 2. 如果项目目录没有,使用默认MCP设置
|
||
if mcp_settings is None:
|
||
try:
|
||
default_mcp_file = os.path.join("mcp", "mcp_settings.json")
|
||
if os.path.exists(default_mcp_file):
|
||
with open(default_mcp_file, 'r', encoding='utf-8') as f:
|
||
mcp_settings = json.load(f)
|
||
print(f"Using default mcp_settings from mcp folder")
|
||
else:
|
||
mcp_settings = []
|
||
print(f"No default mcp_settings found, using empty list")
|
||
except Exception as e:
|
||
print(f"Failed to load default mcp_settings: {str(e)}")
|
||
mcp_settings = []
|
||
|
||
# 确保返回的是列表格式
|
||
if mcp_settings is None:
|
||
mcp_settings = []
|
||
elif not isinstance(mcp_settings, list):
|
||
print(f"Warning: mcp_settings is not a list, converting to list format")
|
||
mcp_settings = [mcp_settings] if mcp_settings else []
|
||
|
||
# 计算 dataset_dir 用于替换 MCP 配置中的占位符
|
||
dataset_dir = os.path.join(project_dir, "dataset")
|
||
# 替换 MCP 配置中的 {dataset_dir} 占位符
|
||
mcp_settings = replace_mcp_placeholders(mcp_settings, dataset_dir)
|
||
print(mcp_settings)
|
||
return mcp_settings
|
||
|
||
|
||
def is_language_available(language: str) -> bool:
|
||
"""
|
||
检查指定语言的提示词是否可用
|
||
|
||
Args:
|
||
language: 语言代码
|
||
|
||
Returns:
|
||
bool: 如果可用返回True,否则返回False
|
||
"""
|
||
prompt_file = os.path.join("prompt", f"system_prompt_{language}.md")
|
||
return os.path.exists(prompt_file)
|