refactor: gemini
This commit is contained in:
parent
aebbc794cd
commit
74ce98c39b
@ -13,7 +13,7 @@ from google.ai.generativelanguage_v1beta.types import (
|
||||
Tool as GoogleTool,
|
||||
)
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_google_genai._function_utils import _ToolConfigDict, _ToolDict
|
||||
@ -22,6 +22,8 @@ from langchain_google_genai.chat_models import _chat_with_retry, _response_to_re
|
||||
from langchain_google_genai._common import (
|
||||
SafetySettingDict,
|
||||
)
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
@ -46,10 +48,18 @@ class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
|
||||
return self.__dict__.get('_last_generation_info')
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
return self.get_last_generation_info().get('input_tokens', 0)
|
||||
try:
|
||||
return self.get_last_generation_info().get('input_tokens', 0)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return self.get_last_generation_info().get('output_tokens', 0)
|
||||
try:
|
||||
return self.get_last_generation_info().get('output_tokens', 0)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user