fix: 修复模型计算tokens消耗报错的缺陷

This commit is contained in:
wxg0103 2024-08-29 11:50:20 +08:00 committed by shaohuzhang1
parent 83c91da117
commit c2622e4a5d
2 changed files with 17 additions and 0 deletions

View File

@ -43,3 +43,11 @@ class OllamaChatModel(MaxKBBaseModel, ChatOpenAI):
return OllamaChatModel(model=model_name, openai_api_base=base_url,
openai_api_key=model_credential.get('api_key'),
stream_usage=True, **optional_params)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -8,10 +8,18 @@
"""
from typing import List, Dict
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai.chat_models import ChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):
@staticmethod
@ -32,5 +40,6 @@ class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):
**optional_params,
streaming=True,
stream_usage=True,
custom_get_token_ids=custom_get_token_ids
)
return azure_chat_open_ai