""" Token 计数工具模块 使用 tiktoken 替代 LangChain 默认的 chars_per_token=3.3, 支持中/日/英多语言的精确 token 计算。 参考 langchain_core.messages.utils.count_tokens_approximately 的消息读取方式。 """ import json import logging import math from typing import Any, Dict, Sequence from functools import lru_cache try: import tiktoken TIKTOKEN_AVAILABLE = True except ImportError: TIKTOKEN_AVAILABLE = False # 尝试导入 LangChain 的消息类型 try: from langchain_core.messages import ( BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage, FunctionMessage, ChatMessage, convert_to_messages, ) from langchain_core.messages.utils import _get_message_openai_role LANGCHAIN_AVAILABLE = True except ImportError: LANGCHAIN_AVAILABLE = False BaseMessage = None AIMessage = None HumanMessage = None SystemMessage = None ToolMessage = None FunctionMessage = None ChatMessage = None logger = logging.getLogger('app') # ��持的模型编码映射 MODEL_TO_ENCODING: Dict[str, str] = { # OpenAI 模型 "gpt-4o": "o200k_base", "gpt-4o-mini": "o200k_base", "gpt-4-turbo": "cl100k_base", "gpt-4": "cl100k_base", "gpt-3.5-turbo": "cl100k_base", "gpt-3.5": "cl100k_base", # Claude 使用 cl100k_base 作为近似 "claude": "cl100k_base", # 其他模型默认使用 cl100k_base } @lru_cache(maxsize=128) def _get_encoding(model_name: str) -> Any: """ 获取模型的 tiktoken 编码器(带缓存) Args: model_name: 模型名称 Returns: tiktoken.Encoding 实例 """ if not TIKTOKEN_AVAILABLE: return None # 标准化模型名称 model_lower = model_name.lower() # 查找匹配的编码 encoding_name = None for key, encoding in MODEL_TO_ENCODING.items(): if key in model_lower: encoding_name = encoding break # 默认使用 cl100k_base(适用于大多数现代模型) if encoding_name is None: encoding_name = "cl100k_base" try: return tiktoken.get_encoding(encoding_name) except Exception as e: logger.warning(f"Failed to get tiktoken encoding {encoding_name}: {e}") return None def count_tokens(text: str, model_name: str = "gpt-4o") -> int: """ 计算文本的 token 数量 Args: text: 要计算的文本 model_name: 模型名称,用于选择合适的编码器 Returns: token 数量 """ if not text: return 0 encoding = _get_encoding(model_name) if encoding is None: # tiktoken 不可用时,回退到字符估算(保守估计) # 中文/日文约 1.5 字符/token,英文约 4 字符/token # 混合文本使用 2.5 作为中间值 return max(1, len(text) // 2) try: tokens = encoding.encode(text) return len(tokens) except Exception as e: logger.warning(f"Failed to encode text: {e}") return max(1, len(text) // 2) def _get_role(message: Dict[str, Any]) -> str: """ 获取消息的 role(参考 _get_message_openai_role) Args: message: 消息字典 Returns: role 字符串 """ # 优先使用 type 字段 msg_type = message.get("type", "") if msg_type == "ai" or msg_type == "AIMessage": return "assistant" elif msg_type == "human" or msg_type == "HumanMessage": return "user" elif msg_type == "tool" or msg_type == "ToolMessage": return "tool" elif msg_type == "system" or msg_type == "SystemMessage": # 检查是否有 __openai_role__ additional_kwargs = message.get("additional_kwargs", {}) if isinstance(additional_kwargs, dict): return additional_kwargs.get("__openai_role__", "system") return "system" elif msg_type == "function" or msg_type == "FunctionMessage": return "function" elif msg_type == "chat" or msg_type == "ChatMessage": return message.get("role", "user") else: # 如果有 role 字段,直接使用 if "role" in message: return message["role"] return "user" def count_message_tokens(message: Dict[str, Any] | BaseMessage, model_name: str = "gpt-4o") -> int: """ 计算消息的 token 数量(参考 count_tokens_approximately 的消息读取方式) 包括: - 消息内容 (content) - 消息角色 (role) - 消息名称 (name) - AIMessage 的 tool_calls - ToolMessage 的 tool_call_id Args: message: 消息对象(字典或 BaseMessage) model_name: 模型名称 Returns: token 数量 """ # 转换为字典格式处理 if LANGCHAIN_AVAILABLE and isinstance(message, BaseMessage): # 将 BaseMessage 转换为字典 msg_dict = message.model_dump(exclude={"type"}) else: msg_dict = message if isinstance(message, dict) else {} token_count = 0 encoding = _get_encoding(model_name) # 1. 处理 content content = msg_dict.get("content", "") if isinstance(content, str): token_count += count_tokens(content, model_name) elif isinstance(content, list): # 处理多模态内容块 for block in content: if isinstance(block, str): token_count += count_tokens(block, model_name) elif isinstance(block, dict): block_type = block.get("type", "") if block_type == "text": token_count += count_tokens(block.get("text", ""), model_name) elif block_type == "image_url": # 图片 token 计算(OpenAI 标准:85 tokens/base + 每个 tile 170 tokens) token_count += 85 elif block_type == "tool_use": # tool_use 块 token_count += count_tokens(block.get("name", ""), model_name) input_data = block.get("input", {}) if isinstance(input_data, dict): token_count += count_tokens(json.dumps(input_data, ensure_ascii=False), model_name) elif isinstance(input_data, str): token_count += count_tokens(input_data, model_name) elif block_type == "tool_result": # tool_result 块 result_content = block.get("content", "") if isinstance(result_content, str): token_count += count_tokens(result_content, model_name) elif isinstance(result_content, list): for sub_block in result_content: if isinstance(sub_block, dict): if sub_block.get("type") == "text": token_count += count_tokens(sub_block.get("text", ""), model_name) token_count += count_tokens(block.get("tool_use_id", ""), model_name) elif block_type == "json": json_data = block.get("json", {}) token_count += count_tokens(json.dumps(json_data, ensure_ascii=False), model_name) else: # 其他类型,将整个 block 序列化 token_count += count_tokens(repr(block), model_name) else: # 其他类型的 content,序列化后计算 token_count += count_tokens(repr(content), model_name) # 2. 处理 tool_calls(仅当 content 不是 list 时) if msg_dict.get("type") in ["ai", "AIMessage"] or isinstance(msg_dict.get("tool_calls"), list): tool_calls = msg_dict.get("tool_calls", []) # 只有在 content 不是 list 时才单独计算 tool_calls # (因为 Anthropic 格式中 tool_calls 已包含在 content 的 tool_use 块中) if not isinstance(content, list) and tool_calls: tool_calls_str = repr(tool_calls) token_count += count_tokens(tool_calls_str, model_name) # 3. 处理 tool_call_id(ToolMessage) tool_call_id = msg_dict.get("tool_call_id", "") if tool_call_id: token_count += count_tokens(tool_call_id, model_name) # 4. 处理 role role = _get_role(msg_dict) token_count += count_tokens(role, model_name) # 5. 处理 name name = msg_dict.get("name", "") if name: token_count += count_tokens(name, model_name) # 6. 添加每条消息的格式开销(参考 OpenAI 的计算方式) # 每条消息约有 4 个 token 的格式开销 token_count += 4 return token_count def count_messages_tokens(messages: Sequence[Dict[str, Any]] | Sequence[BaseMessage], model_name: str = "gpt-4o") -> int: """ 计算消息列表的总 token 数量 Args: messages: 消息列表(字典列表或 BaseMessage 列表) model_name: 模型名称 Returns: 总 token 数量 """ if not messages: return 0 total = 0 for message in messages: total += count_message_tokens(message, model_name) # 添加回复的估算(3 tokens) total += 3 return int(math.ceil(total)) def create_token_counter(model_name: str = "gpt-4o"): """ 创建 token 计数函数,用于传入 SummarizationMiddleware Args: model_name: 模型名称 Returns: token 计数函数 """ if not TIKTOKEN_AVAILABLE: logger.warning("tiktoken not available, falling back to character-based estimation") # 回退到字符估算(参考 count_tokens_approximately 的方式) def fallback_counter(messages) -> int: token_count = 0.0 for message in messages: # 转换为字典格式处理 if LANGCHAIN_AVAILABLE and isinstance(message, BaseMessage): msg_dict = message.model_dump(exclude={"type"}) else: msg_dict = message if isinstance(message, dict) else {} message_chars = 0 content = msg_dict.get("content", "") if isinstance(content, str): message_chars += len(content) elif isinstance(content, list): message_chars += len(repr(content)) # 处理 tool_calls if (msg_dict.get("type") in ["ai", "AIMessage"] and not isinstance(content, list) and msg_dict.get("tool_calls")): message_chars += len(repr(msg_dict.get("tool_calls"))) # 处理 tool_call_id if msg_dict.get("tool_call_id"): message_chars += len(msg_dict.get("tool_call_id", "")) # 处理 role role = _get_role(msg_dict) message_chars += len(role) # 处理 name if msg_dict.get("name"): message_chars += len(msg_dict.get("name", "")) # 使用 2.5 作为 chars_per_token(适合中/日/英混合文本) token_count += math.ceil(message_chars / 2.5) token_count += 3.0 # extra_tokens_per_message return int(math.ceil(token_count)) return fallback_counter def token_counter(messages: Sequence[Dict[str, Any]] | Sequence[BaseMessage]) -> int: """Token 计数函数""" return count_messages_tokens(messages, model_name) return token_counter