62 lines
2.8 KiB
Python
62 lines
2.8 KiB
Python
"""Custom Summarization middleware with summary tag support."""
|
|
|
|
from typing import Any
|
|
from collections.abc import Callable
|
|
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
|
|
from langgraph.runtime import Runtime
|
|
from langchain.agents.middleware.summarization import SummarizationMiddleware as LangchainSummarizationMiddleware
|
|
from langchain.agents.middleware.types import AgentState
|
|
|
|
|
|
class SummarizationMiddleware(LangchainSummarizationMiddleware):
|
|
"""Summarization middleware that outputs summary in <summary> tags instead of direct output."""
|
|
|
|
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
|
"""Generate summary for the given messages with message_tag in metadata."""
|
|
if not messages_to_summarize:
|
|
return "No previous conversation history."
|
|
|
|
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
|
|
if not trimmed_messages:
|
|
return "Previous conversation was too long to summarize."
|
|
|
|
try:
|
|
response = self.model.invoke(
|
|
self.summary_prompt.format(messages=trimmed_messages),
|
|
config={"metadata": {"message_tag": "SUMMARY"}}
|
|
)
|
|
return response.text.strip()
|
|
except Exception as e: # noqa: BLE001
|
|
return f"Error generating summary: {e!s}"
|
|
|
|
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
|
"""Generate summary for the given messages with message_tag in metadata."""
|
|
if not messages_to_summarize:
|
|
return "No previous conversation history."
|
|
|
|
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
|
|
if not trimmed_messages:
|
|
return "Previous conversation was too long to summarize."
|
|
|
|
try:
|
|
response = await self.model.ainvoke(
|
|
self.summary_prompt.format(messages=trimmed_messages),
|
|
config={"metadata": {"message_tag": "SUMMARY"}}
|
|
)
|
|
return response.text.strip()
|
|
except Exception as e: # noqa: BLE001
|
|
return f"Error generating summary: {e!s}"
|
|
|
|
def _build_new_messages(self, summary: str) -> list[HumanMessage | AIMessage]:
|
|
"""Build messages with summary wrapped in <summary> tags.
|
|
|
|
Similar to how GuidelineMiddleware wraps thinking content in <thinking> tags,
|
|
this wraps the summary in <summary> tags with message_tag set to "SUMMARY".
|
|
"""
|
|
# Create an AIMessage with the summary wrapped in <summary> tags
|
|
content = f"<summary>\n{summary}\n</summary>"
|
|
message = AIMessage(content=content)
|
|
# Set message_tag so the frontend can identify and handle this message appropriately
|
|
message.additional_kwargs["message_tag"] = "SUMMARY"
|
|
return [message]
|