add subagent_context_middleware
This commit is contained in:
parent
e2827c6a47
commit
73042c57a6
95
agent/subagent_context_middleware.py
Normal file
95
agent/subagent_context_middleware.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
"""Middleware that tags logs with the currently executing subagent name.
|
||||||
|
|
||||||
|
Each subagent receives its own instance of this middleware (carrying its name).
|
||||||
|
The middleware writes the name into the request-scoped GlobalContext (`g.subagent`)
|
||||||
|
for the duration of every model call and tool call, so the log Formatter can render
|
||||||
|
which subagent produced each log line. The previous value is restored afterwards so
|
||||||
|
that nested/parallel subagents and the main agent are not affected.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Awaitable, Callable
|
||||||
|
|
||||||
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||||
|
from langchain.tools.tool_node import ToolCallRequest
|
||||||
|
|
||||||
|
from utils.log_util.context import g
|
||||||
|
|
||||||
|
logger = logging.getLogger("app")
|
||||||
|
|
||||||
|
# Context key consumed by utils/log_util/logger.py Formatter.
|
||||||
|
_SUBAGENT_KEY = "subagent"
|
||||||
|
|
||||||
|
|
||||||
|
class SubagentContextMiddleware(AgentMiddleware):
|
||||||
|
"""Set `g.subagent` while this subagent's model/tool calls execute."""
|
||||||
|
|
||||||
|
def __init__(self, subagent_name: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._subagent_name = subagent_name
|
||||||
|
|
||||||
|
def _enter(self) -> dict:
|
||||||
|
# Shallow-copy the whole context dict and rebind a PRIVATE copy for this
|
||||||
|
# context. This is load-bearing: GlobalContext mutates a shared dict in
|
||||||
|
# place, and asyncio task copies share that reference, so a plain
|
||||||
|
# `g.subagent = name` would leak across parallel sibling subagents and
|
||||||
|
# race on restore. Replacing the reference isolates each context.
|
||||||
|
try:
|
||||||
|
prev = dict(g.get_context())
|
||||||
|
except LookupError:
|
||||||
|
prev = {}
|
||||||
|
new_ctx = dict(prev)
|
||||||
|
new_ctx[_SUBAGENT_KEY] = self._subagent_name
|
||||||
|
g.update_context(new_ctx)
|
||||||
|
return prev
|
||||||
|
|
||||||
|
def _exit(self, prev: dict) -> None:
|
||||||
|
# Restore by rebinding the previous dict (also a private copy).
|
||||||
|
g.update_context(dict(prev))
|
||||||
|
|
||||||
|
# ----- model call -----
|
||||||
|
def wrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
|
) -> ModelResponse:
|
||||||
|
prev = self._enter()
|
||||||
|
try:
|
||||||
|
return handler(request)
|
||||||
|
finally:
|
||||||
|
self._exit(prev)
|
||||||
|
|
||||||
|
async def awrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
|
) -> ModelResponse:
|
||||||
|
prev = self._enter()
|
||||||
|
try:
|
||||||
|
return await handler(request)
|
||||||
|
finally:
|
||||||
|
self._exit(prev)
|
||||||
|
|
||||||
|
# ----- tool call -----
|
||||||
|
def wrap_tool_call(
|
||||||
|
self,
|
||||||
|
request: ToolCallRequest,
|
||||||
|
handler: Callable[[ToolCallRequest], Any],
|
||||||
|
) -> Any:
|
||||||
|
prev = self._enter()
|
||||||
|
try:
|
||||||
|
return handler(request)
|
||||||
|
finally:
|
||||||
|
self._exit(prev)
|
||||||
|
|
||||||
|
async def awrap_tool_call(
|
||||||
|
self,
|
||||||
|
request: ToolCallRequest,
|
||||||
|
handler: Callable[[ToolCallRequest], Awaitable[Any]],
|
||||||
|
) -> Any:
|
||||||
|
prev = self._enter()
|
||||||
|
try:
|
||||||
|
return await handler(request)
|
||||||
|
finally:
|
||||||
|
self._exit(prev)
|
||||||
@ -25,6 +25,7 @@ from langchain.tools import BaseTool
|
|||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
from agent.plugin_hook_loader import _get_skill_dirs
|
from agent.plugin_hook_loader import _get_skill_dirs
|
||||||
|
from agent.subagent_context_middleware import SubagentContextMiddleware
|
||||||
|
|
||||||
logger = logging.getLogger('app')
|
logger = logging.getLogger('app')
|
||||||
|
|
||||||
@ -181,6 +182,8 @@ async def load_subagents(
|
|||||||
"system_prompt": parsed["system_prompt"],
|
"system_prompt": parsed["system_prompt"],
|
||||||
"model": model,
|
"model": model,
|
||||||
"tools": filtered_tools,
|
"tools": filtered_tools,
|
||||||
|
# Tag this subagent's model/tool logs with its name.
|
||||||
|
"middleware": [SubagentContextMiddleware(name)],
|
||||||
}
|
}
|
||||||
subagents.append(subagent)
|
subagents.append(subagent)
|
||||||
logger.info(f"Loaded sub-agent '{name}' with {len(filtered_tools)} tools from {parsed['source']}")
|
logger.info(f"Loaded sub-agent '{name}' with {len(filtered_tools)} tools from {parsed['source']}")
|
||||||
|
|||||||
@ -51,6 +51,13 @@ class Formatter(logging.Formatter):
|
|||||||
record.trace_id = getattr(g, "trace_id")
|
record.trace_id = getattr(g, "trace_id")
|
||||||
except LookupError:
|
except LookupError:
|
||||||
record.trace_id = "N/A"
|
record.trace_id = "N/A"
|
||||||
|
# Handle subagent - default to "main" for the orchestrator / no-context paths.
|
||||||
|
# Catch KeyError too: GlobalContext.__getattr__ raises KeyError on a missing key.
|
||||||
|
if not hasattr(record, "subagent"):
|
||||||
|
try:
|
||||||
|
record.subagent = getattr(g, "subagent")
|
||||||
|
except (KeyError, LookupError):
|
||||||
|
record.subagent = "main"
|
||||||
# Handle user_id
|
# Handle user_id
|
||||||
# if not hasattr(record, "user_id"):
|
# if not hasattr(record, "user_id"):
|
||||||
# record.user_id = getattr(g, "user_id")
|
# record.user_id = getattr(g, "user_id")
|
||||||
@ -65,7 +72,7 @@ class Formatter(logging.Formatter):
|
|||||||
def init_logger_once(name,level):
|
def init_logger_once(name,level):
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
logger.setLevel(level=level)
|
logger.setLevel(level=level)
|
||||||
formatter = Formatter("%(timestamp)s | %(levelname)-5s | %(trace_id)s | %(name)s:%(funcName)s:%(lineno)s - %(message)s", datefmt='%Y-%m-%d %H:%M:%S.%f')
|
formatter = Formatter("%(timestamp)s | %(levelname)-5s | %(trace_id)s | %(subagent)s | %(name)s:%(funcName)s:%(lineno)s - %(message)s", datefmt='%Y-%m-%d %H:%M:%S.%f')
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user