316 lines
13 KiB
Python
316 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
System prompt and MCP settings loader utilities
|
||
"""
|
||
import os
|
||
import json
|
||
import asyncio
|
||
from typing import List, Dict, Optional, Any
|
||
from datetime import datetime, timezone, timedelta
|
||
import logging
|
||
from utils.settings import BACKEND_HOST, MASTERKEY
|
||
logger = logging.getLogger('app')
|
||
from .plugin_hook_loader import execute_hooks, merge_skill_mcp_configs
|
||
|
||
def format_datetime_by_language(language: str) -> str:
|
||
"""
|
||
根据语言格式化当前时间字符串,以UTC时间为基准计算各时区时间
|
||
|
||
Args:
|
||
language: 语言代码,如 'zh', 'en', 'ja', 'jp' 等
|
||
|
||
Returns:
|
||
str: 格式化后的时间字符串,包含时区信息
|
||
"""
|
||
try:
|
||
# 获取当前UTC时间
|
||
utc_now = datetime.now(timezone.utc)
|
||
|
||
# 定义语言到时区的映射
|
||
language_timezone_map = {
|
||
'zh': {'offset': 8, 'name': 'CST', 'display': '北京时间'},
|
||
'ja': {'offset': 9, 'name': 'JST', 'display': '日本時間'},
|
||
'jp': {'offset': 9, 'name': 'JST', 'display': '日本時間'},
|
||
'en': {'offset': 0, 'name': 'UTC', 'display': 'UTC'}, # 默认UTC,英语用户全球化
|
||
# 可扩展其他语言...
|
||
}
|
||
|
||
# 获取语言对应的时区信息,默认使用UTC
|
||
tz_info = language_timezone_map.get(language, language_timezone_map['en'])
|
||
offset_hours = tz_info['offset']
|
||
tz_name = tz_info['name']
|
||
|
||
# 计算本地时间
|
||
local_time = utc_now + timedelta(hours=offset_hours)
|
||
|
||
# 根据语言格式化时间
|
||
if language == 'zh':
|
||
# 中文格式:2024年1月15日 14:30 (UTC+8 北京时间)
|
||
return local_time.strftime("%Y年%m月%d日 %H:%M") + f" (UTC{offset_hours:+d} {tz_info['display']})"
|
||
elif language in ['ja', 'jp']:
|
||
# 日文格式:2024年1月15日 14:30 (JST UTC+9)
|
||
return local_time.strftime("%Y年%m月%d日 %H:%M") + f" ({tz_name} UTC{offset_hours:+d})"
|
||
elif language == 'en':
|
||
# 英文格式:January 15, 2024 14:30 EST (UTC-5)
|
||
return local_time.strftime("%B %d, %Y %H:%M") + f" {tz_name} (UTC{offset_hours:+d})"
|
||
else:
|
||
# 默认格式:2024-01-15 14:30:30 (时区)
|
||
return local_time.strftime("%Y-%m-%d %H:%M:%S") + f" (UTC{offset_hours:+d})"
|
||
|
||
except Exception as e:
|
||
# 如果时区处理失败,回退到UTC时间
|
||
utc_now = datetime.now(timezone.utc)
|
||
if language == 'zh':
|
||
return utc_now.strftime("%Y年%m月%d日 %H:%M") + " UTC"
|
||
elif language in ['ja', 'jp']:
|
||
return utc_now.strftime("%Y年%m月%d日 %H:%M") + " UTC"
|
||
elif language == 'en':
|
||
return utc_now.strftime("%B %d, %Y %H:%M") + " UTC"
|
||
else:
|
||
return utc_now.strftime("%Y-%m-%d %H:%M:%S") + " UTC"
|
||
|
||
|
||
async def load_system_prompt_async(config) -> str:
|
||
"""异步版本的系统prompt加载
|
||
|
||
Args:
|
||
config: AgentConfig 对象,包含所有初始化参数
|
||
|
||
Returns:
|
||
str: 加载到的系统提示词内容
|
||
"""
|
||
from agent.config_cache import config_cache
|
||
|
||
# 从config中获取参数
|
||
project_dir = getattr(config, 'project_dir', None)
|
||
language = getattr(config, 'language', None)
|
||
system_prompt = getattr(config, 'system_prompt', None)
|
||
user_identifier = getattr(config, 'user_identifier', '')
|
||
trace_id = getattr(config, 'trace_id', '')
|
||
|
||
# 获取语言显示名称
|
||
language_display_map = {
|
||
'zh': '中文',
|
||
'en': 'English',
|
||
'ja': '日本語',
|
||
'jp': '日本語'
|
||
}
|
||
language_display = language_display_map.get(language, language if language else 'English')
|
||
|
||
# 获取格式化的时间字符串
|
||
datetime_str = format_datetime_by_language(language) if language else format_datetime_by_language('en')
|
||
|
||
system_prompt_default = ""
|
||
try:
|
||
# 使用缓存读取默认prompt文件
|
||
default_prompt_file = os.path.join("prompt", f"system_prompt.md")
|
||
system_prompt_default = await config_cache.get_text_file(default_prompt_file)
|
||
if system_prompt_default:
|
||
logger.info(f"Using cached default system prompt ")
|
||
except Exception as e:
|
||
logger.error(f"Failed to load default system prompt: {str(e)}")
|
||
system_prompt_default = ""
|
||
|
||
readme = ""
|
||
# 只有当 project_dir 不为 None 时才尝试读取 README.md
|
||
if project_dir is not None:
|
||
readme_path = os.path.join(project_dir, "README.md")
|
||
readme = await config_cache.get_text_file(readme_path) or ""
|
||
|
||
# agent_dir_path = f"~/.deepagents/{bot_id}" #agent_dir_path 其实映射的就是 project_dir目录,只是给ai看的目录路径
|
||
prompt = system_prompt_default.format(
|
||
readme=str(readme),
|
||
extra_prompt=system_prompt or "",
|
||
language=language_display,
|
||
user_identifier=user_identifier,
|
||
datetime=datetime_str,
|
||
agent_dir_path=".",
|
||
trace_id=trace_id or ""
|
||
)
|
||
|
||
# ============ 执行 PrePrompt hooks ============
|
||
hook_content = await execute_hooks('PrePrompt', config)
|
||
if hook_content:
|
||
# 将hook内容注入到prompt的末尾
|
||
prompt = f"{prompt}\n\n## Context from Skills\n\n{hook_content}"
|
||
return prompt or ""
|
||
|
||
|
||
|
||
def replace_mcp_placeholders(mcp_settings: List[Dict], dataset_dir: str, bot_id: str, dataset_ids: List[str]) -> List[Dict]:
|
||
"""
|
||
替换 MCP 配置中的占位符
|
||
"""
|
||
if not mcp_settings or not isinstance(mcp_settings, list):
|
||
return mcp_settings
|
||
|
||
def replace_placeholders_in_obj(obj):
|
||
"""递归替换对象中的占位符"""
|
||
if isinstance(obj, dict):
|
||
for key, value in obj.items():
|
||
if key == 'args' and isinstance(value, list):
|
||
# 特别处理 args 列表
|
||
obj[key] = [item.format(dataset_dir=dataset_dir, bot_id=bot_id, dataset_ids=','.join(dataset_ids)) if isinstance(item, str) else item
|
||
for item in value]
|
||
elif isinstance(value, (dict, list)):
|
||
obj[key] = replace_placeholders_in_obj(value)
|
||
elif isinstance(value, str):
|
||
obj[key] = value.format(dataset_dir=dataset_dir, bot_id=bot_id, dataset_ids=','.join(dataset_ids))
|
||
elif isinstance(obj, list):
|
||
return [replace_placeholders_in_obj(item) if isinstance(item, (dict, list)) else
|
||
item.format(dataset_dir=dataset_dir, bot_id=bot_id, dataset_ids=','.join(dataset_ids)) if isinstance(item, str) else item
|
||
for item in obj]
|
||
return obj
|
||
|
||
return replace_placeholders_in_obj(mcp_settings)
|
||
|
||
async def load_mcp_settings_async(config) -> List[Dict]:
|
||
"""异步版本的MCP设置加载
|
||
|
||
Args:
|
||
config: AgentConfig 对象,包含所有初始化参数
|
||
|
||
Returns:
|
||
List[Dict]: 合并后的MCP设置列表
|
||
|
||
Note:
|
||
支持在 mcp_settings.json 的 args 中使用 {dataset_dir} 占位符,
|
||
会在 init_modified_agent_service_with_files 中被替换为实际的路径。
|
||
"""
|
||
from agent.config_cache import config_cache
|
||
|
||
# 从config中获取参数
|
||
project_dir = getattr(config, 'project_dir', None)
|
||
mcp_settings = getattr(config, 'mcp_settings', None)
|
||
bot_id = getattr(config, 'bot_id', '')
|
||
dataset_ids = getattr(config, 'dataset_ids', [])
|
||
|
||
# 1. ============ 首先合并skill目录下的plugin.json配置(不使用缓存,确保改动生效)============
|
||
skill_mcp_settings = await merge_skill_mcp_configs(bot_id)
|
||
merged_settings = []
|
||
if skill_mcp_settings and len(skill_mcp_settings) > 0:
|
||
merged_settings = skill_mcp_settings.copy()
|
||
skill_mcp_servers = skill_mcp_settings[0].get('mcpServers', {})
|
||
logger.info(f"Loaded {len(skill_mcp_servers)} MCP servers from skills")
|
||
# ===========================================================================================
|
||
|
||
# 2. 读取默认MCP设置(使用缓存)
|
||
default_mcp_settings = []
|
||
try:
|
||
default_mcp_file = os.path.join("mcp", f"mcp_settings.json")
|
||
default_mcp_settings = await config_cache.get_json_file(default_mcp_file) or []
|
||
if default_mcp_settings:
|
||
logger.info(f"Using cached default mcp_settings from mcp folder")
|
||
except Exception as e:
|
||
logger.error(f"Failed to load default mcp_settings: {str(e)}")
|
||
default_mcp_settings = []
|
||
|
||
# 3. 合并默认设置到merged_settings(默认设置被skill覆盖)
|
||
if default_mcp_settings and len(default_mcp_settings) > 0:
|
||
default_mcp_servers = default_mcp_settings[0].get('mcpServers', {})
|
||
if merged_settings and len(merged_settings) > 0:
|
||
# skill配置已存在,将默认配置合并进去(skill优先)
|
||
skill_mcp_servers = merged_settings[0].get('mcpServers', {})
|
||
# 默认配置中不存在的才添加
|
||
for server_name, server_config in default_mcp_servers.items():
|
||
if server_name not in skill_mcp_servers:
|
||
skill_mcp_servers[server_name] = server_config
|
||
else:
|
||
# 没有skill配置,直接使用默认配置
|
||
merged_settings = default_mcp_settings.copy()
|
||
|
||
# 遍历mcpServers工具,给每个工具增加env参数
|
||
if merged_settings and len(merged_settings) > 0:
|
||
mcp_servers = merged_settings[0].get('mcpServers', {})
|
||
for server_name, server_config in mcp_servers.items():
|
||
if isinstance(server_config, dict) and 'command' in server_config:
|
||
# 如果还没有env字段,则创建一个
|
||
if 'env' not in server_config:
|
||
server_config['env'] = {}
|
||
# 添加必要的环境变量
|
||
server_config['env']['BACKEND_HOST'] = BACKEND_HOST
|
||
server_config['env']['MASTERKEY'] = MASTERKEY
|
||
|
||
# 4. 处理传入的mcp_settings参数(优先级最高,覆盖所有)
|
||
input_mcp_settings = []
|
||
if mcp_settings is not None:
|
||
if isinstance(mcp_settings, list):
|
||
input_mcp_settings = mcp_settings
|
||
elif mcp_settings:
|
||
input_mcp_settings = [mcp_settings]
|
||
logger.warning(f"Warning: mcp_settings is not a list, converting to list format")
|
||
|
||
# 5. 合并用户传入的mcp_settings
|
||
if input_mcp_settings and len(input_mcp_settings) > 0 and len(merged_settings) > 0:
|
||
merged_mcp_servers = merged_settings[0].get('mcpServers', {})
|
||
input_mcp_servers = input_mcp_settings[0].get('mcpServers', {})
|
||
|
||
# 合并mcpServers对象,传入的设置覆盖已有设置
|
||
merged_mcp_servers.update(input_mcp_servers)
|
||
merged_settings[0]['mcpServers'] = merged_mcp_servers
|
||
logger.info(f"Merged mcpServers: existing + {len(input_mcp_servers)} input servers")
|
||
elif input_mcp_settings and not merged_settings:
|
||
# 如果没有其他配置,直接使用传入设置
|
||
merged_settings = input_mcp_settings.copy()
|
||
|
||
# 确保返回的是列表格式
|
||
if not merged_settings:
|
||
merged_settings = []
|
||
elif not isinstance(merged_settings, list):
|
||
logger.warning(f"Warning: merged_settings is not a list, converting to list format")
|
||
merged_settings = [merged_settings] if merged_settings else []
|
||
|
||
# 计算 dataset_dir 用于替换 MCP 配置中的占位符
|
||
# 只有当 project_dir 不为 None 时才计算 dataset_dir
|
||
dataset_dir = os.path.join(project_dir, "dataset") if project_dir is not None else None
|
||
# 替换 MCP 配置中的 {dataset_dir} 占位符
|
||
if dataset_dir is None:
|
||
dataset_dir = ""
|
||
merged_settings = replace_mcp_placeholders(merged_settings, dataset_dir, bot_id, dataset_ids)
|
||
return merged_settings
|
||
|
||
|
||
def load_guideline_prompt(chat_history:str, memory_text: str, guidelines_text: str, tools: str, scenarios: str, language: str, user_identifier: str = "") -> str:
|
||
"""
|
||
加载并处理guideline提示词
|
||
|
||
Args:
|
||
chat_history: 聊天历史记录
|
||
memory_text: 记忆文本
|
||
guidelines_text: 指导原则文本
|
||
tools: 工具描述文本
|
||
scenarios: 场景描述文本
|
||
language: 语言代码,如 'zh', 'en', 'jp' 等
|
||
user_identifier: 用户标识符,默认为空
|
||
|
||
Returns:
|
||
str: 处理后的guideline提示词
|
||
"""
|
||
guideline_template_file = os.path.join("prompt", "guideline_prompt.md")
|
||
with open(guideline_template_file, 'r', encoding='utf-8') as f:
|
||
guideline_template = f.read()
|
||
|
||
# 获取语言显示文本
|
||
language_display_map = {
|
||
'zh': '中文',
|
||
'en': 'English',
|
||
'ja': '日本語',
|
||
'jp': '日本語'
|
||
}
|
||
language_display = language_display_map.get(language, language if language else 'English')
|
||
datetime_str = format_datetime_by_language(language) if language else format_datetime_by_language('en')
|
||
# 替换模板中的占位符
|
||
system_prompt = guideline_template.format(
|
||
chat_history=chat_history,
|
||
guidelines_text=guidelines_text,
|
||
tools=tools,
|
||
scenarios=scenarios,
|
||
language=language_display,
|
||
user_identifier=user_identifier,
|
||
datetime=datetime_str,
|
||
memory_text=memory_text
|
||
)
|
||
|
||
return system_prompt
|
||
|