"""Outermost middleware that converts tool exceptions into ToolMessage(status="error"). When a tool call raises (most commonly an MCP `ToolException` from `langchain_mcp_adapters`), LangGraph's default handler re-raises and breaks the agent stream, which in turn breaks the SSE response to the client. This middleware sits as the outermost wrapper around every tool call and converts any caught exception into a ToolMessage so the agent can keep looping and reply to the user in natural language about what went wrong. `asyncio.CancelledError` is intentionally not caught — task cancellation must propagate. Metric emission (`ToolMetricsMiddleware`) still observes the inner `raise` because it sits *inside* this middleware in the chain. """ import asyncio import logging import re from typing import Any, Callable from langchain.agents.middleware import AgentMiddleware from langchain.tools.tool_node import ToolCallRequest from langchain_core.messages import ToolMessage logger = logging.getLogger("app") # Matches `text="..."` (or `text='...'`) inside MCP TextContent repr. Non-greedy # so each TextContent in a list is captured separately. _TEXT_CONTENT_PATTERN = re.compile( r"""TextContent\([^)]*?text=(?P['"])(?P.*?)(? str: """Pull human-readable text out of an exception. MCP `ToolException` typically wraps a list of `TextContent` objects, so their string repr looks like `[TextContent(type='text', text="...", ...)]`. Strip the wrapper and keep just the inner `text` fields when present; otherwise fall back to `str(error)`. """ raw = str(error) matches = _TEXT_CONTENT_PATTERN.findall(raw) if matches: # findall returns list of tuples because of named groups; pick group 'text' (index 1). return "\n".join(text for _quote, text in matches if text) return raw def _build_error_message( self, request: ToolCallRequest, error: Exception, ) -> ToolMessage: tool_call = request.tool_call or {} tool_name = tool_call.get("name") or "unknown_tool" tool_call_id = tool_call.get("id") or "" error_text = self._extract_error_text(error) content = f"Tool '{tool_name}' failed: {error_text}" logger.warning( "Tool error recovered as ToolMessage: tool_name=%s error_type=%s", tool_name, type(error).__name__, ) return ToolMessage( content=content, tool_call_id=tool_call_id, name=tool_name, status="error", ) def wrap_tool_call( self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], Any], ) -> Any: try: return handler(request) except Exception as exc: # noqa: BLE001 — outermost recovery return self._build_error_message(request, exc) async def awrap_tool_call( self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], Any], ) -> Any: try: return await handler(request) except asyncio.CancelledError: # Cancellation must propagate so the agent task can shut down cleanly. raise except Exception as exc: # noqa: BLE001 — outermost recovery return self._build_error_message(request, exc)