refactor: update gemini params

This commit is contained in:
wxg0103 2024-08-19 11:48:35 +08:00
parent 8e8ce795b2
commit 527a0953d8

View File

@ -6,11 +6,18 @@
@Author Brian Yang @Author Brian Yang
@Date 5/13/24 7:40 AM @Date 5/13/24 7:40 AM
""" """
from typing import List, Dict from typing import List, Dict, Optional, Sequence, Union, Any, Iterator, cast
from google.ai.generativelanguage_v1 import GenerateContentResponse
from google.generativeai.responder import ToolDict
from google.generativeai.types import FunctionDeclarationType, SafetySettingDict
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.outputs import ChatGenerationChunk
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_genai._function_utils import _ToolConfigDict
from langchain_google_genai.chat_models import _chat_with_retry, _response_to_result
from google.generativeai.types import Tool as GoogleTool
from common.config.tokenizer_manage_config import TokenizerManage from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.base_model_provider import MaxKBBaseModel
@ -36,10 +43,49 @@ class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
) )
return gemini_chat return gemini_chat
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer() return self.get_last_generation_info().get('input_tokens', 0)
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer() return self.get_last_generation_info().get('output_tokens', 0)
return len(tokenizer.encode(text))
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
*,
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
functions: Optional[Sequence[FunctionDeclarationType]] = None,
safety_settings: Optional[SafetySettingDict] = None,
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
generation_config: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = self._prepare_request(
messages,
stop=stop,
tools=tools,
functions=functions,
safety_settings=safety_settings,
tool_config=tool_config,
generation_config=generation_config,
)
response: GenerateContentResponse = _chat_with_retry(
request=request,
generation_method=self.client.stream_generate_content,
**kwargs,
metadata=self.default_metadata,
)
for chunk in response:
_chat_result = _response_to_result(chunk, stream=True)
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
if gen.message:
token_usage = gen.message.usage_metadata
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
if run_manager:
run_manager.on_llm_new_token(gen.text)
yield gen