Merge branch 'prod' into bot_manager
This commit is contained in:
commit
7e058e1505
@ -135,6 +135,18 @@ workflows:
|
|||||||
branches:
|
branches:
|
||||||
only:
|
only:
|
||||||
- prod
|
- prod
|
||||||
|
- build-and-push:
|
||||||
|
name: build-for-staging
|
||||||
|
context:
|
||||||
|
- ecr-new
|
||||||
|
path: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
repo: catalog-agent
|
||||||
|
docker-tag: ''
|
||||||
|
filters:
|
||||||
|
branches:
|
||||||
|
only:
|
||||||
|
- staging
|
||||||
- deploy:
|
- deploy:
|
||||||
name: deploy-for-prod
|
name: deploy-for-prod
|
||||||
docker-tag: ''
|
docker-tag: ''
|
||||||
@ -149,6 +161,20 @@ workflows:
|
|||||||
- prod
|
- prod
|
||||||
requires:
|
requires:
|
||||||
- build-for-prod
|
- build-for-prod
|
||||||
|
- deploy:
|
||||||
|
name: deploy-for-staging
|
||||||
|
docker-tag: ''
|
||||||
|
path: '/home/ubuntu/cluster-for-B/gbase-staging/catalog-agent/deploy.yaml'
|
||||||
|
deploy-name: catalog-agent
|
||||||
|
deploy-namespace: gbase-staging
|
||||||
|
context:
|
||||||
|
- ecr-new
|
||||||
|
filters:
|
||||||
|
branches:
|
||||||
|
only:
|
||||||
|
- staging
|
||||||
|
requires:
|
||||||
|
- build-for-staging
|
||||||
- docker-hub-build-push:
|
- docker-hub-build-push:
|
||||||
name: docker-hub-build-push
|
name: docker-hub-build-push
|
||||||
repo: gptbasesparticle/catalog-agent
|
repo: gptbasesparticle/catalog-agent
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -6,3 +6,6 @@ __pycache__
|
|||||||
models
|
models
|
||||||
projects/queue_data
|
projects/queue_data
|
||||||
worktree
|
worktree
|
||||||
|
.idea/*
|
||||||
|
|
||||||
|
.idea/
|
||||||
|
|||||||
@ -12,7 +12,8 @@ from deepagents.backends.sandbox import SandboxBackendProtocol
|
|||||||
from deepagents_cli.agent import create_cli_agent
|
from deepagents_cli.agent import create_cli_agent
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langgraph.store.base import BaseStore
|
from langgraph.store.base import BaseStore
|
||||||
from langchain.agents.middleware import SummarizationMiddleware
|
from langchain.agents.middleware import SummarizationMiddleware as LangchainSummarizationMiddleware
|
||||||
|
from .summarization_middleware import SummarizationMiddleware
|
||||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||||
from sympy.printing.cxx import none
|
from sympy.printing.cxx import none
|
||||||
from utils.fastapi_utils import detect_provider
|
from utils.fastapi_utils import detect_provider
|
||||||
@ -21,11 +22,13 @@ from .tool_output_length_middleware import ToolOutputLengthMiddleware
|
|||||||
from .tool_use_cleanup_middleware import ToolUseCleanupMiddleware
|
from .tool_use_cleanup_middleware import ToolUseCleanupMiddleware
|
||||||
from utils.settings import (
|
from utils.settings import (
|
||||||
SUMMARIZATION_MAX_TOKENS,
|
SUMMARIZATION_MAX_TOKENS,
|
||||||
SUMMARIZATION_MESSAGES_TO_KEEP,
|
SUMMARIZATION_TOKENS_TO_KEEP,
|
||||||
TOOL_OUTPUT_MAX_LENGTH,
|
TOOL_OUTPUT_MAX_LENGTH,
|
||||||
MCP_HTTP_TIMEOUT,
|
MCP_HTTP_TIMEOUT,
|
||||||
MCP_SSE_READ_TIMEOUT,
|
MCP_SSE_READ_TIMEOUT,
|
||||||
|
DEFAULT_TRIM_TOKEN_LIMIT
|
||||||
)
|
)
|
||||||
|
from utils.token_counter import create_token_counter
|
||||||
from agent.agent_config import AgentConfig
|
from agent.agent_config import AgentConfig
|
||||||
from .mem0_manager import get_mem0_manager
|
from .mem0_manager import get_mem0_manager
|
||||||
from .mem0_middleware import create_mem0_middleware
|
from .mem0_middleware import create_mem0_middleware
|
||||||
@ -252,8 +255,9 @@ async def init_agent(config: AgentConfig):
|
|||||||
summarization_middleware = SummarizationMiddleware(
|
summarization_middleware = SummarizationMiddleware(
|
||||||
model=llm_instance,
|
model=llm_instance,
|
||||||
trigger=('tokens', SUMMARIZATION_MAX_TOKENS),
|
trigger=('tokens', SUMMARIZATION_MAX_TOKENS),
|
||||||
keep=('messages', SUMMARIZATION_MESSAGES_TO_KEEP),
|
trim_tokens_to_summarize=DEFAULT_TRIM_TOKEN_LIMIT,
|
||||||
summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。"
|
keep=('tokens', SUMMARIZATION_TOKENS_TO_KEEP),
|
||||||
|
token_counter=create_token_counter(config.model_name)
|
||||||
)
|
)
|
||||||
middleware.append(summarization_middleware)
|
middleware.append(summarization_middleware)
|
||||||
|
|
||||||
|
|||||||
61
agent/summarization_middleware.py
Normal file
61
agent/summarization_middleware.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
"""Custom Summarization middleware with summary tag support."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from collections.abc import Callable
|
||||||
|
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from langchain.agents.middleware.summarization import SummarizationMiddleware as LangchainSummarizationMiddleware
|
||||||
|
from langchain.agents.middleware.types import AgentState
|
||||||
|
|
||||||
|
|
||||||
|
class SummarizationMiddleware(LangchainSummarizationMiddleware):
|
||||||
|
"""Summarization middleware that outputs summary in <summary> tags instead of direct output."""
|
||||||
|
|
||||||
|
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||||
|
"""Generate summary for the given messages with message_tag in metadata."""
|
||||||
|
if not messages_to_summarize:
|
||||||
|
return "No previous conversation history."
|
||||||
|
|
||||||
|
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
|
||||||
|
if not trimmed_messages:
|
||||||
|
return "Previous conversation was too long to summarize."
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.model.invoke(
|
||||||
|
self.summary_prompt.format(messages=trimmed_messages),
|
||||||
|
config={"metadata": {"message_tag": "SUMMARY"}}
|
||||||
|
)
|
||||||
|
return response.text.strip()
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
return f"Error generating summary: {e!s}"
|
||||||
|
|
||||||
|
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||||
|
"""Generate summary for the given messages with message_tag in metadata."""
|
||||||
|
if not messages_to_summarize:
|
||||||
|
return "No previous conversation history."
|
||||||
|
|
||||||
|
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
|
||||||
|
if not trimmed_messages:
|
||||||
|
return "Previous conversation was too long to summarize."
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.model.ainvoke(
|
||||||
|
self.summary_prompt.format(messages=trimmed_messages),
|
||||||
|
config={"metadata": {"message_tag": "SUMMARY"}}
|
||||||
|
)
|
||||||
|
return response.text.strip()
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
return f"Error generating summary: {e!s}"
|
||||||
|
|
||||||
|
def _build_new_messages(self, summary: str) -> list[HumanMessage | AIMessage]:
|
||||||
|
"""Build messages with summary wrapped in <summary> tags.
|
||||||
|
|
||||||
|
Similar to how GuidelineMiddleware wraps thinking content in <thinking> tags,
|
||||||
|
this wraps the summary in <summary> tags with message_tag set to "SUMMARY".
|
||||||
|
"""
|
||||||
|
# Create an AIMessage with the summary wrapped in <summary> tags
|
||||||
|
content = f"<summary>\n{summary}\n</summary>"
|
||||||
|
message = AIMessage(content=content)
|
||||||
|
# Set message_tag so the frontend can identify and handle this message appropriately
|
||||||
|
message.additional_kwargs["message_tag"] = "SUMMARY"
|
||||||
|
return [message]
|
||||||
@ -35,6 +35,7 @@ dependencies = [
|
|||||||
"mem0ai (>=0.1.50,<0.3.0)",
|
"mem0ai (>=0.1.50,<0.3.0)",
|
||||||
"psycopg2-binary (>=2.9.11,<3.0.0)",
|
"psycopg2-binary (>=2.9.11,<3.0.0)",
|
||||||
"json-repair (>=0.29.0,<0.30.0)",
|
"json-repair (>=0.29.0,<0.30.0)",
|
||||||
|
"tiktoken (>=0.5.0,<1.0.0)",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.poetry.requires-plugins]
|
[tool.poetry.requires-plugins]
|
||||||
|
|||||||
@ -96,6 +96,9 @@ async def enhanced_generate_stream_response(
|
|||||||
preamble_completed.set()
|
preamble_completed.set()
|
||||||
await output_queue.put(("preamble_done", None))
|
await output_queue.put(("preamble_done", None))
|
||||||
meta_message_tag = metadata.get("message_tag", "ANSWER")
|
meta_message_tag = metadata.get("message_tag", "ANSWER")
|
||||||
|
# SUMMARY 不输出内容
|
||||||
|
if meta_message_tag == "SUMMARY":
|
||||||
|
continue
|
||||||
if meta_message_tag != message_tag:
|
if meta_message_tag != message_tag:
|
||||||
message_tag = meta_message_tag
|
message_tag = meta_message_tag
|
||||||
new_content = f"[{meta_message_tag}]\n"
|
new_content = f"[{meta_message_tag}]\n"
|
||||||
@ -234,6 +237,8 @@ async def create_agent_and_generate_response(
|
|||||||
if isinstance(msg,AIMessage):
|
if isinstance(msg,AIMessage):
|
||||||
if len(msg.text)>0:
|
if len(msg.text)>0:
|
||||||
meta_message_tag = msg.additional_kwargs.get("message_tag", "ANSWER")
|
meta_message_tag = msg.additional_kwargs.get("message_tag", "ANSWER")
|
||||||
|
if meta_message_tag == "SUMMARY":
|
||||||
|
continue
|
||||||
output_text = msg.text.replace("````","").replace("````","") if meta_message_tag == "THINK" else msg.text
|
output_text = msg.text.replace("````","").replace("````","") if meta_message_tag == "THINK" else msg.text
|
||||||
response_text += f"[{meta_message_tag}]\n"+output_text+ "\n"
|
response_text += f"[{meta_message_tag}]\n"+output_text+ "\n"
|
||||||
if len(msg.tool_calls)>0 and config.tool_response:
|
if len(msg.tool_calls)>0 and config.tool_response:
|
||||||
|
|||||||
@ -2,18 +2,19 @@ import os
|
|||||||
|
|
||||||
# 必填参数
|
# 必填参数
|
||||||
# API Settings
|
# API Settings
|
||||||
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api.gbase.ai")
|
||||||
MASTERKEY = os.getenv("MASTERKEY", "master")
|
MASTERKEY = os.getenv("MASTERKEY", "master")
|
||||||
FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
|
FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
|
||||||
|
|
||||||
# LLM Token Settings
|
# LLM Token Settings
|
||||||
MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 262144))
|
MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 200000))
|
||||||
MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000))
|
MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000))
|
||||||
|
|
||||||
# 可选参数
|
# 可选参数
|
||||||
# Summarization Settings
|
# Summarization Settings
|
||||||
SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000
|
SUMMARIZATION_MAX_TOKENS = int(MAX_CONTEXT_TOKENS/3)
|
||||||
SUMMARIZATION_MESSAGES_TO_KEEP = int(os.getenv("SUMMARIZATION_MESSAGES_TO_KEEP", 20))
|
SUMMARIZATION_TOKENS_TO_KEEP = int(SUMMARIZATION_MAX_TOKENS/3)
|
||||||
|
DEFAULT_TRIM_TOKEN_LIMIT = SUMMARIZATION_MAX_TOKENS - SUMMARIZATION_TOKENS_TO_KEEP + 5000
|
||||||
|
|
||||||
# Agent Cache Settings
|
# Agent Cache Settings
|
||||||
TOOL_CACHE_MAX_SIZE = int(os.getenv("TOOL_CACHE_MAX_SIZE", 20))
|
TOOL_CACHE_MAX_SIZE = int(os.getenv("TOOL_CACHE_MAX_SIZE", 20))
|
||||||
|
|||||||
353
utils/token_counter.py
Normal file
353
utils/token_counter.py
Normal file
@ -0,0 +1,353 @@
|
|||||||
|
"""
|
||||||
|
Token 计数工具模块
|
||||||
|
|
||||||
|
使用 tiktoken 替代 LangChain 默认的 chars_per_token=3.3,
|
||||||
|
支持中/日/英多语言的精确 token 计算。
|
||||||
|
|
||||||
|
参考 langchain_core.messages.utils.count_tokens_approximately 的消息读取方式。
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import Any, Dict, Sequence
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
TIKTOKEN_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TIKTOKEN_AVAILABLE = False
|
||||||
|
|
||||||
|
# 尝试导入 LangChain 的消息类型
|
||||||
|
try:
|
||||||
|
from langchain_core.messages import (
|
||||||
|
BaseMessage,
|
||||||
|
AIMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
ToolMessage,
|
||||||
|
FunctionMessage,
|
||||||
|
ChatMessage,
|
||||||
|
convert_to_messages,
|
||||||
|
)
|
||||||
|
from langchain_core.messages.utils import _get_message_openai_role
|
||||||
|
LANGCHAIN_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
LANGCHAIN_AVAILABLE = False
|
||||||
|
BaseMessage = None
|
||||||
|
AIMessage = None
|
||||||
|
HumanMessage = None
|
||||||
|
SystemMessage = None
|
||||||
|
ToolMessage = None
|
||||||
|
FunctionMessage = None
|
||||||
|
ChatMessage = None
|
||||||
|
|
||||||
|
logger = logging.getLogger('app')
|
||||||
|
|
||||||
|
|
||||||
|
# <20><>持的模型编码映射
|
||||||
|
MODEL_TO_ENCODING: Dict[str, str] = {
|
||||||
|
# OpenAI 模型
|
||||||
|
"gpt-4o": "o200k_base",
|
||||||
|
"gpt-4o-mini": "o200k_base",
|
||||||
|
"gpt-4-turbo": "cl100k_base",
|
||||||
|
"gpt-4": "cl100k_base",
|
||||||
|
"gpt-3.5-turbo": "cl100k_base",
|
||||||
|
"gpt-3.5": "cl100k_base",
|
||||||
|
# Claude 使用 cl100k_base 作为近似
|
||||||
|
"claude": "cl100k_base",
|
||||||
|
# 其他模型默认使用 cl100k_base
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=128)
|
||||||
|
def _get_encoding(model_name: str) -> Any:
|
||||||
|
"""
|
||||||
|
获取模型的 tiktoken 编码器(带缓存)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: 模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tiktoken.Encoding 实例
|
||||||
|
"""
|
||||||
|
if not TIKTOKEN_AVAILABLE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 标准化模型名称
|
||||||
|
model_lower = model_name.lower()
|
||||||
|
|
||||||
|
# 查找匹配的编码
|
||||||
|
encoding_name = None
|
||||||
|
for key, encoding in MODEL_TO_ENCODING.items():
|
||||||
|
if key in model_lower:
|
||||||
|
encoding_name = encoding
|
||||||
|
break
|
||||||
|
|
||||||
|
# 默认使用 cl100k_base(适用于大多数现代模型)
|
||||||
|
if encoding_name is None:
|
||||||
|
encoding_name = "cl100k_base"
|
||||||
|
|
||||||
|
try:
|
||||||
|
return tiktoken.get_encoding(encoding_name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get tiktoken encoding {encoding_name}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def count_tokens(text: str, model_name: str = "gpt-4o") -> int:
|
||||||
|
"""
|
||||||
|
计算文本的 token 数量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 要计算的文本
|
||||||
|
model_name: 模型名称,用于选择合适的编码器
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
token 数量
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
encoding = _get_encoding(model_name)
|
||||||
|
|
||||||
|
if encoding is None:
|
||||||
|
# tiktoken 不可用时,回退到字符估算(保守估计)
|
||||||
|
# 中文/日文约 1.5 字符/token,英文约 4 字符/token
|
||||||
|
# 混合文本使用 2.5 作为中间值
|
||||||
|
return max(1, len(text) // 2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tokens = encoding.encode(text)
|
||||||
|
return len(tokens)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to encode text: {e}")
|
||||||
|
return max(1, len(text) // 2)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_role(message: Dict[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
获取消息的 role(参考 _get_message_openai_role)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 消息字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
role 字符串
|
||||||
|
"""
|
||||||
|
# 优先使用 type 字段
|
||||||
|
msg_type = message.get("type", "")
|
||||||
|
|
||||||
|
if msg_type == "ai" or msg_type == "AIMessage":
|
||||||
|
return "assistant"
|
||||||
|
elif msg_type == "human" or msg_type == "HumanMessage":
|
||||||
|
return "user"
|
||||||
|
elif msg_type == "tool" or msg_type == "ToolMessage":
|
||||||
|
return "tool"
|
||||||
|
elif msg_type == "system" or msg_type == "SystemMessage":
|
||||||
|
# 检查是否有 __openai_role__
|
||||||
|
additional_kwargs = message.get("additional_kwargs", {})
|
||||||
|
if isinstance(additional_kwargs, dict):
|
||||||
|
return additional_kwargs.get("__openai_role__", "system")
|
||||||
|
return "system"
|
||||||
|
elif msg_type == "function" or msg_type == "FunctionMessage":
|
||||||
|
return "function"
|
||||||
|
elif msg_type == "chat" or msg_type == "ChatMessage":
|
||||||
|
return message.get("role", "user")
|
||||||
|
else:
|
||||||
|
# 如果有 role 字段,直接使用
|
||||||
|
if "role" in message:
|
||||||
|
return message["role"]
|
||||||
|
return "user"
|
||||||
|
|
||||||
|
|
||||||
|
def count_message_tokens(message: Dict[str, Any] | BaseMessage, model_name: str = "gpt-4o") -> int:
|
||||||
|
"""
|
||||||
|
计算消息的 token 数量(参考 count_tokens_approximately 的消息读取方式)
|
||||||
|
|
||||||
|
包括:
|
||||||
|
- 消息内容 (content)
|
||||||
|
- 消息角色 (role)
|
||||||
|
- 消息名称 (name)
|
||||||
|
- AIMessage 的 tool_calls
|
||||||
|
- ToolMessage 的 tool_call_id
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 消息对象(字典或 BaseMessage)
|
||||||
|
model_name: 模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
token 数量
|
||||||
|
"""
|
||||||
|
# 转换为字典格式处理
|
||||||
|
if LANGCHAIN_AVAILABLE and isinstance(message, BaseMessage):
|
||||||
|
# 将 BaseMessage 转换为字典
|
||||||
|
msg_dict = message.model_dump(exclude={"type"})
|
||||||
|
else:
|
||||||
|
msg_dict = message if isinstance(message, dict) else {}
|
||||||
|
|
||||||
|
token_count = 0
|
||||||
|
encoding = _get_encoding(model_name)
|
||||||
|
|
||||||
|
# 1. 处理 content
|
||||||
|
content = msg_dict.get("content", "")
|
||||||
|
|
||||||
|
if isinstance(content, str):
|
||||||
|
token_count += count_tokens(content, model_name)
|
||||||
|
elif isinstance(content, list):
|
||||||
|
# 处理多模态内容块
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, str):
|
||||||
|
token_count += count_tokens(block, model_name)
|
||||||
|
elif isinstance(block, dict):
|
||||||
|
block_type = block.get("type", "")
|
||||||
|
|
||||||
|
if block_type == "text":
|
||||||
|
token_count += count_tokens(block.get("text", ""), model_name)
|
||||||
|
elif block_type == "image_url":
|
||||||
|
# 图片 token 计算(OpenAI 标准:85 tokens/base + 每个 tile 170 tokens)
|
||||||
|
token_count += 85
|
||||||
|
elif block_type == "tool_use":
|
||||||
|
# tool_use 块
|
||||||
|
token_count += count_tokens(block.get("name", ""), model_name)
|
||||||
|
input_data = block.get("input", {})
|
||||||
|
if isinstance(input_data, dict):
|
||||||
|
token_count += count_tokens(json.dumps(input_data, ensure_ascii=False), model_name)
|
||||||
|
elif isinstance(input_data, str):
|
||||||
|
token_count += count_tokens(input_data, model_name)
|
||||||
|
elif block_type == "tool_result":
|
||||||
|
# tool_result 块
|
||||||
|
result_content = block.get("content", "")
|
||||||
|
if isinstance(result_content, str):
|
||||||
|
token_count += count_tokens(result_content, model_name)
|
||||||
|
elif isinstance(result_content, list):
|
||||||
|
for sub_block in result_content:
|
||||||
|
if isinstance(sub_block, dict):
|
||||||
|
if sub_block.get("type") == "text":
|
||||||
|
token_count += count_tokens(sub_block.get("text", ""), model_name)
|
||||||
|
token_count += count_tokens(block.get("tool_use_id", ""), model_name)
|
||||||
|
elif block_type == "json":
|
||||||
|
json_data = block.get("json", {})
|
||||||
|
token_count += count_tokens(json.dumps(json_data, ensure_ascii=False), model_name)
|
||||||
|
else:
|
||||||
|
# 其他类型,将整个 block 序列化
|
||||||
|
token_count += count_tokens(repr(block), model_name)
|
||||||
|
else:
|
||||||
|
# 其他类型的 content,序列化后计算
|
||||||
|
token_count += count_tokens(repr(content), model_name)
|
||||||
|
|
||||||
|
# 2. 处理 tool_calls(仅当 content 不是 list 时)
|
||||||
|
if msg_dict.get("type") in ["ai", "AIMessage"] or isinstance(msg_dict.get("tool_calls"), list):
|
||||||
|
tool_calls = msg_dict.get("tool_calls", [])
|
||||||
|
# 只有在 content 不是 list 时才单独计算 tool_calls
|
||||||
|
# (因为 Anthropic 格式中 tool_calls 已包含在 content 的 tool_use 块中)
|
||||||
|
if not isinstance(content, list) and tool_calls:
|
||||||
|
tool_calls_str = repr(tool_calls)
|
||||||
|
token_count += count_tokens(tool_calls_str, model_name)
|
||||||
|
|
||||||
|
# 3. 处理 tool_call_id(ToolMessage)
|
||||||
|
tool_call_id = msg_dict.get("tool_call_id", "")
|
||||||
|
if tool_call_id:
|
||||||
|
token_count += count_tokens(tool_call_id, model_name)
|
||||||
|
|
||||||
|
# 4. 处理 role
|
||||||
|
role = _get_role(msg_dict)
|
||||||
|
token_count += count_tokens(role, model_name)
|
||||||
|
|
||||||
|
# 5. 处理 name
|
||||||
|
name = msg_dict.get("name", "")
|
||||||
|
if name:
|
||||||
|
token_count += count_tokens(name, model_name)
|
||||||
|
|
||||||
|
# 6. 添加每条消息的格式开销(参考 OpenAI 的计算方式)
|
||||||
|
# 每条消息约有 4 个 token 的格式开销
|
||||||
|
token_count += 4
|
||||||
|
|
||||||
|
return token_count
|
||||||
|
|
||||||
|
|
||||||
|
def count_messages_tokens(messages: Sequence[Dict[str, Any]] | Sequence[BaseMessage], model_name: str = "gpt-4o") -> int:
|
||||||
|
"""
|
||||||
|
计算消息列表的总 token 数量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 消息列表(字典列表或 BaseMessage 列表)
|
||||||
|
model_name: 模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
总 token 数量
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
for message in messages:
|
||||||
|
total += count_message_tokens(message, model_name)
|
||||||
|
|
||||||
|
# 添加回复的估算(3 tokens)
|
||||||
|
total += 3
|
||||||
|
|
||||||
|
return int(math.ceil(total))
|
||||||
|
|
||||||
|
|
||||||
|
def create_token_counter(model_name: str = "gpt-4o"):
|
||||||
|
"""
|
||||||
|
创建 token 计数函数,用于传入 SummarizationMiddleware
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: 模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
token 计数函数
|
||||||
|
"""
|
||||||
|
if not TIKTOKEN_AVAILABLE:
|
||||||
|
logger.warning("tiktoken not available, falling back to character-based estimation")
|
||||||
|
# 回退到字符估算(参考 count_tokens_approximately 的方式)
|
||||||
|
def fallback_counter(messages) -> int:
|
||||||
|
token_count = 0.0
|
||||||
|
for message in messages:
|
||||||
|
# 转换为字典格式处理
|
||||||
|
if LANGCHAIN_AVAILABLE and isinstance(message, BaseMessage):
|
||||||
|
msg_dict = message.model_dump(exclude={"type"})
|
||||||
|
else:
|
||||||
|
msg_dict = message if isinstance(message, dict) else {}
|
||||||
|
|
||||||
|
message_chars = 0
|
||||||
|
content = msg_dict.get("content", "")
|
||||||
|
|
||||||
|
if isinstance(content, str):
|
||||||
|
message_chars += len(content)
|
||||||
|
elif isinstance(content, list):
|
||||||
|
message_chars += len(repr(content))
|
||||||
|
|
||||||
|
# 处理 tool_calls
|
||||||
|
if (msg_dict.get("type") in ["ai", "AIMessage"] and
|
||||||
|
not isinstance(content, list) and
|
||||||
|
msg_dict.get("tool_calls")):
|
||||||
|
message_chars += len(repr(msg_dict.get("tool_calls")))
|
||||||
|
|
||||||
|
# 处理 tool_call_id
|
||||||
|
if msg_dict.get("tool_call_id"):
|
||||||
|
message_chars += len(msg_dict.get("tool_call_id", ""))
|
||||||
|
|
||||||
|
# 处理 role
|
||||||
|
role = _get_role(msg_dict)
|
||||||
|
message_chars += len(role)
|
||||||
|
|
||||||
|
# 处理 name
|
||||||
|
if msg_dict.get("name"):
|
||||||
|
message_chars += len(msg_dict.get("name", ""))
|
||||||
|
|
||||||
|
# 使用 2.5 作为 chars_per_token(适合中/日/英混合文本)
|
||||||
|
token_count += math.ceil(message_chars / 2.5)
|
||||||
|
token_count += 3.0 # extra_tokens_per_message
|
||||||
|
|
||||||
|
return int(math.ceil(token_count))
|
||||||
|
|
||||||
|
return fallback_counter
|
||||||
|
|
||||||
|
def token_counter(messages: Sequence[Dict[str, Any]] | Sequence[BaseMessage]) -> int:
|
||||||
|
"""Token 计数函数"""
|
||||||
|
return count_messages_tokens(messages, model_name)
|
||||||
|
|
||||||
|
return token_counter
|
||||||
Loading…
Reference in New Issue
Block a user