update summary

This commit is contained in:
朱潮 2026-02-04 15:31:41 +08:00
parent 352a2f2f44
commit 525801d7f5
5 changed files with 427 additions and 7 deletions

View File

@ -12,7 +12,8 @@ from deepagents.backends.sandbox import SandboxBackendProtocol
from deepagents_cli.agent import create_cli_agent from deepagents_cli.agent import create_cli_agent
from langchain.agents import create_agent from langchain.agents import create_agent
from langgraph.store.base import BaseStore from langgraph.store.base import BaseStore
from langchain.agents.middleware import SummarizationMiddleware from langchain.agents.middleware import SummarizationMiddleware as LangchainSummarizationMiddleware
from .summarization_middleware import SummarizationMiddleware
from langchain_mcp_adapters.client import MultiServerMCPClient from langchain_mcp_adapters.client import MultiServerMCPClient
from sympy.printing.cxx import none from sympy.printing.cxx import none
from utils.fastapi_utils import detect_provider from utils.fastapi_utils import detect_provider
@ -21,11 +22,13 @@ from .tool_output_length_middleware import ToolOutputLengthMiddleware
from .tool_use_cleanup_middleware import ToolUseCleanupMiddleware from .tool_use_cleanup_middleware import ToolUseCleanupMiddleware
from utils.settings import ( from utils.settings import (
SUMMARIZATION_MAX_TOKENS, SUMMARIZATION_MAX_TOKENS,
SUMMARIZATION_MESSAGES_TO_KEEP, SUMMARIZATION_TOKENS_TO_KEEP,
TOOL_OUTPUT_MAX_LENGTH, TOOL_OUTPUT_MAX_LENGTH,
MCP_HTTP_TIMEOUT, MCP_HTTP_TIMEOUT,
MCP_SSE_READ_TIMEOUT, MCP_SSE_READ_TIMEOUT,
DEFAULT_TRIM_TOKEN_LIMIT
) )
from utils.token_counter import create_token_counter
from agent.agent_config import AgentConfig from agent.agent_config import AgentConfig
from .mem0_manager import get_mem0_manager from .mem0_manager import get_mem0_manager
from .mem0_middleware import create_mem0_middleware from .mem0_middleware import create_mem0_middleware
@ -252,8 +255,9 @@ async def init_agent(config: AgentConfig):
summarization_middleware = SummarizationMiddleware( summarization_middleware = SummarizationMiddleware(
model=llm_instance, model=llm_instance,
trigger=('tokens', SUMMARIZATION_MAX_TOKENS), trigger=('tokens', SUMMARIZATION_MAX_TOKENS),
keep=('messages', SUMMARIZATION_MESSAGES_TO_KEEP), trim_tokens_to_summarize=DEFAULT_TRIM_TOKEN_LIMIT,
summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。" keep=('tokens', SUMMARIZATION_TOKENS_TO_KEEP),
token_counter=create_token_counter(config.model_name)
) )
middleware.append(summarization_middleware) middleware.append(summarization_middleware)

View File

@ -0,0 +1,61 @@
"""Custom Summarization middleware with summary tag support."""
from typing import Any
from collections.abc import Callable
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
from langgraph.runtime import Runtime
from langchain.agents.middleware.summarization import SummarizationMiddleware as LangchainSummarizationMiddleware
from langchain.agents.middleware.types import AgentState
class SummarizationMiddleware(LangchainSummarizationMiddleware):
"""Summarization middleware that outputs summary in <summary> tags instead of direct output."""
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary for the given messages with message_tag in metadata."""
if not messages_to_summarize:
return "No previous conversation history."
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
if not trimmed_messages:
return "Previous conversation was too long to summarize."
try:
response = self.model.invoke(
self.summary_prompt.format(messages=trimmed_messages),
config={"metadata": {"message_tag": "SUMMARY"}}
)
return response.text.strip()
except Exception as e: # noqa: BLE001
return f"Error generating summary: {e!s}"
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary for the given messages with message_tag in metadata."""
if not messages_to_summarize:
return "No previous conversation history."
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
if not trimmed_messages:
return "Previous conversation was too long to summarize."
try:
response = await self.model.ainvoke(
self.summary_prompt.format(messages=trimmed_messages),
config={"metadata": {"message_tag": "SUMMARY"}}
)
return response.text.strip()
except Exception as e: # noqa: BLE001
return f"Error generating summary: {e!s}"
def _build_new_messages(self, summary: str) -> list[HumanMessage | AIMessage]:
"""Build messages with summary wrapped in <summary> tags.
Similar to how GuidelineMiddleware wraps thinking content in <thinking> tags,
this wraps the summary in <summary> tags with message_tag set to "SUMMARY".
"""
# Create an AIMessage with the summary wrapped in <summary> tags
content = f"<summary>\n{summary}\n</summary>"
message = AIMessage(content=content)
# Set message_tag so the frontend can identify and handle this message appropriately
message.additional_kwargs["message_tag"] = "SUMMARY"
return [message]

