refactor: tokens
This commit is contained in:
parent
3d1c43c020
commit
ff41b1ff6e
@ -20,5 +20,5 @@ class BaiLianChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
openai_api_base=model_credential.get('api_base'),
|
openai_api_base=model_credential.get('api_base'),
|
||||||
openai_api_key=model_credential.get('api_key'),
|
openai_api_key=model_credential.get('api_key'),
|
||||||
**optional_params
|
extra_body=optional_params
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
|
|
||||||
from botocore.config import Config
|
from botocore.config import Config
|
||||||
from langchain_community.chat_models import BedrockChat
|
from langchain_community.chat_models import BedrockChat
|
||||||
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
from models_provider.base_model_provider import MaxKBBaseModel
|
from models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
@ -72,6 +74,19 @@ class BedrockModel(MaxKBBaseModel, BedrockChat):
|
|||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
try:
|
||||||
|
return super().get_num_tokens_from_messages(messages)
|
||||||
|
except Exception as e:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
try:
|
||||||
|
return super().get_num_tokens(text)
|
||||||
|
except Exception as e:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
|
|
||||||
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
|
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
|
||||||
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
|
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
|
||||||
|
|||||||
@ -26,6 +26,6 @@ class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
openai_api_base='https://api.deepseek.com',
|
openai_api_base='https://api.deepseek.com',
|
||||||
openai_api_key=model_credential.get('api_key'),
|
openai_api_key=model_credential.get('api_key'),
|
||||||
**optional_params
|
extra_body=optional_params
|
||||||
)
|
)
|
||||||
return deepseek_chat_open_ai
|
return deepseek_chat_open_ai
|
||||||
|
|||||||
@ -26,6 +26,6 @@ class KimiChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||||||
openai_api_base=model_credential['api_base'],
|
openai_api_base=model_credential['api_base'],
|
||||||
openai_api_key=model_credential['api_key'],
|
openai_api_key=model_credential['api_key'],
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
**optional_params
|
extra_body=optional_params,
|
||||||
)
|
)
|
||||||
return kimi_chat_open_ai
|
return kimi_chat_open_ai
|
||||||
|
|||||||
@ -25,7 +25,7 @@ class OllamaLLMModelParams(BaseForm):
|
|||||||
_step=0.01,
|
_step=0.01,
|
||||||
precision=2)
|
precision=2)
|
||||||
|
|
||||||
max_tokens = forms.SliderField(
|
num_predict = forms.SliderField(
|
||||||
TooltipLabel(_('Output the maximum Tokens'),
|
TooltipLabel(_('Output the maximum Tokens'),
|
||||||
_('Specify the maximum number of tokens that the model can generate')),
|
_('Specify the maximum number of tokens that the model can generate')),
|
||||||
required=True, default_value=1024,
|
required=True, default_value=1024,
|
||||||
|
|||||||
@ -33,15 +33,15 @@ class OpenAIChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||||||
streaming = model_kwargs.get('streaming', True)
|
streaming = model_kwargs.get('streaming', True)
|
||||||
if 'o1' in model_name:
|
if 'o1' in model_name:
|
||||||
streaming = False
|
streaming = False
|
||||||
azure_chat_open_ai = OpenAIChatModel(
|
chat_open_ai = OpenAIChatModel(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
openai_api_base=model_credential.get('api_base'),
|
openai_api_base=model_credential.get('api_base'),
|
||||||
openai_api_key=model_credential.get('api_key'),
|
openai_api_key=model_credential.get('api_key'),
|
||||||
**optional_params,
|
extra_body=optional_params,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
custom_get_token_ids=custom_get_token_ids
|
custom_get_token_ids=custom_get_token_ids
|
||||||
)
|
)
|
||||||
return azure_chat_open_ai
|
return chat_open_ai
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -34,5 +34,5 @@ class SiliconCloudChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
openai_api_base=model_credential.get('api_base'),
|
openai_api_base=model_credential.get('api_base'),
|
||||||
openai_api_key=model_credential.get('api_key'),
|
openai_api_key=model_credential.get('api_key'),
|
||||||
**optional_params
|
extra_body=optional_params
|
||||||
)
|
)
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class VllmChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
openai_api_base=model_credential.get('api_base'),
|
openai_api_base=model_credential.get('api_base'),
|
||||||
openai_api_key=model_credential.get('api_key'),
|
openai_api_key=model_credential.get('api_key'),
|
||||||
**optional_params,
|
extra_body=optional_params,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
stream_usage=True,
|
stream_usage=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -17,5 +17,5 @@ class VolcanicEngineChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
openai_api_base=model_credential.get('api_base'),
|
openai_api_base=model_credential.get('api_base'),
|
||||||
openai_api_key=model_credential.get('api_key'),
|
openai_api_key=model_credential.get('api_key'),
|
||||||
**optional_params
|
extra_body=optional_params
|
||||||
)
|
)
|
||||||
|
|||||||
@ -34,7 +34,7 @@ class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
openai_api_base=base_url,
|
openai_api_base=base_url,
|
||||||
openai_api_key=model_credential.get('api_key'),
|
openai_api_key=model_credential.get('api_key'),
|
||||||
**optional_params
|
extra_body=optional_params
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user