diff --git a/agent/deep_assistant.py b/agent/deep_assistant.py index f49a63f..aa6a00c 100644 --- a/agent/deep_assistant.py +++ b/agent/deep_assistant.py @@ -13,7 +13,8 @@ from langgraph.checkpoint.memory import MemorySaver from utils.fastapi_utils import detect_provider from .guideline_middleware import GuidelineMiddleware -from utils.settings import SUMMARIZATION_MAX_TOKENS +from .tool_output_length_middleware import ToolOutputLengthMiddleware +from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH class LoggingCallbackHandler(BaseCallbackHandler): """自定义的 CallbackHandler,使用项目的 logger 来记录日志""" @@ -167,6 +168,17 @@ async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None, # 构建中间件列表 middleware = [GuidelineMiddleware(bot_id, llm_instance, system, robot_type, language, user_identifier)] + # 添加工具输出长度控制中间件 + tool_output_middleware = ToolOutputLengthMiddleware( + max_length=getattr(generate_cfg, 'tool_output_max_length', None) or TOOL_OUTPUT_MAX_LENGTH, + truncation_strategy=getattr(generate_cfg, 'tool_output_truncation_strategy', 'smart'), + tool_filters=getattr(generate_cfg, 'tool_output_filters', None), # 可配置特定工具 + exclude_tools=getattr(generate_cfg, 'tool_output_exclude', []), # 排除的工具 + preserve_code_blocks=getattr(generate_cfg, 'preserve_code_blocks', True), + preserve_json=getattr(generate_cfg, 'preserve_json', True) + ) + middleware.append(tool_output_middleware) + # 初始化 checkpointer 和中间件 checkpointer = None diff --git a/agent/tool_output_length_middleware.py b/agent/tool_output_length_middleware.py new file mode 100644 index 0000000..1afaff8 --- /dev/null +++ b/agent/tool_output_length_middleware.py @@ -0,0 +1,489 @@ +""" +Tool Output Length Middleware for LangGraph Agents + +This middleware provides configurable control over the length of tool outputs, +helping to manage context window usage and improve agent performance. +""" + +import logging +import re +from typing import Any, Dict, List, Optional, Union, Callable +from langchain.agents.middleware import AgentMiddleware, AgentState +from langchain.tools.tool_node import ToolCallRequest +from langchain_core.messages import BaseMessage, ToolMessage, AIMessage +from langgraph.types import Command +from langgraph.runtime import Runtime + +logger = logging.getLogger('app') + + +class ToolOutputLengthMiddleware(AgentMiddleware): + """ + Middleware to control and truncate tool output length in LangGraph agents. + + Features: + - Configurable maximum output length + - Multiple truncation strategies (end, start, smart, preserve_blocks) + - Selective tool filtering + - Metadata tracking for truncated responses + - Comprehensive logging + """ + + def __init__( + self, + max_length: int = 2000, + truncation_strategy: str = "smart", + tool_filters: Optional[List[str]] = None, + exclude_tools: Optional[List[str]] = None, + add_metadata: bool = True, + preserve_code_blocks: bool = True, + preserve_json: bool = True, + ellipsis: str = "\n\n...[response truncated due to length]..." + ): + """ + Initialize the ToolOutputLengthMiddleware. + + Args: + max_length: Maximum character length for tool outputs + truncation_strategy: Strategy for truncation ('end', 'start', 'smart', 'preserve_blocks') + tool_filters: List of tool names to apply truncation to (None = all tools) + exclude_tools: List of tool names to exclude from truncation + add_metadata: Whether to add metadata about truncation + preserve_code_blocks: Whether to preserve code blocks in smart mode + preserve_json: Whether to preserve JSON structure in smart mode + ellipsis: Text to append when truncating + """ + self.max_length = max_length + self.truncation_strategy = truncation_strategy + self.tool_filters = tool_filters + self.exclude_tools = exclude_tools or [] + self.add_metadata = add_metadata + self.preserve_code_blocks = preserve_code_blocks + self.preserve_json = preserve_json + self.ellipsis = ellipsis + + # Statistics tracking + self.stats = { + 'total_tool_calls': 0, + 'truncated_calls': 0, + 'total_chars_saved': 0 + } + + # Compile regex patterns for efficient matching + self.code_block_pattern = re.compile(r'```(\w+)?\n.*?```', re.DOTALL) + self.json_pattern = re.compile(r'\{.*?\}|\[.*?\]', re.DOTALL) + + logger.info(f"ToolOutputLengthMiddleware initialized: max_length={max_length}, strategy={truncation_strategy}") + + def should_process_tool(self, tool_name: str) -> bool: + """Check if a tool should be processed based on filters.""" + # If no filters specified, process all tools + if self.tool_filters is None: + return tool_name not in self.exclude_tools + + # Check if tool is in filter list and not excluded + return tool_name in self.tool_filters and tool_name not in self.exclude_tools + + def truncate_content(self, content: str, tool_name: str) -> tuple[str, Dict[str, Any]]: + """ + Truncate content based on the configured strategy. + + Returns: + Tuple of (truncated_content, metadata) + """ + original_length = len(content) + + if original_length <= self.max_length: + return content, {'truncated': False} + + metadata = { + 'truncated': True, + 'original_length': original_length, + 'truncated_length': 0, + 'strategy': self.truncation_strategy, + 'tool_name': tool_name + } + + if self.truncation_strategy == "end": + truncated = self._truncate_end(content) + elif self.truncation_strategy == "start": + truncated = self._truncate_start(content) + elif self.truncation_strategy == "smart": + truncated = self._smart_truncate(content, tool_name) + elif self.truncation_strategy == "preserve_blocks": + truncated = self._preserve_blocks_truncate(content) + else: + # Default to end truncation + truncated = self._truncate_end(content) + metadata['strategy'] = 'end' + + metadata['truncated_length'] = len(truncated) + metadata['chars_saved'] = original_length - len(truncated) + + return truncated, metadata + + def _truncate_end(self, content: str) -> str: + """Simple truncation from the end.""" + return content[:self.max_length] + self.ellipsis + + def _truncate_start(self, content: str) -> str: + """Truncate from the start, keeping the end.""" + return self.ellipsis + content[-self.max_length:] + + def _smart_truncate(self, content: str, tool_name: str) -> str: + """ + Smart truncation that tries to preserve important content. + + Strategies: + 1. Look for natural break points (paragraphs, sentences) + 2. Preserve code blocks and JSON if enabled + 3. Try to maintain context around key information + """ + # Check if content is primarily code or JSON + if self.preserve_code_blocks and self._is_mainly_code(content): + return self._preserve_code_blocks(content) + + if self.preserve_json and self._is_json(content): + return self._preserve_json(content) + + # Try to find a good breaking point + target_length = self.max_length - len(self.ellipsis) + truncated_part = content[:target_length] + + # Look for paragraph breaks + paragraph_break = truncated_part.rfind('\n\n') + if paragraph_break > target_length * 0.8: + return content[:paragraph_break] + self.ellipsis + + # Look for sentence endings + for end_char in ['.', '!', '?']: + last_pos = truncated_part.rfind(end_char) + if last_pos > target_length * 0.7: + return content[:last_pos + 1] + self.ellipsis + + # Look for line breaks + line_break = truncated_part.rfind('\n') + if line_break > target_length * 0.9: + return content[:line_break] + self.ellipsis + + # Fallback to simple truncation + return truncated_part + self.ellipsis + + def _preserve_blocks_truncate(self, content: str) -> str: + """ + Preserve important blocks (code, JSON, etc.) while truncating the rest. + """ + # Find all important blocks + important_blocks = [] + + if self.preserve_code_blocks: + for match in self.code_block_pattern.finditer(content): + important_blocks.append({ + 'start': match.start(), + 'end': match.end(), + 'content': match.group(), + 'type': 'code' + }) + + if self.preserve_json: + for match in self.json_pattern.finditer(content): + # Avoid duplicating code blocks that contain JSON + if not any(b['start'] <= match.start() <= b['end'] for b in important_blocks): + important_blocks.append({ + 'start': match.start(), + 'end': match.end(), + 'content': match.group(), + 'type': 'json' + }) + + # Sort blocks by position + important_blocks.sort(key=lambda x: x['start']) + + # Calculate total preserved content + preserved_content = "" + current_pos = 0 + + for block in important_blocks: + # Add content before the block (truncated if needed) + before_length = block['start'] - current_pos + if before_length > 0: + available_space = self.max_length - len(preserved_content) - len(block['content']) - len(self.ellipsis) + if available_space > 0: + before_content = content[current_pos:block['start']] + if len(before_content) > available_space: + before_content = before_content[:available_space] + "..." + preserved_content += before_content + + # Add the preserved block + if len(preserved_content) + len(block['content']) <= self.max_length: + preserved_content += block['content'] + current_pos = block['end'] + else: + break + + # Add ellipsis if content was truncated + if len(preserved_content) < len(content): + preserved_content += self.ellipsis + + return preserved_content if preserved_content else content[:self.max_length] + self.ellipsis + + def _is_mainly_code(self, content: str) -> bool: + """Check if content is primarily code.""" + code_matches = len(self.code_block_pattern.findall(content)) + code_chars = sum(len(match[0]) for match in self.code_block_pattern.finditer(content)) + return code_chars > len(content) * 0.5 + + def _is_json(self, content: str) -> bool: + """Check if content is valid JSON.""" + import json + try: + json.loads(content) + return True + except: + return False + + def _preserve_code_blocks(self, content: str) -> str: + """Preserve code blocks while truncating other content.""" + return self._preserve_blocks_truncate(content) + + def _preserve_json(self, content: str) -> str: + """Preserve JSON structure while possibly truncating string values.""" + import json + + try: + data = json.loads(content) + # Recursively truncate string values + def truncate_strings(obj): + if isinstance(obj, str): + if len(obj) > self.max_length // 2: + return obj[:self.max_length // 2] + "...[truncated]..." + return obj + elif isinstance(obj, dict): + return {k: truncate_strings(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [truncate_strings(item) for item in obj] + return obj + + truncated_data = truncate_strings(data) + return json.dumps(truncated_data, ensure_ascii=False, indent=2) + except: + return self._truncate_end(content) + + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + """ + Intercept and potentially truncate tool call responses. + + This method is called when a tool is invoked and allows us to + intercept and modify the tool's response before it's returned + to the agent. + """ + # Get tool name from the request + tool_name = request.tool_call['name'] + + # Check if this tool should be processed + if not self.should_process_tool(tool_name): + # If not in filters, just pass through + return handler(request) + + # Update statistics + self.stats['total_tool_calls'] += 1 + + # Execute the tool to get the response + try: + result = handler(request) + except Exception as e: + logger.error(f"Tool execution failed for '{tool_name}': {e}") + raise + + # Handle different return types + if isinstance(result, ToolMessage): + # Process ToolMessage + content = result.text + if content and len(content) > self.max_length: + # Truncate the content + truncated_content, metadata = self.truncate_content(content, tool_name) + + # Create new ToolMessage with truncated content + truncated_message = ToolMessage( + content=truncated_content, + tool_call_id=result.tool_call_id, + name=result.name + ) + + # Add metadata if requested + if self.add_metadata: + truncated_message.additional_kwargs = result.additional_kwargs.copy() + truncated_message.additional_kwargs.update({ + 'truncation_info': metadata, + 'original_length': len(content) + }) + + # Update statistics + self.stats['truncated_calls'] += 1 + self.stats['total_chars_saved'] += metadata.get('chars_saved', 0) + + logger.info( + f"Tool output truncated for '{tool_name}': " + f"{len(content)} -> {len(truncated_content)} chars " + f"(saved {metadata.get('chars_saved', 0)} chars)" + ) + + return truncated_message + + elif isinstance(result, Command): + # For Command objects, we need to handle the state update + # Check if the command contains messages that need truncation + if 'messages' in result.update: + messages = result.update['messages'] + updated_messages = [] + + for msg in messages: + if isinstance(msg, ToolMessage): + tool_msg_tool_name = getattr(msg, 'name', tool_name) + if self.should_process_tool(tool_msg_tool_name) and len(msg.content) > self.max_length: + # Truncate the ToolMessage content + truncated_content, metadata = self.truncate_content(msg.content, tool_msg_tool_name) + + # Create new ToolMessage + truncated_msg = ToolMessage( + content=truncated_content, + tool_call_id=msg.tool_call_id, + name=msg.name + ) + + # Add metadata if requested + if self.add_metadata: + truncated_msg.additional_kwargs = msg.additional_kwargs.copy() + truncated_msg.additional_kwargs.update({ + 'truncation_info': metadata, + 'original_length': len(msg.content) + }) + + # Update statistics + self.stats['truncated_calls'] += 1 + self.stats['total_chars_saved'] += metadata.get('chars_saved', 0) + + logger.info( + f"Tool output truncated for '{tool_msg_tool_name}' in Command: " + f"{len(msg.content)} -> {len(truncated_content)} chars " + f"(saved {metadata.get('chars_saved', 0)} chars)" + ) + + updated_messages.append(truncated_msg) + else: + updated_messages.append(msg) + else: + updated_messages.append(msg) + + # Update the Command with the modified messages + result.update['messages'] = updated_messages + + return result + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + """Async version of wrap_tool_call.""" + tool_name = request.tool_call['name'] + + if not self.should_process_tool(tool_name): + return await handler(request) + + self.stats['total_tool_calls'] += 1 + + # Execute the tool asynchronously + try: + result = await handler(request) + except Exception as e: + logger.error(f"Tool execution failed for '{tool_name}': {e}") + raise + + # Handle different return types (same logic as sync version) + if isinstance(result, ToolMessage): + content = result.text + if content and len(content) > self.max_length: + truncated_content, metadata = self.truncate_content(content, tool_name) + + truncated_message = ToolMessage( + content=truncated_content, + tool_call_id=result.tool_call_id, + name=result.name + ) + + if self.add_metadata: + truncated_message.additional_kwargs = result.additional_kwargs.copy() + truncated_message.additional_kwargs.update({ + 'truncation_info': metadata, + 'original_length': len(content) + }) + + self.stats['truncated_calls'] += 1 + self.stats['total_chars_saved'] += metadata.get('chars_saved', 0) + + logger.info( + f"Tool output truncated for '{tool_name}': " + f"{len(content)} -> {len(truncated_content)} chars " + f"(saved {metadata.get('chars_saved', 0)} chars)" + ) + + return truncated_message + + elif isinstance(result, Command) and 'messages' in result.update: + messages = result.update['messages'] + updated_messages = [] + + for msg in messages: + if isinstance(msg, ToolMessage): + tool_msg_tool_name = getattr(msg, 'name', tool_name) + if self.should_process_tool(tool_msg_tool_name) and len(msg.content) > self.max_length: + truncated_content, metadata = self.truncate_content(msg.content, tool_msg_tool_name) + + truncated_msg = ToolMessage( + content=truncated_content, + tool_call_id=msg.tool_call_id, + name=msg.name + ) + + if self.add_metadata: + truncated_msg.additional_kwargs = msg.additional_kwargs.copy() + truncated_msg.additional_kwargs.update({ + 'truncation_info': metadata, + 'original_length': len(msg.content) + }) + + self.stats['truncated_calls'] += 1 + self.stats['total_chars_saved'] += metadata.get('chars_saved', 0) + + logger.info( + f"Tool output truncated for '{tool_msg_tool_name}' in Command: " + f"{len(msg.content)} -> {len(truncated_content)} chars " + f"(saved {metadata.get('chars_saved', 0)} chars)" + ) + + updated_messages.append(truncated_msg) + else: + updated_messages.append(msg) + else: + updated_messages.append(msg) + + result.update['messages'] = updated_messages + + return result + + def get_stats(self) -> Dict[str, Any]: + """Get statistics about truncation activity.""" + return self.stats.copy() + + def reset_stats(self): + """Reset truncation statistics.""" + self.stats = { + 'total_tool_calls': 0, + 'truncated_calls': 0, + 'total_chars_saved': 0 + } \ No newline at end of file diff --git a/utils/settings.py b/utils/settings.py index 0714fe0..db08334 100644 --- a/utils/settings.py +++ b/utils/settings.py @@ -1,7 +1,7 @@ import os # LLM Token Settings -MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 65536)) +MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 32679)) MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000)) SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000 @@ -21,7 +21,7 @@ FILE_CACHE_SIZE = int(os.getenv("FILE_CACHE_SIZE", 1000)) FILE_CACHE_TTL = int(os.getenv("FILE_CACHE_TTL", 300)) # API Settings -BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai") +BACKEND_HOST = os.getenv("BACKEND_HOST", "http://backend:8000") MASTERKEY = os.getenv("MASTERKEY", "master") FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') @@ -32,4 +32,8 @@ PROJECT_DATA_DIR = os.getenv("PROJECT_DATA_DIR", "./projects/data") TOKENIZERS_PARALLELISM = os.getenv("TOKENIZERS_PARALLELISM", "true") # Embedding Model Settings -SENTENCE_TRANSFORMER_MODEL = os.getenv("SENTENCE_TRANSFORMER_MODEL", "TaylorAI/gte-tiny") \ No newline at end of file +SENTENCE_TRANSFORMER_MODEL = os.getenv("SENTENCE_TRANSFORMER_MODEL", "TaylorAI/gte-tiny") + +# Tool Output Length Control Settings +TOOL_OUTPUT_MAX_LENGTH = int(SUMMARIZATION_MAX_TOKENS/3) +TOOL_OUTPUT_TRUNCATION_STRATEGY = os.getenv("TOOL_OUTPUT_TRUNCATION_STRATEGY", "smart") \ No newline at end of file