Add ToolErrorRecoveryMiddleware as the outermost agent middleware so any tool-call exception (notably MCP ToolException) is converted into a ToolMessage with status="error" carrying the raw error text. The agent can then loop once more and reply to the user in natural language about what failed, instead of bubbling the exception up through agent.astream and breaking the SSE response in routes/chat.py. The recovery layer extracts the inner `text="..."` payload out of the MCP TextContent repr when present, falling back to str(error) otherwise. It deliberately re-raises asyncio.CancelledError so task cancellation still propagates, and sits *outside* ToolMetricsMiddleware so the existing status=error metric is still emitted before recovery kicks in. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
99 lines
3.6 KiB
Python
99 lines
3.6 KiB
Python
"""Outermost middleware that converts tool exceptions into ToolMessage(status="error").
|
|
|
|
When a tool call raises (most commonly an MCP `ToolException` from
|
|
`langchain_mcp_adapters`), LangGraph's default handler re-raises and breaks the
|
|
agent stream, which in turn breaks the SSE response to the client. This
|
|
middleware sits as the outermost wrapper around every tool call and converts any
|
|
caught exception into a ToolMessage so the agent can keep looping and reply to
|
|
the user in natural language about what went wrong.
|
|
|
|
`asyncio.CancelledError` is intentionally not caught — task cancellation must
|
|
propagate. Metric emission (`ToolMetricsMiddleware`) still observes the inner
|
|
`raise` because it sits *inside* this middleware in the chain.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import re
|
|
from typing import Any, Callable
|
|
|
|
from langchain.agents.middleware import AgentMiddleware
|
|
from langchain.tools.tool_node import ToolCallRequest
|
|
from langchain_core.messages import ToolMessage
|
|
|
|
logger = logging.getLogger("app")
|
|
|
|
# Matches `text="..."` (or `text='...'`) inside MCP TextContent repr. Non-greedy
|
|
# so each TextContent in a list is captured separately.
|
|
_TEXT_CONTENT_PATTERN = re.compile(
|
|
r"""TextContent\([^)]*?text=(?P<quote>['"])(?P<text>.*?)(?<!\\)(?P=quote)""",
|
|
re.DOTALL,
|
|
)
|
|
|
|
|
|
class ToolErrorRecoveryMiddleware(AgentMiddleware):
|
|
"""Catch tool-call exceptions and return them as error ToolMessages."""
|
|
|
|
def _extract_error_text(self, error: Exception) -> str:
|
|
"""Pull human-readable text out of an exception.
|
|
|
|
MCP `ToolException` typically wraps a list of `TextContent` objects, so
|
|
their string repr looks like `[TextContent(type='text', text="...", ...)]`.
|
|
Strip the wrapper and keep just the inner `text` fields when present;
|
|
otherwise fall back to `str(error)`.
|
|
"""
|
|
raw = str(error)
|
|
matches = _TEXT_CONTENT_PATTERN.findall(raw)
|
|
if matches:
|
|
# findall returns list of tuples because of named groups; pick group 'text' (index 1).
|
|
return "\n".join(text for _quote, text in matches if text)
|
|
return raw
|
|
|
|
def _build_error_message(
|
|
self,
|
|
request: ToolCallRequest,
|
|
error: Exception,
|
|
) -> ToolMessage:
|
|
tool_call = request.tool_call or {}
|
|
tool_name = tool_call.get("name") or "unknown_tool"
|
|
tool_call_id = tool_call.get("id") or ""
|
|
|
|
error_text = self._extract_error_text(error)
|
|
content = f"Tool '{tool_name}' failed: {error_text}"
|
|
|
|
logger.warning(
|
|
"Tool error recovered as ToolMessage: tool_name=%s error_type=%s",
|
|
tool_name,
|
|
type(error).__name__,
|
|
)
|
|
|
|
return ToolMessage(
|
|
content=content,
|
|
tool_call_id=tool_call_id,
|
|
name=tool_name,
|
|
status="error",
|
|
)
|
|
|
|
def wrap_tool_call(
|
|
self,
|
|
request: ToolCallRequest,
|
|
handler: Callable[[ToolCallRequest], Any],
|
|
) -> Any:
|
|
try:
|
|
return handler(request)
|
|
except Exception as exc: # noqa: BLE001 — outermost recovery
|
|
return self._build_error_message(request, exc)
|
|
|
|
async def awrap_tool_call(
|
|
self,
|
|
request: ToolCallRequest,
|
|
handler: Callable[[ToolCallRequest], Any],
|
|
) -> Any:
|
|
try:
|
|
return await handler(request)
|
|
except asyncio.CancelledError:
|
|
# Cancellation must propagate so the agent task can shut down cleanly.
|
|
raise
|
|
except Exception as exc: # noqa: BLE001 — outermost recovery
|
|
return self._build_error_message(request, exc)
|