View File

@ -35,6 +35,7 @@ dependencies = [
"mem0ai (>=0.1.50,<0.3.0)", "mem0ai (>=0.1.50,<0.3.0)",
"psycopg2-binary (>=2.9.11,<3.0.0)", "psycopg2-binary (>=2.9.11,<3.0.0)",
"json-repair (>=0.29.0,<0.30.0)", "json-repair (>=0.29.0,<0.30.0)",
"tiktoken (>=0.5.0,<1.0.0)",
] ]
[tool.poetry.requires-plugins] [tool.poetry.requires-plugins]

View File

@ -2,7 +2,7 @@ import os
# 必填参数 # 必填参数
# API Settings # API Settings
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai") BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api.gbase.ai")
MASTERKEY = os.getenv("MASTERKEY", "master") MASTERKEY = os.getenv("MASTERKEY", "master")
FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
@ -12,8 +12,9 @@ MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000))
# 可选参数 # 可选参数
# Summarization Settings # Summarization Settings
SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000 SUMMARIZATION_MAX_TOKENS = int(MAX_CONTEXT_TOKENS/3)
SUMMARIZATION_MESSAGES_TO_KEEP = int(os.getenv("SUMMARIZATION_MESSAGES_TO_KEEP", 20)) SUMMARIZATION_TOKENS_TO_KEEP = int(SUMMARIZATION_MAX_TOKENS/3)
DEFAULT_TRIM_TOKEN_LIMIT = SUMMARIZATION_MAX_TOKENS - SUMMARIZATION_TOKENS_TO_KEEP + 5000
# Agent Cache Settings # Agent Cache Settings
TOOL_CACHE_MAX_SIZE = int(os.getenv("TOOL_CACHE_MAX_SIZE", 20)) TOOL_CACHE_MAX_SIZE = int(os.getenv("TOOL_CACHE_MAX_SIZE", 20))

353
utils/token_counter.py Normal file
View File

