qwen_agent/agent/prompt_loader.py
2025-12-03 14:13:39 +08:00

377 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
System prompt and MCP settings loader utilities
"""
import os
import json
import asyncio
from typing import List, Dict, Optional, Any
from datetime import datetime, timezone, timedelta
import logging
logger = logging.getLogger('app')
def safe_replace(text: str, placeholder: str, value: Any) -> str:
"""
安全的字符串替换函数,确保 value 被转换为字符串
Args:
text: 原始文本
placeholder: 要替换的占位符(如 '{user_identifier}'
value: 用于替换的值(可以是任意类型)
Returns:
str: 替换后的文本
"""
if not isinstance(text, str):
text = str(text)
# 如果占位符为空,不进行替换
if not placeholder:
return text
# 将 value 转换为字符串,处理 None 等特殊情况
if value is None:
replacement = ""
else:
replacement = str(value)
return text.replace(placeholder, replacement)
def format_datetime_by_language(language: str) -> str:
"""
根据语言格式化当前时间字符串以UTC时间为基准计算各时区时间
Args:
language: 语言代码,如 'zh', 'en', 'ja', 'jp'
Returns:
str: 格式化后的时间字符串,包含时区信息
"""
try:
# 获取当前UTC时间
utc_now = datetime.now(timezone.utc)
# 定义语言到时区的映射
language_timezone_map = {
'zh': {'offset': 8, 'name': 'CST', 'display': '北京时间'},
'ja': {'offset': 9, 'name': 'JST', 'display': '日本時間'},
'jp': {'offset': 9, 'name': 'JST', 'display': '日本時間'},
'en': {'offset': 0, 'name': 'UTC', 'display': 'UTC'}, # 默认UTC英语用户全球化
# 可扩展其他语言...
}
# 获取语言对应的时区信息默认使用UTC
tz_info = language_timezone_map.get(language, language_timezone_map['en'])
offset_hours = tz_info['offset']
tz_name = tz_info['name']
# 计算本地时间
local_time = utc_now + timedelta(hours=offset_hours)
# 根据语言格式化时间
if language == 'zh':
# 中文格式2024年1月15日 14:30 (UTC+8 北京时间)
return local_time.strftime("%Y年%m月%d%H:%M") + f" (UTC{offset_hours:+d} {tz_info['display']})"
elif language in ['ja', 'jp']:
# 日文格式2024年1月15日 14:30 (JST UTC+9)
return local_time.strftime("%Y年%m月%d%H:%M") + f" ({tz_name} UTC{offset_hours:+d})"
elif language == 'en':
# 英文格式January 15, 2024 14:30 EST (UTC-5)
return local_time.strftime("%B %d, %Y %H:%M") + f" {tz_name} (UTC{offset_hours:+d})"
else:
# 默认格式2024-01-15 14:30:30 (时区)
return local_time.strftime("%Y-%m-%d %H:%M:%S") + f" (UTC{offset_hours:+d})"
except Exception as e:
# 如果时区处理失败回退到UTC时间
utc_now = datetime.now(timezone.utc)
if language == 'zh':
return utc_now.strftime("%Y年%m月%d%H:%M") + " UTC"
elif language in ['ja', 'jp']:
return utc_now.strftime("%Y年%m月%d%H:%M") + " UTC"
elif language == 'en':
return utc_now.strftime("%B %d, %Y %H:%M") + " UTC"
else:
return utc_now.strftime("%Y-%m-%d %H:%M:%S") + " UTC"
async def load_system_prompt_async(project_dir: str, language: str = None, system_prompt: str=None, robot_type: str = "general_agent", bot_id: str="", user_identifier: str = "") -> str:
"""异步版本的系统prompt加载
Args:
project_dir: 项目目录路径可以为None
language: 语言代码,如 'zh', 'en', 'jp'
system_prompt: 可选的系统提示词,优先级高于项目配置
robot_type: 机器人类型,取值 agent/catalog_agent
bot_id: 机器人ID
user_identifier: 用户标识符
Returns:
str: 加载到的系统提示词内容
"""
from agent.config_cache import config_cache
# 获取语言显示名称
language_display_map = {
'zh': '中文',
'en': 'English',
'ja': '日本語',
'jp': '日本語'
}
language_display = language_display_map.get(language, language if language else 'English')
# 获取格式化的时间字符串
datetime_str = format_datetime_by_language(language) if language else format_datetime_by_language('en')
# 如果存在{language} 占位符,那么就直接使用 system_prompt
if system_prompt and "{language}" in system_prompt:
prompt = system_prompt
prompt = safe_replace(prompt, "{language}", language_display)
prompt = safe_replace(prompt, '{bot_id}', bot_id)
prompt = safe_replace(prompt, '{user_identifier}', user_identifier)
prompt = safe_replace(prompt, '{datetime}', datetime_str)
return prompt or ""
elif robot_type == "general_agent" or robot_type == "catalog_agent":
"""
优先使用项目目录的README.md没有才使用默认的system_prompt_{robot_type}.md
"""
system_prompt_default = None
try:
# 使用缓存读取默认prompt文件
default_prompt_file = os.path.join("prompt", f"system_prompt_{robot_type}.md")
system_prompt_default = await config_cache.get_text_file(default_prompt_file)
if system_prompt_default:
logger.info(f"Using cached default system prompt for {robot_type} from prompt folder")
except Exception as e:
logger.error(f"Failed to load default system prompt for {robot_type}: {str(e)}")
system_prompt_default = None
readme = ""
# 只有当 project_dir 不为 None 时才尝试读取 README.md
if project_dir is not None:
readme_path = os.path.join(project_dir, "README.md")
readme = await config_cache.get_text_file(readme_path) or ""
if system_prompt_default:
system_prompt_default = safe_replace(system_prompt_default, "{readme}", str(readme))
prompt = system_prompt_default or ""
prompt = safe_replace(prompt, "{language}", language_display)
prompt = safe_replace(prompt, "{extra_prompt}", system_prompt or "")
prompt = safe_replace(prompt, '{bot_id}', bot_id)
prompt = safe_replace(prompt, '{user_identifier}', user_identifier)
prompt = safe_replace(prompt, '{datetime}', datetime_str)
return prompt or ""
else:
prompt = system_prompt
prompt = safe_replace(prompt, "{language}", language_display)
prompt = safe_replace(prompt, '{bot_id}', bot_id)
prompt = safe_replace(prompt, '{user_identifier}', user_identifier)
prompt = safe_replace(prompt, '{datetime}', datetime_str)
return prompt or ""
def load_system_prompt(project_dir: str, language: str = None, system_prompt: str=None, robot_type: str = "general_agent", bot_id: str="", user_identifier: str = "") -> str:
"""同步版本的系统prompt加载内部调用异步版本以保持向后兼容"""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
load_system_prompt_async(project_dir, language, system_prompt, robot_type, bot_id, user_identifier)
)
finally:
if loop.is_running():
pass
else:
loop.close()
def replace_mcp_placeholders(mcp_settings: List[Dict], dataset_dir: str, bot_id: str) -> List[Dict]:
"""
替换 MCP 配置中的占位符
"""
if not mcp_settings or not isinstance(mcp_settings, list):
return mcp_settings
def replace_placeholders_in_obj(obj):
"""递归替换对象中的占位符"""
if isinstance(obj, dict):
for key, value in obj.items():
if key == 'args' and isinstance(value, list):
# 特别处理 args 列表
obj[key] = [safe_replace(safe_replace(item, '{dataset_dir}', dataset_dir), '{bot_id}', bot_id) if isinstance(item, str) else item
for item in value]
elif isinstance(value, (dict, list)):
obj[key] = replace_placeholders_in_obj(value)
elif isinstance(value, str):
obj[key] = safe_replace(value, '{dataset_dir}', dataset_dir)
obj[key] = safe_replace(obj[key], '{bot_id}', bot_id)
elif isinstance(obj, list):
return [replace_placeholders_in_obj(item) if isinstance(item, (dict, list)) else
safe_replace(safe_replace(item, '{dataset_dir}', dataset_dir), '{bot_id}', bot_id) if isinstance(item, str) else item
for item in obj]
return obj
return replace_placeholders_in_obj(mcp_settings)
async def load_mcp_settings_async(project_dir: str, mcp_settings: list=None, bot_id: str="", robot_type: str = "general_agent") -> List[Dict]:
"""异步版本的MCP设置加载
Args:
project_dir: 项目目录路径
mcp_settings: 可选的MCP设置将与默认设置合并
bot_id: 机器人项目ID
robot_type: 机器人类型,取值 agent/catalog_agent
Returns:
List[Dict]: 合并后的MCP设置列表
Note:
支持在 mcp_settings.json 的 args 中使用 {dataset_dir} 占位符,
会在 init_modified_agent_service_with_files 中被替换为实际的路径。
"""
from agent.config_cache import config_cache
# 1. 首先读取默认MCP设置
default_mcp_settings = []
try:
# 使用缓存读取默认MCP设置文件
default_mcp_file = os.path.join("mcp", f"mcp_settings_{robot_type}.json")
default_mcp_settings = await config_cache.get_json_file(default_mcp_file) or []
if default_mcp_settings:
logger.info(f"Using cached default mcp_settings_{robot_type} from mcp folder")
else:
logger.warning(f"No default mcp_settings_{robot_type} found, using empty default settings")
except Exception as e:
logger.error(f"Failed to load default mcp_settings_{robot_type}: {str(e)}")
default_mcp_settings = []
# 遍历mcpServers工具给每个工具增加env参数
if default_mcp_settings and len(default_mcp_settings) > 0:
mcp_servers = default_mcp_settings[0].get('mcpServers', {})
for server_name, server_config in mcp_servers.items():
if isinstance(server_config, dict):
# 如果还没有env字段则创建一个
if 'env' not in server_config:
server_config['env'] = {}
# 添加必要的环境变量
server_config['env']['BACKEND_HOST'] = os.environ.get('BACKEND_HOST', 'https://api-dev.gptbase.ai')
# 2. 处理传入的mcp_settings参数
input_mcp_settings = []
if mcp_settings is not None:
if isinstance(mcp_settings, list):
input_mcp_settings = mcp_settings
elif mcp_settings:
input_mcp_settings = [mcp_settings]
logger.warning(f"Warning: mcp_settings is not a list, converting to list format")
# 3. 合并默认设置和传入设置
merged_settings = []
# 如果有默认设置,以此为基准
if default_mcp_settings:
merged_settings = default_mcp_settings.copy()
# 如果有传入设置合并mcpServers对象
if input_mcp_settings and len(input_mcp_settings) > 0 and len(merged_settings) > 0:
default_mcp_servers = merged_settings[0].get('mcpServers', {})
input_mcp_servers = input_mcp_settings[0].get('mcpServers', {})
# 合并mcpServers对象传入的设置覆盖默认设置中相同的key
default_mcp_servers.update(input_mcp_servers)
merged_settings[0]['mcpServers'] = default_mcp_servers
logger.info(f"Merged mcpServers: default + {len(input_mcp_servers)} input servers")
# 如果没有默认设置但有传入设置,直接使用传入设置
elif input_mcp_settings:
merged_settings = input_mcp_settings.copy()
# 确保返回的是列表格式
if not merged_settings:
merged_settings = []
elif not isinstance(merged_settings, list):
logger.warning(f"Warning: merged_settings is not a list, converting to list format")
merged_settings = [merged_settings] if merged_settings else []
# 计算 dataset_dir 用于替换 MCP 配置中的占位符
# 只有当 project_dir 不为 None 时才计算 dataset_dir
dataset_dir = os.path.join(project_dir, "dataset") if project_dir is not None else None
# 替换 MCP 配置中的 {dataset_dir} 占位符
if dataset_dir is None:
dataset_dir = ""
merged_settings = replace_mcp_placeholders(merged_settings, dataset_dir, bot_id)
return merged_settings
def load_mcp_settings(project_dir: str, mcp_settings: list=None, bot_id: str="", robot_type: str = "general_agent") -> List[Dict]:
"""同步版本的MCP设置加载内部调用异步版本以保持向后兼容"""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
load_mcp_settings_async(project_dir, mcp_settings, bot_id, robot_type)
)
finally:
if loop.is_running():
pass
else:
loop.close()
def load_guideline_prompt(chat_history: str, guidelines_text: str, tools: str, scenarios: str, terms: str, language: str, user_identifier: str = "") -> str:
"""
加载并处理guideline提示词
Args:
chat_history: 聊天历史记录
guidelines_text: 指导原则文本
terms: 条款文本
language: 语言代码,如 'zh', 'en', 'jp'
user_identifier: 用户标识符,默认为空
datetime_str: 时间字符串,默认为空
Returns:
str: 处理后的guideline提示词
"""
try:
with open('./prompt/guideline_prompt.md', 'r', encoding='utf-8') as f:
guideline_template = f.read()
except Exception as e:
logger.error(f"Error reading guideline prompt template: {e}")
return ""
# 获取语言显示文本
language_display_map = {
'zh': '中文',
'en': 'English',
'ja': '日本語',
'jp': '日本語'
}
language_display = language_display_map.get(language, language if language else 'English')
datetime_str = format_datetime_by_language(language) if language else format_datetime_by_language('en')
# 替换模板中的占位符
system_prompt = safe_replace(guideline_template, '{chat_history}', chat_history)
system_prompt = safe_replace(system_prompt, '{guidelines_text}', guidelines_text)
system_prompt = safe_replace(system_prompt, '{terms}', terms)
system_prompt = safe_replace(system_prompt, '{tools}', tools)
system_prompt = safe_replace(system_prompt, '{scenarios}', scenarios)
system_prompt = safe_replace(system_prompt, '{language}', language_display)
system_prompt = safe_replace(system_prompt, '{user_identifier}', user_identifier)
system_prompt = safe_replace(system_prompt, '{datetime}', datetime_str)
return system_prompt