fix: 修复openai供应商计算tokens错误,修复旧版本应用编辑页面直接报错提示问题为必填参数错误
This commit is contained in:
parent
b627b6638c
commit
96ac12ea31
@ -96,7 +96,7 @@ class NoReferencesSetting(serializers.Serializer):
|
|||||||
|
|
||||||
|
|
||||||
def valid_model_params_setting(model_id, model_params_setting):
|
def valid_model_params_setting(model_id, model_params_setting):
|
||||||
if model_id is None:
|
if model_id is None or model_params_setting is None or len(model_params_setting.keys()) == 0:
|
||||||
return
|
return
|
||||||
model = QuerySet(Model).filter(id=model_id).first()
|
model = QuerySet(Model).filter(id=model_id).first()
|
||||||
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
||||||
@ -416,7 +416,7 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
model_setting=application.get('model_setting'),
|
model_setting=application.get('model_setting'),
|
||||||
problem_optimization=application.get('problem_optimization'),
|
problem_optimization=application.get('problem_optimization'),
|
||||||
type=ApplicationTypeChoices.SIMPLE,
|
type=ApplicationTypeChoices.SIMPLE,
|
||||||
model_params_setting=application.get('model_params_setting',{}),
|
model_params_setting=application.get('model_params_setting', {}),
|
||||||
work_flow={}
|
work_flow={}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -697,7 +697,9 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
ApplicationSerializer.Edit(data=instance).is_valid(
|
ApplicationSerializer.Edit(data=instance).is_valid(
|
||||||
raise_exception=True)
|
raise_exception=True)
|
||||||
application_id = self.data.get("application_id")
|
application_id = self.data.get("application_id")
|
||||||
valid_model_params_setting(instance.get('model_id'), instance.get('model_params_setting'))
|
valid_model_params_setting(instance.get('model_id'),
|
||||||
|
instance.get('model_params_setting'))
|
||||||
|
|
||||||
application = QuerySet(Application).get(id=application_id)
|
application = QuerySet(Application).get(id=application_id)
|
||||||
if instance.get('model_id') is None or len(instance.get('model_id')) == 0:
|
if instance.get('model_id') is None or len(instance.get('model_id')) == 0:
|
||||||
application.model_id = None
|
application.model_id = None
|
||||||
|
|||||||
@ -43,3 +43,17 @@ class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):
|
|||||||
custom_get_token_ids=custom_get_token_ids
|
custom_get_token_ids=custom_get_token_ids
|
||||||
)
|
)
|
||||||
return azure_chat_open_ai
|
return azure_chat_open_ai
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
try:
|
||||||
|
super().get_num_tokens_from_messages(messages)
|
||||||
|
except Exception as e:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
try:
|
||||||
|
super().get_num_tokens(text)
|
||||||
|
except Exception as e:
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user