Convert all Chinese comments, docstrings, logger/print output, HTTPException detail messages, and API response messages to English across the entire codebase. Functional zh/ja localized strings (e.g. prompt templates, timezone display names, date formats) are preserved as-is. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
354 lines
12 KiB
Python
354 lines
12 KiB
Python
"""
|
|
Token counting utility module.
|
|
|
|
Uses tiktoken instead of LangChain's default chars_per_token=3.3,
|
|
supporting accurate token counting for Chinese/Japanese/English multilingual text.
|
|
|
|
References langchain_core.messages.utils.count_tokens_approximately for message reading.
|
|
"""
|
|
import json
|
|
import logging
|
|
import math
|
|
from typing import Any, Dict, Sequence
|
|
from functools import lru_cache
|
|
|
|
try:
|
|
import tiktoken
|
|
TIKTOKEN_AVAILABLE = True
|
|
except ImportError:
|
|
TIKTOKEN_AVAILABLE = False
|
|
|
|
# Try importing LangChain message types
|
|
try:
|
|
from langchain_core.messages import (
|
|
BaseMessage,
|
|
AIMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
ToolMessage,
|
|
FunctionMessage,
|
|
ChatMessage,
|
|
convert_to_messages,
|
|
)
|
|
from langchain_core.messages.utils import _get_message_openai_role
|
|
LANGCHAIN_AVAILABLE = True
|
|
except ImportError:
|
|
LANGCHAIN_AVAILABLE = False
|
|
BaseMessage = None
|
|
AIMessage = None
|
|
HumanMessage = None
|
|
SystemMessage = None
|
|
ToolMessage = None
|
|
FunctionMessage = None
|
|
ChatMessage = None
|
|
|
|
logger = logging.getLogger('app')
|
|
|
|
|
|
# Supported model encoding mapping
|
|
MODEL_TO_ENCODING: Dict[str, str] = {
|
|
# OpenAI models
|
|
"gpt-4o": "o200k_base",
|
|
"gpt-4o-mini": "o200k_base",
|
|
"gpt-4-turbo": "cl100k_base",
|
|
"gpt-4": "cl100k_base",
|
|
"gpt-3.5-turbo": "cl100k_base",
|
|
"gpt-3.5": "cl100k_base",
|
|
# Claude uses cl100k_base as an approximation
|
|
"claude": "cl100k_base",
|
|
# Other models default to cl100k_base
|
|
}
|
|
|
|
|
|
@lru_cache(maxsize=128)
|
|
def _get_encoding(model_name: str) -> Any:
|
|
"""
|
|
Get the tiktoken encoder for a model (with caching).
|
|
|
|
Args:
|
|
model_name: Model name
|
|
|
|
Returns:
|
|
tiktoken.Encoding instance
|
|
"""
|
|
if not TIKTOKEN_AVAILABLE:
|
|
return None
|
|
|
|
# Normalize model name
|
|
model_lower = model_name.lower()
|
|
|
|
# Find matching encoding
|
|
encoding_name = None
|
|
for key, encoding in MODEL_TO_ENCODING.items():
|
|
if key in model_lower:
|
|
encoding_name = encoding
|
|
break
|
|
|
|
# Default to cl100k_base (suitable for most modern models)
|
|
if encoding_name is None:
|
|
encoding_name = "cl100k_base"
|
|
|
|
try:
|
|
return tiktoken.get_encoding(encoding_name)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get tiktoken encoding {encoding_name}: {e}")
|
|
return None
|
|
|
|
|
|
def count_tokens(text: str, model_name: str = "gpt-4o") -> int:
|
|
"""
|
|
Count the number of tokens in text.
|
|
|
|
Args:
|
|
text: Text to count
|
|
model_name: Model name for selecting the appropriate encoder
|
|
|
|
Returns:
|
|
Number of tokens
|
|
"""
|
|
if not text:
|
|
return 0
|
|
|
|
encoding = _get_encoding(model_name)
|
|
|
|
if encoding is None:
|
|
# Fallback to character estimation when tiktoken is unavailable (conservative estimate)
|
|
# Chinese/Japanese ~1.5 chars/token, English ~4 chars/token
|
|
# Use 2.5 as a middle value for mixed text
|
|
return max(1, len(text) // 2)
|
|
|
|
try:
|
|
tokens = encoding.encode(text)
|
|
return len(tokens)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to encode text: {e}")
|
|
return max(1, len(text) // 2)
|
|
|
|
|
|
def _get_role(message: Dict[str, Any]) -> str:
|
|
"""
|
|
Get the role of a message (references _get_message_openai_role).
|
|
|
|
Args:
|
|
message: Message dictionary
|
|
|
|
Returns:
|
|
Role string
|
|
"""
|
|
# Prefer the type field
|
|
msg_type = message.get("type", "")
|
|
|
|
if msg_type == "ai" or msg_type == "AIMessage":
|
|
return "assistant"
|
|
elif msg_type == "human" or msg_type == "HumanMessage":
|
|
return "user"
|
|
elif msg_type == "tool" or msg_type == "ToolMessage":
|
|
return "tool"
|
|
elif msg_type == "system" or msg_type == "SystemMessage":
|
|
# Check for __openai_role__
|
|
additional_kwargs = message.get("additional_kwargs", {})
|
|
if isinstance(additional_kwargs, dict):
|
|
return additional_kwargs.get("__openai_role__", "system")
|
|
return "system"
|
|
elif msg_type == "function" or msg_type == "FunctionMessage":
|
|
return "function"
|
|
elif msg_type == "chat" or msg_type == "ChatMessage":
|
|
return message.get("role", "user")
|
|
else:
|
|
# If there is a role field, use it directly
|
|
if "role" in message:
|
|
return message["role"]
|
|
return "user"
|
|
|
|
|
|
def count_message_tokens(message: Dict[str, Any] | BaseMessage, model_name: str = "gpt-4o") -> int:
|
|
"""
|
|
Count tokens in a message (references count_tokens_approximately for message reading).
|
|
|
|
Includes:
|
|
- Message content (content)
|
|
- Message role (role)
|
|
- Message name (name)
|
|
- AIMessage tool_calls
|
|
- ToolMessage tool_call_id
|
|
|
|
Args:
|
|
message: Message object (dict or BaseMessage)
|
|
model_name: Model name
|
|
|
|
Returns:
|
|
Number of tokens
|
|
"""
|
|
# Convert to dict format for processing
|
|
if LANGCHAIN_AVAILABLE and isinstance(message, BaseMessage):
|
|
# Convert BaseMessage to dict
|
|
msg_dict = message.model_dump(exclude={"type"})
|
|
else:
|
|
msg_dict = message if isinstance(message, dict) else {}
|
|
|
|
token_count = 0
|
|
encoding = _get_encoding(model_name)
|
|
|
|
# 1. Process content
|
|
content = msg_dict.get("content", "")
|
|
|
|
if isinstance(content, str):
|
|
token_count += count_tokens(content, model_name)
|
|
elif isinstance(content, list):
|
|
# Process multimodal content blocks
|
|
for block in content:
|
|
if isinstance(block, str):
|
|
token_count += count_tokens(block, model_name)
|
|
elif isinstance(block, dict):
|
|
block_type = block.get("type", "")
|
|
|
|
if block_type == "text":
|
|
token_count += count_tokens(block.get("text", ""), model_name)
|
|
elif block_type == "image_url":
|
|
# Image token calculation (OpenAI standard: 85 tokens/base + 170 tokens per tile)
|
|
token_count += 85
|
|
elif block_type == "tool_use":
|
|
# tool_use block
|
|
token_count += count_tokens(block.get("name", ""), model_name)
|
|
input_data = block.get("input", {})
|
|
if isinstance(input_data, dict):
|
|
token_count += count_tokens(json.dumps(input_data, ensure_ascii=False), model_name)
|
|
elif isinstance(input_data, str):
|
|
token_count += count_tokens(input_data, model_name)
|
|
elif block_type == "tool_result":
|
|
# tool_result block
|
|
result_content = block.get("content", "")
|
|
if isinstance(result_content, str):
|
|
token_count += count_tokens(result_content, model_name)
|
|
elif isinstance(result_content, list):
|
|
for sub_block in result_content:
|
|
if isinstance(sub_block, dict):
|
|
if sub_block.get("type") == "text":
|
|
token_count += count_tokens(sub_block.get("text", ""), model_name)
|
|
token_count += count_tokens(block.get("tool_use_id", ""), model_name)
|
|
elif block_type == "json":
|
|
json_data = block.get("json", {})
|
|
token_count += count_tokens(json.dumps(json_data, ensure_ascii=False), model_name)
|
|
else:
|
|
# Other types: serialize the entire block
|
|
token_count += count_tokens(repr(block), model_name)
|
|
else:
|
|
# Other content types: serialize and count
|
|
token_count += count_tokens(repr(content), model_name)
|
|
|
|
# 2. Process tool_calls (only when content is not a list)
|
|
if msg_dict.get("type") in ["ai", "AIMessage"] or isinstance(msg_dict.get("tool_calls"), list):
|
|
tool_calls = msg_dict.get("tool_calls", [])
|
|
# Only count tool_calls separately when content is not a list
|
|
# (In Anthropic format, tool_calls are already included in content's tool_use blocks)
|
|
if not isinstance(content, list) and tool_calls:
|
|
tool_calls_str = repr(tool_calls)
|
|
token_count += count_tokens(tool_calls_str, model_name)
|
|
|
|
# 3. Process tool_call_id (ToolMessage)
|
|
tool_call_id = msg_dict.get("tool_call_id", "")
|
|
if tool_call_id:
|
|
token_count += count_tokens(tool_call_id, model_name)
|
|
|
|
# 4. Process role
|
|
role = _get_role(msg_dict)
|
|
token_count += count_tokens(role, model_name)
|
|
|
|
# 5. Process name
|
|
name = msg_dict.get("name", "")
|
|
if name:
|
|
token_count += count_tokens(name, model_name)
|
|
|
|
# 6. Add per-message format overhead (following OpenAI's calculation)
|
|
# Approximately 4 tokens of format overhead per message
|
|
token_count += 4
|
|
|
|
return token_count
|
|
|
|
|
|
def count_messages_tokens(messages: Sequence[Dict[str, Any]] | Sequence[BaseMessage], model_name: str = "gpt-4o") -> int:
|
|
"""
|
|
Count total tokens in a message list.
|
|
|
|
Args:
|
|
messages: Message list (list of dicts or BaseMessage)
|
|
model_name: Model name
|
|
|
|
Returns:
|
|
Total token count
|
|
"""
|
|
if not messages:
|
|
return 0
|
|
|
|
total = 0
|
|
for message in messages:
|
|
total += count_message_tokens(message, model_name)
|
|
|
|
# Add reply estimate (3 tokens)
|
|
total += 3
|
|
|
|
return int(math.ceil(total))
|
|
|
|
|
|
def create_token_counter(model_name: str = "gpt-4o"):
|
|
"""
|
|
Create a token counting function for use with SummarizationMiddleware.
|
|
|
|
Args:
|
|
model_name: Model name
|
|
|
|
Returns:
|
|
Token counting function
|
|
"""
|
|
if not TIKTOKEN_AVAILABLE:
|
|
logger.warning("tiktoken not available, falling back to character-based estimation")
|
|
# Fallback to character estimation (following count_tokens_approximately)
|
|
def fallback_counter(messages) -> int:
|
|
token_count = 0.0
|
|
for message in messages:
|
|
# Convert to dict format for processing
|
|
if LANGCHAIN_AVAILABLE and isinstance(message, BaseMessage):
|
|
msg_dict = message.model_dump(exclude={"type"})
|
|
else:
|
|
msg_dict = message if isinstance(message, dict) else {}
|
|
|
|
message_chars = 0
|
|
content = msg_dict.get("content", "")
|
|
|
|
if isinstance(content, str):
|
|
message_chars += len(content)
|
|
elif isinstance(content, list):
|
|
message_chars += len(repr(content))
|
|
|
|
# Process tool_calls
|
|
if (msg_dict.get("type") in ["ai", "AIMessage"] and
|
|
not isinstance(content, list) and
|
|
msg_dict.get("tool_calls")):
|
|
message_chars += len(repr(msg_dict.get("tool_calls")))
|
|
|
|
# Process tool_call_id
|
|
if msg_dict.get("tool_call_id"):
|
|
message_chars += len(msg_dict.get("tool_call_id", ""))
|
|
|
|
# Process role
|
|
role = _get_role(msg_dict)
|
|
message_chars += len(role)
|
|
|
|
# Process name
|
|
if msg_dict.get("name"):
|
|
message_chars += len(msg_dict.get("name", ""))
|
|
|
|
# Use 2.5 as chars_per_token (suitable for Chinese/Japanese/English mixed text)
|
|
token_count += math.ceil(message_chars / 2.5)
|
|
token_count += 3.0 # extra_tokens_per_message
|
|
|
|
return int(math.ceil(token_count))
|
|
|
|
return fallback_counter
|
|
|
|
def token_counter(messages: Sequence[Dict[str, Any]] | Sequence[BaseMessage]) -> int:
|
|
"""Token counting function."""
|
|
return count_messages_tokens(messages, model_name)
|
|
|
|
return token_counter
|