diff --git a/agent/tool_use_cleanup_middleware.py b/agent/tool_use_cleanup_middleware.py index 83fb733..5758f52 100644 --- a/agent/tool_use_cleanup_middleware.py +++ b/agent/tool_use_cleanup_middleware.py @@ -66,6 +66,45 @@ class ToolUseCleanupMiddleware(AgentMiddleware): # Other types, treat as having content if it's truthy return bool(content) + def _clean_content_blocks(self, content: Any, valid_tool_call_ids: set[str]) -> tuple[Any, bool]: + """ + Clean up content blocks by removing orphaned tool_use blocks. + + Args: + content: The content to clean (can be str, list, or other) + valid_tool_call_ids: Set of valid tool_call_ids to keep + + Returns: + Tuple of (cleaned_content, has_meaningful_content) + """ + if isinstance(content, str): + return content, bool(content.strip()) + elif isinstance(content, list): + # Filter out tool_use blocks that don't have valid ids + cleaned_blocks = [] + has_text_content = False + + for block in content: + if isinstance(block, dict): + block_type = block.get('type') + if block_type == 'tool_use': + # Only keep tool_use blocks with valid ids + tool_id = block.get('id') + if tool_id in valid_tool_call_ids: + cleaned_blocks.append(block) + elif block_type == 'text' and block.get('text', '').strip(): + cleaned_blocks.append(block) + has_text_content = True + else: + # Keep other block types (images, etc.) + cleaned_blocks.append(block) + else: + cleaned_blocks.append(block) + + return cleaned_blocks, has_text_content + else: + return content, bool(content) + def _cleanup_tool_use_messages(self, messages: list[AnyMessage]) -> list[AnyMessage]: """ Clean up messages by removing tool_use blocks that don't have corresponding tool_results. @@ -98,21 +137,25 @@ class ToolUseCleanupMiddleware(AgentMiddleware): ] removed_count = len(current_msg.tool_calls) - len(filtered_tool_calls) + removed_ids = [tc.get('id') for tc in current_msg.tool_calls if tc.get('id') not in valid_tool_call_ids] if removed_count > 0: self.stats['removed_tool_calls'] += removed_count logger.warning( f"Removed {removed_count} orphaned tool_use(s) from AIMessage. " - f"tool_call_ids: {[tc.get('id') for tc in current_msg.tool_calls]}, " + f"tool_call_ids: {removed_ids}, " f"valid_ids: {valid_tool_call_ids}" ) - has_content = self._has_meaningful_content(current_msg) + # Clean content blocks to remove orphaned tool_use blocks + cleaned_content, has_meaningful_content = self._clean_content_blocks( + current_msg.content, valid_tool_call_ids + ) if filtered_tool_calls: # Has valid tool_calls, keep the message cleaned_msg = AIMessage( - content=current_msg.content, + content=cleaned_content, tool_calls=filtered_tool_calls, additional_kwargs=current_msg.additional_kwargs, response_metadata=current_msg.response_metadata, @@ -120,10 +163,10 @@ class ToolUseCleanupMiddleware(AgentMiddleware): name=current_msg.name, ) cleaned_messages.append(cleaned_msg) - elif has_content: + elif has_meaningful_content: # No valid tool_calls but has meaningful content, keep without tool_calls cleaned_msg = AIMessage( - content=current_msg.content, + content=cleaned_content, tool_calls=[], additional_kwargs=current_msg.additional_kwargs, response_metadata=current_msg.response_metadata, @@ -134,14 +177,14 @@ class ToolUseCleanupMiddleware(AgentMiddleware): self.stats['cleaned_ai_messages'] += 1 logger.info( f"Removed all tool_calls from AIMessage but kept content. " - f"Content preview: {current_msg.content[:50]}..." + f"Content preview: {str(cleaned_content)[:200]}..." ) else: # No valid tool_calls and no meaningful content, completely remove this message self.stats['removed_ai_messages'] += 1 logger.info( f"Removed entire AIMessage with orphaned tool_calls (no meaningful content). " - f"Removed tool_call_ids: {[tc.get('id') for tc in current_msg.tool_calls]}" + f"Removed tool_call_ids: {removed_ids}" ) # Don't add to cleaned_messages - skip this message entirely else: diff --git a/mcp/mcp_settings_deep_agent.json b/mcp/mcp_settings_deep_agent.json new file mode 100644 index 0000000..ddf9962 --- /dev/null +++ b/mcp/mcp_settings_deep_agent.json @@ -0,0 +1,14 @@ +[ + { + "mcpServers": { + "rag_retrieve": { + "transport": "stdio", + "command": "python", + "args": [ + "./mcp/rag_retrieve_server.py", + "{bot_id}" + ] + } + } + } +]