"""Middleware that tags logs with the currently executing subagent name. Each subagent receives its own instance of this middleware (carrying its name). The middleware writes the name into the request-scoped GlobalContext (`g.subagent`) for the duration of every model call and tool call, so the log Formatter can render which subagent produced each log line. The previous value is restored afterwards so that nested/parallel subagents and the main agent are not affected. """ import logging from typing import Any, Awaitable, Callable from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware.types import ModelRequest, ModelResponse from langchain.tools.tool_node import ToolCallRequest from utils.log_util.context import g logger = logging.getLogger("app") # Context key consumed by utils/log_util/logger.py Formatter. _SUBAGENT_KEY = "subagent" class SubagentContextMiddleware(AgentMiddleware): """Set `g.subagent` while this subagent's model/tool calls execute.""" def __init__(self, subagent_name: str) -> None: super().__init__() self._subagent_name = subagent_name def _enter(self) -> dict: # Shallow-copy the whole context dict and rebind a PRIVATE copy for this # context. This is load-bearing: GlobalContext mutates a shared dict in # place, and asyncio task copies share that reference, so a plain # `g.subagent = name` would leak across parallel sibling subagents and # race on restore. Replacing the reference isolates each context. try: prev = dict(g.get_context()) except LookupError: prev = {} new_ctx = dict(prev) new_ctx[_SUBAGENT_KEY] = self._subagent_name g.update_context(new_ctx) return prev def _exit(self, prev: dict) -> None: # Restore by rebinding the previous dict (also a private copy). g.update_context(dict(prev)) # ----- model call ----- def wrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse], ) -> ModelResponse: prev = self._enter() try: return handler(request) finally: self._exit(prev) async def awrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelResponse: prev = self._enter() try: return await handler(request) finally: self._exit(prev) # ----- tool call ----- def wrap_tool_call( self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], Any], ) -> Any: prev = self._enter() try: return handler(request) finally: self._exit(prev) async def awrap_tool_call( self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], Awaitable[Any]], ) -> Any: prev = self._enter() try: return await handler(request) finally: self._exit(prev)