add agent/tool_use_cleanup_middleware.py
This commit is contained in:
parent
aaad9df20a
commit
61c6b69aa5
@ -11,6 +11,7 @@ from langchain_mcp_adapters.client import MultiServerMCPClient
|
|||||||
from utils.fastapi_utils import detect_provider
|
from utils.fastapi_utils import detect_provider
|
||||||
from .guideline_middleware import GuidelineMiddleware
|
from .guideline_middleware import GuidelineMiddleware
|
||||||
from .tool_output_length_middleware import ToolOutputLengthMiddleware
|
from .tool_output_length_middleware import ToolOutputLengthMiddleware
|
||||||
|
from .tool_use_cleanup_middleware import ToolUseCleanupMiddleware
|
||||||
from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH
|
from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH
|
||||||
from agent.agent_config import AgentConfig
|
from agent.agent_config import AgentConfig
|
||||||
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||||||
@ -141,6 +142,8 @@ async def init_agent(config: AgentConfig):
|
|||||||
else:
|
else:
|
||||||
# 构建中间件列表
|
# 构建中间件列表
|
||||||
middleware = []
|
middleware = []
|
||||||
|
# 首先添加 ToolUseCleanupMiddleware 来清理孤立的 tool_use
|
||||||
|
middleware.append(ToolUseCleanupMiddleware())
|
||||||
# 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware
|
# 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware
|
||||||
if config.enable_thinking:
|
if config.enable_thinking:
|
||||||
middleware.append(GuidelineMiddleware(llm_instance, config, system_prompt))
|
middleware.append(GuidelineMiddleware(llm_instance, config, system_prompt))
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
"""日志回调处理器模块"""
|
"""日志回调处理器模块"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional, Dict
|
from typing import Any, Optional, Dict, List
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
|
|
||||||
class LoggingCallbackHandler(BaseCallbackHandler):
|
class LoggingCallbackHandler(BaseCallbackHandler):
|
||||||
@ -11,6 +12,27 @@ class LoggingCallbackHandler(BaseCallbackHandler):
|
|||||||
def __init__(self, logger_name: str = 'app'):
|
def __init__(self, logger_name: str = 'app'):
|
||||||
self.logger = logging.getLogger(logger_name)
|
self.logger = logging.getLogger(logger_name)
|
||||||
|
|
||||||
|
# def on_chat_model_start(
|
||||||
|
# self,
|
||||||
|
# serialized: Dict[str, Any],
|
||||||
|
# messages: List[List[BaseMessage]],
|
||||||
|
# **kwargs: Any,
|
||||||
|
# ) -> None:
|
||||||
|
# """当 Chat 模型开始时调用"""
|
||||||
|
# self.logger.info("✅ Chat Model Start - Messages:")
|
||||||
|
# for msg_list in messages:
|
||||||
|
# for msg in msg_list:
|
||||||
|
# msg_type = msg.__class__.__name__
|
||||||
|
# content = msg.content if hasattr(msg, 'content') else str(msg)
|
||||||
|
# self.logger.info(f"[{msg_type}] {content}")
|
||||||
|
|
||||||
|
# def on_llm_start(self, serialized: Dict[str, Any], prompts: Any, **kwargs: Any) -> None:
|
||||||
|
# """当 LLM 开始时调用(用于普通 LLM,非 Chat 模型)"""
|
||||||
|
# self.logger.info("✅ LLM Start - Input:")
|
||||||
|
# for prompt in prompts:
|
||||||
|
# self.logger.info(str(prompt))
|
||||||
|
|
||||||
|
|
||||||
def on_llm_end(self, response, **kwargs: Any) -> None:
|
def on_llm_end(self, response, **kwargs: Any) -> None:
|
||||||
"""当 LLM 结束时调用"""
|
"""当 LLM 结束时调用"""
|
||||||
self.logger.info("✅ LLM End - Output:")
|
self.logger.info("✅ LLM End - Output:")
|
||||||
|
|||||||
211
agent/tool_use_cleanup_middleware.py
Normal file
211
agent/tool_use_cleanup_middleware.py
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
Tool Use Cleanup Middleware for LangGraph Agents
|
||||||
|
|
||||||
|
This middleware removes tool_use blocks that don't have corresponding tool_result blocks,
|
||||||
|
preventing errors like: `tool_use` ids were found without `tool_result` blocks immediately after.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable
|
||||||
|
from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse
|
||||||
|
from langchain_core.messages import AIMessage, AnyMessage, ToolMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger('app')
|
||||||
|
|
||||||
|
|
||||||
|
class ToolUseCleanupMiddleware(AgentMiddleware):
|
||||||
|
"""
|
||||||
|
Middleware to clean up orphaned tool_use blocks in messages.
|
||||||
|
|
||||||
|
Ensures that every tool_use has a corresponding tool_result. If a tool_use
|
||||||
|
doesn't have a tool_result in the next message, it will be removed from the AIMessage.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- AIMessage(tool_calls=[tool_1]) + ToolMessage(tool_call_id=tool_1) -> keep
|
||||||
|
- AIMessage(tool_calls=[tool_1]) + AIMessage(content="你好") -> remove tool_1
|
||||||
|
- AIMessage(tool_calls=[tool_1], content="请稍候") + AIMessage(content="好了") -> keep content, remove tool_1
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.stats = {
|
||||||
|
'removed_ai_messages': 0,
|
||||||
|
'cleaned_ai_messages': 0,
|
||||||
|
'removed_tool_calls': 0
|
||||||
|
}
|
||||||
|
logger.info("ToolUseCleanupMiddleware initialized")
|
||||||
|
|
||||||
|
def _has_meaningful_content(self, message: AIMessage) -> bool:
|
||||||
|
"""
|
||||||
|
Check if an AIMessage has meaningful content (text) besides tool_calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The AIMessage to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the message has non-empty content, False otherwise
|
||||||
|
"""
|
||||||
|
content = message.content
|
||||||
|
|
||||||
|
# Handle different content types
|
||||||
|
if content is None:
|
||||||
|
return False
|
||||||
|
elif isinstance(content, str):
|
||||||
|
return bool(content.strip())
|
||||||
|
elif isinstance(content, list):
|
||||||
|
# Content is a list of content blocks (e.g., text, image, etc.)
|
||||||
|
# Check if there are any meaningful text blocks
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict):
|
||||||
|
# Check for text type blocks with non-empty content
|
||||||
|
if block.get('type') == 'text' and block.get('text', '').strip():
|
||||||
|
return True
|
||||||
|
elif isinstance(block, str) and block.strip():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# Other types, treat as having content if it's truthy
|
||||||
|
return bool(content)
|
||||||
|
|
||||||
|
def _cleanup_tool_use_messages(self, messages: list[AnyMessage]) -> list[AnyMessage]:
|
||||||
|
"""
|
||||||
|
Clean up messages by removing tool_use blocks that don't have corresponding tool_results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages to clean
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cleaned list of messages
|
||||||
|
"""
|
||||||
|
cleaned_messages = []
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
while i < len(messages):
|
||||||
|
current_msg = messages[i]
|
||||||
|
|
||||||
|
# Check if current message is an AIMessage with tool_calls
|
||||||
|
if isinstance(current_msg, AIMessage) and current_msg.tool_calls:
|
||||||
|
# Check the next message(s) for corresponding ToolMessages
|
||||||
|
valid_tool_call_ids = set()
|
||||||
|
j = i + 1
|
||||||
|
while j < len(messages) and isinstance(messages[j], ToolMessage):
|
||||||
|
valid_tool_call_ids.add(messages[j].tool_call_id)
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
# Filter tool_calls to only keep those with valid tool_call_ids
|
||||||
|
filtered_tool_calls = [
|
||||||
|
tc for tc in current_msg.tool_calls
|
||||||
|
if tc.get('id') in valid_tool_call_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
removed_count = len(current_msg.tool_calls) - len(filtered_tool_calls)
|
||||||
|
|
||||||
|
if removed_count > 0:
|
||||||
|
self.stats['removed_tool_calls'] += removed_count
|
||||||
|
logger.warning(
|
||||||
|
f"Removed {removed_count} orphaned tool_use(s) from AIMessage. "
|
||||||
|
f"tool_call_ids: {[tc.get('id') for tc in current_msg.tool_calls]}, "
|
||||||
|
f"valid_ids: {valid_tool_call_ids}"
|
||||||
|
)
|
||||||
|
|
||||||
|
has_content = self._has_meaningful_content(current_msg)
|
||||||
|
|
||||||
|
if filtered_tool_calls:
|
||||||
|
# Has valid tool_calls, keep the message
|
||||||
|
cleaned_msg = AIMessage(
|
||||||
|
content=current_msg.content,
|
||||||
|
tool_calls=filtered_tool_calls,
|
||||||
|
additional_kwargs=current_msg.additional_kwargs,
|
||||||
|
response_metadata=current_msg.response_metadata,
|
||||||
|
id=current_msg.id,
|
||||||
|
name=current_msg.name,
|
||||||
|
)
|
||||||
|
cleaned_messages.append(cleaned_msg)
|
||||||
|
elif has_content:
|
||||||
|
# No valid tool_calls but has meaningful content, keep without tool_calls
|
||||||
|
cleaned_msg = AIMessage(
|
||||||
|
content=current_msg.content,
|
||||||
|
tool_calls=[],
|
||||||
|
additional_kwargs=current_msg.additional_kwargs,
|
||||||
|
response_metadata=current_msg.response_metadata,
|
||||||
|
id=current_msg.id,
|
||||||
|
name=current_msg.name,
|
||||||
|
)
|
||||||
|
cleaned_messages.append(cleaned_msg)
|
||||||
|
self.stats['cleaned_ai_messages'] += 1
|
||||||
|
logger.info(
|
||||||
|
f"Removed all tool_calls from AIMessage but kept content. "
|
||||||
|
f"Content preview: {current_msg.content[:50]}..."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No valid tool_calls and no meaningful content, completely remove this message
|
||||||
|
self.stats['removed_ai_messages'] += 1
|
||||||
|
logger.info(
|
||||||
|
f"Removed entire AIMessage with orphaned tool_calls (no meaningful content). "
|
||||||
|
f"Removed tool_call_ids: {[tc.get('id') for tc in current_msg.tool_calls]}"
|
||||||
|
)
|
||||||
|
# Don't add to cleaned_messages - skip this message entirely
|
||||||
|
else:
|
||||||
|
# Not an AIMessage with tool_calls, add as-is
|
||||||
|
cleaned_messages.append(current_msg)
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return cleaned_messages
|
||||||
|
|
||||||
|
def wrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""
|
||||||
|
Synchronous wrapper to clean up orphaned tool_use blocks before model call.
|
||||||
|
"""
|
||||||
|
cleaned_messages = self._cleanup_tool_use_messages(request.messages)
|
||||||
|
|
||||||
|
if (self.stats['removed_ai_messages'] > 0 or
|
||||||
|
self.stats['cleaned_ai_messages'] > 0 or
|
||||||
|
self.stats['removed_tool_calls'] > 0):
|
||||||
|
logger.info(
|
||||||
|
f"ToolUseCleanupMiddleware: Removed {self.stats['removed_ai_messages']} messages, "
|
||||||
|
f"cleaned {self.stats['cleaned_ai_messages']} messages, "
|
||||||
|
f"removed {self.stats['removed_tool_calls']} tool_calls."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override with cleaned messages
|
||||||
|
cleaned_request = request.override(messages=cleaned_messages)
|
||||||
|
return handler(cleaned_request)
|
||||||
|
|
||||||
|
async def awrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], Any],
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""
|
||||||
|
Async wrapper to clean up orphaned tool_use blocks before model call.
|
||||||
|
"""
|
||||||
|
cleaned_messages = self._cleanup_tool_use_messages(request.messages)
|
||||||
|
|
||||||
|
if (self.stats['removed_ai_messages'] > 0 or
|
||||||
|
self.stats['cleaned_ai_messages'] > 0 or
|
||||||
|
self.stats['removed_tool_calls'] > 0):
|
||||||
|
logger.info(
|
||||||
|
f"ToolUseCleanupMiddleware: Removed {self.stats['removed_ai_messages']} messages, "
|
||||||
|
f"cleaned {self.stats['cleaned_ai_messages']} messages, "
|
||||||
|
f"removed {self.stats['removed_tool_calls']} tool_calls."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override with cleaned messages
|
||||||
|
cleaned_request = request.override(messages=cleaned_messages)
|
||||||
|
return await handler(cleaned_request)
|
||||||
|
|
||||||
|
def get_stats(self) -> dict[str, int]:
|
||||||
|
"""Get statistics about cleanup activity."""
|
||||||
|
return self.stats.copy()
|
||||||
|
|
||||||
|
def reset_stats(self):
|
||||||
|
"""Reset cleanup statistics."""
|
||||||
|
self.stats = {
|
||||||
|
'removed_ai_messages': 0,
|
||||||
|
'cleaned_ai_messages': 0,
|
||||||
|
'removed_tool_calls': 0
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user