shell_env support

This commit is contained in:
朱潮 2026-03-24 00:12:19 +08:00
parent 29da20fa22
commit e13405ba29

View File

@ -138,30 +138,56 @@ async def load_system_prompt_async(config) -> str:
def replace_mcp_placeholders(mcp_settings: List[Dict], dataset_dir: str, bot_id: str, dataset_ids: List[str]) -> List[Dict]:
def replace_mcp_placeholders(mcp_settings: List[Dict], dataset_dir: str, bot_id: str, dataset_ids: List[str], shell_env: Optional[Dict[str, str]] = None) -> List[Dict]:
"""
替换 MCP 配置中的占位符
支持的占位符来源优先级从高到低
1. 内置变量: {dataset_dir}, {bot_id}, {dataset_ids}
2. shell_env 中的自定义环境变量
3. 系统环境变量 (os.environ)
"""
if not mcp_settings or not isinstance(mcp_settings, list):
return mcp_settings
dataset_id_str = ','.join(dataset_ids) if dataset_ids else ''
# 构建占位符映射:系统环境变量 < shell_env < 内置变量(优先级递增)
import re
placeholders = {}
placeholders.update(os.environ)
if shell_env:
placeholders.update(shell_env)
placeholders.update({
'dataset_dir': dataset_dir,
'bot_id': bot_id,
'dataset_ids': dataset_id_str,
})
def _safe_format(s: str) -> str:
"""安全地替换字符串中的占位符,未匹配的占位符保持原样"""
try:
def _replacer(match):
key = match.group(1)
return placeholders.get(key, match.group(0))
return re.sub(r'\{(\w+)\}', _replacer, s)
except Exception:
return s
def replace_placeholders_in_obj(obj):
"""递归替换对象中的占位符"""
if isinstance(obj, dict):
for key, value in obj.items():
if key == 'args' and isinstance(value, list):
# 特别处理 args 列表
obj[key] = [item.format(dataset_dir=dataset_dir, bot_id=bot_id, dataset_ids=dataset_id_str) if isinstance(item, str) else item
obj[key] = [_safe_format(item) if isinstance(item, str) else item
for item in value]
elif isinstance(value, (dict, list)):
obj[key] = replace_placeholders_in_obj(value)
elif isinstance(value, str):
obj[key] = value.format(dataset_dir=dataset_dir, bot_id=bot_id, dataset_ids=dataset_id_str)
obj[key] = _safe_format(value)
elif isinstance(obj, list):
return [replace_placeholders_in_obj(item) if isinstance(item, (dict, list)) else
item.format(dataset_dir=dataset_dir, bot_id=bot_id, dataset_ids=dataset_id_str) if isinstance(item, str) else item
_safe_format(item) if isinstance(item, str) else item
for item in obj]
return obj
@ -269,7 +295,8 @@ async def load_mcp_settings_async(config) -> List[Dict]:
# 替换 MCP 配置中的 {dataset_dir} 占位符
if dataset_dir is None:
dataset_dir = ""
merged_settings = replace_mcp_placeholders(merged_settings, dataset_dir, bot_id, dataset_ids)
shell_env = getattr(config, 'shell_env', None) or {}
merged_settings = replace_mcp_placeholders(merged_settings, dataset_dir, bot_id, dataset_ids, shell_env)
return merged_settings