96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
"""Middleware that tags logs with the currently executing subagent name.
|
|
|
|
Each subagent receives its own instance of this middleware (carrying its name).
|
|
The middleware writes the name into the request-scoped GlobalContext (`g.subagent`)
|
|
for the duration of every model call and tool call, so the log Formatter can render
|
|
which subagent produced each log line. The previous value is restored afterwards so
|
|
that nested/parallel subagents and the main agent are not affected.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any, Awaitable, Callable
|
|
|
|
from langchain.agents.middleware import AgentMiddleware
|
|
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
|
from langchain.tools.tool_node import ToolCallRequest
|
|
|
|
from utils.log_util.context import g
|
|
|
|
logger = logging.getLogger("app")
|
|
|
|
# Context key consumed by utils/log_util/logger.py Formatter.
|
|
_SUBAGENT_KEY = "subagent"
|
|
|
|
|
|
class SubagentContextMiddleware(AgentMiddleware):
|
|
"""Set `g.subagent` while this subagent's model/tool calls execute."""
|
|
|
|
def __init__(self, subagent_name: str) -> None:
|
|
super().__init__()
|
|
self._subagent_name = subagent_name
|
|
|
|
def _enter(self) -> dict:
|
|
# Shallow-copy the whole context dict and rebind a PRIVATE copy for this
|
|
# context. This is load-bearing: GlobalContext mutates a shared dict in
|
|
# place, and asyncio task copies share that reference, so a plain
|
|
# `g.subagent = name` would leak across parallel sibling subagents and
|
|
# race on restore. Replacing the reference isolates each context.
|
|
try:
|
|
prev = dict(g.get_context())
|
|
except LookupError:
|
|
prev = {}
|
|
new_ctx = dict(prev)
|
|
new_ctx[_SUBAGENT_KEY] = self._subagent_name
|
|
g.update_context(new_ctx)
|
|
return prev
|
|
|
|
def _exit(self, prev: dict) -> None:
|
|
# Restore by rebinding the previous dict (also a private copy).
|
|
g.update_context(dict(prev))
|
|
|
|
# ----- model call -----
|
|
def wrap_model_call(
|
|
self,
|
|
request: ModelRequest,
|
|
handler: Callable[[ModelRequest], ModelResponse],
|
|
) -> ModelResponse:
|
|
prev = self._enter()
|
|
try:
|
|
return handler(request)
|
|
finally:
|
|
self._exit(prev)
|
|
|
|
async def awrap_model_call(
|
|
self,
|
|
request: ModelRequest,
|
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
) -> ModelResponse:
|
|
prev = self._enter()
|
|
try:
|
|
return await handler(request)
|
|
finally:
|
|
self._exit(prev)
|
|
|
|
# ----- tool call -----
|
|
def wrap_tool_call(
|
|
self,
|
|
request: ToolCallRequest,
|
|
handler: Callable[[ToolCallRequest], Any],
|
|
) -> Any:
|
|
prev = self._enter()
|
|
try:
|
|
return handler(request)
|
|
finally:
|
|
self._exit(prev)
|
|
|
|
async def awrap_tool_call(
|
|
self,
|
|
request: ToolCallRequest,
|
|
handler: Callable[[ToolCallRequest], Awaitable[Any]],
|
|
) -> Any:
|
|
prev = self._enter()
|
|
try:
|
|
return await handler(request)
|
|
finally:
|
|
self._exit(prev)
|