refactor: model api
This commit is contained in:
parent
66c868e71f
commit
6e0e0d2366
@ -23,6 +23,52 @@ class ModelListResponse(APIMixin):
|
|||||||
|
|
||||||
return ModelListResult
|
return ModelListResult
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_parameters():
|
||||||
|
return [OpenApiParameter(
|
||||||
|
name="workspace_id",
|
||||||
|
description=_("workspace id"),
|
||||||
|
type=OpenApiTypes.STR,
|
||||||
|
location=OpenApiParameter.PATH,
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
OpenApiParameter(
|
||||||
|
name="name",
|
||||||
|
description=_("model name"),
|
||||||
|
type=OpenApiTypes.STR,
|
||||||
|
location=OpenApiParameter.QUERY,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
OpenApiParameter(
|
||||||
|
name="model_type",
|
||||||
|
description=_("model type"),
|
||||||
|
type=OpenApiTypes.STR,
|
||||||
|
location=OpenApiParameter.QUERY,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
OpenApiParameter(
|
||||||
|
name="model_name",
|
||||||
|
description=_("base model"),
|
||||||
|
type=OpenApiTypes.STR,
|
||||||
|
location=OpenApiParameter.QUERY,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
OpenApiParameter(
|
||||||
|
name="provider",
|
||||||
|
description=_("provider"),
|
||||||
|
type=OpenApiTypes.STR,
|
||||||
|
location=OpenApiParameter.QUERY,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
OpenApiParameter(
|
||||||
|
name="create_user",
|
||||||
|
description=_("create user"),
|
||||||
|
type=OpenApiTypes.STR,
|
||||||
|
location=OpenApiParameter.QUERY,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ModelCreateAPI(APIMixin):
|
class ModelCreateAPI(APIMixin):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -34,7 +80,7 @@ class ModelCreateAPI(APIMixin):
|
|||||||
return ModelCreateResponse
|
return ModelCreateResponse
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_query_params_api(cls):
|
def get_parameters(cls):
|
||||||
return [OpenApiParameter(
|
return [OpenApiParameter(
|
||||||
name="workspace_id",
|
name="workspace_id",
|
||||||
description=_("workspace id"),
|
description=_("workspace id"),
|
||||||
|
|||||||
@ -105,7 +105,7 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
|
|
||||||
class Operate(serializers.Serializer):
|
class Operate(serializers.Serializer):
|
||||||
id = serializers.UUIDField(required=True, label=_("model id"))
|
id = serializers.UUIDField(required=True, label=_("model id"))
|
||||||
user_id = serializers.UUIDField(required=True, label=_("user id"))
|
user_id = serializers.UUIDField(required=False, label=_("user id"))
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=False):
|
def is_valid(self, *, raise_exception=False):
|
||||||
super().is_valid(raise_exception=True)
|
super().is_valid(raise_exception=True)
|
||||||
@ -114,6 +114,8 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
).first()
|
).first()
|
||||||
if model is None:
|
if model is None:
|
||||||
raise AppApiException(500, _('Model does not exist'))
|
raise AppApiException(500, _('Model does not exist'))
|
||||||
|
if model.workspace_id == 'None':
|
||||||
|
raise AppApiException(500, _('Shared models cannot be deleted or modified'))
|
||||||
|
|
||||||
def one(self, with_valid=False):
|
def one(self, with_valid=False):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
@ -147,8 +149,6 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
model_id = self.data.get('id')
|
model_id = self.data.get('id')
|
||||||
model = Model.objects.filter(id=model_id).first()
|
model = Model.objects.filter(id=model_id).first()
|
||||||
if not model:
|
|
||||||
raise AppApiException(500, _("Model does not exist"))
|
|
||||||
# TODO : 这里可以添加模型删除的逻辑,需要注意删除模型时的权限和关联关系
|
# TODO : 这里可以添加模型删除的逻辑,需要注意删除模型时的权限和关联关系
|
||||||
# if model.model_type == 'LLM':
|
# if model.model_type == 'LLM':
|
||||||
# application_count = Application.objects.filter(model_id=model_id).count()
|
# application_count = Application.objects.filter(model_id=model_id).count()
|
||||||
@ -174,35 +174,32 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
model = QuerySet(Model).filter(id=self.data.get('id')).first()
|
model = QuerySet(Model).filter(id=self.data.get('id')).first()
|
||||||
|
|
||||||
if model is None:
|
credential, model_credential, provider_handler = ModelSerializer.Edit(
|
||||||
raise AppApiException(500, _('Model does not exist'))
|
data={**instance}).is_valid(
|
||||||
else:
|
model=model)
|
||||||
credential, model_credential, provider_handler = ModelSerializer.Edit(
|
try:
|
||||||
data={**instance}).is_valid(
|
model.status = Status.SUCCESS
|
||||||
model=model)
|
default_params = {item['field']: item['default_value'] for item in model.model_params_form}
|
||||||
try:
|
# 校验模型认证数据
|
||||||
model.status = Status.SUCCESS
|
provider_handler.is_valid_credential(model.model_type,
|
||||||
default_params = {item['field']: item['default_value'] for item in model.model_params_form}
|
instance.get("model_name"),
|
||||||
# 校验模型认证数据
|
credential,
|
||||||
provider_handler.is_valid_credential(model.model_type,
|
default_params,
|
||||||
instance.get("model_name"),
|
raise_exception=True)
|
||||||
credential,
|
|
||||||
default_params,
|
|
||||||
raise_exception=True)
|
|
||||||
|
|
||||||
except AppApiException as e:
|
except AppApiException as e:
|
||||||
if e.code == ValidCode.model_not_fount:
|
if e.code == ValidCode.model_not_fount:
|
||||||
model.status = Status.DOWNLOAD
|
model.status = Status.DOWNLOAD
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
update_keys = ['credential', 'name', 'model_type', 'model_name']
|
||||||
|
for update_key in update_keys:
|
||||||
|
if update_key in instance and instance.get(update_key) is not None:
|
||||||
|
if update_key == 'credential':
|
||||||
|
model_credential_str = json.dumps(credential)
|
||||||
|
model.__setattr__(update_key, rsa_long_encrypt(model_credential_str))
|
||||||
else:
|
else:
|
||||||
raise e
|
model.__setattr__(update_key, instance.get(update_key))
|
||||||
update_keys = ['credential', 'name', 'model_type', 'model_name']
|
|
||||||
for update_key in update_keys:
|
|
||||||
if update_key in instance and instance.get(update_key) is not None:
|
|
||||||
if update_key == 'credential':
|
|
||||||
model_credential_str = json.dumps(credential)
|
|
||||||
model.__setattr__(update_key, rsa_long_encrypt(model_credential_str))
|
|
||||||
else:
|
|
||||||
model.__setattr__(update_key, instance.get(update_key))
|
|
||||||
|
|
||||||
ModelManage.delete_key(str(model.id))
|
ModelManage.delete_key(str(model.id))
|
||||||
model.save()
|
model.save()
|
||||||
|
|||||||
@ -61,7 +61,7 @@ class ModelSetting(APIView):
|
|||||||
description=_("Create model"),
|
description=_("Create model"),
|
||||||
operation_id=_("Create model"), # type: ignore
|
operation_id=_("Create model"), # type: ignore
|
||||||
tags=[_("Model")], # type: ignore
|
tags=[_("Model")], # type: ignore
|
||||||
parameters=ModelCreateAPI.get_query_params_api(),
|
parameters=ModelCreateAPI.get_parameters(),
|
||||||
request=ModelCreateAPI.get_request(),
|
request=ModelCreateAPI.get_request(),
|
||||||
responses=ModelCreateAPI.get_response())
|
responses=ModelCreateAPI.get_response())
|
||||||
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission())
|
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission())
|
||||||
@ -90,7 +90,7 @@ class ModelSetting(APIView):
|
|||||||
summary=_('Query model list'),
|
summary=_('Query model list'),
|
||||||
description=_('Query model list'),
|
description=_('Query model list'),
|
||||||
operation_id=_('Query model list'), # type: ignore
|
operation_id=_('Query model list'), # type: ignore
|
||||||
parameters=ModelCreateAPI.get_query_params_api(),
|
parameters=ModelListResponse.get_parameters(),
|
||||||
responses=ModelListResponse.get_response(),
|
responses=ModelListResponse.get_response(),
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
|
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
|
||||||
@ -108,7 +108,7 @@ class ModelSetting(APIView):
|
|||||||
description=_('Update model'),
|
description=_('Update model'),
|
||||||
operation_id=_('Update model'), # type: ignore
|
operation_id=_('Update model'), # type: ignore
|
||||||
request=ModelEditApi.get_request(),
|
request=ModelEditApi.get_request(),
|
||||||
parameters=GetModelApi.get_query_params_api(),
|
parameters=GetModelApi.get_parameters(),
|
||||||
responses=ModelEditApi.get_response(),
|
responses=ModelEditApi.get_response(),
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_permission())
|
@has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_permission())
|
||||||
@ -125,7 +125,7 @@ class ModelSetting(APIView):
|
|||||||
summary=_('Delete model'),
|
summary=_('Delete model'),
|
||||||
description=_('Delete model'),
|
description=_('Delete model'),
|
||||||
operation_id=_('Delete model'), # type: ignore
|
operation_id=_('Delete model'), # type: ignore
|
||||||
parameters=GetModelApi.get_query_params_api(),
|
parameters=GetModelApi.get_parameters(),
|
||||||
responses=DefaultModelResponse.get_response(),
|
responses=DefaultModelResponse.get_response(),
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_permission())
|
@has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_permission())
|
||||||
@ -139,7 +139,7 @@ class ModelSetting(APIView):
|
|||||||
summary=_('Query model details'),
|
summary=_('Query model details'),
|
||||||
description=_('Query model details'),
|
description=_('Query model details'),
|
||||||
operation_id=_('Query model details'), # type: ignore
|
operation_id=_('Query model details'), # type: ignore
|
||||||
parameters=GetModelApi.get_query_params_api(),
|
parameters=GetModelApi.get_parameters(),
|
||||||
responses=GetModelApi.get_response(),
|
responses=GetModelApi.get_response(),
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
|
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
|
||||||
@ -154,7 +154,7 @@ class ModelSetting(APIView):
|
|||||||
summary=_('Get model parameter form'),
|
summary=_('Get model parameter form'),
|
||||||
description=_('Get model parameter form'),
|
description=_('Get model parameter form'),
|
||||||
operation_id=_('Get model parameter form'), # type: ignore
|
operation_id=_('Get model parameter form'), # type: ignore
|
||||||
parameters=GetModelApi.get_query_params_api(),
|
parameters=GetModelApi.get_parameters(),
|
||||||
responses=ProvideApi.ModelParamsForm.get_response(),
|
responses=ProvideApi.ModelParamsForm.get_response(),
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
|
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
|
||||||
@ -166,7 +166,7 @@ class ModelSetting(APIView):
|
|||||||
summary=_('Save model parameter form'),
|
summary=_('Save model parameter form'),
|
||||||
description=_('Save model parameter form'),
|
description=_('Save model parameter form'),
|
||||||
operation_id=_('Save model parameter form'), # type: ignore
|
operation_id=_('Save model parameter form'), # type: ignore
|
||||||
parameters=GetModelApi.get_query_params_api(),
|
parameters=GetModelApi.get_parameters(),
|
||||||
request=GetModelApi.get_request(),
|
request=GetModelApi.get_request(),
|
||||||
responses=ProvideApi.ModelParamsForm.get_response(),
|
responses=ProvideApi.ModelParamsForm.get_response(),
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@ -187,7 +187,7 @@ class ModelSetting(APIView):
|
|||||||
'Query model meta information, this interface does not carry authentication information'),
|
'Query model meta information, this interface does not carry authentication information'),
|
||||||
operation_id=_(
|
operation_id=_(
|
||||||
'Query model meta information, this interface does not carry authentication information'),
|
'Query model meta information, this interface does not carry authentication information'),
|
||||||
parameters=GetModelApi.get_query_params_api(),
|
parameters=GetModelApi.get_parameters(),
|
||||||
responses=GetModelApi.get_response(),
|
responses=GetModelApi.get_response(),
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
|
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
|
||||||
@ -202,7 +202,7 @@ class ModelSetting(APIView):
|
|||||||
summary=_('Pause model download'),
|
summary=_('Pause model download'),
|
||||||
description=_('Pause model download'),
|
description=_('Pause model download'),
|
||||||
operation_id=_('Pause model download'), # type: ignore
|
operation_id=_('Pause model download'), # type: ignore
|
||||||
parameters=GetModelApi.get_query_params_api(),
|
parameters=GetModelApi.get_parameters(),
|
||||||
request=GetModelApi.get_request(),
|
request=GetModelApi.get_request(),
|
||||||
responses=DefaultModelResponse.get_response(),
|
responses=DefaultModelResponse.get_response(),
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@ -218,9 +218,8 @@ class ModelSetting(APIView):
|
|||||||
summary=_('Get Share model'),
|
summary=_('Get Share model'),
|
||||||
description=_('Get Share model'),
|
description=_('Get Share model'),
|
||||||
operation_id=_('Get Share model'), # type: ignore
|
operation_id=_('Get Share model'), # type: ignore
|
||||||
parameters=GetModelApi.get_query_params_api(),
|
parameters=ModelListResponse.get_parameters(),
|
||||||
request=GetModelApi.get_request(),
|
responses=ModelListResponse.get_response(),
|
||||||
responses=DefaultModelResponse.get_response(),
|
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
|
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
|
||||||
def get(self, request: Request, workspace_id: str):
|
def get(self, request: Request, workspace_id: str):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user