265 lines
10 KiB
Python
265 lines
10 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
System prompt and MCP settings loader utilities
|
||
"""
|
||
import os
|
||
import json
|
||
import asyncio
|
||
from typing import List, Dict, Optional, Any
|
||
|
||
|
||
def safe_replace(text: str, placeholder: str, value: Any) -> str:
|
||
"""
|
||
安全的字符串替换函数,确保 value 被转换为字符串
|
||
|
||
Args:
|
||
text: 原始文本
|
||
placeholder: 要替换的占位符(如 '{user_identifier}')
|
||
value: 用于替换的值(可以是任意类型)
|
||
|
||
Returns:
|
||
str: 替换后的文本
|
||
"""
|
||
if not isinstance(text, str):
|
||
text = str(text)
|
||
|
||
# 如果占位符为空,不进行替换
|
||
if not placeholder:
|
||
return text
|
||
|
||
# 将 value 转换为字符串,处理 None 等特殊情况
|
||
if value is None:
|
||
replacement = ""
|
||
else:
|
||
replacement = str(value)
|
||
|
||
return text.replace(placeholder, replacement)
|
||
|
||
|
||
async def load_system_prompt_async(project_dir: str, language: str = None, system_prompt: str=None, robot_type: str = "agent", bot_id: str="", user_identifier: str = "") -> str:
|
||
"""异步版本的系统prompt加载
|
||
|
||
Args:
|
||
project_dir: 项目目录路径,可以为None
|
||
language: 语言代码,如 'zh', 'en', 'jp' 等
|
||
system_prompt: 可选的系统提示词,优先级高于项目配置
|
||
robot_type: 机器人类型,取值 agent/catalog_agent
|
||
bot_id: 机器人ID
|
||
user_identifier: 用户标识符
|
||
|
||
Returns:
|
||
str: 加载到的系统提示词内容
|
||
"""
|
||
from .config_cache import config_cache
|
||
|
||
# 获取语言显示名称
|
||
language_display_map = {
|
||
'zh': '中文',
|
||
'en': 'English',
|
||
'ja': '日本語',
|
||
'jp': '日本語'
|
||
}
|
||
language_display = language_display_map.get(language, language if language else 'English')
|
||
|
||
# 如果存在{language} 占位符,那么就直接使用 system_prompt
|
||
if system_prompt and "{language}" in system_prompt:
|
||
prompt = system_prompt
|
||
prompt = safe_replace(prompt, "{language}", language_display)
|
||
prompt = safe_replace(prompt, '{bot_id}', bot_id)
|
||
prompt = safe_replace(prompt, '{user_identifier}', user_identifier)
|
||
return prompt or ""
|
||
elif robot_type == "agent" or robot_type == "catalog_agent":
|
||
"""
|
||
优先使用项目目录的README.md,没有才使用默认的system_prompt_{robot_type}.md
|
||
"""
|
||
system_prompt_default = None
|
||
|
||
try:
|
||
# 使用缓存读取默认prompt文件
|
||
default_prompt_file = os.path.join("prompt", f"system_prompt_{robot_type}.md")
|
||
system_prompt_default = await config_cache.get_text_file(default_prompt_file)
|
||
if system_prompt_default:
|
||
print(f"Using cached default system prompt for {robot_type} from prompt folder")
|
||
except Exception as e:
|
||
print(f"Failed to load default system prompt for {robot_type}: {str(e)}")
|
||
system_prompt_default = None
|
||
|
||
readme = ""
|
||
# 只有当 project_dir 不为 None 时才尝试读取 README.md
|
||
if project_dir is not None:
|
||
readme_path = os.path.join(project_dir, "README.md")
|
||
readme = await config_cache.get_text_file(readme_path) or ""
|
||
if system_prompt_default:
|
||
system_prompt_default = safe_replace(system_prompt_default, "{readme}", str(readme))
|
||
|
||
prompt = system_prompt_default or ""
|
||
prompt = safe_replace(prompt, "{language}", language_display)
|
||
prompt = safe_replace(prompt, "{extra_prompt}", system_prompt or "")
|
||
prompt = safe_replace(prompt, '{bot_id}', bot_id)
|
||
prompt = safe_replace(prompt, '{user_identifier}', user_identifier)
|
||
return prompt or ""
|
||
else:
|
||
prompt = system_prompt
|
||
prompt = safe_replace(prompt, "{language}", language_display)
|
||
prompt = safe_replace(prompt, '{bot_id}', bot_id)
|
||
prompt = safe_replace(prompt, '{user_identifier}', user_identifier)
|
||
return prompt or ""
|
||
|
||
|
||
def load_system_prompt(project_dir: str, language: str = None, system_prompt: str=None, robot_type: str = "agent", bot_id: str="", user_identifier: str = "") -> str:
|
||
"""同步版本的系统prompt加载,内部调用异步版本以保持向后兼容"""
|
||
try:
|
||
loop = asyncio.get_event_loop()
|
||
except RuntimeError:
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
|
||
try:
|
||
return loop.run_until_complete(
|
||
load_system_prompt_async(project_dir, language, system_prompt, robot_type, bot_id, user_identifier)
|
||
)
|
||
finally:
|
||
if loop.is_running():
|
||
pass
|
||
else:
|
||
loop.close()
|
||
|
||
|
||
|
||
def replace_mcp_placeholders(mcp_settings: List[Dict], dataset_dir: str, bot_id: str) -> List[Dict]:
|
||
"""
|
||
替换 MCP 配置中的占位符
|
||
"""
|
||
if not mcp_settings or not isinstance(mcp_settings, list):
|
||
return mcp_settings
|
||
|
||
def replace_placeholders_in_obj(obj):
|
||
"""递归替换对象中的占位符"""
|
||
if isinstance(obj, dict):
|
||
for key, value in obj.items():
|
||
if key == 'args' and isinstance(value, list):
|
||
# 特别处理 args 列表
|
||
obj[key] = [safe_replace(safe_replace(item, '{dataset_dir}', dataset_dir), '{bot_id}', bot_id) if isinstance(item, str) else item
|
||
for item in value]
|
||
elif isinstance(value, (dict, list)):
|
||
obj[key] = replace_placeholders_in_obj(value)
|
||
elif isinstance(value, str):
|
||
obj[key] = safe_replace(value, '{dataset_dir}', dataset_dir)
|
||
obj[key] = safe_replace(obj[key], '{bot_id}', bot_id)
|
||
elif isinstance(obj, list):
|
||
return [replace_placeholders_in_obj(item) if isinstance(item, (dict, list)) else
|
||
safe_replace(safe_replace(item, '{dataset_dir}', dataset_dir), '{bot_id}', bot_id) if isinstance(item, str) else item
|
||
for item in obj]
|
||
return obj
|
||
|
||
return replace_placeholders_in_obj(mcp_settings)
|
||
|
||
async def load_mcp_settings_async(project_dir: str, mcp_settings: list=None, bot_id: str="", robot_type: str = "agent") -> List[Dict]:
|
||
"""异步版本的MCP设置加载
|
||
|
||
Args:
|
||
project_dir: 项目目录路径
|
||
mcp_settings: 可选的MCP设置,将与默认设置合并
|
||
bot_id: 机器人项目ID
|
||
robot_type: 机器人类型,取值 agent/catalog_agent
|
||
|
||
Returns:
|
||
List[Dict]: 合并后的MCP设置列表
|
||
|
||
Note:
|
||
支持在 mcp_settings.json 的 args 中使用 {dataset_dir} 占位符,
|
||
会在 init_modified_agent_service_with_files 中被替换为实际的路径。
|
||
"""
|
||
from .config_cache import config_cache
|
||
|
||
# 1. 首先读取默认MCP设置
|
||
default_mcp_settings = []
|
||
try:
|
||
# 使用缓存读取默认MCP设置文件
|
||
default_mcp_file = os.path.join("mcp", f"mcp_settings_{robot_type}.json")
|
||
default_mcp_settings = await config_cache.get_json_file(default_mcp_file) or []
|
||
if default_mcp_settings:
|
||
print(f"Using cached default mcp_settings_{robot_type} from mcp folder")
|
||
else:
|
||
print(f"No default mcp_settings_{robot_type} found, using empty default settings")
|
||
except Exception as e:
|
||
print(f"Failed to load default mcp_settings_{robot_type}: {str(e)}")
|
||
default_mcp_settings = []
|
||
|
||
# 遍历mcpServers工具,给每个工具增加env参数
|
||
if default_mcp_settings and len(default_mcp_settings) > 0:
|
||
mcp_servers = default_mcp_settings[0].get('mcpServers', {})
|
||
for server_name, server_config in mcp_servers.items():
|
||
if isinstance(server_config, dict):
|
||
# 如果还没有env字段,则创建一个
|
||
if 'env' not in server_config:
|
||
server_config['env'] = {}
|
||
# 添加必要的环境变量
|
||
server_config['env']['BACKEND_HOST'] = os.environ.get('BACKEND_HOST', 'https://api-dev.gptbase.ai')
|
||
|
||
# 2. 处理传入的mcp_settings参数
|
||
input_mcp_settings = []
|
||
if mcp_settings is not None:
|
||
if isinstance(mcp_settings, list):
|
||
input_mcp_settings = mcp_settings
|
||
elif mcp_settings:
|
||
input_mcp_settings = [mcp_settings]
|
||
print(f"Warning: mcp_settings is not a list, converting to list format")
|
||
|
||
# 3. 合并默认设置和传入设置
|
||
merged_settings = []
|
||
|
||
# 如果有默认设置,以此为基准
|
||
if default_mcp_settings:
|
||
merged_settings = default_mcp_settings.copy()
|
||
|
||
# 如果有传入设置,合并mcpServers对象
|
||
if input_mcp_settings and len(input_mcp_settings) > 0 and len(merged_settings) > 0:
|
||
default_mcp_servers = merged_settings[0].get('mcpServers', {})
|
||
input_mcp_servers = input_mcp_settings[0].get('mcpServers', {})
|
||
|
||
# 合并mcpServers对象,传入的设置覆盖默认设置中相同的key
|
||
default_mcp_servers.update(input_mcp_servers)
|
||
merged_settings[0]['mcpServers'] = default_mcp_servers
|
||
print(f"Merged mcpServers: default + {len(input_mcp_servers)} input servers")
|
||
|
||
# 如果没有默认设置但有传入设置,直接使用传入设置
|
||
elif input_mcp_settings:
|
||
merged_settings = input_mcp_settings.copy()
|
||
|
||
# 确保返回的是列表格式
|
||
if not merged_settings:
|
||
merged_settings = []
|
||
elif not isinstance(merged_settings, list):
|
||
print(f"Warning: merged_settings is not a list, converting to list format")
|
||
merged_settings = [merged_settings] if merged_settings else []
|
||
|
||
# 计算 dataset_dir 用于替换 MCP 配置中的占位符
|
||
# 只有当 project_dir 不为 None 时才计算 dataset_dir
|
||
dataset_dir = os.path.join(project_dir, "dataset") if project_dir is not None else None
|
||
# 替换 MCP 配置中的 {dataset_dir} 占位符
|
||
if dataset_dir is None:
|
||
dataset_dir = ""
|
||
merged_settings = replace_mcp_placeholders(merged_settings, dataset_dir, bot_id)
|
||
return merged_settings
|
||
|
||
|
||
def load_mcp_settings(project_dir: str, mcp_settings: list=None, bot_id: str="", robot_type: str = "agent") -> List[Dict]:
|
||
"""同步版本的MCP设置加载,内部调用异步版本以保持向后兼容"""
|
||
try:
|
||
loop = asyncio.get_event_loop()
|
||
except RuntimeError:
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
|
||
try:
|
||
return loop.run_until_complete(
|
||
load_mcp_settings_async(project_dir, mcp_settings, bot_id, robot_type)
|
||
)
|
||
finally:
|
||
if loop.is_running():
|
||
pass
|
||
else:
|
||
loop.close()
|
||
|