From c2622e4a5d458a13eba6589e5b9a874f93929424 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Thu, 29 Aug 2024 11:50:20 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E8=AE=A1=E7=AE=97tokens=E6=B6=88=E8=80=97=E6=8A=A5=E9=94=99?= =?UTF-8?q?=E7=9A=84=E7=BC=BA=E9=99=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../impl/ollama_model_provider/model/llm.py | 8 ++++++++ .../impl/openai_model_provider/model/llm.py | 9 +++++++++ 2 files changed, 17 insertions(+) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py index d70f0de1..abb5fcb8 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py @@ -43,3 +43,11 @@ class OllamaChatModel(MaxKBBaseModel, ChatOpenAI): return OllamaChatModel(model=model_name, openai_api_base=base_url, openai_api_key=model_credential.get('api_key'), stream_usage=True, **optional_params) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py index d73065e8..40f90fa0 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -8,10 +8,18 @@ """ from typing import List, Dict +from langchain_core.messages import BaseMessage, get_buffer_string from langchain_openai.chat_models import ChatOpenAI + +from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI): @staticmethod @@ -32,5 +40,6 @@ class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI): **optional_params, streaming=True, stream_usage=True, + custom_get_token_ids=custom_get_token_ids ) return azure_chat_open_ai