qwen_agent/agent/tool_use_cleanup_middleware.py

255 lines
10 KiB
Python

"""
Tool Use Cleanup Middleware for LangGraph Agents
This middleware removes tool_use blocks that don't have corresponding tool_result blocks,
preventing errors like: `tool_use` ids were found without `tool_result` blocks immediately after.
"""
import logging
from typing import Any, Callable
from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse
from langchain_core.messages import AIMessage, AnyMessage, ToolMessage
logger = logging.getLogger('app')
class ToolUseCleanupMiddleware(AgentMiddleware):
"""
Middleware to clean up orphaned tool_use blocks in messages.
Ensures that every tool_use has a corresponding tool_result. If a tool_use
doesn't have a tool_result in the next message, it will be removed from the AIMessage.
Examples:
- AIMessage(tool_calls=[tool_1]) + ToolMessage(tool_call_id=tool_1) -> keep
- AIMessage(tool_calls=[tool_1]) + AIMessage(content="你好") -> remove tool_1
- AIMessage(tool_calls=[tool_1], content="请稍候") + AIMessage(content="好了") -> keep content, remove tool_1
"""
def __init__(self):
self.stats = {
'removed_ai_messages': 0,
'cleaned_ai_messages': 0,
'removed_tool_calls': 0
}
logger.info("ToolUseCleanupMiddleware initialized")
def _has_meaningful_content(self, message: AIMessage) -> bool:
"""
Check if an AIMessage has meaningful content (text) besides tool_calls.
Args:
message: The AIMessage to check
Returns:
True if the message has non-empty content, False otherwise
"""
content = message.content
# Handle different content types
if content is None:
return False
elif isinstance(content, str):
return bool(content.strip())
elif isinstance(content, list):
# Content is a list of content blocks (e.g., text, image, etc.)
# Check if there are any meaningful text blocks
for block in content:
if isinstance(block, dict):
# Check for text type blocks with non-empty content
if block.get('type') == 'text' and block.get('text', '').strip():
return True
elif isinstance(block, str) and block.strip():
return True
return False
else:
# Other types, treat as having content if it's truthy
return bool(content)
def _clean_content_blocks(self, content: Any, valid_tool_call_ids: set[str]) -> tuple[Any, bool]:
"""
Clean up content blocks by removing orphaned tool_use blocks.
Args:
content: The content to clean (can be str, list, or other)
valid_tool_call_ids: Set of valid tool_call_ids to keep
Returns:
Tuple of (cleaned_content, has_meaningful_content)
"""
if isinstance(content, str):
return content, bool(content.strip())
elif isinstance(content, list):
# Filter out tool_use blocks that don't have valid ids
cleaned_blocks = []
has_text_content = False
for block in content:
if isinstance(block, dict):
block_type = block.get('type')
if block_type == 'tool_use':
# Only keep tool_use blocks with valid ids
tool_id = block.get('id')
if tool_id in valid_tool_call_ids:
cleaned_blocks.append(block)
elif block_type == 'text' and block.get('text', '').strip():
cleaned_blocks.append(block)
has_text_content = True
else:
# Keep other block types (images, etc.)
cleaned_blocks.append(block)
else:
cleaned_blocks.append(block)
return cleaned_blocks, has_text_content
else:
return content, bool(content)
def _cleanup_tool_use_messages(self, messages: list[AnyMessage]) -> list[AnyMessage]:
"""
Clean up messages by removing tool_use blocks that don't have corresponding tool_results.
Args:
messages: List of messages to clean
Returns:
Cleaned list of messages
"""
cleaned_messages = []
i = 0
while i < len(messages):
current_msg = messages[i]
# Check if current message is an AIMessage with tool_calls
if isinstance(current_msg, AIMessage) and current_msg.tool_calls:
# Check the next message(s) for corresponding ToolMessages
valid_tool_call_ids = set()
j = i + 1
while j < len(messages) and isinstance(messages[j], ToolMessage):
valid_tool_call_ids.add(messages[j].tool_call_id)
j += 1
# Filter tool_calls to only keep those with valid tool_call_ids
filtered_tool_calls = [
tc for tc in current_msg.tool_calls
if tc.get('id') in valid_tool_call_ids
]
removed_count = len(current_msg.tool_calls) - len(filtered_tool_calls)
removed_ids = [tc.get('id') for tc in current_msg.tool_calls if tc.get('id') not in valid_tool_call_ids]
if removed_count > 0:
self.stats['removed_tool_calls'] += removed_count
logger.warning(
f"Removed {removed_count} orphaned tool_use(s) from AIMessage. "
f"tool_call_ids: {removed_ids}, "
f"valid_ids: {valid_tool_call_ids}"
)
# Clean content blocks to remove orphaned tool_use blocks
cleaned_content, has_meaningful_content = self._clean_content_blocks(
current_msg.content, valid_tool_call_ids
)
if filtered_tool_calls:
# Has valid tool_calls, keep the message
cleaned_msg = AIMessage(
content=cleaned_content,
tool_calls=filtered_tool_calls,
additional_kwargs=current_msg.additional_kwargs,
response_metadata=current_msg.response_metadata,
id=current_msg.id,
name=current_msg.name,
)
cleaned_messages.append(cleaned_msg)
elif has_meaningful_content:
# No valid tool_calls but has meaningful content, keep without tool_calls
cleaned_msg = AIMessage(
content=cleaned_content,
tool_calls=[],
additional_kwargs=current_msg.additional_kwargs,
response_metadata=current_msg.response_metadata,
id=current_msg.id,
name=current_msg.name,
)
cleaned_messages.append(cleaned_msg)
self.stats['cleaned_ai_messages'] += 1
logger.info(
f"Removed all tool_calls from AIMessage but kept content. "
f"Content preview: {str(cleaned_content)[:200]}..."
)
else:
# No valid tool_calls and no meaningful content, completely remove this message
self.stats['removed_ai_messages'] += 1
logger.info(
f"Removed entire AIMessage with orphaned tool_calls (no meaningful content). "
f"Removed tool_call_ids: {removed_ids}"
)
# Don't add to cleaned_messages - skip this message entirely
else:
# Not an AIMessage with tool_calls, add as-is
cleaned_messages.append(current_msg)
i += 1
return cleaned_messages
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
"""
Synchronous wrapper to clean up orphaned tool_use blocks before model call.
"""
cleaned_messages = self._cleanup_tool_use_messages(request.messages)
if (self.stats['removed_ai_messages'] > 0 or
self.stats['cleaned_ai_messages'] > 0 or
self.stats['removed_tool_calls'] > 0):
logger.info(
f"ToolUseCleanupMiddleware: Removed {self.stats['removed_ai_messages']} messages, "
f"cleaned {self.stats['cleaned_ai_messages']} messages, "
f"removed {self.stats['removed_tool_calls']} tool_calls."
)
# Override with cleaned messages
cleaned_request = request.override(messages=cleaned_messages)
return handler(cleaned_request)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Any],
) -> ModelResponse:
"""
Async wrapper to clean up orphaned tool_use blocks before model call.
"""
cleaned_messages = self._cleanup_tool_use_messages(request.messages)
if (self.stats['removed_ai_messages'] > 0 or
self.stats['cleaned_ai_messages'] > 0 or
self.stats['removed_tool_calls'] > 0):
logger.info(
f"ToolUseCleanupMiddleware: Removed {self.stats['removed_ai_messages']} messages, "
f"cleaned {self.stats['cleaned_ai_messages']} messages, "
f"removed {self.stats['removed_tool_calls']} tool_calls."
)
# Override with cleaned messages
cleaned_request = request.override(messages=cleaned_messages)
return await handler(cleaned_request)
def get_stats(self) -> dict[str, int]:
"""Get statistics about cleanup activity."""
return self.stats.copy()
def reset_stats(self):
"""Reset cleanup statistics."""
self.stats = {
'removed_ai_messages': 0,
'cleaned_ai_messages': 0,
'removed_tool_calls': 0
}