TOOL_OUTPUT_MAX_LENGTH

This commit is contained in:
朱潮 2025-12-15 23:54:32 +08:00
parent d077b447f0
commit b6975e1762
3 changed files with 509 additions and 4 deletions

View File

@ -13,7 +13,8 @@ from langgraph.checkpoint.memory import MemorySaver
from utils.fastapi_utils import detect_provider from utils.fastapi_utils import detect_provider
from .guideline_middleware import GuidelineMiddleware from .guideline_middleware import GuidelineMiddleware
from utils.settings import SUMMARIZATION_MAX_TOKENS from .tool_output_length_middleware import ToolOutputLengthMiddleware
from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH
class LoggingCallbackHandler(BaseCallbackHandler): class LoggingCallbackHandler(BaseCallbackHandler):
"""自定义的 CallbackHandler使用项目的 logger 来记录日志""" """自定义的 CallbackHandler使用项目的 logger 来记录日志"""
@ -167,6 +168,17 @@ async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None,
# 构建中间件列表 # 构建中间件列表
middleware = [GuidelineMiddleware(bot_id, llm_instance, system, robot_type, language, user_identifier)] middleware = [GuidelineMiddleware(bot_id, llm_instance, system, robot_type, language, user_identifier)]
# 添加工具输出长度控制中间件
tool_output_middleware = ToolOutputLengthMiddleware(
max_length=getattr(generate_cfg, 'tool_output_max_length', None) or TOOL_OUTPUT_MAX_LENGTH,
truncation_strategy=getattr(generate_cfg, 'tool_output_truncation_strategy', 'smart'),
tool_filters=getattr(generate_cfg, 'tool_output_filters', None), # 可配置特定工具
exclude_tools=getattr(generate_cfg, 'tool_output_exclude', []), # 排除的工具
preserve_code_blocks=getattr(generate_cfg, 'preserve_code_blocks', True),
preserve_json=getattr(generate_cfg, 'preserve_json', True)
)
middleware.append(tool_output_middleware)
# 初始化 checkpointer 和中间件 # 初始化 checkpointer 和中间件
checkpointer = None checkpointer = None

View File

