qwen_agent/agent/tool_error_recovery_middleware.py
朱潮 06974e9744 feat: recover tool exceptions into ToolMessage so SSE stream keeps flowing
Add ToolErrorRecoveryMiddleware as the outermost agent middleware so any
tool-call exception (notably MCP ToolException) is converted into a
ToolMessage with status="error" carrying the raw error text. The agent
can then loop once more and reply to the user in natural language about
what failed, instead of bubbling the exception up through agent.astream
and breaking the SSE response in routes/chat.py.

The recovery layer extracts the inner `text="..."` payload out of the MCP
TextContent repr when present, falling back to str(error) otherwise. It
deliberately re-raises asyncio.CancelledError so task cancellation still
propagates, and sits *outside* ToolMetricsMiddleware so the existing
status=error metric is still emitted before recovery kicks in.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-20 13:07:19 +08:00

99 lines
3.6 KiB
Python

"""Outermost middleware that converts tool exceptions into ToolMessage(status="error").
When a tool call raises (most commonly an MCP `ToolException` from
`langchain_mcp_adapters`), LangGraph's default handler re-raises and breaks the
agent stream, which in turn breaks the SSE response to the client. This
middleware sits as the outermost wrapper around every tool call and converts any
caught exception into a ToolMessage so the agent can keep looping and reply to
the user in natural language about what went wrong.
`asyncio.CancelledError` is intentionally not caught — task cancellation must
propagate. Metric emission (`ToolMetricsMiddleware`) still observes the inner
`raise` because it sits *inside* this middleware in the chain.
"""
import asyncio
import logging
import re
from typing import Any, Callable
from langchain.agents.middleware import AgentMiddleware
from langchain.tools.tool_node import ToolCallRequest
from langchain_core.messages import ToolMessage
logger = logging.getLogger("app")
# Matches `text="..."` (or `text='...'`) inside MCP TextContent repr. Non-greedy
# so each TextContent in a list is captured separately.
_TEXT_CONTENT_PATTERN = re.compile(
r"""TextContent\([^)]*?text=(?P<quote>['"])(?P<text>.*?)(?<!\\)(?P=quote)""",
re.DOTALL,
)
class ToolErrorRecoveryMiddleware(AgentMiddleware):
"""Catch tool-call exceptions and return them as error ToolMessages."""
def _extract_error_text(self, error: Exception) -> str:
"""Pull human-readable text out of an exception.
MCP `ToolException` typically wraps a list of `TextContent` objects, so
their string repr looks like `[TextContent(type='text', text="...", ...)]`.
Strip the wrapper and keep just the inner `text` fields when present;
otherwise fall back to `str(error)`.
"""
raw = str(error)
matches = _TEXT_CONTENT_PATTERN.findall(raw)
if matches:
# findall returns list of tuples because of named groups; pick group 'text' (index 1).
return "\n".join(text for _quote, text in matches if text)
return raw
def _build_error_message(
self,
request: ToolCallRequest,
error: Exception,
) -> ToolMessage:
tool_call = request.tool_call or {}
tool_name = tool_call.get("name") or "unknown_tool"
tool_call_id = tool_call.get("id") or ""
error_text = self._extract_error_text(error)
content = f"Tool '{tool_name}' failed: {error_text}"
logger.warning(
"Tool error recovered as ToolMessage: tool_name=%s error_type=%s",
tool_name,
type(error).__name__,
)
return ToolMessage(
content=content,
tool_call_id=tool_call_id,
name=tool_name,
status="error",
)
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Any],
) -> Any:
try:
return handler(request)
except Exception as exc: # noqa: BLE001 — outermost recovery
return self._build_error_message(request, exc)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Any],
) -> Any:
try:
return await handler(request)
except asyncio.CancelledError:
# Cancellation must propagate so the agent task can shut down cleanly.
raise
except Exception as exc: # noqa: BLE001 — outermost recovery
return self._build_error_message(request, exc)