"""Custom Summarization middleware with summary tag support.""" from typing import Any from collections.abc import Callable from langchain_core.messages import AIMessage, AnyMessage, HumanMessage from langgraph.runtime import Runtime from langchain.agents.middleware.summarization import SummarizationMiddleware as LangchainSummarizationMiddleware from langchain.agents.middleware.types import AgentState class SummarizationMiddleware(LangchainSummarizationMiddleware): """Summarization middleware that outputs summary in tags instead of direct output.""" def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str: """Generate summary for the given messages with message_tag in metadata.""" if not messages_to_summarize: return "No previous conversation history." trimmed_messages = self._trim_messages_for_summary(messages_to_summarize) if not trimmed_messages: return "Previous conversation was too long to summarize." try: response = self.model.invoke( self.summary_prompt.format(messages=trimmed_messages), config={"metadata": {"message_tag": "SUMMARY"}} ) return response.text.strip() except Exception as e: # noqa: BLE001 return f"Error generating summary: {e!s}" async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str: """Generate summary for the given messages with message_tag in metadata.""" if not messages_to_summarize: return "No previous conversation history." trimmed_messages = self._trim_messages_for_summary(messages_to_summarize) if not trimmed_messages: return "Previous conversation was too long to summarize." try: response = await self.model.ainvoke( self.summary_prompt.format(messages=trimmed_messages), config={"metadata": {"message_tag": "SUMMARY"}} ) return response.text.strip() except Exception as e: # noqa: BLE001 return f"Error generating summary: {e!s}" def _build_new_messages(self, summary: str) -> list[HumanMessage | AIMessage]: """Build messages with summary wrapped in tags. Similar to how GuidelineMiddleware wraps thinking content in tags, this wraps the summary in tags with message_tag set to "SUMMARY". """ # Create an AIMessage with the summary wrapped in tags content = f"\n{summary}\n" message = AIMessage(content=content) # Set message_tag so the frontend can identify and handle this message appropriately message.additional_kwargs["message_tag"] = "SUMMARY" return [message]