diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py index d70f0de1..abb5fcb8 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py @@ -43,3 +43,11 @@ class OllamaChatModel(MaxKBBaseModel, ChatOpenAI): return OllamaChatModel(model=model_name, openai_api_base=base_url, openai_api_key=model_credential.get('api_key'), stream_usage=True, **optional_params) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py index d73065e8..40f90fa0 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -8,10 +8,18 @@ """ from typing import List, Dict +from langchain_core.messages import BaseMessage, get_buffer_string from langchain_openai.chat_models import ChatOpenAI + +from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI): @staticmethod @@ -32,5 +40,6 @@ class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI): **optional_params, streaming=True, stream_usage=True, + custom_get_token_ids=custom_get_token_ids ) return azure_chat_open_ai