qwen_agent/agent/tool_output_length_middleware.py
2025-12-23 20:13:46 +08:00

506 lines
20 KiB
Python

"""
Tool Output Length Middleware for LangGraph Agents
This middleware provides configurable control over the length of tool outputs,
helping to manage context window usage and improve agent performance.
"""
import logging
import re
from typing import Any, Dict, List, Optional, Union, Callable
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain.tools.tool_node import ToolCallRequest
from langchain_core.messages import BaseMessage, ToolMessage, AIMessage
from langgraph.types import Command
from langgraph.runtime import Runtime
logger = logging.getLogger('app')
class ToolOutputLengthMiddleware(AgentMiddleware):
"""
Middleware to control and truncate tool output length in LangGraph agents.
Features:
- Configurable maximum output length
- Multiple truncation strategies (end, start, smart, preserve_blocks)
- Selective tool filtering
- Metadata tracking for truncated responses
- Comprehensive logging
"""
def __init__(
self,
max_length: int = 2000,
truncation_strategy: str = "smart",
tool_filters: Optional[List[str]] = None,
exclude_tools: Optional[List[str]] = None,
add_metadata: bool = True,
preserve_code_blocks: bool = True,
preserve_json: bool = True,
ellipsis: str = "\n\n...[response truncated due to length]..."
):
"""
Initialize the ToolOutputLengthMiddleware.
Args:
max_length: Maximum character length for tool outputs
truncation_strategy: Strategy for truncation ('end', 'start', 'smart', 'preserve_blocks')
tool_filters: List of tool names to apply truncation to (None = all tools)
exclude_tools: List of tool names to exclude from truncation
add_metadata: Whether to add metadata about truncation
preserve_code_blocks: Whether to preserve code blocks in smart mode
preserve_json: Whether to preserve JSON structure in smart mode
ellipsis: Text to append when truncating
"""
self.max_length = max_length
self.truncation_strategy = truncation_strategy
self.tool_filters = tool_filters
self.exclude_tools = exclude_tools or []
self.add_metadata = add_metadata
self.preserve_code_blocks = preserve_code_blocks
self.preserve_json = preserve_json
self.ellipsis = ellipsis
# Statistics tracking
self.stats = {
'total_tool_calls': 0,
'truncated_calls': 0,
'total_chars_saved': 0
}
# Compile regex patterns for efficient matching
self.code_block_pattern = re.compile(r'```(\w+)?\n.*?```', re.DOTALL)
self.json_pattern = re.compile(r'\{.*?\}|\[.*?\]', re.DOTALL)
logger.info(f"ToolOutputLengthMiddleware initialized: max_length={max_length}, strategy={truncation_strategy}")
def should_process_tool(self, tool_name: str) -> bool:
"""Check if a tool should be processed based on filters."""
# If no filters specified, process all tools
if self.tool_filters is None:
return tool_name not in self.exclude_tools
# Check if tool is in filter list and not excluded
return tool_name in self.tool_filters and tool_name not in self.exclude_tools
def truncate_content(self, content: str, tool_name: str) -> tuple[str, Dict[str, Any]]:
"""
Truncate content based on the configured strategy.
Returns:
Tuple of (truncated_content, metadata)
"""
original_length = len(content)
if original_length <= self.max_length:
return content, {'truncated': False}
metadata = {
'truncated': True,
'original_length': original_length,
'truncated_length': 0,
'strategy': self.truncation_strategy,
'tool_name': tool_name
}
if self.truncation_strategy == "end":
truncated = self._truncate_end(content)
elif self.truncation_strategy == "start":
truncated = self._truncate_start(content)
elif self.truncation_strategy == "smart":
truncated = self._smart_truncate(content, tool_name)
elif self.truncation_strategy == "preserve_blocks":
truncated = self._preserve_blocks_truncate(content)
else:
# Default to end truncation
truncated = self._truncate_end(content)
metadata['strategy'] = 'end'
metadata['truncated_length'] = len(truncated)
metadata['chars_saved'] = original_length - len(truncated)
return truncated, metadata
def _truncate_end(self, content: str) -> str:
"""Simple truncation from the end."""
return content[:self.max_length] + self.ellipsis
def _truncate_start(self, content: str) -> str:
"""Truncate from the start, keeping the end."""
return self.ellipsis + content[-self.max_length:]
def _smart_truncate(self, content: str, tool_name: str) -> str:
"""
Smart truncation that tries to preserve important content.
Strategies:
1. Look for natural break points (paragraphs, sentences)
2. Preserve code blocks and JSON if enabled
3. Try to maintain context around key information
"""
# Check if content is primarily code or JSON
if self.preserve_code_blocks and self._is_mainly_code(content):
return self._preserve_code_blocks(content)
if self.preserve_json and self._is_json(content):
return self._preserve_json(content)
# Try to find a good breaking point
target_length = self.max_length - len(self.ellipsis)
truncated_part = content[:target_length]
# Look for paragraph breaks
paragraph_break = truncated_part.rfind('\n\n')
if paragraph_break > target_length * 0.8:
return content[:paragraph_break] + self.ellipsis
# Look for sentence endings
for end_char in ['.', '!', '?']:
last_pos = truncated_part.rfind(end_char)
if last_pos > target_length * 0.7:
return content[:last_pos + 1] + self.ellipsis
# Look for line breaks
line_break = truncated_part.rfind('\n')
if line_break > target_length * 0.9:
return content[:line_break] + self.ellipsis
# Fallback to simple truncation
return truncated_part + self.ellipsis
def _preserve_blocks_truncate(self, content: str) -> str:
"""
Preserve important blocks (code, JSON, etc.) while truncating the rest.
"""
# Find all important blocks
important_blocks = []
if self.preserve_code_blocks:
for match in self.code_block_pattern.finditer(content):
important_blocks.append({
'start': match.start(),
'end': match.end(),
'content': match.group(),
'type': 'code'
})
if self.preserve_json:
for match in self.json_pattern.finditer(content):
# Avoid duplicating code blocks that contain JSON
if not any(b['start'] <= match.start() <= b['end'] for b in important_blocks):
important_blocks.append({
'start': match.start(),
'end': match.end(),
'content': match.group(),
'type': 'json'
})
# Sort blocks by position
important_blocks.sort(key=lambda x: x['start'])
# Calculate total preserved content
preserved_content = ""
current_pos = 0
for block in important_blocks:
# Add content before the block (truncated if needed)
before_length = block['start'] - current_pos
if before_length > 0:
available_space = self.max_length - len(preserved_content) - len(block['content']) - len(self.ellipsis)
if available_space > 0:
before_content = content[current_pos:block['start']]
if len(before_content) > available_space:
before_content = before_content[:available_space] + "..."
preserved_content += before_content
# Add the preserved block
if len(preserved_content) + len(block['content']) <= self.max_length:
preserved_content += block['content']
current_pos = block['end']
else:
break
# Add ellipsis if content was truncated
if len(preserved_content) < len(content):
preserved_content += self.ellipsis
return preserved_content if preserved_content else content[:self.max_length] + self.ellipsis
def _is_mainly_code(self, content: str) -> bool:
"""Check if content is primarily code."""
code_matches = len(self.code_block_pattern.findall(content))
code_chars = sum(len(match[0]) for match in self.code_block_pattern.finditer(content))
return code_chars > len(content) * 0.5
def _is_json(self, content: str) -> bool:
"""Check if content is valid JSON."""
import json
try:
json.loads(content)
return True
except:
return False
def _preserve_code_blocks(self, content: str) -> str:
"""Preserve code blocks while truncating other content."""
return self._preserve_blocks_truncate(content)
def _preserve_json(self, content: str) -> str:
"""Preserve JSON structure while possibly truncating string values."""
import json
try:
data = json.loads(content)
# Recursively truncate string values
def truncate_strings(obj):
if isinstance(obj, str):
if len(obj) > self.max_length // 2:
return obj[:self.max_length // 2] + "...[truncated]..."
return obj
elif isinstance(obj, dict):
return {k: truncate_strings(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [truncate_strings(item) for item in obj]
return obj
truncated_data = truncate_strings(data)
return json.dumps(truncated_data, ensure_ascii=False, indent=2)
except:
return self._truncate_end(content)
def _format_error_message(self, tool_name: str, error: Exception) -> str:
"""格式化错误消息为用户友好的文本。"""
# 用户友好的错误消息,不暴露技术细节
return f"工具调用失败:{tool_name} 暂时无法使用,请稍后重试。"
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""
Intercept and potentially truncate tool call responses.
This method is called when a tool is invoked and allows us to
intercept and modify the tool's response before it's returned
to the agent.
"""
# Get tool name from the request
tool_name = request.tool_call['name']
# Check if this tool should be processed
if not self.should_process_tool(tool_name):
# If not in filters, just pass through
return handler(request)
# Update statistics
self.stats['total_tool_calls'] += 1
# Execute the tool to get the response
try:
result = handler(request)
except Exception as e:
logger.error(f"Tool execution failed for '{tool_name}': {e}")
# 返回错误 ToolMessage 而不是重新抛出异常
error_message = self._format_error_message(tool_name, e)
return ToolMessage(
content=error_message,
tool_call_id=request.tool_call.get('id', ''),
name=tool_name
)
# Handle different return types
if isinstance(result, ToolMessage):
# Process ToolMessage
content = result.text
if content and len(content) > self.max_length:
# Truncate the content
truncated_content, metadata = self.truncate_content(content, tool_name)
# Create new ToolMessage with truncated content
truncated_message = ToolMessage(
content=truncated_content,
tool_call_id=result.tool_call_id,
name=result.name
)
# Add metadata if requested
if self.add_metadata:
truncated_message.additional_kwargs = result.additional_kwargs.copy()
truncated_message.additional_kwargs.update({
'truncation_info': metadata,
'original_length': len(content)
})
# Update statistics
self.stats['truncated_calls'] += 1
self.stats['total_chars_saved'] += metadata.get('chars_saved', 0)
logger.info(
f"Tool output truncated for '{tool_name}': "
f"{len(content)} -> {len(truncated_content)} chars "
f"(saved {metadata.get('chars_saved', 0)} chars)"
)
return truncated_message
elif isinstance(result, Command):
# For Command objects, we need to handle the state update
# Check if the command contains messages that need truncation
if 'messages' in result.update:
messages = result.update['messages']
updated_messages = []
for msg in messages:
if isinstance(msg, ToolMessage):
tool_msg_tool_name = getattr(msg, 'name', tool_name)
if self.should_process_tool(tool_msg_tool_name) and len(msg.content) > self.max_length:
# Truncate the ToolMessage content
truncated_content, metadata = self.truncate_content(msg.content, tool_msg_tool_name)
# Create new ToolMessage
truncated_msg = ToolMessage(
content=truncated_content,
tool_call_id=msg.tool_call_id,
name=msg.name
)
# Add metadata if requested
if self.add_metadata:
truncated_msg.additional_kwargs = msg.additional_kwargs.copy()
truncated_msg.additional_kwargs.update({
'truncation_info': metadata,
'original_length': len(msg.content)
})
# Update statistics
self.stats['truncated_calls'] += 1
self.stats['total_chars_saved'] += metadata.get('chars_saved', 0)
logger.info(
f"Tool output truncated for '{tool_msg_tool_name}' in Command: "
f"{len(msg.content)} -> {len(truncated_content)} chars "
f"(saved {metadata.get('chars_saved', 0)} chars)"
)
updated_messages.append(truncated_msg)
else:
updated_messages.append(msg)
else:
updated_messages.append(msg)
# Update the Command with the modified messages
result.update['messages'] = updated_messages
return result
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Async version of wrap_tool_call."""
tool_name = request.tool_call['name']
if not self.should_process_tool(tool_name):
return await handler(request)
self.stats['total_tool_calls'] += 1
# Execute the tool asynchronously
try:
result = await handler(request)
except Exception as e:
logger.error(f"Tool execution failed for '{tool_name}': {e}")
# 返回错误 ToolMessage 而不是重新抛出异常
error_message = self._format_error_message(tool_name, e)
return ToolMessage(
content=error_message,
tool_call_id=request.tool_call.get('id', ''),
name=tool_name
)
# Handle different return types (same logic as sync version)
if isinstance(result, ToolMessage):
content = result.text
if content and len(content) > self.max_length:
truncated_content, metadata = self.truncate_content(content, tool_name)
truncated_message = ToolMessage(
content=truncated_content,
tool_call_id=result.tool_call_id,
name=result.name
)
if self.add_metadata:
truncated_message.additional_kwargs = result.additional_kwargs.copy()
truncated_message.additional_kwargs.update({
'truncation_info': metadata,
'original_length': len(content)
})
self.stats['truncated_calls'] += 1
self.stats['total_chars_saved'] += metadata.get('chars_saved', 0)
logger.info(
f"Tool output truncated for '{tool_name}': "
f"{len(content)} -> {len(truncated_content)} chars "
f"(saved {metadata.get('chars_saved', 0)} chars)"
)
return truncated_message
elif isinstance(result, Command) and 'messages' in result.update:
messages = result.update['messages']
updated_messages = []
for msg in messages:
if isinstance(msg, ToolMessage):
tool_msg_tool_name = getattr(msg, 'name', tool_name)
if self.should_process_tool(tool_msg_tool_name) and len(msg.content) > self.max_length:
truncated_content, metadata = self.truncate_content(msg.content, tool_msg_tool_name)
truncated_msg = ToolMessage(
content=truncated_content,
tool_call_id=msg.tool_call_id,
name=msg.name
)
if self.add_metadata:
truncated_msg.additional_kwargs = msg.additional_kwargs.copy()
truncated_msg.additional_kwargs.update({
'truncation_info': metadata,
'original_length': len(msg.content)
})
self.stats['truncated_calls'] += 1
self.stats['total_chars_saved'] += metadata.get('chars_saved', 0)
logger.info(
f"Tool output truncated for '{tool_msg_tool_name}' in Command: "
f"{len(msg.content)} -> {len(truncated_content)} chars "
f"(saved {metadata.get('chars_saved', 0)} chars)"
)
updated_messages.append(truncated_msg)
else:
updated_messages.append(msg)
else:
updated_messages.append(msg)
result.update['messages'] = updated_messages
return result
def get_stats(self) -> Dict[str, Any]:
"""Get statistics about truncation activity."""
return self.stats.copy()
def reset_stats(self):
"""Reset truncation statistics."""
self.stats = {
'total_tool_calls': 0,
'truncated_calls': 0,
'total_chars_saved': 0
}