From 61c6b69aa50b82aa0930a668b78fc875a3a8aa6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Tue, 23 Dec 2025 12:04:26 +0800 Subject: [PATCH] add agent/tool_use_cleanup_middleware.py --- agent/deep_assistant.py | 3 + agent/logging_handler.py | 24 ++- agent/tool_use_cleanup_middleware.py | 211 +++++++++++++++++++++++++++ 3 files changed, 237 insertions(+), 1 deletion(-) create mode 100644 agent/tool_use_cleanup_middleware.py diff --git a/agent/deep_assistant.py b/agent/deep_assistant.py index 54c6d27..dd04083 100644 --- a/agent/deep_assistant.py +++ b/agent/deep_assistant.py @@ -11,6 +11,7 @@ from langchain_mcp_adapters.client import MultiServerMCPClient from utils.fastapi_utils import detect_provider from .guideline_middleware import GuidelineMiddleware from .tool_output_length_middleware import ToolOutputLengthMiddleware +from .tool_use_cleanup_middleware import ToolUseCleanupMiddleware from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH from agent.agent_config import AgentConfig from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async @@ -141,6 +142,8 @@ async def init_agent(config: AgentConfig): else: # 构建中间件列表 middleware = [] + # 首先添加 ToolUseCleanupMiddleware 来清理孤立的 tool_use + middleware.append(ToolUseCleanupMiddleware()) # 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware if config.enable_thinking: middleware.append(GuidelineMiddleware(llm_instance, config, system_prompt)) diff --git a/agent/logging_handler.py b/agent/logging_handler.py index 96eea99..e54e003 100644 --- a/agent/logging_handler.py +++ b/agent/logging_handler.py @@ -1,8 +1,9 @@ """日志回调处理器模块""" import logging -from typing import Any, Optional, Dict +from typing import Any, Optional, Dict, List from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.messages import BaseMessage class LoggingCallbackHandler(BaseCallbackHandler): @@ -11,6 +12,27 @@ class LoggingCallbackHandler(BaseCallbackHandler): def __init__(self, logger_name: str = 'app'): self.logger = logging.getLogger(logger_name) + # def on_chat_model_start( + # self, + # serialized: Dict[str, Any], + # messages: List[List[BaseMessage]], + # **kwargs: Any, + # ) -> None: + # """当 Chat 模型开始时调用""" + # self.logger.info("✅ Chat Model Start - Messages:") + # for msg_list in messages: + # for msg in msg_list: + # msg_type = msg.__class__.__name__ + # content = msg.content if hasattr(msg, 'content') else str(msg) + # self.logger.info(f"[{msg_type}] {content}") + + # def on_llm_start(self, serialized: Dict[str, Any], prompts: Any, **kwargs: Any) -> None: + # """当 LLM 开始时调用(用于普通 LLM,非 Chat 模型)""" + # self.logger.info("✅ LLM Start - Input:") + # for prompt in prompts: + # self.logger.info(str(prompt)) + + def on_llm_end(self, response, **kwargs: Any) -> None: """当 LLM 结束时调用""" self.logger.info("✅ LLM End - Output:") diff --git a/agent/tool_use_cleanup_middleware.py b/agent/tool_use_cleanup_middleware.py new file mode 100644 index 0000000..83fb733 --- /dev/null +++ b/agent/tool_use_cleanup_middleware.py @@ -0,0 +1,211 @@ +""" +Tool Use Cleanup Middleware for LangGraph Agents + +This middleware removes tool_use blocks that don't have corresponding tool_result blocks, +preventing errors like: `tool_use` ids were found without `tool_result` blocks immediately after. +""" + +import logging +from typing import Any, Callable +from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse +from langchain_core.messages import AIMessage, AnyMessage, ToolMessage + +logger = logging.getLogger('app') + + +class ToolUseCleanupMiddleware(AgentMiddleware): + """ + Middleware to clean up orphaned tool_use blocks in messages. + + Ensures that every tool_use has a corresponding tool_result. If a tool_use + doesn't have a tool_result in the next message, it will be removed from the AIMessage. + + Examples: + - AIMessage(tool_calls=[tool_1]) + ToolMessage(tool_call_id=tool_1) -> keep + - AIMessage(tool_calls=[tool_1]) + AIMessage(content="你好") -> remove tool_1 + - AIMessage(tool_calls=[tool_1], content="请稍候") + AIMessage(content="好了") -> keep content, remove tool_1 + """ + + def __init__(self): + self.stats = { + 'removed_ai_messages': 0, + 'cleaned_ai_messages': 0, + 'removed_tool_calls': 0 + } + logger.info("ToolUseCleanupMiddleware initialized") + + def _has_meaningful_content(self, message: AIMessage) -> bool: + """ + Check if an AIMessage has meaningful content (text) besides tool_calls. + + Args: + message: The AIMessage to check + + Returns: + True if the message has non-empty content, False otherwise + """ + content = message.content + + # Handle different content types + if content is None: + return False + elif isinstance(content, str): + return bool(content.strip()) + elif isinstance(content, list): + # Content is a list of content blocks (e.g., text, image, etc.) + # Check if there are any meaningful text blocks + for block in content: + if isinstance(block, dict): + # Check for text type blocks with non-empty content + if block.get('type') == 'text' and block.get('text', '').strip(): + return True + elif isinstance(block, str) and block.strip(): + return True + return False + else: + # Other types, treat as having content if it's truthy + return bool(content) + + def _cleanup_tool_use_messages(self, messages: list[AnyMessage]) -> list[AnyMessage]: + """ + Clean up messages by removing tool_use blocks that don't have corresponding tool_results. + + Args: + messages: List of messages to clean + + Returns: + Cleaned list of messages + """ + cleaned_messages = [] + i = 0 + + while i < len(messages): + current_msg = messages[i] + + # Check if current message is an AIMessage with tool_calls + if isinstance(current_msg, AIMessage) and current_msg.tool_calls: + # Check the next message(s) for corresponding ToolMessages + valid_tool_call_ids = set() + j = i + 1 + while j < len(messages) and isinstance(messages[j], ToolMessage): + valid_tool_call_ids.add(messages[j].tool_call_id) + j += 1 + + # Filter tool_calls to only keep those with valid tool_call_ids + filtered_tool_calls = [ + tc for tc in current_msg.tool_calls + if tc.get('id') in valid_tool_call_ids + ] + + removed_count = len(current_msg.tool_calls) - len(filtered_tool_calls) + + if removed_count > 0: + self.stats['removed_tool_calls'] += removed_count + logger.warning( + f"Removed {removed_count} orphaned tool_use(s) from AIMessage. " + f"tool_call_ids: {[tc.get('id') for tc in current_msg.tool_calls]}, " + f"valid_ids: {valid_tool_call_ids}" + ) + + has_content = self._has_meaningful_content(current_msg) + + if filtered_tool_calls: + # Has valid tool_calls, keep the message + cleaned_msg = AIMessage( + content=current_msg.content, + tool_calls=filtered_tool_calls, + additional_kwargs=current_msg.additional_kwargs, + response_metadata=current_msg.response_metadata, + id=current_msg.id, + name=current_msg.name, + ) + cleaned_messages.append(cleaned_msg) + elif has_content: + # No valid tool_calls but has meaningful content, keep without tool_calls + cleaned_msg = AIMessage( + content=current_msg.content, + tool_calls=[], + additional_kwargs=current_msg.additional_kwargs, + response_metadata=current_msg.response_metadata, + id=current_msg.id, + name=current_msg.name, + ) + cleaned_messages.append(cleaned_msg) + self.stats['cleaned_ai_messages'] += 1 + logger.info( + f"Removed all tool_calls from AIMessage but kept content. " + f"Content preview: {current_msg.content[:50]}..." + ) + else: + # No valid tool_calls and no meaningful content, completely remove this message + self.stats['removed_ai_messages'] += 1 + logger.info( + f"Removed entire AIMessage with orphaned tool_calls (no meaningful content). " + f"Removed tool_call_ids: {[tc.get('id') for tc in current_msg.tool_calls]}" + ) + # Don't add to cleaned_messages - skip this message entirely + else: + # Not an AIMessage with tool_calls, add as-is + cleaned_messages.append(current_msg) + + i += 1 + + return cleaned_messages + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + """ + Synchronous wrapper to clean up orphaned tool_use blocks before model call. + """ + cleaned_messages = self._cleanup_tool_use_messages(request.messages) + + if (self.stats['removed_ai_messages'] > 0 or + self.stats['cleaned_ai_messages'] > 0 or + self.stats['removed_tool_calls'] > 0): + logger.info( + f"ToolUseCleanupMiddleware: Removed {self.stats['removed_ai_messages']} messages, " + f"cleaned {self.stats['cleaned_ai_messages']} messages, " + f"removed {self.stats['removed_tool_calls']} tool_calls." + ) + + # Override with cleaned messages + cleaned_request = request.override(messages=cleaned_messages) + return handler(cleaned_request) + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Any], + ) -> ModelResponse: + """ + Async wrapper to clean up orphaned tool_use blocks before model call. + """ + cleaned_messages = self._cleanup_tool_use_messages(request.messages) + + if (self.stats['removed_ai_messages'] > 0 or + self.stats['cleaned_ai_messages'] > 0 or + self.stats['removed_tool_calls'] > 0): + logger.info( + f"ToolUseCleanupMiddleware: Removed {self.stats['removed_ai_messages']} messages, " + f"cleaned {self.stats['cleaned_ai_messages']} messages, " + f"removed {self.stats['removed_tool_calls']} tool_calls." + ) + + # Override with cleaned messages + cleaned_request = request.override(messages=cleaned_messages) + return await handler(cleaned_request) + + def get_stats(self) -> dict[str, int]: + """Get statistics about cleanup activity.""" + return self.stats.copy() + + def reset_stats(self): + """Reset cleanup statistics.""" + self.stats = { + 'removed_ai_messages': 0, + 'cleaned_ai_messages': 0, + 'removed_tool_calls': 0 + }