""" Token counting utility module. Uses tiktoken instead of LangChain's default chars_per_token=3.3, supporting accurate token counting for Chinese/Japanese/English multilingual text. References langchain_core.messages.utils.count_tokens_approximately for message reading. """ import json import logging import math from typing import Any, Dict, Sequence from functools import lru_cache try: import tiktoken TIKTOKEN_AVAILABLE = True except ImportError: TIKTOKEN_AVAILABLE = False # Try importing LangChain message types try: from langchain_core.messages import ( BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage, FunctionMessage, ChatMessage, convert_to_messages, ) from langchain_core.messages.utils import _get_message_openai_role LANGCHAIN_AVAILABLE = True except ImportError: LANGCHAIN_AVAILABLE = False BaseMessage = None AIMessage = None HumanMessage = None SystemMessage = None ToolMessage = None FunctionMessage = None ChatMessage = None logger = logging.getLogger('app') # Supported model encoding mapping MODEL_TO_ENCODING: Dict[str, str] = { # OpenAI models "gpt-4o": "o200k_base", "gpt-4o-mini": "o200k_base", "gpt-4-turbo": "cl100k_base", "gpt-4": "cl100k_base", "gpt-3.5-turbo": "cl100k_base", "gpt-3.5": "cl100k_base", # Claude uses cl100k_base as an approximation "claude": "cl100k_base", # Other models default to cl100k_base } @lru_cache(maxsize=128) def _get_encoding(model_name: str) -> Any: """ Get the tiktoken encoder for a model (with caching). Args: model_name: Model name Returns: tiktoken.Encoding instance """ if not TIKTOKEN_AVAILABLE: return None # Normalize model name model_lower = model_name.lower() # Find matching encoding encoding_name = None for key, encoding in MODEL_TO_ENCODING.items(): if key in model_lower: encoding_name = encoding break # Default to cl100k_base (suitable for most modern models) if encoding_name is None: encoding_name = "cl100k_base" try: return tiktoken.get_encoding(encoding_name) except Exception as e: logger.warning(f"Failed to get tiktoken encoding {encoding_name}: {e}") return None def count_tokens(text: str, model_name: str = "gpt-4o") -> int: """ Count the number of tokens in text. Args: text: Text to count model_name: Model name for selecting the appropriate encoder Returns: Number of tokens """ if not text: return 0 encoding = _get_encoding(model_name) if encoding is None: # Fallback to character estimation when tiktoken is unavailable (conservative estimate) # Chinese/Japanese ~1.5 chars/token, English ~4 chars/token # Use 2.5 as a middle value for mixed text return max(1, len(text) // 2) try: tokens = encoding.encode(text) return len(tokens) except Exception as e: logger.warning(f"Failed to encode text: {e}") return max(1, len(text) // 2) def _get_role(message: Dict[str, Any]) -> str: """ Get the role of a message (references _get_message_openai_role). Args: message: Message dictionary Returns: Role string """ # Prefer the type field msg_type = message.get("type", "") if msg_type == "ai" or msg_type == "AIMessage": return "assistant" elif msg_type == "human" or msg_type == "HumanMessage": return "user" elif msg_type == "tool" or msg_type == "ToolMessage": return "tool" elif msg_type == "system" or msg_type == "SystemMessage": # Check for __openai_role__ additional_kwargs = message.get("additional_kwargs", {}) if isinstance(additional_kwargs, dict): return additional_kwargs.get("__openai_role__", "system") return "system" elif msg_type == "function" or msg_type == "FunctionMessage": return "function" elif msg_type == "chat" or msg_type == "ChatMessage": return message.get("role", "user") else: # If there is a role field, use it directly if "role" in message: return message["role"] return "user" def count_message_tokens(message: Dict[str, Any] | BaseMessage, model_name: str = "gpt-4o") -> int: """ Count tokens in a message (references count_tokens_approximately for message reading). Includes: - Message content (content) - Message role (role) - Message name (name) - AIMessage tool_calls - ToolMessage tool_call_id Args: message: Message object (dict or BaseMessage) model_name: Model name Returns: Number of tokens """ # Convert to dict format for processing if LANGCHAIN_AVAILABLE and isinstance(message, BaseMessage): # Convert BaseMessage to dict msg_dict = message.model_dump(exclude={"type"}) else: msg_dict = message if isinstance(message, dict) else {} token_count = 0 encoding = _get_encoding(model_name) # 1. Process content content = msg_dict.get("content", "") if isinstance(content, str): token_count += count_tokens(content, model_name) elif isinstance(content, list): # Process multimodal content blocks for block in content: if isinstance(block, str): token_count += count_tokens(block, model_name) elif isinstance(block, dict): block_type = block.get("type", "") if block_type == "text": token_count += count_tokens(block.get("text", ""), model_name) elif block_type == "image_url": # Image token calculation (OpenAI standard: 85 tokens/base + 170 tokens per tile) token_count += 85 elif block_type == "tool_use": # tool_use block token_count += count_tokens(block.get("name", ""), model_name) input_data = block.get("input", {}) if isinstance(input_data, dict): token_count += count_tokens(json.dumps(input_data, ensure_ascii=False), model_name) elif isinstance(input_data, str): token_count += count_tokens(input_data, model_name) elif block_type == "tool_result": # tool_result block result_content = block.get("content", "") if isinstance(result_content, str): token_count += count_tokens(result_content, model_name) elif isinstance(result_content, list): for sub_block in result_content: if isinstance(sub_block, dict): if sub_block.get("type") == "text": token_count += count_tokens(sub_block.get("text", ""), model_name) token_count += count_tokens(block.get("tool_use_id", ""), model_name) elif block_type == "json": json_data = block.get("json", {}) token_count += count_tokens(json.dumps(json_data, ensure_ascii=False), model_name) else: # Other types: serialize the entire block token_count += count_tokens(repr(block), model_name) else: # Other content types: serialize and count token_count += count_tokens(repr(content), model_name) # 2. Process tool_calls (only when content is not a list) if msg_dict.get("type") in ["ai", "AIMessage"] or isinstance(msg_dict.get("tool_calls"), list): tool_calls = msg_dict.get("tool_calls", []) # Only count tool_calls separately when content is not a list # (In Anthropic format, tool_calls are already included in content's tool_use blocks) if not isinstance(content, list) and tool_calls: tool_calls_str = repr(tool_calls) token_count += count_tokens(tool_calls_str, model_name) # 3. Process tool_call_id (ToolMessage) tool_call_id = msg_dict.get("tool_call_id", "") if tool_call_id: token_count += count_tokens(tool_call_id, model_name) # 4. Process role role = _get_role(msg_dict) token_count += count_tokens(role, model_name) # 5. Process name name = msg_dict.get("name", "") if name: token_count += count_tokens(name, model_name) # 6. Add per-message format overhead (following OpenAI's calculation) # Approximately 4 tokens of format overhead per message token_count += 4 return token_count def count_messages_tokens(messages: Sequence[Dict[str, Any]] | Sequence[BaseMessage], model_name: str = "gpt-4o") -> int: """ Count total tokens in a message list. Args: messages: Message list (list of dicts or BaseMessage) model_name: Model name Returns: Total token count """ if not messages: return 0 total = 0 for message in messages: total += count_message_tokens(message, model_name) # Add reply estimate (3 tokens) total += 3 return int(math.ceil(total)) def create_token_counter(model_name: str = "gpt-4o"): """ Create a token counting function for use with SummarizationMiddleware. Args: model_name: Model name Returns: Token counting function """ if not TIKTOKEN_AVAILABLE: logger.warning("tiktoken not available, falling back to character-based estimation") # Fallback to character estimation (following count_tokens_approximately) def fallback_counter(messages) -> int: token_count = 0.0 for message in messages: # Convert to dict format for processing if LANGCHAIN_AVAILABLE and isinstance(message, BaseMessage): msg_dict = message.model_dump(exclude={"type"}) else: msg_dict = message if isinstance(message, dict) else {} message_chars = 0 content = msg_dict.get("content", "") if isinstance(content, str): message_chars += len(content) elif isinstance(content, list): message_chars += len(repr(content)) # Process tool_calls if (msg_dict.get("type") in ["ai", "AIMessage"] and not isinstance(content, list) and msg_dict.get("tool_calls")): message_chars += len(repr(msg_dict.get("tool_calls"))) # Process tool_call_id if msg_dict.get("tool_call_id"): message_chars += len(msg_dict.get("tool_call_id", "")) # Process role role = _get_role(msg_dict) message_chars += len(role) # Process name if msg_dict.get("name"): message_chars += len(msg_dict.get("name", "")) # Use 2.5 as chars_per_token (suitable for Chinese/Japanese/English mixed text) token_count += math.ceil(message_chars / 2.5) token_count += 3.0 # extra_tokens_per_message return int(math.ceil(token_count)) return fallback_counter def token_counter(messages: Sequence[Dict[str, Any]] | Sequence[BaseMessage]) -> int: """Token counting function.""" return count_messages_tokens(messages, model_name) return token_counter