@ -0,0 +1,353 @@
"""
Token 计数工具模块
使用 tiktoken 替代 LangChain 默认的 chars_per_token=3.3
支持中//英多语言的精确 token 计算
参考 langchain_core.messages.utils.count_tokens_approximately 的消息读取方式
"""
import json
import logging
import math
from typing import Any, Dict, Sequence
from functools import lru_cache
try:
import tiktoken
TIKTOKEN_AVAILABLE = True
except ImportError:
TIKTOKEN_AVAILABLE = False
# 尝试导入 LangChain 的消息类型
try:
from langchain_core.messages import (
BaseMessage,
AIMessage,
HumanMessage,
SystemMessage,
ToolMessage,
FunctionMessage,
ChatMessage,
convert_to_messages,
)
from langchain_core.messages.utils import _get_message_openai_role
LANGCHAIN_AVAILABLE = True
except ImportError:
LANGCHAIN_AVAILABLE = False
BaseMessage = None
AIMessage = None
HumanMessage = None
SystemMessage = None
ToolMessage = None
FunctionMessage = None
ChatMessage = None
logger = logging.getLogger('app')
# <20><>持的模型编码映射
MODEL_TO_ENCODING: Dict[str, str] = {
# OpenAI 模型
"gpt-4o": "o200k_base",
"gpt-4o-mini": "o200k_base",
"gpt-4-turbo": "cl100k_base",
"gpt-4": "cl100k_base",
"gpt-3.5-turbo": "cl100k_base",
"gpt-3.5": "cl100k_base",
# Claude 使用 cl100k_base 作为近似
"claude": "cl100k_base",
# 其他模型默认使用 cl100k_base
}
@lru_cache(maxsize=128)
def _get_encoding(model_name: str) -> Any:
"""
获取模型的 tiktoken 编码器带缓存
Args:
model_name: 模型名称
Returns:
tiktoken.Encoding 实例
"""
if not TIKTOKEN_AVAILABLE:
return None
# 标准化模型名称
model_lower = model_name.lower()
# 查找匹配的编码
encoding_name = None
for key, encoding in MODEL_TO_ENCODING.items():
if key in model_lower:
encoding_name = encoding
break
# 默认使用 cl100k_base适用于大多数现代模型
if encoding_name is None:
encoding_name = "cl100k_base"
try:
return tiktoken.get_encoding(encoding_name)
except Exception as e:
logger.warning(f"Failed to get tiktoken encoding {encoding_name}: {e}")
return None
def count_tokens(text: str, model_name: str = "gpt-4o") -> int:
"""
计算文本的 token 数量
Args:
text: 要计算的文本
model_name: 模型名称用于选择合适的编码器
Returns:
token 数量
"""
if not text:
return 0
encoding = _get_encoding(model_name)
if encoding is None:
# tiktoken 不可用时,回退到字符估算(保守估计)
# 中文/日文约 1.5 字符/token英文约 4 字符/token
# 混合文本使用 2.5 作为中间值
return max(1, len(text) // 2)
try:
tokens = encoding.encode(text)
return len(tokens)
except Exception as e:
logger.warning(f"Failed to encode text: {e}")
return max(1, len(text) // 2)
def _get_role(message: Dict[str, Any]) -> str:
"""
获取消息的 role参考 _get_message_openai_role
Args:
message: 消息字典
Returns:
role 字符串
"""
# 优先使用 type 字段
msg_type = message.get("type", "")
if msg_type == "ai" or msg_type == "AIMessage":
return "assistant"
elif msg_type == "human" or msg_type == "HumanMessage":
return "user"
elif msg_type == "tool" or msg_type == "ToolMessage":
return "tool"
elif msg_type == "system" or msg_type == "SystemMessage":
# 检查是否有 __openai_role__
additional_kwargs = message.get("additional_kwargs", {})
if isinstance(additional_kwargs, dict):
return additional_kwargs.get("__openai_role__", "system")
return "system"
elif msg_type == "function" or msg_type == "FunctionMessage":
return "function"
elif msg_type == "chat" or msg_type == "ChatMessage":
return message.get("role", "user")
else:
# 如果有 role 字段,直接使用
if "role" in message:
return message["role"]
return "user"
def count_message_tokens(message: Dict[str, Any] | BaseMessage, model_name: str = "gpt-4o") -> int:
"""
计算消息的 token 数量参考 count_tokens_approximately 的消息读取方式
包括:
- 消息内容 (content)
- 消息角色 (role)
- 消息名称 (name)
- AIMessage tool_calls
- ToolMessage tool_call_id
Args:
message: 消息对象字典或 BaseMessage
model_name: 模型名称
Returns:
token 数量
"""
# 转换为字典格式处理
if LANGCHAIN_AVAILABLE and isinstance(message, BaseMessage):
# 将 BaseMessage 转换为字典
msg_dict = message.model_dump(exclude={"type"})
else:
msg_dict = message if isinstance(message, dict) else {}
token_count = 0
encoding = _get_encoding(model_name)
# 1. 处理 content
content = msg_dict.get("content", "")
if isinstance(content, str):
token_count += count_tokens(content, model_name)
elif isinstance(content, list):
# 处理多模态内容块
for block in content:
if isinstance(block, str):
token_count += count_tokens(block, model_name)
elif isinstance(block, dict):
block_type = block.get("type", "")
if block_type == "text":
token_count += count_tokens(block.get("text", ""), model_name)
elif block_type == "image_url":
# 图片 token 计算OpenAI 标准85 tokens/base + 每个 tile 170 tokens
token_count += 85
elif block_type == "tool_use":
# tool_use 块
token_count += count_tokens(block.get("name", ""), model_name)
input_data = block.get("input", {})
if isinstance(input_data, dict):
token_count += count_tokens(json.dumps(input_data, ensure_ascii=False), model_name)
elif isinstance(input_data, str):
token_count += count_tokens(input_data, model_name)
elif block_type == "tool_result":
# tool_result 块
result_content = block.get("content", "")
if isinstance(result_content, str):
token_count += count_tokens(result_content, model_name)
elif isinstance(result_content, list):
for sub_block in result_content:
if isinstance(sub_block, dict):
if sub_block.get("type") == "text":
token_count += count_tokens(sub_block.get("text", ""), model_name)
token_count += count_tokens(block.get("tool_use_id", ""), model_name)
elif block_type == "json":
json_data = block.get("json", {})
token_count += count_tokens(json.dumps(json_data, ensure_ascii=False), model_name)
else:
# 其他类型,将整个 block 序列化
token_count += count_tokens(repr(block), model_name)
else:
# 其他类型的 content序列化后计算
token_count += count_tokens(repr(content), model_name)
# 2. 处理 tool_calls仅当 content 不是 list 时)
if msg_dict.get("type") in ["ai", "AIMessage"] or isinstance(msg_dict.get("tool_calls"), list):
tool_calls = msg_dict.get("tool_calls", [])
# 只有在 content 不是 list 时才单独计算 tool_calls
# (因为 Anthropic 格式中 tool_calls 已包含在 content 的 tool_use 块中)
if not isinstance(content, list) and tool_calls:
tool_calls_str = repr(tool_calls)
token_count += count_tokens(tool_calls_str, model_name)
# 3. 处理 tool_call_idToolMessage
tool_call_id = msg_dict.get("tool_call_id", "")
if tool_call_id:
token_count += count_tokens(tool_call_id, model_name)
# 4. 处理 role
role = _get_role(msg_dict)
token_count += count_tokens(role, model_name)
# 5. 处理 name
name = msg_dict.get("name", "")
if name:
token_count += count_tokens(name, model_name)
# 6. 添加每条消息的格式开销(参考 OpenAI 的计算方式)
# 每条消息约有 4 个 token 的格式开销
token_count += 4
return token_count
def count_messages_tokens(messages: Sequence[Dict[str, Any]] | Sequence[BaseMessage], model_name: str = "gpt-4o") -> int:
"""
计算消息列表的总 token 数量
Args:
messages: 消息列表字典列表或 BaseMessage 列表
model_name: 模型名称
Returns:
token 数量
"""
if not messages:
return 0
total = 0
for message in messages:
total += count_message_tokens(message, model_name)
# 添加回复的估算3 tokens
total += 3
return int(math.ceil(total))
def create_token_counter(model_name: str = "gpt-4o"):
"""
创建 token 计数函数用于传入 SummarizationMiddleware
Args:
model_name: 模型名称
Returns:
token 计数函数
"""
if not TIKTOKEN_AVAILABLE:
logger.warning("tiktoken not available, falling back to character-based estimation")
# 回退到字符估算(参考 count_tokens_approximately 的方式)
def fallback_counter(messages) -> int:
token_count = 0.0
for message in messages:
# 转换为字典格式处理
if LANGCHAIN_AVAILABLE and isinstance(message, BaseMessage):
msg_dict = message.model_dump(exclude={"type"})
else:
msg_dict = message if isinstance(message, dict) else {}
message_chars = 0
content = msg_dict.get("content", "")
if isinstance(content, str):
message_chars += len(content)
elif isinstance(content, list):
message_chars += len(repr(content))
# 处理 tool_calls
if (msg_dict.get("type") in ["ai", "AIMessage"] and
not isinstance(content, list) and
msg_dict.get("tool_calls")):
message_chars += len(repr(msg_dict.get("tool_calls")))
# 处理 tool_call_id
if msg_dict.get("tool_call_id"):
message_chars += len(msg_dict.get("tool_call_id", ""))
# 处理 role
role = _get_role(msg_dict)
message_chars += len(role)
# 处理 name
if msg_dict.get("name"):
message_chars += len(msg_dict.get("name", ""))
# 使用 2.5 作为 chars_per_token适合中/日/英混合文本)
token_count += math.ceil(message_chars / 2.5)
token_count += 3.0 # extra_tokens_per_message
return int(math.ceil(token_count))
return fallback_counter
def token_counter(messages: Sequence[Dict[str, Any]] | Sequence[BaseMessage]) -> int:
"""Token 计数函数"""
return count_messages_tokens(messages, model_name)
return token_counter