354 lines
11 KiB
Python
354 lines
11 KiB
Python
"""
|
||
Token 计数工具模块
|
||
|
||
使用 tiktoken 替代 LangChain 默认的 chars_per_token=3.3,
|
||
支持中/日/英多语言的精确 token 计算。
|
||
|
||
参考 langchain_core.messages.utils.count_tokens_approximately 的消息读取方式。
|
||
"""
|
||
import json
|
||
import logging
|
||
import math
|
||
from typing import Any, Dict, Sequence
|
||
from functools import lru_cache
|
||
|
||
try:
|
||
import tiktoken
|
||
TIKTOKEN_AVAILABLE = True
|
||
except ImportError:
|
||
TIKTOKEN_AVAILABLE = False
|
||
|
||
# 尝试导入 LangChain 的消息类型
|
||
try:
|
||
from langchain_core.messages import (
|
||
BaseMessage,
|
||
AIMessage,
|
||
HumanMessage,
|
||
SystemMessage,
|
||
ToolMessage,
|
||
FunctionMessage,
|
||
ChatMessage,
|
||
convert_to_messages,
|
||
)
|
||
from langchain_core.messages.utils import _get_message_openai_role
|
||
LANGCHAIN_AVAILABLE = True
|
||
except ImportError:
|
||
LANGCHAIN_AVAILABLE = False
|
||
BaseMessage = None
|
||
AIMessage = None
|
||
HumanMessage = None
|
||
SystemMessage = None
|
||
ToolMessage = None
|
||
FunctionMessage = None
|
||
ChatMessage = None
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
|
||
# <20><>持的模型编码映射
|
||
MODEL_TO_ENCODING: Dict[str, str] = {
|
||
# OpenAI 模型
|
||
"gpt-4o": "o200k_base",
|
||
"gpt-4o-mini": "o200k_base",
|
||
"gpt-4-turbo": "cl100k_base",
|
||
"gpt-4": "cl100k_base",
|
||
"gpt-3.5-turbo": "cl100k_base",
|
||
"gpt-3.5": "cl100k_base",
|
||
# Claude 使用 cl100k_base 作为近似
|
||
"claude": "cl100k_base",
|
||
# 其他模型默认使用 cl100k_base
|
||
}
|
||
|
||
|
||
@lru_cache(maxsize=128)
|
||
def _get_encoding(model_name: str) -> Any:
|
||
"""
|
||
获取模型的 tiktoken 编码器(带缓存)
|
||
|
||
Args:
|
||
model_name: 模型名称
|
||
|
||
Returns:
|
||
tiktoken.Encoding 实例
|
||
"""
|
||
if not TIKTOKEN_AVAILABLE:
|
||
return None
|
||
|
||
# 标准化模型名称
|
||
model_lower = model_name.lower()
|
||
|
||
# 查找匹配的编码
|
||
encoding_name = None
|
||
for key, encoding in MODEL_TO_ENCODING.items():
|
||
if key in model_lower:
|
||
encoding_name = encoding
|
||
break
|
||
|
||
# 默认使用 cl100k_base(适用于大多数现代模型)
|
||
if encoding_name is None:
|
||
encoding_name = "cl100k_base"
|
||
|
||
try:
|
||
return tiktoken.get_encoding(encoding_name)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to get tiktoken encoding {encoding_name}: {e}")
|
||
return None
|
||
|
||
|
||
def count_tokens(text: str, model_name: str = "gpt-4o") -> int:
|
||
"""
|
||
计算文本的 token 数量
|
||
|
||
Args:
|
||
text: 要计算的文本
|
||
model_name: 模型名称,用于选择合适的编码器
|
||
|
||
Returns:
|
||
token 数量
|
||
"""
|
||
if not text:
|
||
return 0
|
||
|
||
encoding = _get_encoding(model_name)
|
||
|
||
if encoding is None:
|
||
# tiktoken 不可用时,回退到字符估算(保守估计)
|
||
# 中文/日文约 1.5 字符/token,英文约 4 字符/token
|
||
# 混合文本使用 2.5 作为中间值
|
||
return max(1, len(text) // 2)
|
||
|
||
try:
|
||
tokens = encoding.encode(text)
|
||
return len(tokens)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to encode text: {e}")
|
||
return max(1, len(text) // 2)
|
||
|
||
|
||
def _get_role(message: Dict[str, Any]) -> str:
|
||
"""
|
||
获取消息的 role(参考 _get_message_openai_role)
|
||
|
||
Args:
|
||
message: 消息字典
|
||
|
||
Returns:
|
||
role 字符串
|
||
"""
|
||
# 优先使用 type 字段
|
||
msg_type = message.get("type", "")
|
||
|
||
if msg_type == "ai" or msg_type == "AIMessage":
|
||
return "assistant"
|
||
elif msg_type == "human" or msg_type == "HumanMessage":
|
||
return "user"
|
||
elif msg_type == "tool" or msg_type == "ToolMessage":
|
||
return "tool"
|
||
elif msg_type == "system" or msg_type == "SystemMessage":
|
||
# 检查是否有 __openai_role__
|
||
additional_kwargs = message.get("additional_kwargs", {})
|
||
if isinstance(additional_kwargs, dict):
|
||
return additional_kwargs.get("__openai_role__", "system")
|
||
return "system"
|
||
elif msg_type == "function" or msg_type == "FunctionMessage":
|
||
return "function"
|
||
elif msg_type == "chat" or msg_type == "ChatMessage":
|
||
return message.get("role", "user")
|
||
else:
|
||
# 如果有 role 字段,直接使用
|
||
if "role" in message:
|
||
return message["role"]
|
||
return "user"
|
||
|
||
|
||
def count_message_tokens(message: Dict[str, Any] | BaseMessage, model_name: str = "gpt-4o") -> int:
|
||
"""
|
||
计算消息的 token 数量(参考 count_tokens_approximately 的消息读取方式)
|
||
|
||
包括:
|
||
- 消息内容 (content)
|
||
- 消息角色 (role)
|
||
- 消息名称 (name)
|
||
- AIMessage 的 tool_calls
|
||
- ToolMessage 的 tool_call_id
|
||
|
||
Args:
|
||
message: 消息对象(字典或 BaseMessage)
|
||
model_name: 模型名称
|
||
|
||
Returns:
|
||
token 数量
|
||
"""
|
||
# 转换为字典格式处理
|
||
if LANGCHAIN_AVAILABLE and isinstance(message, BaseMessage):
|
||
# 将 BaseMessage 转换为字典
|
||
msg_dict = message.model_dump(exclude={"type"})
|
||
else:
|
||
msg_dict = message if isinstance(message, dict) else {}
|
||
|
||
token_count = 0
|
||
encoding = _get_encoding(model_name)
|
||
|
||
# 1. 处理 content
|
||
content = msg_dict.get("content", "")
|
||
|
||
if isinstance(content, str):
|
||
token_count += count_tokens(content, model_name)
|
||
elif isinstance(content, list):
|
||
# 处理多模态内容块
|
||
for block in content:
|
||
if isinstance(block, str):
|
||
token_count += count_tokens(block, model_name)
|
||
elif isinstance(block, dict):
|
||
block_type = block.get("type", "")
|
||
|
||
if block_type == "text":
|
||
token_count += count_tokens(block.get("text", ""), model_name)
|
||
elif block_type == "image_url":
|
||
# 图片 token 计算(OpenAI 标准:85 tokens/base + 每个 tile 170 tokens)
|
||
token_count += 85
|
||
elif block_type == "tool_use":
|
||
# tool_use 块
|
||
token_count += count_tokens(block.get("name", ""), model_name)
|
||
input_data = block.get("input", {})
|
||
if isinstance(input_data, dict):
|
||
token_count += count_tokens(json.dumps(input_data, ensure_ascii=False), model_name)
|
||
elif isinstance(input_data, str):
|
||
token_count += count_tokens(input_data, model_name)
|
||
elif block_type == "tool_result":
|
||
# tool_result 块
|
||
result_content = block.get("content", "")
|
||
if isinstance(result_content, str):
|
||
token_count += count_tokens(result_content, model_name)
|
||
elif isinstance(result_content, list):
|
||
for sub_block in result_content:
|
||
if isinstance(sub_block, dict):
|
||
if sub_block.get("type") == "text":
|
||
token_count += count_tokens(sub_block.get("text", ""), model_name)
|
||
token_count += count_tokens(block.get("tool_use_id", ""), model_name)
|
||
elif block_type == "json":
|
||
json_data = block.get("json", {})
|
||
token_count += count_tokens(json.dumps(json_data, ensure_ascii=False), model_name)
|
||
else:
|
||
# 其他类型,将整个 block 序列化
|
||
token_count += count_tokens(repr(block), model_name)
|
||
else:
|
||
# 其他类型的 content,序列化后计算
|
||
token_count += count_tokens(repr(content), model_name)
|
||
|
||
# 2. 处理 tool_calls(仅当 content 不是 list 时)
|
||
if msg_dict.get("type") in ["ai", "AIMessage"] or isinstance(msg_dict.get("tool_calls"), list):
|
||
tool_calls = msg_dict.get("tool_calls", [])
|
||
# 只有在 content 不是 list 时才单独计算 tool_calls
|
||
# (因为 Anthropic 格式中 tool_calls 已包含在 content 的 tool_use 块中)
|
||
if not isinstance(content, list) and tool_calls:
|
||
tool_calls_str = repr(tool_calls)
|
||
token_count += count_tokens(tool_calls_str, model_name)
|
||
|
||
# 3. 处理 tool_call_id(ToolMessage)
|
||
tool_call_id = msg_dict.get("tool_call_id", "")
|
||
if tool_call_id:
|
||
token_count += count_tokens(tool_call_id, model_name)
|
||
|
||
# 4. 处理 role
|
||
role = _get_role(msg_dict)
|
||
token_count += count_tokens(role, model_name)
|
||
|
||
# 5. 处理 name
|
||
name = msg_dict.get("name", "")
|
||
if name:
|
||
token_count += count_tokens(name, model_name)
|
||
|
||
# 6. 添加每条消息的格式开销(参考 OpenAI 的计算方式)
|
||
# 每条消息约有 4 个 token 的格式开销
|
||
token_count += 4
|
||
|
||
return token_count
|
||
|
||
|
||
def count_messages_tokens(messages: Sequence[Dict[str, Any]] | Sequence[BaseMessage], model_name: str = "gpt-4o") -> int:
|
||
"""
|
||
计算消息列表的总 token 数量
|
||
|
||
Args:
|
||
messages: 消息列表(字典列表或 BaseMessage 列表)
|
||
model_name: 模型名称
|
||
|
||
Returns:
|
||
总 token 数量
|
||
"""
|
||
if not messages:
|
||
return 0
|
||
|
||
total = 0
|
||
for message in messages:
|
||
total += count_message_tokens(message, model_name)
|
||
|
||
# 添加回复的估算(3 tokens)
|
||
total += 3
|
||
|
||
return int(math.ceil(total))
|
||
|
||
|
||
def create_token_counter(model_name: str = "gpt-4o"):
|
||
"""
|
||
创建 token 计数函数,用于传入 SummarizationMiddleware
|
||
|
||
Args:
|
||
model_name: 模型名称
|
||
|
||
Returns:
|
||
token 计数函数
|
||
"""
|
||
if not TIKTOKEN_AVAILABLE:
|
||
logger.warning("tiktoken not available, falling back to character-based estimation")
|
||
# 回退到字符估算(参考 count_tokens_approximately 的方式)
|
||
def fallback_counter(messages) -> int:
|
||
token_count = 0.0
|
||
for message in messages:
|
||
# 转换为字典格式处理
|
||
if LANGCHAIN_AVAILABLE and isinstance(message, BaseMessage):
|
||
msg_dict = message.model_dump(exclude={"type"})
|
||
else:
|
||
msg_dict = message if isinstance(message, dict) else {}
|
||
|
||
message_chars = 0
|
||
content = msg_dict.get("content", "")
|
||
|
||
if isinstance(content, str):
|
||
message_chars += len(content)
|
||
elif isinstance(content, list):
|
||
message_chars += len(repr(content))
|
||
|
||
# 处理 tool_calls
|
||
if (msg_dict.get("type") in ["ai", "AIMessage"] and
|
||
not isinstance(content, list) and
|
||
msg_dict.get("tool_calls")):
|
||
message_chars += len(repr(msg_dict.get("tool_calls")))
|
||
|
||
# 处理 tool_call_id
|
||
if msg_dict.get("tool_call_id"):
|
||
message_chars += len(msg_dict.get("tool_call_id", ""))
|
||
|
||
# 处理 role
|
||
role = _get_role(msg_dict)
|
||
message_chars += len(role)
|
||
|
||
# 处理 name
|
||
if msg_dict.get("name"):
|
||
message_chars += len(msg_dict.get("name", ""))
|
||
|
||
# 使用 2.5 作为 chars_per_token(适合中/日/英混合文本)
|
||
token_count += math.ceil(message_chars / 2.5)
|
||
token_count += 3.0 # extra_tokens_per_message
|
||
|
||
return int(math.ceil(token_count))
|
||
|
||
return fallback_counter
|
||
|
||
def token_counter(messages: Sequence[Dict[str, Any]] | Sequence[BaseMessage]) -> int:
|
||
"""Token 计数函数"""
|
||
return count_messages_tokens(messages, model_name)
|
||
|
||
return token_counter
|