refactor: check model use model_params
This commit is contained in:
parent
628cf705ce
commit
6412825d30
@ -81,7 +81,7 @@ def get_model_type_list(provider):
|
|||||||
return get_provider(provider).get_model_type_list()
|
return get_provider(provider).get_model_type_list()
|
||||||
|
|
||||||
|
|
||||||
def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
|
def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], model_params, raise_exception=False):
|
||||||
"""
|
"""
|
||||||
校验模型认证参数
|
校验模型认证参数
|
||||||
@param provider: 供应商字符串
|
@param provider: 供应商字符串
|
||||||
@ -91,4 +91,4 @@ def is_valid_credential(provider, model_type, model_name, model_credential: Dict
|
|||||||
@param raise_exception: 是否抛出错误
|
@param raise_exception: 是否抛出错误
|
||||||
@return: True|False
|
@return: True|False
|
||||||
"""
|
"""
|
||||||
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception)
|
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, model_params, raise_exception)
|
||||||
|
|||||||
@ -67,9 +67,13 @@ class IModelProvider(ABC):
|
|||||||
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
||||||
return model_info.model_credential
|
return model_info.model_credential
|
||||||
|
|
||||||
def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
|
def get_model_params(self, model_type, model_name):
|
||||||
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
||||||
return model_info.model_credential.is_valid(model_type, model_name, model_credential, self,
|
return model_info.model_credential
|
||||||
|
|
||||||
|
def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object], model_params: Dict[str, object], raise_exception=False):
|
||||||
|
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
||||||
|
return model_info.model_credential.is_valid(model_type, model_name, model_credential, model_params, self,
|
||||||
raise_exception=raise_exception)
|
raise_exception=raise_exception)
|
||||||
|
|
||||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseModel:
|
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseModel:
|
||||||
@ -105,7 +109,7 @@ class MaxKBBaseModel(ABC):
|
|||||||
class BaseModelCredential(ABC):
|
class BaseModelCredential(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def is_valid(self, model_type: str, model_name, model: Dict[str, object], provider, raise_exception=True):
|
def is_valid(self, model_type: str, model_name, model: Dict[str, object], model_params, provider, raise_exception=True):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.embedding
|
|||||||
|
|
||||||
class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):
|
class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -37,7 +37,7 @@ class QwenModelParams(BaseForm):
|
|||||||
|
|
||||||
class QwenVLModelCredential(BaseForm, BaseModelCredential):
|
class QwenVLModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -49,7 +49,7 @@ class QwenVLModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
||||||
for chunk in res:
|
for chunk in res:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
|||||||
@ -28,7 +28,7 @@ class BaiLianLLMModelParams(BaseForm):
|
|||||||
|
|
||||||
class BaiLianLLMModelCredential(BaseForm, BaseModelCredential):
|
class BaiLianLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -41,7 +41,7 @@ class BaiLianLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.invoke([HumanMessage(content='你好')])
|
model.invoke([HumanMessage(content='你好')])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.reranker
|
|||||||
|
|
||||||
class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential):
|
class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
if not model_type == 'RERANKER':
|
if not model_type == 'RERANKER':
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
|
|||||||
class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
|
class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
api_key = forms.PasswordInputField("API Key", required=True)
|
api_key = forms.PasswordInputField("API Key", required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -61,7 +61,7 @@ class QwenModelParams(BaseForm):
|
|||||||
|
|
||||||
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
|
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -73,7 +73,7 @@ class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.check_auth()
|
res = model.check_auth()
|
||||||
print(res)
|
print(res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -45,7 +45,7 @@ class AliyunBaiLianTTSModelGeneralParams(BaseForm):
|
|||||||
class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential):
|
class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential):
|
||||||
api_key = forms.PasswordInputField("API Key", required=True)
|
api_key = forms.PasswordInputField("API Key", required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -58,7 +58,7 @@ class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.check_auth()
|
model.check_auth()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
|
|||||||
@ -26,7 +26,7 @@ class BedrockEmbeddingCredential(BaseForm, BaseModelCredential):
|
|||||||
with open(credentials_path, 'w') as file:
|
with open(credentials_path, 'w') as file:
|
||||||
file.write(content)
|
file.write(content)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(mt.get('value') == model_type for mt in model_type_list):
|
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||||
|
|||||||
@ -44,7 +44,7 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
with open(credentials_path, 'w') as file:
|
with open(credentials_path, 'w') as file:
|
||||||
file.write(content)
|
file.write(content)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(mt.get('value') == model_type for mt in model_type_list):
|
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||||
@ -62,7 +62,7 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
self._update_aws_credentials('aws-profile', model_credential['access_key_id'],
|
self._update_aws_credentials('aws-profile', model_credential['access_key_id'],
|
||||||
model_credential['secret_access_key'])
|
model_credential['secret_access_key'])
|
||||||
model_credential['credentials_profile_name'] = 'aws-profile'
|
model_credential['credentials_profile_name'] = 'aws-profile'
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.invoke([HumanMessage(content='你好')])
|
model.invoke([HumanMessage(content='你好')])
|
||||||
except AppApiException:
|
except AppApiException:
|
||||||
raise
|
raise
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
|
|||||||
|
|
||||||
class AzureOpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
|
class AzureOpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -33,7 +33,7 @@ class AzureOpenAIImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
|
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
|
||||||
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
|
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -46,7 +46,7 @@ class AzureOpenAIImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
||||||
for chunk in res:
|
for chunk in res:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class AzureLLMModelParams(BaseForm):
|
|||||||
|
|
||||||
class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -48,7 +48,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.invoke([HumanMessage(content='你好')])
|
model.invoke([HumanMessage(content='你好')])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
|
|||||||
@ -12,7 +12,7 @@ class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential):
|
|||||||
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
|
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
|
||||||
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
|
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -51,7 +51,7 @@ class AzureOpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
|
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
|
||||||
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
|
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -64,7 +64,7 @@ class AzureOpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.check_auth()
|
res = model.check_auth()
|
||||||
print(res)
|
print(res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -28,7 +28,7 @@ class AzureOpenAITTSModelCredential(BaseForm, BaseModelCredential):
|
|||||||
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
|
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
|
||||||
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
|
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -41,7 +41,7 @@ class AzureOpenAITTSModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.check_auth()
|
model.check_auth()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class DeepSeekLLMModelParams(BaseForm):
|
|||||||
|
|
||||||
class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
|
class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -48,7 +48,7 @@ class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.invoke([HumanMessage(content='你好')])
|
model.invoke([HumanMessage(content='你好')])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
|
|||||||
|
|
||||||
|
|
||||||
class GeminiEmbeddingCredential(BaseForm, BaseModelCredential):
|
class GeminiEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=True):
|
raise_exception=True):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class GeminiImageModelParams(BaseForm):
|
|||||||
class GeminiImageModelCredential(BaseForm, BaseModelCredential):
|
class GeminiImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -44,7 +44,7 @@ class GeminiImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
||||||
for chunk in res:
|
for chunk in res:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class GeminiLLMModelParams(BaseForm):
|
|||||||
|
|
||||||
class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
|
class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -48,7 +48,7 @@ class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.invoke([HumanMessage(content='你好')])
|
res = model.invoke([HumanMessage(content='你好')])
|
||||||
print(res)
|
print(res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
|
|||||||
class GeminiSTTModelCredential(BaseForm, BaseModelCredential):
|
class GeminiSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class KimiLLMModelParams(BaseForm):
|
|||||||
|
|
||||||
class KimiLLMModelCredential(BaseForm, BaseModelCredential):
|
class KimiLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -48,7 +48,7 @@ class KimiLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.invoke([HumanMessage(content='你好')])
|
model.invoke([HumanMessage(content='你好')])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from setting.models_provider.impl.local_model_provider.model.embedding import Lo
|
|||||||
|
|
||||||
class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
|
class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
if not model_type == 'EMBEDDING':
|
if not model_type == 'EMBEDDING':
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from setting.models_provider.impl.local_model_provider.model.reranker import Loc
|
|||||||
|
|
||||||
class LocalRerankerCredential(BaseForm, BaseModelCredential):
|
class LocalRerankerCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
if not model_type == 'RERANKER':
|
if not model_type == 'RERANKER':
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from setting.models_provider.impl.local_model_provider.model.embedding import Lo
|
|||||||
|
|
||||||
|
|
||||||
class OllamaEmbeddingModelCredential(BaseForm, BaseModelCredential):
|
class OllamaEmbeddingModelCredential(BaseForm, BaseModelCredential):
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -32,7 +32,7 @@ class OllamaImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
api_base = forms.TextInputField('API 域名', required=True)
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -32,7 +32,7 @@ class OllamaLLMModelParams(BaseForm):
|
|||||||
|
|
||||||
|
|
||||||
class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
|
class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
|
class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=True):
|
raise_exception=True):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -32,7 +32,7 @@ class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
api_base = forms.TextInputField('API 域名', required=True)
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -45,7 +45,7 @@ class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
||||||
for chunk in res:
|
for chunk in res:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class OpenAILLMModelParams(BaseForm):
|
|||||||
|
|
||||||
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -48,7 +48,8 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.invoke([HumanMessage(content='你好')])
|
model.invoke([HumanMessage(content='你好')])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
|
|||||||
@ -11,7 +11,7 @@ class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
|
|||||||
api_base = forms.TextInputField('API 域名', required=True)
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -50,7 +50,7 @@ class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
api_base = forms.TextInputField('API 域名', required=True)
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -63,7 +63,7 @@ class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.check_auth()
|
res = model.check_auth()
|
||||||
print(res)
|
print(res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -27,7 +27,7 @@ class OpenAITTSModelCredential(BaseForm, BaseModelCredential):
|
|||||||
api_base = forms.TextInputField('API 域名', required=True)
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -40,7 +40,7 @@ class OpenAITTSModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.check_auth()
|
model.check_auth()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
|
|||||||
@ -37,7 +37,7 @@ class QwenModelParams(BaseForm):
|
|||||||
|
|
||||||
class QwenVLModelCredential(BaseForm, BaseModelCredential):
|
class QwenVLModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -49,7 +49,7 @@ class QwenVLModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
||||||
for chunk in res:
|
for chunk in res:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class QwenModelParams(BaseForm):
|
|||||||
|
|
||||||
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -47,7 +47,7 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.invoke([HumanMessage(content='你好')])
|
model.invoke([HumanMessage(content='你好')])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
|
|||||||
@ -61,7 +61,7 @@ class QwenModelParams(BaseForm):
|
|||||||
|
|
||||||
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
|
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -73,7 +73,7 @@ class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.check_auth()
|
res = model.check_auth()
|
||||||
print(res)
|
print(res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
|
|||||||
|
|
||||||
class TencentEmbeddingCredential(BaseForm, BaseModelCredential):
|
class TencentEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=True) -> bool:
|
raise_exception=True) -> bool:
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -37,7 +37,7 @@ class QwenModelParams(BaseForm):
|
|||||||
|
|
||||||
class TencentVisionModelCredential(BaseForm, BaseModelCredential):
|
class TencentVisionModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -49,7 +49,7 @@ class TencentVisionModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
||||||
for chunk in res:
|
for chunk in res:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
|||||||
@ -36,12 +36,12 @@ class TencentLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def is_valid(self, model_type, model_name, model_credential, provider, raise_exception=False):
|
def is_valid(self, model_type, model_name, model_credential, provider, model_params, raise_exception=False):
|
||||||
if not (self._validate_model_type(model_type, provider, raise_exception) and
|
if not (self._validate_model_type(model_type, provider, raise_exception) and
|
||||||
self._validate_credential_fields(model_credential, raise_exception)):
|
self._validate_credential_fields(model_credential, raise_exception)):
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.invoke([HumanMessage(content='你好')])
|
model.invoke([HumanMessage(content='你好')])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if raise_exception:
|
if raise_exception:
|
||||||
|
|||||||
@ -85,12 +85,12 @@ class TencentTTIModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def is_valid(self, model_type, model_name, model_credential, provider, raise_exception=False):
|
def is_valid(self, model_type, model_name, model_credential, model_params, provider, raise_exception=False):
|
||||||
if not (self._validate_model_type(model_type, provider, raise_exception) and
|
if not (self._validate_model_type(model_type, provider, raise_exception) and
|
||||||
self._validate_credential_fields(model_credential, raise_exception)):
|
self._validate_credential_fields(model_credential, raise_exception)):
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.check_auth()
|
model.check_auth()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if raise_exception:
|
if raise_exception:
|
||||||
|
|||||||
@ -28,7 +28,7 @@ class VLLMModelParams(BaseForm):
|
|||||||
|
|
||||||
|
|
||||||
class VLLMModelCredential(BaseForm, BaseModelCredential):
|
class VLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -40,7 +40,7 @@ class VLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
exist = provider.get_model_info_by_name(model_list, model_name)
|
exist = provider.get_model_info_by_name(model_list, model_name)
|
||||||
if len(exist) == 0:
|
if len(exist) == 0:
|
||||||
raise AppApiException(ValidCode.valid_error.value, "模型不存在,请先下载模型")
|
raise AppApiException(ValidCode.valid_error.value, "模型不存在,请先下载模型")
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
try:
|
try:
|
||||||
res = model.invoke([HumanMessage(content='你好')])
|
res = model.invoke([HumanMessage(content='你好')])
|
||||||
print(res)
|
print(res)
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
|
class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=True):
|
raise_exception=True):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -30,7 +30,7 @@ class VolcanicEngineImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
api_base = forms.TextInputField('API 域名', required=True)
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -43,7 +43,7 @@ class VolcanicEngineImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
||||||
for chunk in res:
|
for chunk in res:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class VolcanicEngineLLMModelParams(BaseForm):
|
|||||||
|
|
||||||
class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential):
|
class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -48,7 +48,7 @@ class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.invoke([HumanMessage(content='你好')])
|
res = model.invoke([HumanMessage(content='你好')])
|
||||||
print(res)
|
print(res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -14,7 +14,7 @@ class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
|
|||||||
volcanic_token = forms.PasswordInputField('Access Token', required=True)
|
volcanic_token = forms.PasswordInputField('Access Token', required=True)
|
||||||
volcanic_cluster = forms.TextInputField('Cluster ID', required=True)
|
volcanic_cluster = forms.TextInputField('Cluster ID', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class VolcanicEngineTTIModelCredential(BaseForm, BaseModelCredential):
|
|||||||
access_key = forms.PasswordInputField('Access Key ID', required=True)
|
access_key = forms.PasswordInputField('Access Key ID', required=True)
|
||||||
secret_key = forms.PasswordInputField('Secret Access Key', required=True)
|
secret_key = forms.PasswordInputField('Secret Access Key', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -44,7 +44,7 @@ class VolcanicEngineTTIModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.check_auth()
|
model.check_auth()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
|
|||||||
@ -42,7 +42,7 @@ class VolcanicEngineTTSModelCredential(BaseForm, BaseModelCredential):
|
|||||||
volcanic_token = forms.PasswordInputField('Access Token', required=True)
|
volcanic_token = forms.PasswordInputField('Access Token', required=True)
|
||||||
volcanic_cluster = forms.TextInputField('Cluster ID', required=True)
|
volcanic_cluster = forms.TextInputField('Cluster ID', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -55,7 +55,7 @@ class VolcanicEngineTTSModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.check_auth()
|
model.check_auth()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
|
|||||||
|
|
||||||
class QianfanEmbeddingCredential(BaseForm, BaseModelCredential):
|
class QianfanEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -34,12 +34,12 @@ class WenxinLLMModelParams(BaseForm):
|
|||||||
|
|
||||||
|
|
||||||
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model_info = [model.lower() for model in model.client.models()]
|
model_info = [model.lower() for model in model.client.models()]
|
||||||
if not model_info.__contains__(model_name.lower()):
|
if not model_info.__contains__(model_name.lower()):
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持')
|
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持')
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
|
|||||||
|
|
||||||
class XFEmbeddingCredential(BaseForm, BaseModelCredential):
|
class XFEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -18,7 +18,7 @@ class XunFeiImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||||
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -31,7 +31,7 @@ class XunFeiImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||||
with open(f'{cwd}/img_1.png', 'rb') as f:
|
with open(f'{cwd}/img_1.png', 'rb') as f:
|
||||||
message_list = [ImageMessage(str(base64.b64encode(f.read()), 'utf-8')), HumanMessage('请概述这张图片')]
|
message_list = [ImageMessage(str(base64.b64encode(f.read()), 'utf-8')), HumanMessage('请概述这张图片')]
|
||||||
|
|||||||
@ -52,7 +52,7 @@ class XunFeiLLMModelProParams(BaseForm):
|
|||||||
|
|
||||||
class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
|
class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -65,7 +65,7 @@ class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.invoke([HumanMessage(content='你好')])
|
model.invoke([HumanMessage(content='你好')])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
|
|||||||
@ -14,7 +14,7 @@ class XunFeiSTTModelCredential(BaseForm, BaseModelCredential):
|
|||||||
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||||
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -36,7 +36,7 @@ class XunFeiTTSModelCredential(BaseForm, BaseModelCredential):
|
|||||||
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||||
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -49,7 +49,7 @@ class XunFeiTTSModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.check_auth()
|
model.check_auth()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from setting.models_provider.impl.local_model_provider.model.embedding import Lo
|
|||||||
|
|
||||||
|
|
||||||
class XinferenceEmbeddingModelCredential(BaseForm, BaseModelCredential):
|
class XinferenceEmbeddingModelCredential(BaseForm, BaseModelCredential):
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object],model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -32,7 +32,7 @@ class XinferenceImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
api_base = forms.TextInputField('API 域名', required=True)
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -45,7 +45,7 @@ class XinferenceImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
||||||
for chunk in res:
|
for chunk in res:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
|||||||
@ -28,7 +28,7 @@ class XinferenceLLMModelParams(BaseForm):
|
|||||||
|
|
||||||
|
|
||||||
class XinferenceLLMModelCredential(BaseForm, BaseModelCredential):
|
class XinferenceLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -40,7 +40,7 @@ class XinferenceLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
exist = provider.get_model_info_by_name(model_list, model_name)
|
exist = provider.get_model_info_by_name(model_list, model_name)
|
||||||
if len(exist) == 0:
|
if len(exist) == 0:
|
||||||
raise AppApiException(ValidCode.valid_error.value, "模型不存在,请先下载模型")
|
raise AppApiException(ValidCode.valid_error.value, "模型不存在,请先下载模型")
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.invoke([HumanMessage(content='你好')])
|
model.invoke([HumanMessage(content='你好')])
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
|
|||||||
|
|
||||||
|
|
||||||
class XInferenceRerankerModelCredential(BaseForm, BaseModelCredential):
|
class XInferenceRerankerModelCredential(BaseForm, BaseModelCredential):
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=True):
|
raise_exception=True):
|
||||||
if not model_type == 'RERANKER':
|
if not model_type == 'RERANKER':
|
||||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|||||||
@ -11,7 +11,7 @@ class XInferenceSTTModelCredential(BaseForm, BaseModelCredential):
|
|||||||
api_base = forms.TextInputField('API 域名', required=True)
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
|||||||
@ -50,7 +50,7 @@ class XinferenceTextToImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
api_base = forms.TextInputField('API 域名', required=True)
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -63,7 +63,7 @@ class XinferenceTextToImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.check_auth()
|
res = model.check_auth()
|
||||||
print(res)
|
print(res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -29,7 +29,7 @@ class XInferenceTTSModelCredential(BaseForm, BaseModelCredential):
|
|||||||
api_base = forms.TextInputField('API 域名', required=True)
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -42,7 +42,7 @@ class XInferenceTTSModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.check_auth()
|
model.check_auth()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
|
|||||||
@ -29,7 +29,7 @@ class ZhiPuImageModelParams(BaseForm):
|
|||||||
class ZhiPuImageModelCredential(BaseForm, BaseModelCredential):
|
class ZhiPuImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -42,7 +42,7 @@ class ZhiPuImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
||||||
for chunk in res:
|
for chunk in res:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class ZhiPuLLMModelParams(BaseForm):
|
|||||||
|
|
||||||
class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential):
|
class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -47,7 +47,7 @@ class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.invoke([HumanMessage(content='你好')])
|
model.invoke([HumanMessage(content='你好')])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
|
|||||||
@ -29,7 +29,7 @@ class ZhiPuTTIModelParams(BaseForm):
|
|||||||
class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential):
|
class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
@ -42,7 +42,7 @@ class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
res = model.check_auth()
|
res = model.check_auth()
|
||||||
print(res)
|
print(res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -189,9 +189,11 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
if QuerySet(Model).filter(user_id=self.data.get('user_id'),
|
if QuerySet(Model).filter(user_id=self.data.get('user_id'),
|
||||||
name=self.data.get('name')).exists():
|
name=self.data.get('name')).exists():
|
||||||
raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在')
|
raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在')
|
||||||
|
default_params = {item['field']: item['default_value'] for item in self.data.get('model_params_form')}
|
||||||
ModelProvideConstants[self.data.get('provider')].value.is_valid_credential(self.data.get('model_type'),
|
ModelProvideConstants[self.data.get('provider')].value.is_valid_credential(self.data.get('model_type'),
|
||||||
self.data.get('model_name'),
|
self.data.get('model_name'),
|
||||||
self.data.get('credential'),
|
self.data.get('credential'),
|
||||||
|
default_params,
|
||||||
raise_exception=True
|
raise_exception=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -354,10 +356,12 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
model=model)
|
model=model)
|
||||||
try:
|
try:
|
||||||
model.status = Status.SUCCESS
|
model.status = Status.SUCCESS
|
||||||
|
default_params = {item['field']: item['default_value'] for item in model.model_params_form}
|
||||||
# 校验模型认证数据
|
# 校验模型认证数据
|
||||||
provider_handler.is_valid_credential(model.model_type,
|
provider_handler.is_valid_credential(model.model_type,
|
||||||
instance.get("model_name"),
|
instance.get("model_name"),
|
||||||
credential,
|
credential,
|
||||||
|
default_params,
|
||||||
raise_exception=True)
|
raise_exception=True)
|
||||||
|
|
||||||
except AppApiException as e:
|
except AppApiException as e:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user