@ -0,0 +1,489 @@
"""
Tool Output Length Middleware for LangGraph Agents
This middleware provides configurable control over the length of tool outputs,
helping to manage context window usage and improve agent performance.
"""
import logging
import re
from typing import Any, Dict, List, Optional, Union, Callable
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain.tools.tool_node import ToolCallRequest
from langchain_core.messages import BaseMessage, ToolMessage, AIMessage
from langgraph.types import Command
from langgraph.runtime import Runtime
logger = logging.getLogger('app')
class ToolOutputLengthMiddleware(AgentMiddleware):
"""
Middleware to control and truncate tool output length in LangGraph agents.
Features:
- Configurable maximum output length
- Multiple truncation strategies (end, start, smart, preserve_blocks)
- Selective tool filtering
- Metadata tracking for truncated responses
- Comprehensive logging
"""
def __init__(
self,
max_length: int = 2000,
truncation_strategy: str = "smart",
tool_filters: Optional[List[str]] = None,
exclude_tools: Optional[List[str]] = None,
add_metadata: bool = True,
preserve_code_blocks: bool = True,
preserve_json: bool = True,
ellipsis: str = "\n\n...[response truncated due to length]..."
):
"""
Initialize the ToolOutputLengthMiddleware.
Args:
max_length: Maximum character length for tool outputs
truncation_strategy: Strategy for truncation ('end', 'start', 'smart', 'preserve_blocks')
tool_filters: List of tool names to apply truncation to (None = all tools)
exclude_tools: List of tool names to exclude from truncation
add_metadata: Whether to add metadata about truncation
preserve_code_blocks: Whether to preserve code blocks in smart mode
preserve_json: Whether to preserve JSON structure in smart mode
ellipsis: Text to append when truncating
"""
self.max_length = max_length
self.truncation_strategy = truncation_strategy
self.tool_filters = tool_filters
self.exclude_tools = exclude_tools or []
self.add_metadata = add_metadata
self.preserve_code_blocks = preserve_code_blocks
self.preserve_json = preserve_json
self.ellipsis = ellipsis
# Statistics tracking
self.stats = {
'total_tool_calls': 0,
'truncated_calls': 0,
'total_chars_saved': 0
}
# Compile regex patterns for efficient matching
self.code_block_pattern = re.compile(r'```(\w+)?\n.*?```', re.DOTALL)
self.json_pattern = re.compile(r'\{.*?\}|\[.*?\]', re.DOTALL)
logger.info(f"ToolOutputLengthMiddleware initialized: max_length={max_length}, strategy={truncation_strategy}")
def should_process_tool(self, tool_name: str) -> bool:
"""Check if a tool should be processed based on filters."""
# If no filters specified, process all tools
if self.tool_filters is None:
return tool_name not in self.exclude_tools
# Check if tool is in filter list and not excluded
return tool_name in self.tool_filters and tool_name not in self.exclude_tools
def truncate_content(self, content: str, tool_name: str) -> tuple[str, Dict[str, Any]]:
"""
Truncate content based on the configured strategy.
Returns:
Tuple of (truncated_content, metadata)
"""
original_length = len(content)
if original_length <= self.max_length:
return content, {'truncated': False}
metadata = {
'truncated': True,
'original_length': original_length,
'truncated_length': 0,
'strategy': self.truncation_strategy,
'tool_name': tool_name
}
if self.truncation_strategy == "end":
truncated = self._truncate_end(content)
elif self.truncation_strategy == "start":
truncated = self._truncate_start(content)
elif self.truncation_strategy == "smart":
truncated = self._smart_truncate(content, tool_name)
elif self.truncation_strategy == "preserve_blocks":
truncated = self._preserve_blocks_truncate(content)
else:
# Default to end truncation
truncated = self._truncate_end(content)
metadata['strategy'] = 'end'
metadata['truncated_length'] = len(truncated)
metadata['chars_saved'] = original_length - len(truncated)
return truncated, metadata
def _truncate_end(self, content: str) -> str:
"""Simple truncation from the end."""
return content[:self.max_length] + self.ellipsis
def _truncate_start(self, content: str) -> str:
"""Truncate from the start, keeping the end."""
return self.ellipsis + content[-self.max_length:]
def _smart_truncate(self, content: str, tool_name: str) -> str:
"""
Smart truncation that tries to preserve important content.
Strategies:
1. Look for natural break points (paragraphs, sentences)
2. Preserve code blocks and JSON if enabled
3. Try to maintain context around key information
"""
# Check if content is primarily code or JSON
if self.preserve_code_blocks and self._is_mainly_code(content):
return self._preserve_code_blocks(content)
if self.preserve_json and self._is_json(content):
return self._preserve_json(content)
# Try to find a good breaking point
target_length = self.max_length - len(self.ellipsis)
truncated_part = content[:target_length]
# Look for paragraph breaks
paragraph_break = truncated_part.rfind('\n\n')
if paragraph_break > target_length * 0.8:
return content[:paragraph_break] + self.ellipsis
# Look for sentence endings
for end_char in ['.', '!', '?']:
last_pos = truncated_part.rfind(end_char)
if last_pos > target_length * 0.7:
return content[:last_pos + 1] + self.ellipsis
# Look for line breaks
line_break = truncated_part.rfind('\n')
if line_break > target_length * 0.9:
return content[:line_break] + self.ellipsis
# Fallback to simple truncation
return truncated_part + self.ellipsis
def _preserve_blocks_truncate(self, content: str) -> str:
"""
Preserve important blocks (code, JSON, etc.) while truncating the rest.
"""
# Find all important blocks
important_blocks = []
if self.preserve_code_blocks:
for match in self.code_block_pattern.finditer(content):
important_blocks.append({
'start': match.start(),
'end': match.end(),
'content': match.group(),
'type': 'code'
})
if self.preserve_json:
for match in self.json_pattern.finditer(content):
# Avoid duplicating code blocks that contain JSON
if not any(b['start'] <= match.start() <= b['end'] for b in important_blocks):
important_blocks.append({
'start': match.start(),
'end': match.end(),
'content': match.group(),
'type': 'json'
})
# Sort blocks by position
important_blocks.sort(key=lambda x: x['start'])
# Calculate total preserved content
preserved_content = ""
current_pos = 0
for block in important_blocks:
# Add content before the block (truncated if needed)
before_length = block['start'] - current_pos
if before_length > 0:
available_space = self.max_length - len(preserved_content) - len(block['content']) - len(self.ellipsis)
if available_space > 0:
before_content = content[current_pos:block['start']]
if len(before_content) > available_space:
before_content = before_content[:available_space] + "..."
preserved_content += before_content
# Add the preserved block
if len(preserved_content) + len(block['content']) <= self.max_length:
preserved_content += block['content']
current_pos = block['end']
else:
break
# Add ellipsis if content was truncated
if len(preserved_content) < len(content):
preserved_content += self.ellipsis
return preserved_content if preserved_content else content[:self.max_length] + self.ellipsis
def _is_mainly_code(self, content: str) -> bool:
"""Check if content is primarily code."""
code_matches = len(self.code_block_pattern.findall(content))
code_chars = sum(len(match[0]) for match in self.code_block_pattern.finditer(content))
return code_chars > len(content) * 0.5
def _is_json(self, content: str) -> bool:
"""Check if content is valid JSON."""
import json
try:
json.loads(content)
return True
except:
return False
def _preserve_code_blocks(self, content: str) -> str:
"""Preserve code blocks while truncating other content."""
return self._preserve_blocks_truncate(content)
def _preserve_json(self, content: str) -> str:
"""Preserve JSON structure while possibly truncating string values."""
import json
try:
data = json.loads(content)
# Recursively truncate string values
def truncate_strings(obj):
if isinstance(obj, str):
if len(obj) > self.max_length // 2:
return obj[:self.max_length // 2] + "...[truncated]..."
return obj
elif isinstance(obj, dict):
return {k: truncate_strings(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [truncate_strings(item) for item in obj]
return obj
truncated_data = truncate_strings(data)
return json.dumps(truncated_data, ensure_ascii=False, indent=2)
except:
return self._truncate_end(content)
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""
Intercept and potentially truncate tool call responses.
This method is called when a tool is invoked and allows us to
intercept and modify the tool's response before it's returned
to the agent.
"""
# Get tool name from the request
tool_name = request.tool_call['name']
# Check if this tool should be processed
if not self.should_process_tool(tool_name):
# If not in filters, just pass through
return handler(request)
# Update statistics
self.stats['total_tool_calls'] += 1
# Execute the tool to get the response
try:
result = handler(request)
except Exception as e:
logger.error(f"Tool execution failed for '{tool_name}': {e}")
raise
# Handle different return types
if isinstance(result, ToolMessage):
# Process ToolMessage
content = result.text
if content and len(content) > self.max_length:
# Truncate the content
truncated_content, metadata = self.truncate_content(content, tool_name)
# Create new ToolMessage with truncated content
truncated_message = ToolMessage(
content=truncated_content,
tool_call_id=result.tool_call_id,
name=result.name
)
# Add metadata if requested
if self.add_metadata:
truncated_message.additional_kwargs = result.additional_kwargs.copy()
truncated_message.additional_kwargs.update({
'truncation_info': metadata,
'original_length': len(content)
})
# Update statistics
self.stats['truncated_calls'] += 1
self.stats['total_chars_saved'] += metadata.get('chars_saved', 0)
logger.info(
f"Tool output truncated for '{tool_name}': "
f"{len(content)} -> {len(truncated_content)} chars "
f"(saved {metadata.get('chars_saved', 0)} chars)"
)
return truncated_message
elif isinstance(result, Command):
# For Command objects, we need to handle the state update
# Check if the command contains messages that need truncation
if 'messages' in result.update:
messages = result.update['messages']
updated_messages = []
for msg in messages:
if isinstance(msg, ToolMessage):
tool_msg_tool_name = getattr(msg, 'name', tool_name)
if self.should_process_tool(tool_msg_tool_name) and len(msg.content) > self.max_length:
# Truncate the ToolMessage content
truncated_content, metadata = self.truncate_content(msg.content, tool_msg_tool_name)
# Create new ToolMessage
truncated_msg = ToolMessage(
content=truncated_content,
tool_call_id=msg.tool_call_id,
name=msg.name
)
# Add metadata if requested
if self.add_metadata:
truncated_msg.additional_kwargs = msg.additional_kwargs.copy()
truncated_msg.additional_kwargs.update({
'truncation_info': metadata,
'original_length': len(msg.content)
})
# Update statistics
self.stats['truncated_calls'] += 1
self.stats['total_chars_saved'] += metadata.get('chars_saved', 0)
logger.info(
f"Tool output truncated for '{tool_msg_tool_name}' in Command: "
f"{len(msg.content)} -> {len(truncated_content)} chars "
f"(saved {metadata.get('chars_saved', 0)} chars)"
)
updated_messages.append(truncated_msg)
else:
updated_messages.append(msg)
else:
updated_messages.append(msg)
# Update the Command with the modified messages
result.update['messages'] = updated_messages
return result
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Async version of wrap_tool_call."""
tool_name = request.tool_call['name']
if not self.should_process_tool(tool_name):
return await handler(request)
self.stats['total_tool_calls'] += 1
# Execute the tool asynchronously
try:
result = await handler(request)
except Exception as e:
logger.error(f"Tool execution failed for '{tool_name}': {e}")
raise
# Handle different return types (same logic as sync version)
if isinstance(result, ToolMessage):
content = result.text
if content and len(content) > self.max_length:
truncated_content, metadata = self.truncate_content(content, tool_name)
truncated_message = ToolMessage(
content=truncated_content,
tool_call_id=result.tool_call_id,
name=result.name
)
if self.add_metadata:
truncated_message.additional_kwargs = result.additional_kwargs.copy()
truncated_message.additional_kwargs.update({
'truncation_info': metadata,
'original_length': len(content)
})
self.stats['truncated_calls'] += 1
self.stats['total_chars_saved'] += metadata.get('chars_saved', 0)
logger.info(
f"Tool output truncated for '{tool_name}': "
f"{len(content)} -> {len(truncated_content)} chars "
f"(saved {metadata.get('chars_saved', 0)} chars)"
)
return truncated_message
elif isinstance(result, Command) and 'messages' in result.update:
messages = result.update['messages']
updated_messages = []
for msg in messages:
if isinstance(msg, ToolMessage):
tool_msg_tool_name = getattr(msg, 'name', tool_name)
if self.should_process_tool(tool_msg_tool_name) and len(msg.content) > self.max_length:
truncated_content, metadata = self.truncate_content(msg.content, tool_msg_tool_name)
truncated_msg = ToolMessage(
content=truncated_content,
tool_call_id=msg.tool_call_id,
name=msg.name
)
if self.add_metadata:
truncated_msg.additional_kwargs = msg.additional_kwargs.copy()
truncated_msg.additional_kwargs.update({
'truncation_info': metadata,
'original_length': len(msg.content)
})
self.stats['truncated_calls'] += 1
self.stats['total_chars_saved'] += metadata.get('chars_saved', 0)
logger.info(
f"Tool output truncated for '{tool_msg_tool_name}' in Command: "
f"{len(msg.content)} -> {len(truncated_content)} chars "
f"(saved {metadata.get('chars_saved', 0)} chars)"
)
updated_messages.append(truncated_msg)
else:
updated_messages.append(msg)
else:
updated_messages.append(msg)
result.update['messages'] = updated_messages
return result
def get_stats(self) -> Dict[str, Any]:
"""Get statistics about truncation activity."""
return self.stats.copy()
def reset_stats(self):
"""Reset truncation statistics."""
self.stats = {
'total_tool_calls': 0,
'truncated_calls': 0,
'total_chars_saved': 0
}

View File

@ -1,7 +1,7 @@
import os import os
# LLM Token Settings # LLM Token Settings
MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 65536)) MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 32679))
MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000)) MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000))
SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000 SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000
@ -21,7 +21,7 @@ FILE_CACHE_SIZE = int(os.getenv("FILE_CACHE_SIZE", 1000))
FILE_CACHE_TTL = int(os.getenv("FILE_CACHE_TTL", 300)) FILE_CACHE_TTL = int(os.getenv("FILE_CACHE_TTL", 300))
# API Settings # API Settings
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai") BACKEND_HOST = os.getenv("BACKEND_HOST", "http://backend:8000")
MASTERKEY = os.getenv("MASTERKEY", "master") MASTERKEY = os.getenv("MASTERKEY", "master")
FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
@ -33,3 +33,7 @@ TOKENIZERS_PARALLELISM = os.getenv("TOKENIZERS_PARALLELISM", "true")
# Embedding Model Settings # Embedding Model Settings
SENTENCE_TRANSFORMER_MODEL = os.getenv("SENTENCE_TRANSFORMER_MODEL", "TaylorAI/gte-tiny") SENTENCE_TRANSFORMER_MODEL = os.getenv("SENTENCE_TRANSFORMER_MODEL", "TaylorAI/gte-tiny")
# Tool Output Length Control Settings
TOOL_OUTPUT_MAX_LENGTH = int(SUMMARIZATION_MAX_TOKENS/3)
TOOL_OUTPUT_TRUNCATION_STRATEGY = os.getenv("TOOL_OUTPUT_TRUNCATION_STRATEGY", "smart")