diff --git a/agent/subagent_context_middleware.py b/agent/subagent_context_middleware.py new file mode 100644 index 0000000..98d1422 --- /dev/null +++ b/agent/subagent_context_middleware.py @@ -0,0 +1,95 @@ +"""Middleware that tags logs with the currently executing subagent name. + +Each subagent receives its own instance of this middleware (carrying its name). +The middleware writes the name into the request-scoped GlobalContext (`g.subagent`) +for the duration of every model call and tool call, so the log Formatter can render +which subagent produced each log line. The previous value is restored afterwards so +that nested/parallel subagents and the main agent are not affected. +""" + +import logging +from typing import Any, Awaitable, Callable + +from langchain.agents.middleware import AgentMiddleware +from langchain.agents.middleware.types import ModelRequest, ModelResponse +from langchain.tools.tool_node import ToolCallRequest + +from utils.log_util.context import g + +logger = logging.getLogger("app") + +# Context key consumed by utils/log_util/logger.py Formatter. +_SUBAGENT_KEY = "subagent" + + +class SubagentContextMiddleware(AgentMiddleware): + """Set `g.subagent` while this subagent's model/tool calls execute.""" + + def __init__(self, subagent_name: str) -> None: + super().__init__() + self._subagent_name = subagent_name + + def _enter(self) -> dict: + # Shallow-copy the whole context dict and rebind a PRIVATE copy for this + # context. This is load-bearing: GlobalContext mutates a shared dict in + # place, and asyncio task copies share that reference, so a plain + # `g.subagent = name` would leak across parallel sibling subagents and + # race on restore. Replacing the reference isolates each context. + try: + prev = dict(g.get_context()) + except LookupError: + prev = {} + new_ctx = dict(prev) + new_ctx[_SUBAGENT_KEY] = self._subagent_name + g.update_context(new_ctx) + return prev + + def _exit(self, prev: dict) -> None: + # Restore by rebinding the previous dict (also a private copy). + g.update_context(dict(prev)) + + # ----- model call ----- + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + prev = self._enter() + try: + return handler(request) + finally: + self._exit(prev) + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelResponse: + prev = self._enter() + try: + return await handler(request) + finally: + self._exit(prev) + + # ----- tool call ----- + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Any], + ) -> Any: + prev = self._enter() + try: + return handler(request) + finally: + self._exit(prev) + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[Any]], + ) -> Any: + prev = self._enter() + try: + return await handler(request) + finally: + self._exit(prev) diff --git a/agent/subagent_loader.py b/agent/subagent_loader.py index ed18fb9..f27d051 100644 --- a/agent/subagent_loader.py +++ b/agent/subagent_loader.py @@ -25,6 +25,7 @@ from langchain.tools import BaseTool from langchain_core.language_models import BaseChatModel from agent.plugin_hook_loader import _get_skill_dirs +from agent.subagent_context_middleware import SubagentContextMiddleware logger = logging.getLogger('app') @@ -181,6 +182,8 @@ async def load_subagents( "system_prompt": parsed["system_prompt"], "model": model, "tools": filtered_tools, + # Tag this subagent's model/tool logs with its name. + "middleware": [SubagentContextMiddleware(name)], } subagents.append(subagent) logger.info(f"Loaded sub-agent '{name}' with {len(filtered_tools)} tools from {parsed['source']}") diff --git a/utils/log_util/logger.py b/utils/log_util/logger.py index b6e98ba..a88f5ae 100644 --- a/utils/log_util/logger.py +++ b/utils/log_util/logger.py @@ -51,6 +51,13 @@ class Formatter(logging.Formatter): record.trace_id = getattr(g, "trace_id") except LookupError: record.trace_id = "N/A" + # Handle subagent - default to "main" for the orchestrator / no-context paths. + # Catch KeyError too: GlobalContext.__getattr__ raises KeyError on a missing key. + if not hasattr(record, "subagent"): + try: + record.subagent = getattr(g, "subagent") + except (KeyError, LookupError): + record.subagent = "main" # Handle user_id # if not hasattr(record, "user_id"): # record.user_id = getattr(g, "user_id") @@ -65,7 +72,7 @@ class Formatter(logging.Formatter): def init_logger_once(name,level): logger = logging.getLogger(name) logger.setLevel(level=level) - formatter = Formatter("%(timestamp)s | %(levelname)-5s | %(trace_id)s | %(name)s:%(funcName)s:%(lineno)s - %(message)s", datefmt='%Y-%m-%d %H:%M:%S.%f') + formatter = Formatter("%(timestamp)s | %(levelname)-5s | %(trace_id)s | %(subagent)s | %(name)s:%(funcName)s:%(lineno)s - %(message)s", datefmt='%Y-%m-%d %H:%M:%S.%f') handler = logging.StreamHandler() handler.setFormatter(formatter) logger.addHandler(handler)