""" Tool Use Cleanup Middleware for LangGraph Agents This middleware removes tool_use blocks that don't have corresponding tool_result blocks, preventing errors like: `tool_use` ids were found without `tool_result` blocks immediately after. """ import logging from typing import Any, Callable from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse from langchain_core.messages import AIMessage, AnyMessage, ToolMessage logger = logging.getLogger('app') class ToolUseCleanupMiddleware(AgentMiddleware): """ Middleware to clean up orphaned tool_use blocks in messages. Ensures that every tool_use has a corresponding tool_result. If a tool_use doesn't have a tool_result in the next message, it will be removed from the AIMessage. Examples: - AIMessage(tool_calls=[tool_1]) + ToolMessage(tool_call_id=tool_1) -> keep - AIMessage(tool_calls=[tool_1]) + AIMessage(content="你好") -> remove tool_1 - AIMessage(tool_calls=[tool_1], content="请稍候") + AIMessage(content="好了") -> keep content, remove tool_1 """ def __init__(self): self.stats = { 'removed_ai_messages': 0, 'cleaned_ai_messages': 0, 'removed_tool_calls': 0 } logger.info("ToolUseCleanupMiddleware initialized") def _has_meaningful_content(self, message: AIMessage) -> bool: """ Check if an AIMessage has meaningful content (text) besides tool_calls. Args: message: The AIMessage to check Returns: True if the message has non-empty content, False otherwise """ content = message.content # Handle different content types if content is None: return False elif isinstance(content, str): return bool(content.strip()) elif isinstance(content, list): # Content is a list of content blocks (e.g., text, image, etc.) # Check if there are any meaningful text blocks for block in content: if isinstance(block, dict): # Check for text type blocks with non-empty content if block.get('type') == 'text' and block.get('text', '').strip(): return True elif isinstance(block, str) and block.strip(): return True return False else: # Other types, treat as having content if it's truthy return bool(content) def _cleanup_tool_use_messages(self, messages: list[AnyMessage]) -> list[AnyMessage]: """ Clean up messages by removing tool_use blocks that don't have corresponding tool_results. Args: messages: List of messages to clean Returns: Cleaned list of messages """ cleaned_messages = [] i = 0 while i < len(messages): current_msg = messages[i] # Check if current message is an AIMessage with tool_calls if isinstance(current_msg, AIMessage) and current_msg.tool_calls: # Check the next message(s) for corresponding ToolMessages valid_tool_call_ids = set() j = i + 1 while j < len(messages) and isinstance(messages[j], ToolMessage): valid_tool_call_ids.add(messages[j].tool_call_id) j += 1 # Filter tool_calls to only keep those with valid tool_call_ids filtered_tool_calls = [ tc for tc in current_msg.tool_calls if tc.get('id') in valid_tool_call_ids ] removed_count = len(current_msg.tool_calls) - len(filtered_tool_calls) if removed_count > 0: self.stats['removed_tool_calls'] += removed_count logger.warning( f"Removed {removed_count} orphaned tool_use(s) from AIMessage. " f"tool_call_ids: {[tc.get('id') for tc in current_msg.tool_calls]}, " f"valid_ids: {valid_tool_call_ids}" ) has_content = self._has_meaningful_content(current_msg) if filtered_tool_calls: # Has valid tool_calls, keep the message cleaned_msg = AIMessage( content=current_msg.content, tool_calls=filtered_tool_calls, additional_kwargs=current_msg.additional_kwargs, response_metadata=current_msg.response_metadata, id=current_msg.id, name=current_msg.name, ) cleaned_messages.append(cleaned_msg) elif has_content: # No valid tool_calls but has meaningful content, keep without tool_calls cleaned_msg = AIMessage( content=current_msg.content, tool_calls=[], additional_kwargs=current_msg.additional_kwargs, response_metadata=current_msg.response_metadata, id=current_msg.id, name=current_msg.name, ) cleaned_messages.append(cleaned_msg) self.stats['cleaned_ai_messages'] += 1 logger.info( f"Removed all tool_calls from AIMessage but kept content. " f"Content preview: {current_msg.content[:50]}..." ) else: # No valid tool_calls and no meaningful content, completely remove this message self.stats['removed_ai_messages'] += 1 logger.info( f"Removed entire AIMessage with orphaned tool_calls (no meaningful content). " f"Removed tool_call_ids: {[tc.get('id') for tc in current_msg.tool_calls]}" ) # Don't add to cleaned_messages - skip this message entirely else: # Not an AIMessage with tool_calls, add as-is cleaned_messages.append(current_msg) i += 1 return cleaned_messages def wrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse], ) -> ModelResponse: """ Synchronous wrapper to clean up orphaned tool_use blocks before model call. """ cleaned_messages = self._cleanup_tool_use_messages(request.messages) if (self.stats['removed_ai_messages'] > 0 or self.stats['cleaned_ai_messages'] > 0 or self.stats['removed_tool_calls'] > 0): logger.info( f"ToolUseCleanupMiddleware: Removed {self.stats['removed_ai_messages']} messages, " f"cleaned {self.stats['cleaned_ai_messages']} messages, " f"removed {self.stats['removed_tool_calls']} tool_calls." ) # Override with cleaned messages cleaned_request = request.override(messages=cleaned_messages) return handler(cleaned_request) async def awrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], Any], ) -> ModelResponse: """ Async wrapper to clean up orphaned tool_use blocks before model call. """ cleaned_messages = self._cleanup_tool_use_messages(request.messages) if (self.stats['removed_ai_messages'] > 0 or self.stats['cleaned_ai_messages'] > 0 or self.stats['removed_tool_calls'] > 0): logger.info( f"ToolUseCleanupMiddleware: Removed {self.stats['removed_ai_messages']} messages, " f"cleaned {self.stats['cleaned_ai_messages']} messages, " f"removed {self.stats['removed_tool_calls']} tool_calls." ) # Override with cleaned messages cleaned_request = request.override(messages=cleaned_messages) return await handler(cleaned_request) def get_stats(self) -> dict[str, int]: """Get statistics about cleanup activity.""" return self.stats.copy() def reset_stats(self): """Reset cleanup statistics.""" self.stats = { 'removed_ai_messages': 0, 'cleaned_ai_messages': 0, 'removed_tool_calls': 0 }