add subagent_context_middleware
This commit is contained in:
parent
e2827c6a47
commit
73042c57a6
95
agent/subagent_context_middleware.py
Normal file
95
agent/subagent_context_middleware.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""Middleware that tags logs with the currently executing subagent name.
|
||||
|
||||
Each subagent receives its own instance of this middleware (carrying its name).
|
||||
The middleware writes the name into the request-scoped GlobalContext (`g.subagent`)
|
||||
for the duration of every model call and tool call, so the log Formatter can render
|
||||
which subagent produced each log line. The previous value is restored afterwards so
|
||||
that nested/parallel subagents and the main agent are not affected.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain.tools.tool_node import ToolCallRequest
|
||||
|
||||
from utils.log_util.context import g
|
||||
|
||||
logger = logging.getLogger("app")
|
||||
|
||||
# Context key consumed by utils/log_util/logger.py Formatter.
|
||||
_SUBAGENT_KEY = "subagent"
|
||||
|
||||
|
||||
class SubagentContextMiddleware(AgentMiddleware):
|
||||
"""Set `g.subagent` while this subagent's model/tool calls execute."""
|
||||
|
||||
def __init__(self, subagent_name: str) -> None:
|
||||
super().__init__()
|
||||
self._subagent_name = subagent_name
|
||||
|
||||
def _enter(self) -> dict:
|
||||
# Shallow-copy the whole context dict and rebind a PRIVATE copy for this
|
||||
# context. This is load-bearing: GlobalContext mutates a shared dict in
|
||||
# place, and asyncio task copies share that reference, so a plain
|
||||
# `g.subagent = name` would leak across parallel sibling subagents and
|
||||
# race on restore. Replacing the reference isolates each context.
|
||||
try:
|
||||
prev = dict(g.get_context())
|
||||
except LookupError:
|
||||
prev = {}
|
||||
new_ctx = dict(prev)
|
||||
new_ctx[_SUBAGENT_KEY] = self._subagent_name
|
||||
g.update_context(new_ctx)
|
||||
return prev
|
||||
|
||||
def _exit(self, prev: dict) -> None:
|
||||
# Restore by rebinding the previous dict (also a private copy).
|
||||
g.update_context(dict(prev))
|
||||
|
||||
# ----- model call -----
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
prev = self._enter()
|
||||
try:
|
||||
return handler(request)
|
||||
finally:
|
||||
self._exit(prev)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
prev = self._enter()
|
||||
try:
|
||||
return await handler(request)
|
||||
finally:
|
||||
self._exit(prev)
|
||||
|
||||
# ----- tool call -----
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Any],
|
||||
) -> Any:
|
||||
prev = self._enter()
|
||||
try:
|
||||
return handler(request)
|
||||
finally:
|
||||
self._exit(prev)
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[Any]],
|
||||
) -> Any:
|
||||
prev = self._enter()
|
||||
try:
|
||||
return await handler(request)
|
||||
finally:
|
||||
self._exit(prev)
|
||||
@ -25,6 +25,7 @@ from langchain.tools import BaseTool
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from agent.plugin_hook_loader import _get_skill_dirs
|
||||
from agent.subagent_context_middleware import SubagentContextMiddleware
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
@ -181,6 +182,8 @@ async def load_subagents(
|
||||
"system_prompt": parsed["system_prompt"],
|
||||
"model": model,
|
||||
"tools": filtered_tools,
|
||||
# Tag this subagent's model/tool logs with its name.
|
||||
"middleware": [SubagentContextMiddleware(name)],
|
||||
}
|
||||
subagents.append(subagent)
|
||||
logger.info(f"Loaded sub-agent '{name}' with {len(filtered_tools)} tools from {parsed['source']}")
|
||||
|
||||
@ -51,6 +51,13 @@ class Formatter(logging.Formatter):
|
||||
record.trace_id = getattr(g, "trace_id")
|
||||
except LookupError:
|
||||
record.trace_id = "N/A"
|
||||
# Handle subagent - default to "main" for the orchestrator / no-context paths.
|
||||
# Catch KeyError too: GlobalContext.__getattr__ raises KeyError on a missing key.
|
||||
if not hasattr(record, "subagent"):
|
||||
try:
|
||||
record.subagent = getattr(g, "subagent")
|
||||
except (KeyError, LookupError):
|
||||
record.subagent = "main"
|
||||
# Handle user_id
|
||||
# if not hasattr(record, "user_id"):
|
||||
# record.user_id = getattr(g, "user_id")
|
||||
@ -65,7 +72,7 @@ class Formatter(logging.Formatter):
|
||||
def init_logger_once(name,level):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level=level)
|
||||
formatter = Formatter("%(timestamp)s | %(levelname)-5s | %(trace_id)s | %(name)s:%(funcName)s:%(lineno)s - %(message)s", datefmt='%Y-%m-%d %H:%M:%S.%f')
|
||||
formatter = Formatter("%(timestamp)s | %(levelname)-5s | %(trace_id)s | %(subagent)s | %(name)s:%(funcName)s:%(lineno)s - %(message)s", datefmt='%Y-%m-%d %H:%M:%S.%f')
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user