add subagent_context_middleware

This commit is contained in:
朱潮 2026-06-12 15:56:26 +08:00
parent e2827c6a47
commit 73042c57a6
3 changed files with 106 additions and 1 deletions

View File

@ -0,0 +1,95 @@
"""Middleware that tags logs with the currently executing subagent name.
Each subagent receives its own instance of this middleware (carrying its name).
The middleware writes the name into the request-scoped GlobalContext (`g.subagent`)
for the duration of every model call and tool call, so the log Formatter can render
which subagent produced each log line. The previous value is restored afterwards so
that nested/parallel subagents and the main agent are not affected.
"""
import logging
from typing import Any, Awaitable, Callable
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain.tools.tool_node import ToolCallRequest
from utils.log_util.context import g
logger = logging.getLogger("app")
# Context key consumed by utils/log_util/logger.py Formatter.
_SUBAGENT_KEY = "subagent"
class SubagentContextMiddleware(AgentMiddleware):
"""Set `g.subagent` while this subagent's model/tool calls execute."""
def __init__(self, subagent_name: str) -> None:
super().__init__()
self._subagent_name = subagent_name
def _enter(self) -> dict:
# Shallow-copy the whole context dict and rebind a PRIVATE copy for this
# context. This is load-bearing: GlobalContext mutates a shared dict in
# place, and asyncio task copies share that reference, so a plain
# `g.subagent = name` would leak across parallel sibling subagents and
# race on restore. Replacing the reference isolates each context.
try:
prev = dict(g.get_context())
except LookupError:
prev = {}
new_ctx = dict(prev)
new_ctx[_SUBAGENT_KEY] = self._subagent_name
g.update_context(new_ctx)
return prev
def _exit(self, prev: dict) -> None:
# Restore by rebinding the previous dict (also a private copy).
g.update_context(dict(prev))
# ----- model call -----
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
prev = self._enter()
try:
return handler(request)
finally:
self._exit(prev)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse:
prev = self._enter()
try:
return await handler(request)
finally:
self._exit(prev)
# ----- tool call -----
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Any],
) -> Any:
prev = self._enter()
try:
return handler(request)
finally:
self._exit(prev)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[Any]],
) -> Any:
prev = self._enter()
try:
return await handler(request)
finally:
self._exit(prev)

View File

@ -25,6 +25,7 @@ from langchain.tools import BaseTool
from langchain_core.language_models import BaseChatModel
from agent.plugin_hook_loader import _get_skill_dirs
from agent.subagent_context_middleware import SubagentContextMiddleware
logger = logging.getLogger('app')
@ -181,6 +182,8 @@ async def load_subagents(
"system_prompt": parsed["system_prompt"],
"model": model,
"tools": filtered_tools,
# Tag this subagent's model/tool logs with its name.
"middleware": [SubagentContextMiddleware(name)],
}
subagents.append(subagent)
logger.info(f"Loaded sub-agent '{name}' with {len(filtered_tools)} tools from {parsed['source']}")

View File

@ -51,6 +51,13 @@ class Formatter(logging.Formatter):
record.trace_id = getattr(g, "trace_id")
except LookupError:
record.trace_id = "N/A"
# Handle subagent - default to "main" for the orchestrator / no-context paths.
# Catch KeyError too: GlobalContext.__getattr__ raises KeyError on a missing key.
if not hasattr(record, "subagent"):
try:
record.subagent = getattr(g, "subagent")
except (KeyError, LookupError):
record.subagent = "main"
# Handle user_id
# if not hasattr(record, "user_id"):
# record.user_id = getattr(g, "user_id")
@ -65,7 +72,7 @@ class Formatter(logging.Formatter):
def init_logger_once(name,level):
logger = logging.getLogger(name)
logger.setLevel(level=level)
formatter = Formatter("%(timestamp)s | %(levelname)-5s | %(trace_id)s | %(name)s:%(funcName)s:%(lineno)s - %(message)s", datefmt='%Y-%m-%d %H:%M:%S.%f')
formatter = Formatter("%(timestamp)s | %(levelname)-5s | %(trace_id)s | %(subagent)s | %(name)s:%(funcName)s:%(lineno)s - %(message)s", datefmt='%Y-%m-%d %H:%M:%S.%f')
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)