506 lines
20 KiB
Python
506 lines
20 KiB
Python
"""
|
|
Tool Output Length Middleware for LangGraph Agents
|
|
|
|
This middleware provides configurable control over the length of tool outputs,
|
|
helping to manage context window usage and improve agent performance.
|
|
"""
|
|
|
|
import logging
|
|
import re
|
|
from typing import Any, Dict, List, Optional, Union, Callable
|
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
|
from langchain.tools.tool_node import ToolCallRequest
|
|
from langchain_core.messages import BaseMessage, ToolMessage, AIMessage
|
|
from langgraph.types import Command
|
|
from langgraph.runtime import Runtime
|
|
|
|
logger = logging.getLogger('app')
|
|
|
|
|
|
class ToolOutputLengthMiddleware(AgentMiddleware):
|
|
"""
|
|
Middleware to control and truncate tool output length in LangGraph agents.
|
|
|
|
Features:
|
|
- Configurable maximum output length
|
|
- Multiple truncation strategies (end, start, smart, preserve_blocks)
|
|
- Selective tool filtering
|
|
- Metadata tracking for truncated responses
|
|
- Comprehensive logging
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_length: int = 2000,
|
|
truncation_strategy: str = "smart",
|
|
tool_filters: Optional[List[str]] = None,
|
|
exclude_tools: Optional[List[str]] = None,
|
|
add_metadata: bool = True,
|
|
preserve_code_blocks: bool = True,
|
|
preserve_json: bool = True,
|
|
ellipsis: str = "\n\n...[response truncated due to length]..."
|
|
):
|
|
"""
|
|
Initialize the ToolOutputLengthMiddleware.
|
|
|
|
Args:
|
|
max_length: Maximum character length for tool outputs
|
|
truncation_strategy: Strategy for truncation ('end', 'start', 'smart', 'preserve_blocks')
|
|
tool_filters: List of tool names to apply truncation to (None = all tools)
|
|
exclude_tools: List of tool names to exclude from truncation
|
|
add_metadata: Whether to add metadata about truncation
|
|
preserve_code_blocks: Whether to preserve code blocks in smart mode
|
|
preserve_json: Whether to preserve JSON structure in smart mode
|
|
ellipsis: Text to append when truncating
|
|
"""
|
|
self.max_length = max_length
|
|
self.truncation_strategy = truncation_strategy
|
|
self.tool_filters = tool_filters
|
|
self.exclude_tools = exclude_tools or []
|
|
self.add_metadata = add_metadata
|
|
self.preserve_code_blocks = preserve_code_blocks
|
|
self.preserve_json = preserve_json
|
|
self.ellipsis = ellipsis
|
|
|
|
# Statistics tracking
|
|
self.stats = {
|
|
'total_tool_calls': 0,
|
|
'truncated_calls': 0,
|
|
'total_chars_saved': 0
|
|
}
|
|
|
|
# Compile regex patterns for efficient matching
|
|
self.code_block_pattern = re.compile(r'```(\w+)?\n.*?```', re.DOTALL)
|
|
self.json_pattern = re.compile(r'\{.*?\}|\[.*?\]', re.DOTALL)
|
|
|
|
logger.info(f"ToolOutputLengthMiddleware initialized: max_length={max_length}, strategy={truncation_strategy}")
|
|
|
|
def should_process_tool(self, tool_name: str) -> bool:
|
|
"""Check if a tool should be processed based on filters."""
|
|
# If no filters specified, process all tools
|
|
if self.tool_filters is None:
|
|
return tool_name not in self.exclude_tools
|
|
|
|
# Check if tool is in filter list and not excluded
|
|
return tool_name in self.tool_filters and tool_name not in self.exclude_tools
|
|
|
|
def truncate_content(self, content: str, tool_name: str) -> tuple[str, Dict[str, Any]]:
|
|
"""
|
|
Truncate content based on the configured strategy.
|
|
|
|
Returns:
|
|
Tuple of (truncated_content, metadata)
|
|
"""
|
|
original_length = len(content)
|
|
|
|
if original_length <= self.max_length:
|
|
return content, {'truncated': False}
|
|
|
|
metadata = {
|
|
'truncated': True,
|
|
'original_length': original_length,
|
|
'truncated_length': 0,
|
|
'strategy': self.truncation_strategy,
|
|
'tool_name': tool_name
|
|
}
|
|
|
|
if self.truncation_strategy == "end":
|
|
truncated = self._truncate_end(content)
|
|
elif self.truncation_strategy == "start":
|
|
truncated = self._truncate_start(content)
|
|
elif self.truncation_strategy == "smart":
|
|
truncated = self._smart_truncate(content, tool_name)
|
|
elif self.truncation_strategy == "preserve_blocks":
|
|
truncated = self._preserve_blocks_truncate(content)
|
|
else:
|
|
# Default to end truncation
|
|
truncated = self._truncate_end(content)
|
|
metadata['strategy'] = 'end'
|
|
|
|
metadata['truncated_length'] = len(truncated)
|
|
metadata['chars_saved'] = original_length - len(truncated)
|
|
|
|
return truncated, metadata
|
|
|
|
def _truncate_end(self, content: str) -> str:
|
|
"""Simple truncation from the end."""
|
|
return content[:self.max_length] + self.ellipsis
|
|
|
|
def _truncate_start(self, content: str) -> str:
|
|
"""Truncate from the start, keeping the end."""
|
|
return self.ellipsis + content[-self.max_length:]
|
|
|
|
def _smart_truncate(self, content: str, tool_name: str) -> str:
|
|
"""
|
|
Smart truncation that tries to preserve important content.
|
|
|
|
Strategies:
|
|
1. Look for natural break points (paragraphs, sentences)
|
|
2. Preserve code blocks and JSON if enabled
|
|
3. Try to maintain context around key information
|
|
"""
|
|
# Check if content is primarily code or JSON
|
|
if self.preserve_code_blocks and self._is_mainly_code(content):
|
|
return self._preserve_code_blocks(content)
|
|
|
|
if self.preserve_json and self._is_json(content):
|
|
return self._preserve_json(content)
|
|
|
|
# Try to find a good breaking point
|
|
target_length = self.max_length - len(self.ellipsis)
|
|
truncated_part = content[:target_length]
|
|
|
|
# Look for paragraph breaks
|
|
paragraph_break = truncated_part.rfind('\n\n')
|
|
if paragraph_break > target_length * 0.8:
|
|
return content[:paragraph_break] + self.ellipsis
|
|
|
|
# Look for sentence endings
|
|
for end_char in ['.', '!', '?']:
|
|
last_pos = truncated_part.rfind(end_char)
|
|
if last_pos > target_length * 0.7:
|
|
return content[:last_pos + 1] + self.ellipsis
|
|
|
|
# Look for line breaks
|
|
line_break = truncated_part.rfind('\n')
|
|
if line_break > target_length * 0.9:
|
|
return content[:line_break] + self.ellipsis
|
|
|
|
# Fallback to simple truncation
|
|
return truncated_part + self.ellipsis
|
|
|
|
def _preserve_blocks_truncate(self, content: str) -> str:
|
|
"""
|
|
Preserve important blocks (code, JSON, etc.) while truncating the rest.
|
|
"""
|
|
# Find all important blocks
|
|
important_blocks = []
|
|
|
|
if self.preserve_code_blocks:
|
|
for match in self.code_block_pattern.finditer(content):
|
|
important_blocks.append({
|
|
'start': match.start(),
|
|
'end': match.end(),
|
|
'content': match.group(),
|
|
'type': 'code'
|
|
})
|
|
|
|
if self.preserve_json:
|
|
for match in self.json_pattern.finditer(content):
|
|
# Avoid duplicating code blocks that contain JSON
|
|
if not any(b['start'] <= match.start() <= b['end'] for b in important_blocks):
|
|
important_blocks.append({
|
|
'start': match.start(),
|
|
'end': match.end(),
|
|
'content': match.group(),
|
|
'type': 'json'
|
|
})
|
|
|
|
# Sort blocks by position
|
|
important_blocks.sort(key=lambda x: x['start'])
|
|
|
|
# Calculate total preserved content
|
|
preserved_content = ""
|
|
current_pos = 0
|
|
|
|
for block in important_blocks:
|
|
# Add content before the block (truncated if needed)
|
|
before_length = block['start'] - current_pos
|
|
if before_length > 0:
|
|
available_space = self.max_length - len(preserved_content) - len(block['content']) - len(self.ellipsis)
|
|
if available_space > 0:
|
|
before_content = content[current_pos:block['start']]
|
|
if len(before_content) > available_space:
|
|
before_content = before_content[:available_space] + "..."
|
|
preserved_content += before_content
|
|
|
|
# Add the preserved block
|
|
if len(preserved_content) + len(block['content']) <= self.max_length:
|
|
preserved_content += block['content']
|
|
current_pos = block['end']
|
|
else:
|
|
break
|
|
|
|
# Add ellipsis if content was truncated
|
|
if len(preserved_content) < len(content):
|
|
preserved_content += self.ellipsis
|
|
|
|
return preserved_content if preserved_content else content[:self.max_length] + self.ellipsis
|
|
|
|
def _is_mainly_code(self, content: str) -> bool:
|
|
"""Check if content is primarily code."""
|
|
code_matches = len(self.code_block_pattern.findall(content))
|
|
code_chars = sum(len(match[0]) for match in self.code_block_pattern.finditer(content))
|
|
return code_chars > len(content) * 0.5
|
|
|
|
def _is_json(self, content: str) -> bool:
|
|
"""Check if content is valid JSON."""
|
|
import json
|
|
try:
|
|
json.loads(content)
|
|
return True
|
|
except:
|
|
return False
|
|
|
|
def _preserve_code_blocks(self, content: str) -> str:
|
|
"""Preserve code blocks while truncating other content."""
|
|
return self._preserve_blocks_truncate(content)
|
|
|
|
def _preserve_json(self, content: str) -> str:
|
|
"""Preserve JSON structure while possibly truncating string values."""
|
|
import json
|
|
|
|
try:
|
|
data = json.loads(content)
|
|
# Recursively truncate string values
|
|
def truncate_strings(obj):
|
|
if isinstance(obj, str):
|
|
if len(obj) > self.max_length // 2:
|
|
return obj[:self.max_length // 2] + "...[truncated]..."
|
|
return obj
|
|
elif isinstance(obj, dict):
|
|
return {k: truncate_strings(v) for k, v in obj.items()}
|
|
elif isinstance(obj, list):
|
|
return [truncate_strings(item) for item in obj]
|
|
return obj
|
|
|
|
truncated_data = truncate_strings(data)
|
|
return json.dumps(truncated_data, ensure_ascii=False, indent=2)
|
|
except:
|
|
return self._truncate_end(content)
|
|
|
|
def _format_error_message(self, tool_name: str, error: Exception) -> str:
|
|
"""格式化错误消息为用户友好的文本。"""
|
|
# 用户友好的错误消息,不暴露技术细节
|
|
return f"工具调用失败:{tool_name} 暂时无法使用,请稍后重试。"
|
|
|
|
def wrap_tool_call(
|
|
self,
|
|
request: ToolCallRequest,
|
|
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
) -> ToolMessage | Command:
|
|
"""
|
|
Intercept and potentially truncate tool call responses.
|
|
|
|
This method is called when a tool is invoked and allows us to
|
|
intercept and modify the tool's response before it's returned
|
|
to the agent.
|
|
"""
|
|
# Get tool name from the request
|
|
tool_name = request.tool_call['name']
|
|
|
|
# Check if this tool should be processed
|
|
if not self.should_process_tool(tool_name):
|
|
# If not in filters, just pass through
|
|
return handler(request)
|
|
|
|
# Update statistics
|
|
self.stats['total_tool_calls'] += 1
|
|
|
|
# Execute the tool to get the response
|
|
try:
|
|
result = handler(request)
|
|
except Exception as e:
|
|
logger.error(f"Tool execution failed for '{tool_name}': {e}")
|
|
# 返回错误 ToolMessage 而不是重新抛出异常
|
|
error_message = self._format_error_message(tool_name, e)
|
|
return ToolMessage(
|
|
content=error_message,
|
|
tool_call_id=request.tool_call.get('id', ''),
|
|
name=tool_name
|
|
)
|
|
|
|
# Handle different return types
|
|
if isinstance(result, ToolMessage):
|
|
# Process ToolMessage
|
|
content = result.text
|
|
if content and len(content) > self.max_length:
|
|
# Truncate the content
|
|
truncated_content, metadata = self.truncate_content(content, tool_name)
|
|
|
|
# Create new ToolMessage with truncated content
|
|
truncated_message = ToolMessage(
|
|
content=truncated_content,
|
|
tool_call_id=result.tool_call_id,
|
|
name=result.name
|
|
)
|
|
|
|
# Add metadata if requested
|
|
if self.add_metadata:
|
|
truncated_message.additional_kwargs = result.additional_kwargs.copy()
|
|
truncated_message.additional_kwargs.update({
|
|
'truncation_info': metadata,
|
|
'original_length': len(content)
|
|
})
|
|
|
|
# Update statistics
|
|
self.stats['truncated_calls'] += 1
|
|
self.stats['total_chars_saved'] += metadata.get('chars_saved', 0)
|
|
|
|
logger.info(
|
|
f"Tool output truncated for '{tool_name}': "
|
|
f"{len(content)} -> {len(truncated_content)} chars "
|
|
f"(saved {metadata.get('chars_saved', 0)} chars)"
|
|
)
|
|
|
|
return truncated_message
|
|
|
|
elif isinstance(result, Command):
|
|
# For Command objects, we need to handle the state update
|
|
# Check if the command contains messages that need truncation
|
|
if 'messages' in result.update:
|
|
messages = result.update['messages']
|
|
updated_messages = []
|
|
|
|
for msg in messages:
|
|
if isinstance(msg, ToolMessage):
|
|
tool_msg_tool_name = getattr(msg, 'name', tool_name)
|
|
if self.should_process_tool(tool_msg_tool_name) and len(msg.content) > self.max_length:
|
|
# Truncate the ToolMessage content
|
|
truncated_content, metadata = self.truncate_content(msg.content, tool_msg_tool_name)
|
|
|
|
# Create new ToolMessage
|
|
truncated_msg = ToolMessage(
|
|
content=truncated_content,
|
|
tool_call_id=msg.tool_call_id,
|
|
name=msg.name
|
|
)
|
|
|
|
# Add metadata if requested
|
|
if self.add_metadata:
|
|
truncated_msg.additional_kwargs = msg.additional_kwargs.copy()
|
|
truncated_msg.additional_kwargs.update({
|
|
'truncation_info': metadata,
|
|
'original_length': len(msg.content)
|
|
})
|
|
|
|
# Update statistics
|
|
self.stats['truncated_calls'] += 1
|
|
self.stats['total_chars_saved'] += metadata.get('chars_saved', 0)
|
|
|
|
logger.info(
|
|
f"Tool output truncated for '{tool_msg_tool_name}' in Command: "
|
|
f"{len(msg.content)} -> {len(truncated_content)} chars "
|
|
f"(saved {metadata.get('chars_saved', 0)} chars)"
|
|
)
|
|
|
|
updated_messages.append(truncated_msg)
|
|
else:
|
|
updated_messages.append(msg)
|
|
else:
|
|
updated_messages.append(msg)
|
|
|
|
# Update the Command with the modified messages
|
|
result.update['messages'] = updated_messages
|
|
|
|
return result
|
|
|
|
async def awrap_tool_call(
|
|
self,
|
|
request: ToolCallRequest,
|
|
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
) -> ToolMessage | Command:
|
|
"""Async version of wrap_tool_call."""
|
|
tool_name = request.tool_call['name']
|
|
|
|
if not self.should_process_tool(tool_name):
|
|
return await handler(request)
|
|
|
|
self.stats['total_tool_calls'] += 1
|
|
|
|
# Execute the tool asynchronously
|
|
try:
|
|
result = await handler(request)
|
|
except Exception as e:
|
|
logger.error(f"Tool execution failed for '{tool_name}': {e}")
|
|
# 返回错误 ToolMessage 而不是重新抛出异常
|
|
error_message = self._format_error_message(tool_name, e)
|
|
return ToolMessage(
|
|
content=error_message,
|
|
tool_call_id=request.tool_call.get('id', ''),
|
|
name=tool_name
|
|
)
|
|
|
|
# Handle different return types (same logic as sync version)
|
|
if isinstance(result, ToolMessage):
|
|
content = result.text
|
|
if content and len(content) > self.max_length:
|
|
truncated_content, metadata = self.truncate_content(content, tool_name)
|
|
|
|
truncated_message = ToolMessage(
|
|
content=truncated_content,
|
|
tool_call_id=result.tool_call_id,
|
|
name=result.name
|
|
)
|
|
|
|
if self.add_metadata:
|
|
truncated_message.additional_kwargs = result.additional_kwargs.copy()
|
|
truncated_message.additional_kwargs.update({
|
|
'truncation_info': metadata,
|
|
'original_length': len(content)
|
|
})
|
|
|
|
self.stats['truncated_calls'] += 1
|
|
self.stats['total_chars_saved'] += metadata.get('chars_saved', 0)
|
|
|
|
logger.info(
|
|
f"Tool output truncated for '{tool_name}': "
|
|
f"{len(content)} -> {len(truncated_content)} chars "
|
|
f"(saved {metadata.get('chars_saved', 0)} chars)"
|
|
)
|
|
|
|
return truncated_message
|
|
|
|
elif isinstance(result, Command) and 'messages' in result.update:
|
|
messages = result.update['messages']
|
|
updated_messages = []
|
|
|
|
for msg in messages:
|
|
if isinstance(msg, ToolMessage):
|
|
tool_msg_tool_name = getattr(msg, 'name', tool_name)
|
|
if self.should_process_tool(tool_msg_tool_name) and len(msg.content) > self.max_length:
|
|
truncated_content, metadata = self.truncate_content(msg.content, tool_msg_tool_name)
|
|
|
|
truncated_msg = ToolMessage(
|
|
content=truncated_content,
|
|
tool_call_id=msg.tool_call_id,
|
|
name=msg.name
|
|
)
|
|
|
|
if self.add_metadata:
|
|
truncated_msg.additional_kwargs = msg.additional_kwargs.copy()
|
|
truncated_msg.additional_kwargs.update({
|
|
'truncation_info': metadata,
|
|
'original_length': len(msg.content)
|
|
})
|
|
|
|
self.stats['truncated_calls'] += 1
|
|
self.stats['total_chars_saved'] += metadata.get('chars_saved', 0)
|
|
|
|
logger.info(
|
|
f"Tool output truncated for '{tool_msg_tool_name}' in Command: "
|
|
f"{len(msg.content)} -> {len(truncated_content)} chars "
|
|
f"(saved {metadata.get('chars_saved', 0)} chars)"
|
|
)
|
|
|
|
updated_messages.append(truncated_msg)
|
|
else:
|
|
updated_messages.append(msg)
|
|
else:
|
|
updated_messages.append(msg)
|
|
|
|
result.update['messages'] = updated_messages
|
|
|
|
return result
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get statistics about truncation activity."""
|
|
return self.stats.copy()
|
|
|
|
def reset_stats(self):
|
|
"""Reset truncation statistics."""
|
|
self.stats = {
|
|
'total_tool_calls': 0,
|
|
'truncated_calls': 0,
|
|
'total_chars_saved': 0
|
|
} |