refactor: azure llm params
This commit is contained in:
parent
b0a4e9e78f
commit
a2b6620b10
@ -10,6 +10,7 @@ import traceback
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
|
from openai import BadRequestError
|
||||||
|
|
||||||
from common import forms
|
from common import forms
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
@ -37,6 +38,17 @@ class AzureLLMModelParams(BaseForm):
|
|||||||
precision=0)
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
|
class o3MiniLLMModelParams(BaseForm):
|
||||||
|
max_completion_tokens = forms.SliderField(
|
||||||
|
TooltipLabel(_('Output the maximum Tokens'),
|
||||||
|
_('Specify the maximum number of tokens that the model can generate')),
|
||||||
|
required=True, default_value=800,
|
||||||
|
_min=1,
|
||||||
|
_max=5000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
@ -57,7 +69,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException) or isinstance(e, BadRequestError):
|
||||||
raise e
|
raise e
|
||||||
if raise_exception:
|
if raise_exception:
|
||||||
raise AppApiException(ValidCode.valid_error.value,
|
raise AppApiException(ValidCode.valid_error.value,
|
||||||
@ -79,4 +91,6 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
deployment_name = forms.TextInputField("Deployment name", required=True)
|
deployment_name = forms.TextInputField("Deployment name", required=True)
|
||||||
|
|
||||||
def get_model_params_setting_form(self, model_name):
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
if 'o3' in model_name or 'o1' in model_name:
|
||||||
|
return o3MiniLLMModelParams()
|
||||||
return AzureLLMModelParams()
|
return AzureLLMModelParams()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user