refactor: update gemini params
This commit is contained in:
parent
8e8ce795b2
commit
527a0953d8
@ -6,11 +6,18 @@
|
|||||||
@Author :Brian Yang
|
@Author :Brian Yang
|
||||||
@Date :5/13/24 7:40 AM
|
@Date :5/13/24 7:40 AM
|
||||||
"""
|
"""
|
||||||
from typing import List, Dict
|
from typing import List, Dict, Optional, Sequence, Union, Any, Iterator, cast
|
||||||
|
|
||||||
|
from google.ai.generativelanguage_v1 import GenerateContentResponse
|
||||||
|
from google.generativeai.responder import ToolDict
|
||||||
|
from google.generativeai.types import FunctionDeclarationType, SafetySettingDict
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
from langchain_core.outputs import ChatGenerationChunk
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
from langchain_google_genai._function_utils import _ToolConfigDict
|
||||||
|
from langchain_google_genai.chat_models import _chat_with_retry, _response_to_result
|
||||||
|
from google.generativeai.types import Tool as GoogleTool
|
||||||
from common.config.tokenizer_manage_config import TokenizerManage
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
@ -36,10 +43,49 @@ class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
|
|||||||
)
|
)
|
||||||
return gemini_chat
|
return gemini_chat
|
||||||
|
|
||||||
|
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||||
|
return self.__dict__.get('_last_generation_info')
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
tokenizer = TokenizerManage.get_tokenizer()
|
return self.get_last_generation_info().get('input_tokens', 0)
|
||||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
tokenizer = TokenizerManage.get_tokenizer()
|
return self.get_last_generation_info().get('output_tokens', 0)
|
||||||
return len(tokenizer.encode(text))
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
*,
|
||||||
|
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
|
||||||
|
functions: Optional[Sequence[FunctionDeclarationType]] = None,
|
||||||
|
safety_settings: Optional[SafetySettingDict] = None,
|
||||||
|
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
||||||
|
generation_config: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
request = self._prepare_request(
|
||||||
|
messages,
|
||||||
|
stop=stop,
|
||||||
|
tools=tools,
|
||||||
|
functions=functions,
|
||||||
|
safety_settings=safety_settings,
|
||||||
|
tool_config=tool_config,
|
||||||
|
generation_config=generation_config,
|
||||||
|
)
|
||||||
|
response: GenerateContentResponse = _chat_with_retry(
|
||||||
|
request=request,
|
||||||
|
generation_method=self.client.stream_generate_content,
|
||||||
|
**kwargs,
|
||||||
|
metadata=self.default_metadata,
|
||||||
|
)
|
||||||
|
for chunk in response:
|
||||||
|
_chat_result = _response_to_result(chunk, stream=True)
|
||||||
|
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
||||||
|
if gen.message:
|
||||||
|
token_usage = gen.message.usage_metadata
|
||||||
|
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(gen.text)
|
||||||
|
yield gen
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user