From 525801d7f5acac51bcaca0f0941f33f1537fad69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Wed, 4 Feb 2026 15:31:41 +0800 Subject: [PATCH] update summary --- agent/deep_assistant.py | 12 +- agent/summarization_middleware.py | 61 ++++++ pyproject.toml | 1 + utils/settings.py | 7 +- utils/token_counter.py | 353 ++++++++++++++++++++++++++++++ 5 files changed, 427 insertions(+), 7 deletions(-) create mode 100644 agent/summarization_middleware.py create mode 100644 utils/token_counter.py diff --git a/agent/deep_assistant.py b/agent/deep_assistant.py index 3631ecd..647e085 100644 --- a/agent/deep_assistant.py +++ b/agent/deep_assistant.py @@ -12,7 +12,8 @@ from deepagents.backends.sandbox import SandboxBackendProtocol from deepagents_cli.agent import create_cli_agent from langchain.agents import create_agent from langgraph.store.base import BaseStore -from langchain.agents.middleware import SummarizationMiddleware +from langchain.agents.middleware import SummarizationMiddleware as LangchainSummarizationMiddleware +from .summarization_middleware import SummarizationMiddleware from langchain_mcp_adapters.client import MultiServerMCPClient from sympy.printing.cxx import none from utils.fastapi_utils import detect_provider @@ -21,11 +22,13 @@ from .tool_output_length_middleware import ToolOutputLengthMiddleware from .tool_use_cleanup_middleware import ToolUseCleanupMiddleware from utils.settings import ( SUMMARIZATION_MAX_TOKENS, - SUMMARIZATION_MESSAGES_TO_KEEP, + SUMMARIZATION_TOKENS_TO_KEEP, TOOL_OUTPUT_MAX_LENGTH, MCP_HTTP_TIMEOUT, MCP_SSE_READ_TIMEOUT, + DEFAULT_TRIM_TOKEN_LIMIT ) +from utils.token_counter import create_token_counter from agent.agent_config import AgentConfig from .mem0_manager import get_mem0_manager from .mem0_middleware import create_mem0_middleware @@ -252,8 +255,9 @@ async def init_agent(config: AgentConfig): summarization_middleware = SummarizationMiddleware( model=llm_instance, trigger=('tokens', SUMMARIZATION_MAX_TOKENS), - keep=('messages', SUMMARIZATION_MESSAGES_TO_KEEP), - summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。" + trim_tokens_to_summarize=DEFAULT_TRIM_TOKEN_LIMIT, + keep=('tokens', SUMMARIZATION_TOKENS_TO_KEEP), + token_counter=create_token_counter(config.model_name) ) middleware.append(summarization_middleware) diff --git a/agent/summarization_middleware.py b/agent/summarization_middleware.py new file mode 100644 index 0000000..1d3ba47 --- /dev/null +++ b/agent/summarization_middleware.py @@ -0,0 +1,61 @@ +"""Custom Summarization middleware with summary tag support.""" + +from typing import Any +from collections.abc import Callable +from langchain_core.messages import AIMessage, AnyMessage, HumanMessage +from langgraph.runtime import Runtime +from langchain.agents.middleware.summarization import SummarizationMiddleware as LangchainSummarizationMiddleware +from langchain.agents.middleware.types import AgentState + + +class SummarizationMiddleware(LangchainSummarizationMiddleware): + """Summarization middleware that outputs summary in tags instead of direct output.""" + + def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str: + """Generate summary for the given messages with message_tag in metadata.""" + if not messages_to_summarize: + return "No previous conversation history." + + trimmed_messages = self._trim_messages_for_summary(messages_to_summarize) + if not trimmed_messages: + return "Previous conversation was too long to summarize." + + try: + response = self.model.invoke( + self.summary_prompt.format(messages=trimmed_messages), + config={"metadata": {"message_tag": "SUMMARY"}} + ) + return response.text.strip() + except Exception as e: # noqa: BLE001 + return f"Error generating summary: {e!s}" + + async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str: + """Generate summary for the given messages with message_tag in metadata.""" + if not messages_to_summarize: + return "No previous conversation history." + + trimmed_messages = self._trim_messages_for_summary(messages_to_summarize) + if not trimmed_messages: + return "Previous conversation was too long to summarize." + + try: + response = await self.model.ainvoke( + self.summary_prompt.format(messages=trimmed_messages), + config={"metadata": {"message_tag": "SUMMARY"}} + ) + return response.text.strip() + except Exception as e: # noqa: BLE001 + return f"Error generating summary: {e!s}" + + def _build_new_messages(self, summary: str) -> list[HumanMessage | AIMessage]: + """Build messages with summary wrapped in tags. + + Similar to how GuidelineMiddleware wraps thinking content in tags, + this wraps the summary in tags with message_tag set to "SUMMARY". + """ + # Create an AIMessage with the summary wrapped in tags + content = f"\n{summary}\n" + message = AIMessage(content=content) + # Set message_tag so the frontend can identify and handle this message appropriately + message.additional_kwargs["message_tag"] = "SUMMARY" + return [message] diff --git a/pyproject.toml b/pyproject.toml index 3403c38..52ee14a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "mem0ai (>=0.1.50,<0.3.0)", "psycopg2-binary (>=2.9.11,<3.0.0)", "json-repair (>=0.29.0,<0.30.0)", + "tiktoken (>=0.5.0,<1.0.0)", ] [tool.poetry.requires-plugins] diff --git a/utils/settings.py b/utils/settings.py index 7b2005f..22c67d3 100644 --- a/utils/settings.py +++ b/utils/settings.py @@ -2,7 +2,7 @@ import os # 必填参数 # API Settings -BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai") +BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api.gbase.ai") MASTERKEY = os.getenv("MASTERKEY", "master") FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') @@ -12,8 +12,9 @@ MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000)) # 可选参数 # Summarization Settings -SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000 -SUMMARIZATION_MESSAGES_TO_KEEP = int(os.getenv("SUMMARIZATION_MESSAGES_TO_KEEP", 20)) +SUMMARIZATION_MAX_TOKENS = int(MAX_CONTEXT_TOKENS/3) +SUMMARIZATION_TOKENS_TO_KEEP = int(SUMMARIZATION_MAX_TOKENS/3) +DEFAULT_TRIM_TOKEN_LIMIT = SUMMARIZATION_MAX_TOKENS - SUMMARIZATION_TOKENS_TO_KEEP + 5000 # Agent Cache Settings TOOL_CACHE_MAX_SIZE = int(os.getenv("TOOL_CACHE_MAX_SIZE", 20)) diff --git a/utils/token_counter.py b/utils/token_counter.py new file mode 100644 index 0000000..5e0876b --- /dev/null +++ b/utils/token_counter.py @@ -0,0 +1,353 @@ +""" +Token 计数工具模块 + +使用 tiktoken 替代 LangChain 默认的 chars_per_token=3.3, +支持中/日/英多语言的精确 token 计算。 + +参考 langchain_core.messages.utils.count_tokens_approximately 的消息读取方式。 +""" +import json +import logging +import math +from typing import Any, Dict, Sequence +from functools import lru_cache + +try: + import tiktoken + TIKTOKEN_AVAILABLE = True +except ImportError: + TIKTOKEN_AVAILABLE = False + +# 尝试导入 LangChain 的消息类型 +try: + from langchain_core.messages import ( + BaseMessage, + AIMessage, + HumanMessage, + SystemMessage, + ToolMessage, + FunctionMessage, + ChatMessage, + convert_to_messages, + ) + from langchain_core.messages.utils import _get_message_openai_role + LANGCHAIN_AVAILABLE = True +except ImportError: + LANGCHAIN_AVAILABLE = False + BaseMessage = None + AIMessage = None + HumanMessage = None + SystemMessage = None + ToolMessage = None + FunctionMessage = None + ChatMessage = None + +logger = logging.getLogger('app') + + +# ��持的模型编码映射 +MODEL_TO_ENCODING: Dict[str, str] = { + # OpenAI 模型 + "gpt-4o": "o200k_base", + "gpt-4o-mini": "o200k_base", + "gpt-4-turbo": "cl100k_base", + "gpt-4": "cl100k_base", + "gpt-3.5-turbo": "cl100k_base", + "gpt-3.5": "cl100k_base", + # Claude 使用 cl100k_base 作为近似 + "claude": "cl100k_base", + # 其他模型默认使用 cl100k_base +} + + +@lru_cache(maxsize=128) +def _get_encoding(model_name: str) -> Any: + """ + 获取模型的 tiktoken 编码器(带缓存) + + Args: + model_name: 模型名称 + + Returns: + tiktoken.Encoding 实例 + """ + if not TIKTOKEN_AVAILABLE: + return None + + # 标准化模型名称 + model_lower = model_name.lower() + + # 查找匹配的编码 + encoding_name = None + for key, encoding in MODEL_TO_ENCODING.items(): + if key in model_lower: + encoding_name = encoding + break + + # 默认使用 cl100k_base(适用于大多数现代模型) + if encoding_name is None: + encoding_name = "cl100k_base" + + try: + return tiktoken.get_encoding(encoding_name) + except Exception as e: + logger.warning(f"Failed to get tiktoken encoding {encoding_name}: {e}") + return None + + +def count_tokens(text: str, model_name: str = "gpt-4o") -> int: + """ + 计算文本的 token 数量 + + Args: + text: 要计算的文本 + model_name: 模型名称,用于选择合适的编码器 + + Returns: + token 数量 + """ + if not text: + return 0 + + encoding = _get_encoding(model_name) + + if encoding is None: + # tiktoken 不可用时,回退到字符估算(保守估计) + # 中文/日文约 1.5 字符/token,英文约 4 字符/token + # 混合文本使用 2.5 作为中间值 + return max(1, len(text) // 2) + + try: + tokens = encoding.encode(text) + return len(tokens) + except Exception as e: + logger.warning(f"Failed to encode text: {e}") + return max(1, len(text) // 2) + + +def _get_role(message: Dict[str, Any]) -> str: + """ + 获取消息的 role(参考 _get_message_openai_role) + + Args: + message: 消息字典 + + Returns: + role 字符串 + """ + # 优先使用 type 字段 + msg_type = message.get("type", "") + + if msg_type == "ai" or msg_type == "AIMessage": + return "assistant" + elif msg_type == "human" or msg_type == "HumanMessage": + return "user" + elif msg_type == "tool" or msg_type == "ToolMessage": + return "tool" + elif msg_type == "system" or msg_type == "SystemMessage": + # 检查是否有 __openai_role__ + additional_kwargs = message.get("additional_kwargs", {}) + if isinstance(additional_kwargs, dict): + return additional_kwargs.get("__openai_role__", "system") + return "system" + elif msg_type == "function" or msg_type == "FunctionMessage": + return "function" + elif msg_type == "chat" or msg_type == "ChatMessage": + return message.get("role", "user") + else: + # 如果有 role 字段,直接使用 + if "role" in message: + return message["role"] + return "user" + + +def count_message_tokens(message: Dict[str, Any] | BaseMessage, model_name: str = "gpt-4o") -> int: + """ + 计算消息的 token 数量(参考 count_tokens_approximately 的消息读取方式) + + 包括: + - 消息内容 (content) + - 消息角色 (role) + - 消息名称 (name) + - AIMessage 的 tool_calls + - ToolMessage 的 tool_call_id + + Args: + message: 消息对象(字典或 BaseMessage) + model_name: 模型名称 + + Returns: + token 数量 + """ + # 转换为字典格式处理 + if LANGCHAIN_AVAILABLE and isinstance(message, BaseMessage): + # 将 BaseMessage 转换为字典 + msg_dict = message.model_dump(exclude={"type"}) + else: + msg_dict = message if isinstance(message, dict) else {} + + token_count = 0 + encoding = _get_encoding(model_name) + + # 1. 处理 content + content = msg_dict.get("content", "") + + if isinstance(content, str): + token_count += count_tokens(content, model_name) + elif isinstance(content, list): + # 处理多模态内容块 + for block in content: + if isinstance(block, str): + token_count += count_tokens(block, model_name) + elif isinstance(block, dict): + block_type = block.get("type", "") + + if block_type == "text": + token_count += count_tokens(block.get("text", ""), model_name) + elif block_type == "image_url": + # 图片 token 计算(OpenAI 标准:85 tokens/base + 每个 tile 170 tokens) + token_count += 85 + elif block_type == "tool_use": + # tool_use 块 + token_count += count_tokens(block.get("name", ""), model_name) + input_data = block.get("input", {}) + if isinstance(input_data, dict): + token_count += count_tokens(json.dumps(input_data, ensure_ascii=False), model_name) + elif isinstance(input_data, str): + token_count += count_tokens(input_data, model_name) + elif block_type == "tool_result": + # tool_result 块 + result_content = block.get("content", "") + if isinstance(result_content, str): + token_count += count_tokens(result_content, model_name) + elif isinstance(result_content, list): + for sub_block in result_content: + if isinstance(sub_block, dict): + if sub_block.get("type") == "text": + token_count += count_tokens(sub_block.get("text", ""), model_name) + token_count += count_tokens(block.get("tool_use_id", ""), model_name) + elif block_type == "json": + json_data = block.get("json", {}) + token_count += count_tokens(json.dumps(json_data, ensure_ascii=False), model_name) + else: + # 其他类型,将整个 block 序列化 + token_count += count_tokens(repr(block), model_name) + else: + # 其他类型的 content,序列化后计算 + token_count += count_tokens(repr(content), model_name) + + # 2. 处理 tool_calls(仅当 content 不是 list 时) + if msg_dict.get("type") in ["ai", "AIMessage"] or isinstance(msg_dict.get("tool_calls"), list): + tool_calls = msg_dict.get("tool_calls", []) + # 只有在 content 不是 list 时才单独计算 tool_calls + # (因为 Anthropic 格式中 tool_calls 已包含在 content 的 tool_use 块中) + if not isinstance(content, list) and tool_calls: + tool_calls_str = repr(tool_calls) + token_count += count_tokens(tool_calls_str, model_name) + + # 3. 处理 tool_call_id(ToolMessage) + tool_call_id = msg_dict.get("tool_call_id", "") + if tool_call_id: + token_count += count_tokens(tool_call_id, model_name) + + # 4. 处理 role + role = _get_role(msg_dict) + token_count += count_tokens(role, model_name) + + # 5. 处理 name + name = msg_dict.get("name", "") + if name: + token_count += count_tokens(name, model_name) + + # 6. 添加每条消息的格式开销(参考 OpenAI 的计算方式) + # 每条消息约有 4 个 token 的格式开销 + token_count += 4 + + return token_count + + +def count_messages_tokens(messages: Sequence[Dict[str, Any]] | Sequence[BaseMessage], model_name: str = "gpt-4o") -> int: + """ + 计算消息列表的总 token 数量 + + Args: + messages: 消息列表(字典列表或 BaseMessage 列表) + model_name: 模型名称 + + Returns: + 总 token 数量 + """ + if not messages: + return 0 + + total = 0 + for message in messages: + total += count_message_tokens(message, model_name) + + # 添加回复的估算(3 tokens) + total += 3 + + return int(math.ceil(total)) + + +def create_token_counter(model_name: str = "gpt-4o"): + """ + 创建 token 计数函数,用于传入 SummarizationMiddleware + + Args: + model_name: 模型名称 + + Returns: + token 计数函数 + """ + if not TIKTOKEN_AVAILABLE: + logger.warning("tiktoken not available, falling back to character-based estimation") + # 回退到字符估算(参考 count_tokens_approximately 的方式) + def fallback_counter(messages) -> int: + token_count = 0.0 + for message in messages: + # 转换为字典格式处理 + if LANGCHAIN_AVAILABLE and isinstance(message, BaseMessage): + msg_dict = message.model_dump(exclude={"type"}) + else: + msg_dict = message if isinstance(message, dict) else {} + + message_chars = 0 + content = msg_dict.get("content", "") + + if isinstance(content, str): + message_chars += len(content) + elif isinstance(content, list): + message_chars += len(repr(content)) + + # 处理 tool_calls + if (msg_dict.get("type") in ["ai", "AIMessage"] and + not isinstance(content, list) and + msg_dict.get("tool_calls")): + message_chars += len(repr(msg_dict.get("tool_calls"))) + + # 处理 tool_call_id + if msg_dict.get("tool_call_id"): + message_chars += len(msg_dict.get("tool_call_id", "")) + + # 处理 role + role = _get_role(msg_dict) + message_chars += len(role) + + # 处理 name + if msg_dict.get("name"): + message_chars += len(msg_dict.get("name", "")) + + # 使用 2.5 作为 chars_per_token(适合中/日/英混合文本) + token_count += math.ceil(message_chars / 2.5) + token_count += 3.0 # extra_tokens_per_message + + return int(math.ceil(token_count)) + + return fallback_counter + + def token_counter(messages: Sequence[Dict[str, Any]] | Sequence[BaseMessage]) -> int: + """Token 计数函数""" + return count_messages_tokens(messages, model_name) + + return token_counter