1. Fix "all environment values must be bytes or str" error in hook execution by ensuring all env values are converted to str (getattr may return None when attribute exists but is None). Also sanitize shell_env values. 2. Increase MEM0_POOL_SIZE default from 20 to 50 to address "connection pool exhausted" errors under high concurrency. Fixes: sparticleinc/felo-mygpt#2519 Co-authored-by: zhuchao <zhuchaowe@163.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
219 lines
7.3 KiB
Python
219 lines
7.3 KiB
Python
"""
|
||
Claude Plugins 模式的 Hook 加载器
|
||
|
||
支持通过 .claude-plugin/plugin.json 配置 hooks 和 mcpServers。
|
||
"""
|
||
import os
|
||
import json
|
||
import logging
|
||
import asyncio
|
||
import subprocess
|
||
from typing import List, Dict, Optional, Any
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
# Hook 类型定义
|
||
HOOK_TYPES = {
|
||
'PrePrompt': '在system_prompt加载时注入内容',
|
||
'PostAgent': '在agent执行后处理',
|
||
'PreSave': '在保存消息前处理',
|
||
'PreMemoryPrompt': '在记忆提取提示词加载时注入内容',
|
||
}
|
||
|
||
|
||
async def execute_hooks(hook_type: str, config, **kwargs) -> Any:
|
||
"""
|
||
执行指定类型的所有 hooks
|
||
|
||
Args:
|
||
hook_type: hook 类型 (PrePrompt, PostAgent, PreSave)
|
||
config: AgentConfig 对象
|
||
**kwargs: hook 特定的参数
|
||
- PrePrompt: 无额外参数,返回 str
|
||
- PostAgent: response (str), metadata (dict),返回 None
|
||
- PreSave: content (str), role (str),返回 str
|
||
|
||
Returns:
|
||
- PrePrompt: str (注入内容)
|
||
- PostAgent: None
|
||
- PreSave: str (处理后的内容)
|
||
"""
|
||
hook_results = []
|
||
bot_id = getattr(config, 'bot_id', '')
|
||
|
||
skill_dirs = _get_skill_dirs(bot_id)
|
||
|
||
for skill_dir in skill_dirs:
|
||
if not os.path.exists(skill_dir):
|
||
continue
|
||
|
||
# 遍历 skill 目录下的每个子文件夹
|
||
for skill_name in os.listdir(skill_dir):
|
||
skill_path = os.path.join(skill_dir, skill_name)
|
||
if not os.path.isdir(skill_path):
|
||
continue
|
||
|
||
plugin_json = os.path.join(skill_path, '.claude-plugin', 'plugin.json')
|
||
if not os.path.exists(plugin_json):
|
||
continue
|
||
|
||
try:
|
||
plugin_config = _load_plugin_config(plugin_json)
|
||
hooks = plugin_config.get('hooks', {}).get(hook_type, [])
|
||
|
||
for hook_config in hooks:
|
||
if hook_config.get('type') == 'command':
|
||
command = hook_config.get('command')
|
||
if command:
|
||
# 在 skill 目录下执行命令
|
||
result = await _execute_command(
|
||
skill_path, command, hook_type, config, **kwargs
|
||
)
|
||
if result:
|
||
hook_results.append(result)
|
||
logger.info(f"Executed {hook_type} hook from {skill_name}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to load hooks from {plugin_json}: {e}")
|
||
|
||
# 根据hook类型返回结果
|
||
if hook_type in ('PrePrompt', 'PreMemoryPrompt'):
|
||
return "\n\n".join(hook_results)
|
||
elif hook_type == 'PreSave':
|
||
# PreSave 返回处理后的内容
|
||
# 如果有hook返回内容,使用最后一个hook的结果
|
||
# 否则返回原始内容
|
||
return hook_results[-1] if hook_results else kwargs.get('content', '')
|
||
return None
|
||
|
||
|
||
async def merge_skill_mcp_configs(bot_id: str) -> List[Dict]:
|
||
"""
|
||
从所有 skill 目录的 plugin.json 中读取 mcpServers 并合并
|
||
|
||
Args:
|
||
bot_id: Bot ID
|
||
|
||
Returns:
|
||
List[Dict]: 合并后的MCP设置列表
|
||
"""
|
||
skill_dirs = _get_skill_dirs(bot_id)
|
||
merged_servers = {}
|
||
|
||
for skill_dir in skill_dirs:
|
||
if not os.path.exists(skill_dir):
|
||
continue
|
||
|
||
for skill_name in os.listdir(skill_dir):
|
||
skill_path = os.path.join(skill_dir, skill_name)
|
||
if not os.path.isdir(skill_path):
|
||
continue
|
||
|
||
plugin_json = os.path.join(skill_path, '.claude-plugin', 'plugin.json')
|
||
if os.path.exists(plugin_json):
|
||
try:
|
||
with open(plugin_json, 'r', encoding='utf-8') as f:
|
||
plugin_config = json.load(f)
|
||
servers = plugin_config.get('mcpServers', {})
|
||
if servers:
|
||
merged_servers.update(servers)
|
||
logger.info(f"Loaded MCP config from skill: {skill_name}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to load mcpServers from {skill_name}: {e}")
|
||
|
||
if merged_servers:
|
||
return [{"mcpServers": merged_servers}]
|
||
|
||
return []
|
||
|
||
|
||
def _load_plugin_config(plugin_json_path: str) -> Dict:
|
||
"""加载 plugin.json 配置"""
|
||
try:
|
||
with open(plugin_json_path, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
except Exception as e:
|
||
logger.error(f"Failed to load plugin.json: {e}")
|
||
return {}
|
||
|
||
|
||
def _get_skill_dirs(bot_id: str) -> List[str]:
|
||
"""获取需要扫描的skill目录列表"""
|
||
dirs = []
|
||
|
||
# 用户上传的skills目录
|
||
if bot_id:
|
||
robot_skills = f"projects/robot/{bot_id}/skills"
|
||
if os.path.exists(robot_skills):
|
||
dirs.append(robot_skills)
|
||
|
||
return dirs
|
||
|
||
|
||
async def _execute_command(skill_path: str, command: str, hook_type: str, config, **kwargs) -> Optional[str]:
|
||
"""执行 hook 命令
|
||
|
||
Args:
|
||
skill_path: skill 目录路径,作为工作目录
|
||
command: 要执行的命令
|
||
hook_type: hook 类型
|
||
config: AgentConfig 对象
|
||
**kwargs: 额外参数
|
||
|
||
Returns:
|
||
str: 命令的 stdout 输出
|
||
"""
|
||
try:
|
||
# 设置环境变量,传递给子进程
|
||
# 注意:subprocess 要求所有 env 值必须是 str 类型,
|
||
# getattr 可能返回 None(属性存在但值为 None),需要确保转换为 str
|
||
env = os.environ.copy()
|
||
env['ASSISTANT_ID'] = str(getattr(config, 'bot_id', '') or '')
|
||
env['USER_IDENTIFIER'] = str(getattr(config, 'user_identifier', '') or '')
|
||
env['TRACE_ID'] = str(getattr(config, 'trace_id', '') or '')
|
||
env['SESSION_ID'] = str(getattr(config, 'session_id', '') or '')
|
||
env['LANGUAGE'] = str(getattr(config, 'language', '') or '')
|
||
env['HOOK_TYPE'] = hook_type
|
||
|
||
# 合并 config 中的自定义 shell 环境变量
|
||
shell_env = getattr(config, 'shell_env', None)
|
||
if shell_env:
|
||
# 确保所有自定义环境变量值也是字符串
|
||
env.update({k: str(v) if v is not None else '' for k, v in shell_env.items()})
|
||
|
||
# 对于 PreSave,传递 content
|
||
if hook_type == 'PreSave':
|
||
env['CONTENT'] = str(kwargs.get('content', '') or '')
|
||
env['ROLE'] = str(kwargs.get('role', '') or '')
|
||
|
||
# 对于 PostAgent,传递 response
|
||
if hook_type == 'PostAgent':
|
||
env['RESPONSE'] = str(kwargs.get('response', '') or '')
|
||
metadata = kwargs.get('metadata', {})
|
||
env['METADATA'] = json.dumps(metadata) if metadata else ''
|
||
|
||
# 使用 subprocess 执行命令,捕获 stdout
|
||
process = await asyncio.create_subprocess_shell(
|
||
command,
|
||
cwd=skill_path,
|
||
env=env,
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.PIPE
|
||
)
|
||
|
||
stdout, stderr = await process.communicate()
|
||
|
||
if stdout:
|
||
result = stdout.decode('utf-8').strip()
|
||
return result
|
||
|
||
if stderr and process.returncode != 0:
|
||
logger.warning(f"Hook command stderr: {stderr.decode('utf-8')}")
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error executing hook command '{command}': {e}")
|
||
return None
|
||
|