qwen_agent/utils/token_counter.py
2026-02-04 15:31:41 +08:00

354 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Token 计数工具模块
使用 tiktoken 替代 LangChain 默认的 chars_per_token=3.3
支持中/日/英多语言的精确 token 计算。
参考 langchain_core.messages.utils.count_tokens_approximately 的消息读取方式。
"""
import json
import logging
import math
from typing import Any, Dict, Sequence
from functools import lru_cache
try:
import tiktoken
TIKTOKEN_AVAILABLE = True
except ImportError:
TIKTOKEN_AVAILABLE = False
# 尝试导入 LangChain 的消息类型
try:
from langchain_core.messages import (
BaseMessage,
AIMessage,
HumanMessage,
SystemMessage,
ToolMessage,
FunctionMessage,
ChatMessage,
convert_to_messages,
)
from langchain_core.messages.utils import _get_message_openai_role
LANGCHAIN_AVAILABLE = True
except ImportError:
LANGCHAIN_AVAILABLE = False
BaseMessage = None
AIMessage = None
HumanMessage = None
SystemMessage = None
ToolMessage = None
FunctionMessage = None
ChatMessage = None
logger = logging.getLogger('app')
# <20><>持的模型编码映射
MODEL_TO_ENCODING: Dict[str, str] = {
# OpenAI 模型
"gpt-4o": "o200k_base",
"gpt-4o-mini": "o200k_base",
"gpt-4-turbo": "cl100k_base",
"gpt-4": "cl100k_base",
"gpt-3.5-turbo": "cl100k_base",
"gpt-3.5": "cl100k_base",
# Claude 使用 cl100k_base 作为近似
"claude": "cl100k_base",
# 其他模型默认使用 cl100k_base
}
@lru_cache(maxsize=128)
def _get_encoding(model_name: str) -> Any:
"""
获取模型的 tiktoken 编码器(带缓存)
Args:
model_name: 模型名称
Returns:
tiktoken.Encoding 实例
"""
if not TIKTOKEN_AVAILABLE:
return None
# 标准化模型名称
model_lower = model_name.lower()
# 查找匹配的编码
encoding_name = None
for key, encoding in MODEL_TO_ENCODING.items():
if key in model_lower:
encoding_name = encoding
break
# 默认使用 cl100k_base适用于大多数现代模型
if encoding_name is None:
encoding_name = "cl100k_base"
try:
return tiktoken.get_encoding(encoding_name)
except Exception as e:
logger.warning(f"Failed to get tiktoken encoding {encoding_name}: {e}")
return None
def count_tokens(text: str, model_name: str = "gpt-4o") -> int:
"""
计算文本的 token 数量
Args:
text: 要计算的文本
model_name: 模型名称,用于选择合适的编码器
Returns:
token 数量
"""
if not text:
return 0
encoding = _get_encoding(model_name)
if encoding is None:
# tiktoken 不可用时,回退到字符估算(保守估计)
# 中文/日文约 1.5 字符/token英文约 4 字符/token
# 混合文本使用 2.5 作为中间值
return max(1, len(text) // 2)
try:
tokens = encoding.encode(text)
return len(tokens)
except Exception as e:
logger.warning(f"Failed to encode text: {e}")
return max(1, len(text) // 2)
def _get_role(message: Dict[str, Any]) -> str:
"""
获取消息的 role参考 _get_message_openai_role
Args:
message: 消息字典
Returns:
role 字符串
"""
# 优先使用 type 字段
msg_type = message.get("type", "")
if msg_type == "ai" or msg_type == "AIMessage":
return "assistant"
elif msg_type == "human" or msg_type == "HumanMessage":
return "user"
elif msg_type == "tool" or msg_type == "ToolMessage":
return "tool"
elif msg_type == "system" or msg_type == "SystemMessage":
# 检查是否有 __openai_role__
additional_kwargs = message.get("additional_kwargs", {})
if isinstance(additional_kwargs, dict):
return additional_kwargs.get("__openai_role__", "system")
return "system"
elif msg_type == "function" or msg_type == "FunctionMessage":
return "function"
elif msg_type == "chat" or msg_type == "ChatMessage":
return message.get("role", "user")
else:
# 如果有 role 字段,直接使用
if "role" in message:
return message["role"]
return "user"
def count_message_tokens(message: Dict[str, Any] | BaseMessage, model_name: str = "gpt-4o") -> int:
"""
计算消息的 token 数量(参考 count_tokens_approximately 的消息读取方式)
包括:
- 消息内容 (content)
- 消息角色 (role)
- 消息名称 (name)
- AIMessage 的 tool_calls
- ToolMessage 的 tool_call_id
Args:
message: 消息对象(字典或 BaseMessage
model_name: 模型名称
Returns:
token 数量
"""
# 转换为字典格式处理
if LANGCHAIN_AVAILABLE and isinstance(message, BaseMessage):
# 将 BaseMessage 转换为字典
msg_dict = message.model_dump(exclude={"type"})
else:
msg_dict = message if isinstance(message, dict) else {}
token_count = 0
encoding = _get_encoding(model_name)
# 1. 处理 content
content = msg_dict.get("content", "")
if isinstance(content, str):
token_count += count_tokens(content, model_name)
elif isinstance(content, list):
# 处理多模态内容块
for block in content:
if isinstance(block, str):
token_count += count_tokens(block, model_name)
elif isinstance(block, dict):
block_type = block.get("type", "")
if block_type == "text":
token_count += count_tokens(block.get("text", ""), model_name)
elif block_type == "image_url":
# 图片 token 计算OpenAI 标准85 tokens/base + 每个 tile 170 tokens
token_count += 85
elif block_type == "tool_use":
# tool_use 块
token_count += count_tokens(block.get("name", ""), model_name)
input_data = block.get("input", {})
if isinstance(input_data, dict):
token_count += count_tokens(json.dumps(input_data, ensure_ascii=False), model_name)
elif isinstance(input_data, str):
token_count += count_tokens(input_data, model_name)
elif block_type == "tool_result":
# tool_result 块
result_content = block.get("content", "")
if isinstance(result_content, str):
token_count += count_tokens(result_content, model_name)
elif isinstance(result_content, list):
for sub_block in result_content:
if isinstance(sub_block, dict):
if sub_block.get("type") == "text":
token_count += count_tokens(sub_block.get("text", ""), model_name)
token_count += count_tokens(block.get("tool_use_id", ""), model_name)
elif block_type == "json":
json_data = block.get("json", {})
token_count += count_tokens(json.dumps(json_data, ensure_ascii=False), model_name)
else:
# 其他类型,将整个 block 序列化
token_count += count_tokens(repr(block), model_name)
else:
# 其他类型的 content序列化后计算
token_count += count_tokens(repr(content), model_name)
# 2. 处理 tool_calls仅当 content 不是 list 时)
if msg_dict.get("type") in ["ai", "AIMessage"] or isinstance(msg_dict.get("tool_calls"), list):
tool_calls = msg_dict.get("tool_calls", [])
# 只有在 content 不是 list 时才单独计算 tool_calls
# (因为 Anthropic 格式中 tool_calls 已包含在 content 的 tool_use 块中)
if not isinstance(content, list) and tool_calls:
tool_calls_str = repr(tool_calls)
token_count += count_tokens(tool_calls_str, model_name)
# 3. 处理 tool_call_idToolMessage
tool_call_id = msg_dict.get("tool_call_id", "")
if tool_call_id:
token_count += count_tokens(tool_call_id, model_name)
# 4. 处理 role
role = _get_role(msg_dict)
token_count += count_tokens(role, model_name)
# 5. 处理 name
name = msg_dict.get("name", "")
if name:
token_count += count_tokens(name, model_name)
# 6. 添加每条消息的格式开销(参考 OpenAI 的计算方式)
# 每条消息约有 4 个 token 的格式开销
token_count += 4
return token_count
def count_messages_tokens(messages: Sequence[Dict[str, Any]] | Sequence[BaseMessage], model_name: str = "gpt-4o") -> int:
"""
计算消息列表的总 token 数量
Args:
messages: 消息列表(字典列表或 BaseMessage 列表)
model_name: 模型名称
Returns:
总 token 数量
"""
if not messages:
return 0
total = 0
for message in messages:
total += count_message_tokens(message, model_name)
# 添加回复的估算3 tokens
total += 3
return int(math.ceil(total))
def create_token_counter(model_name: str = "gpt-4o"):
"""
创建 token 计数函数,用于传入 SummarizationMiddleware
Args:
model_name: 模型名称
Returns:
token 计数函数
"""
if not TIKTOKEN_AVAILABLE:
logger.warning("tiktoken not available, falling back to character-based estimation")
# 回退到字符估算(参考 count_tokens_approximately 的方式)
def fallback_counter(messages) -> int:
token_count = 0.0
for message in messages:
# 转换为字典格式处理
if LANGCHAIN_AVAILABLE and isinstance(message, BaseMessage):
msg_dict = message.model_dump(exclude={"type"})
else:
msg_dict = message if isinstance(message, dict) else {}
message_chars = 0
content = msg_dict.get("content", "")
if isinstance(content, str):
message_chars += len(content)
elif isinstance(content, list):
message_chars += len(repr(content))
# 处理 tool_calls
if (msg_dict.get("type") in ["ai", "AIMessage"] and
not isinstance(content, list) and
msg_dict.get("tool_calls")):
message_chars += len(repr(msg_dict.get("tool_calls")))
# 处理 tool_call_id
if msg_dict.get("tool_call_id"):
message_chars += len(msg_dict.get("tool_call_id", ""))
# 处理 role
role = _get_role(msg_dict)
message_chars += len(role)
# 处理 name
if msg_dict.get("name"):
message_chars += len(msg_dict.get("name", ""))
# 使用 2.5 作为 chars_per_token适合中/日/英混合文本)
token_count += math.ceil(message_chars / 2.5)
token_count += 3.0 # extra_tokens_per_message
return int(math.ceil(token_count))
return fallback_counter
def token_counter(messages: Sequence[Dict[str, Any]] | Sequence[BaseMessage]) -> int:
"""Token 计数函数"""
return count_messages_tokens(messages, model_name)
return token_counter