qwen_agent/utils/prompt_loader.py
2025-10-23 16:31:37 +08:00

183 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
System prompt and MCP settings loader utilities
"""
import os
import json
from typing import List, Dict, Optional
def load_system_prompt(project_dir: str, language: str = None, system_prompt: str=None, robot_type: str = "AGENT") -> str:
if robot_type == "AGENT":
return system_prompt or ""
if robot_type == "CATALOG_AGENT":
"""
优先使用项目目录的system_prompt没有才使用默认的system_prompt_default.md
Args:
project_dir: 项目目录路径
language: 语言代码,如 'zh', 'en', 'jp' 等(此参数将被忽略)
system_prompt: 可选的系统提示词,优先级高于项目配置
robot_type: 机器人类型,取值 AGENT/CATALOG_AGENT
Returns:
str: 加载到的系统提示词内容,如果都未找到则返回空字符串
"""
system_prompt_default = None
# 1. 优先读取项目目录中的system_prompt
system_prompt_file = os.path.join(project_dir, "system_prompt.md")
if os.path.exists(system_prompt_file):
try:
with open(system_prompt_file, 'r', encoding='utf-8') as f:
system_prompt_default = f.read()
print(f"Using project-specific system prompt")
except Exception as e:
print(f"Failed to load project system prompt: {str(e)}")
system_prompt_default = None
# 2. 如果项目目录没有,使用默认提示词
if not system_prompt_default:
try:
default_prompt_file = os.path.join("prompt", "system_prompt_default.md")
with open(default_prompt_file, 'r', encoding='utf-8') as f:
system_prompt_default = f.read()
print(f"Using default system prompt from prompt folder")
except Exception as e:
print(f"Failed to load default system prompt: {str(e)}")
system_prompt_default = None
readme = ""
readme_path = os.path.join(project_dir, "README.md")
if os.path.exists(readme_path):
with open(readme_path, "r", encoding="utf-8") as f:
readme = f.read().strip()
# 获取语言显示名称
language_display_map = {
'zh': '中文',
'en': 'English',
'ja': '日本語',
'jp': '日本語'
}
language_display = language_display_map.get(language, language if language else 'English')
return system_prompt_default.replace("{readme}", str(readme)).replace("{language}", language_display).replace("{extra_prompt}", system_prompt or "") or ""
else:
return system_prompt or ""
def get_available_prompt_languages() -> list:
"""
获取可用的提示词语言列表
Returns:
list: 可用语言代码列表,如 ['zh', 'en', 'jp']
"""
prompt_dir = "prompt"
available_languages = []
if os.path.exists(prompt_dir):
for filename in os.listdir(prompt_dir):
if filename.startswith("system_prompt_") and filename.endswith(".md"):
# 提取语言代码,如从 "system_prompt_zh.md" 中提取 "zh"
language = filename[len("system_prompt_"):-len(".md")]
available_languages.append(language)
return available_languages
def replace_mcp_placeholders(mcp_settings: List[Dict], dataset_dir: str) -> List[Dict]:
"""
替换 MCP 配置中的占位符
"""
if not mcp_settings or not isinstance(mcp_settings, list):
return mcp_settings
def replace_placeholders_in_obj(obj):
"""递归替换对象中的占位符"""
if isinstance(obj, dict):
for key, value in obj.items():
if key == 'args' and isinstance(value, list):
# 特别处理 args 列表
obj[key] = [item.replace('{dataset_dir}', dataset_dir) if isinstance(item, str) else item
for item in value]
elif isinstance(value, (dict, list)):
obj[key] = replace_placeholders_in_obj(value)
elif isinstance(value, str):
obj[key] = value.replace('{dataset_dir}', dataset_dir)
elif isinstance(obj, list):
return [replace_placeholders_in_obj(item) if isinstance(item, (dict, list)) else
item.replace('{dataset_dir}', dataset_dir) if isinstance(item, str) else item
for item in obj]
return obj
return replace_placeholders_in_obj(mcp_settings)
def load_mcp_settings(project_dir: str, mcp_settings: list=None) -> List[Dict]:
"""
优先使用项目目录的mcp_settings.json没有才使用默认的mcp/mcp_settings.json
Args:
project_dir: 项目目录路径
Returns:
List[Dict]: 加载到的MCP设置列表如果都未找到则返回空列表
Note:
支持在 mcp_settings.json 的 args 中使用 {dataset_dir} 占位符,
会在 init_modified_agent_service_with_files 中被替换为实际的路径。
"""
# 1. 优先读取项目目录中的mcp_settings.json
if mcp_settings is None:
mcp_settings_file = os.path.join(project_dir, "mcp_settings.json")
if os.path.exists(mcp_settings_file):
try:
with open(mcp_settings_file, 'r', encoding='utf-8') as f:
mcp_settings = json.load(f)
print(f"Using project-specific mcp_settings")
except Exception as e:
print(f"Failed to load project mcp_settings: {str(e)}")
mcp_settings = None
# 2. 如果项目目录没有使用默认MCP设置
if mcp_settings is None:
try:
default_mcp_file = os.path.join("mcp", "mcp_settings.json")
if os.path.exists(default_mcp_file):
with open(default_mcp_file, 'r', encoding='utf-8') as f:
mcp_settings = json.load(f)
print(f"Using default mcp_settings from mcp folder")
else:
mcp_settings = []
print(f"No default mcp_settings found, using empty list")
except Exception as e:
print(f"Failed to load default mcp_settings: {str(e)}")
mcp_settings = []
# 确保返回的是列表格式
if mcp_settings is None:
mcp_settings = []
elif not isinstance(mcp_settings, list):
print(f"Warning: mcp_settings is not a list, converting to list format")
mcp_settings = [mcp_settings] if mcp_settings else []
# 计算 dataset_dir 用于替换 MCP 配置中的占位符
dataset_dir = os.path.join(project_dir, "dataset")
# 替换 MCP 配置中的 {dataset_dir} 占位符
mcp_settings = replace_mcp_placeholders(mcp_settings, dataset_dir)
return mcp_settings
def is_language_available(language: str) -> bool:
"""
检查指定语言的提示词是否可用
Args:
language: 语言代码
Returns:
bool: 如果可用返回True否则返回False
"""
prompt_file = os.path.join("prompt", f"system_prompt_{language}.md")
return os.path.exists(prompt_file)