""" Tool Output Length Middleware for LangGraph Agents This middleware provides configurable control over the length of tool outputs, helping to manage context window usage and improve agent performance. """ import logging import re from typing import Any, Dict, List, Optional, Union, Callable from langchain.agents.middleware import AgentMiddleware, AgentState from langchain.tools.tool_node import ToolCallRequest from langchain_core.messages import BaseMessage, ToolMessage, AIMessage from langgraph.types import Command from langgraph.runtime import Runtime logger = logging.getLogger('app') class ToolOutputLengthMiddleware(AgentMiddleware): """ Middleware to control and truncate tool output length in LangGraph agents. Features: - Configurable maximum output length - Multiple truncation strategies (end, start, smart, preserve_blocks) - Selective tool filtering - Metadata tracking for truncated responses - Comprehensive logging """ def __init__( self, max_length: int = 2000, truncation_strategy: str = "smart", tool_filters: Optional[List[str]] = None, exclude_tools: Optional[List[str]] = None, add_metadata: bool = True, preserve_code_blocks: bool = True, preserve_json: bool = True, ellipsis: str = "\n\n...[response truncated due to length]..." ): """ Initialize the ToolOutputLengthMiddleware. Args: max_length: Maximum character length for tool outputs truncation_strategy: Strategy for truncation ('end', 'start', 'smart', 'preserve_blocks') tool_filters: List of tool names to apply truncation to (None = all tools) exclude_tools: List of tool names to exclude from truncation add_metadata: Whether to add metadata about truncation preserve_code_blocks: Whether to preserve code blocks in smart mode preserve_json: Whether to preserve JSON structure in smart mode ellipsis: Text to append when truncating """ self.max_length = max_length self.truncation_strategy = truncation_strategy self.tool_filters = tool_filters self.exclude_tools = exclude_tools or [] self.add_metadata = add_metadata self.preserve_code_blocks = preserve_code_blocks self.preserve_json = preserve_json self.ellipsis = ellipsis # Statistics tracking self.stats = { 'total_tool_calls': 0, 'truncated_calls': 0, 'total_chars_saved': 0 } # Compile regex patterns for efficient matching self.code_block_pattern = re.compile(r'```(\w+)?\n.*?```', re.DOTALL) self.json_pattern = re.compile(r'\{.*?\}|\[.*?\]', re.DOTALL) logger.info(f"ToolOutputLengthMiddleware initialized: max_length={max_length}, strategy={truncation_strategy}") def should_process_tool(self, tool_name: str) -> bool: """Check if a tool should be processed based on filters.""" # If no filters specified, process all tools if self.tool_filters is None: return tool_name not in self.exclude_tools # Check if tool is in filter list and not excluded return tool_name in self.tool_filters and tool_name not in self.exclude_tools def truncate_content(self, content: str, tool_name: str) -> tuple[str, Dict[str, Any]]: """ Truncate content based on the configured strategy. Returns: Tuple of (truncated_content, metadata) """ original_length = len(content) if original_length <= self.max_length: return content, {'truncated': False} metadata = { 'truncated': True, 'original_length': original_length, 'truncated_length': 0, 'strategy': self.truncation_strategy, 'tool_name': tool_name } if self.truncation_strategy == "end": truncated = self._truncate_end(content) elif self.truncation_strategy == "start": truncated = self._truncate_start(content) elif self.truncation_strategy == "smart": truncated = self._smart_truncate(content, tool_name) elif self.truncation_strategy == "preserve_blocks": truncated = self._preserve_blocks_truncate(content) else: # Default to end truncation truncated = self._truncate_end(content) metadata['strategy'] = 'end' metadata['truncated_length'] = len(truncated) metadata['chars_saved'] = original_length - len(truncated) return truncated, metadata def _truncate_end(self, content: str) -> str: """Simple truncation from the end.""" return content[:self.max_length] + self.ellipsis def _truncate_start(self, content: str) -> str: """Truncate from the start, keeping the end.""" return self.ellipsis + content[-self.max_length:] def _smart_truncate(self, content: str, tool_name: str) -> str: """ Smart truncation that tries to preserve important content. Strategies: 1. Look for natural break points (paragraphs, sentences) 2. Preserve code blocks and JSON if enabled 3. Try to maintain context around key information """ # Check if content is primarily code or JSON if self.preserve_code_blocks and self._is_mainly_code(content): return self._preserve_code_blocks(content) if self.preserve_json and self._is_json(content): return self._preserve_json(content) # Try to find a good breaking point target_length = self.max_length - len(self.ellipsis) truncated_part = content[:target_length] # Look for paragraph breaks paragraph_break = truncated_part.rfind('\n\n') if paragraph_break > target_length * 0.8: return content[:paragraph_break] + self.ellipsis # Look for sentence endings for end_char in ['.', '!', '?']: last_pos = truncated_part.rfind(end_char) if last_pos > target_length * 0.7: return content[:last_pos + 1] + self.ellipsis # Look for line breaks line_break = truncated_part.rfind('\n') if line_break > target_length * 0.9: return content[:line_break] + self.ellipsis # Fallback to simple truncation return truncated_part + self.ellipsis def _preserve_blocks_truncate(self, content: str) -> str: """ Preserve important blocks (code, JSON, etc.) while truncating the rest. """ # Find all important blocks important_blocks = [] if self.preserve_code_blocks: for match in self.code_block_pattern.finditer(content): important_blocks.append({ 'start': match.start(), 'end': match.end(), 'content': match.group(), 'type': 'code' }) if self.preserve_json: for match in self.json_pattern.finditer(content): # Avoid duplicating code blocks that contain JSON if not any(b['start'] <= match.start() <= b['end'] for b in important_blocks): important_blocks.append({ 'start': match.start(), 'end': match.end(), 'content': match.group(), 'type': 'json' }) # Sort blocks by position important_blocks.sort(key=lambda x: x['start']) # Calculate total preserved content preserved_content = "" current_pos = 0 for block in important_blocks: # Add content before the block (truncated if needed) before_length = block['start'] - current_pos if before_length > 0: available_space = self.max_length - len(preserved_content) - len(block['content']) - len(self.ellipsis) if available_space > 0: before_content = content[current_pos:block['start']] if len(before_content) > available_space: before_content = before_content[:available_space] + "..." preserved_content += before_content # Add the preserved block if len(preserved_content) + len(block['content']) <= self.max_length: preserved_content += block['content'] current_pos = block['end'] else: break # Add ellipsis if content was truncated if len(preserved_content) < len(content): preserved_content += self.ellipsis return preserved_content if preserved_content else content[:self.max_length] + self.ellipsis def _is_mainly_code(self, content: str) -> bool: """Check if content is primarily code.""" code_matches = len(self.code_block_pattern.findall(content)) code_chars = sum(len(match[0]) for match in self.code_block_pattern.finditer(content)) return code_chars > len(content) * 0.5 def _is_json(self, content: str) -> bool: """Check if content is valid JSON.""" import json try: json.loads(content) return True except: return False def _preserve_code_blocks(self, content: str) -> str: """Preserve code blocks while truncating other content.""" return self._preserve_blocks_truncate(content) def _preserve_json(self, content: str) -> str: """Preserve JSON structure while possibly truncating string values.""" import json try: data = json.loads(content) # Recursively truncate string values def truncate_strings(obj): if isinstance(obj, str): if len(obj) > self.max_length // 2: return obj[:self.max_length // 2] + "...[truncated]..." return obj elif isinstance(obj, dict): return {k: truncate_strings(v) for k, v in obj.items()} elif isinstance(obj, list): return [truncate_strings(item) for item in obj] return obj truncated_data = truncate_strings(data) return json.dumps(truncated_data, ensure_ascii=False, indent=2) except: return self._truncate_end(content) def _format_error_message(self, tool_name: str, error: Exception) -> str: """格式化错误消息为用户友好的文本。""" # 用户友好的错误消息,不暴露技术细节 return f"工具调用失败:{tool_name} 暂时无法使用,请稍后重试。" def wrap_tool_call( self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command], ) -> ToolMessage | Command: """ Intercept and potentially truncate tool call responses. This method is called when a tool is invoked and allows us to intercept and modify the tool's response before it's returned to the agent. """ # Get tool name from the request tool_name = request.tool_call['name'] # Check if this tool should be processed if not self.should_process_tool(tool_name): # If not in filters, just pass through return handler(request) # Update statistics self.stats['total_tool_calls'] += 1 # Execute the tool to get the response try: result = handler(request) except Exception as e: logger.error(f"Tool execution failed for '{tool_name}': {e}") # 返回错误 ToolMessage 而不是重新抛出异常 error_message = self._format_error_message(tool_name, e) return ToolMessage( content=error_message, tool_call_id=request.tool_call.get('id', ''), name=tool_name ) # Handle different return types if isinstance(result, ToolMessage): # Process ToolMessage content = result.text if content and len(content) > self.max_length: # Truncate the content truncated_content, metadata = self.truncate_content(content, tool_name) # Create new ToolMessage with truncated content truncated_message = ToolMessage( content=truncated_content, tool_call_id=result.tool_call_id, name=result.name ) # Add metadata if requested if self.add_metadata: truncated_message.additional_kwargs = result.additional_kwargs.copy() truncated_message.additional_kwargs.update({ 'truncation_info': metadata, 'original_length': len(content) }) # Update statistics self.stats['truncated_calls'] += 1 self.stats['total_chars_saved'] += metadata.get('chars_saved', 0) logger.info( f"Tool output truncated for '{tool_name}': " f"{len(content)} -> {len(truncated_content)} chars " f"(saved {metadata.get('chars_saved', 0)} chars)" ) return truncated_message elif isinstance(result, Command): # For Command objects, we need to handle the state update # Check if the command contains messages that need truncation if 'messages' in result.update: messages = result.update['messages'] updated_messages = [] for msg in messages: if isinstance(msg, ToolMessage): tool_msg_tool_name = getattr(msg, 'name', tool_name) if self.should_process_tool(tool_msg_tool_name) and len(msg.content) > self.max_length: # Truncate the ToolMessage content truncated_content, metadata = self.truncate_content(msg.content, tool_msg_tool_name) # Create new ToolMessage truncated_msg = ToolMessage( content=truncated_content, tool_call_id=msg.tool_call_id, name=msg.name ) # Add metadata if requested if self.add_metadata: truncated_msg.additional_kwargs = msg.additional_kwargs.copy() truncated_msg.additional_kwargs.update({ 'truncation_info': metadata, 'original_length': len(msg.content) }) # Update statistics self.stats['truncated_calls'] += 1 self.stats['total_chars_saved'] += metadata.get('chars_saved', 0) logger.info( f"Tool output truncated for '{tool_msg_tool_name}' in Command: " f"{len(msg.content)} -> {len(truncated_content)} chars " f"(saved {metadata.get('chars_saved', 0)} chars)" ) updated_messages.append(truncated_msg) else: updated_messages.append(msg) else: updated_messages.append(msg) # Update the Command with the modified messages result.update['messages'] = updated_messages return result async def awrap_tool_call( self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command], ) -> ToolMessage | Command: """Async version of wrap_tool_call.""" tool_name = request.tool_call['name'] if not self.should_process_tool(tool_name): return await handler(request) self.stats['total_tool_calls'] += 1 # Execute the tool asynchronously try: result = await handler(request) except Exception as e: logger.error(f"Tool execution failed for '{tool_name}': {e}") # 返回错误 ToolMessage 而不是重新抛出异常 error_message = self._format_error_message(tool_name, e) return ToolMessage( content=error_message, tool_call_id=request.tool_call.get('id', ''), name=tool_name ) # Handle different return types (same logic as sync version) if isinstance(result, ToolMessage): content = result.text if content and len(content) > self.max_length: truncated_content, metadata = self.truncate_content(content, tool_name) truncated_message = ToolMessage( content=truncated_content, tool_call_id=result.tool_call_id, name=result.name ) if self.add_metadata: truncated_message.additional_kwargs = result.additional_kwargs.copy() truncated_message.additional_kwargs.update({ 'truncation_info': metadata, 'original_length': len(content) }) self.stats['truncated_calls'] += 1 self.stats['total_chars_saved'] += metadata.get('chars_saved', 0) logger.info( f"Tool output truncated for '{tool_name}': " f"{len(content)} -> {len(truncated_content)} chars " f"(saved {metadata.get('chars_saved', 0)} chars)" ) return truncated_message elif isinstance(result, Command) and 'messages' in result.update: messages = result.update['messages'] updated_messages = [] for msg in messages: if isinstance(msg, ToolMessage): tool_msg_tool_name = getattr(msg, 'name', tool_name) if self.should_process_tool(tool_msg_tool_name) and len(msg.content) > self.max_length: truncated_content, metadata = self.truncate_content(msg.content, tool_msg_tool_name) truncated_msg = ToolMessage( content=truncated_content, tool_call_id=msg.tool_call_id, name=msg.name ) if self.add_metadata: truncated_msg.additional_kwargs = msg.additional_kwargs.copy() truncated_msg.additional_kwargs.update({ 'truncation_info': metadata, 'original_length': len(msg.content) }) self.stats['truncated_calls'] += 1 self.stats['total_chars_saved'] += metadata.get('chars_saved', 0) logger.info( f"Tool output truncated for '{tool_msg_tool_name}' in Command: " f"{len(msg.content)} -> {len(truncated_content)} chars " f"(saved {metadata.get('chars_saved', 0)} chars)" ) updated_messages.append(truncated_msg) else: updated_messages.append(msg) else: updated_messages.append(msg) result.update['messages'] = updated_messages return result def get_stats(self) -> Dict[str, Any]: """Get statistics about truncation activity.""" return self.stats.copy() def reset_stats(self): """Reset truncation statistics.""" self.stats = { 'total_tool_calls': 0, 'truncated_calls': 0, 'total_chars_saved': 0 }