refactor: model api

This commit is contained in:
wxg0103 2025-06-09 16:32:16 +08:00
parent 66c868e71f
commit 6e0e0d2366
3 changed files with 85 additions and 43 deletions

View File

@ -23,6 +23,52 @@ class ModelListResponse(APIMixin):
return ModelListResult return ModelListResult
@staticmethod
def get_parameters():
return [OpenApiParameter(
name="workspace_id",
description=_("workspace id"),
type=OpenApiTypes.STR,
location=OpenApiParameter.PATH,
required=True,
),
OpenApiParameter(
name="name",
description=_("model name"),
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
required=False,
),
OpenApiParameter(
name="model_type",
description=_("model type"),
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
required=False,
),
OpenApiParameter(
name="model_name",
description=_("base model"),
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
required=False,
),
OpenApiParameter(
name="provider",
description=_("provider"),
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
required=False,
),
OpenApiParameter(
name="create_user",
description=_("create user"),
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
required=False,
)
]
class ModelCreateAPI(APIMixin): class ModelCreateAPI(APIMixin):
@staticmethod @staticmethod
@ -34,7 +80,7 @@ class ModelCreateAPI(APIMixin):
return ModelCreateResponse return ModelCreateResponse
@classmethod @classmethod
def get_query_params_api(cls): def get_parameters(cls):
return [OpenApiParameter( return [OpenApiParameter(
name="workspace_id", name="workspace_id",
description=_("workspace id"), description=_("workspace id"),

View File

@ -105,7 +105,7 @@ class ModelSerializer(serializers.Serializer):
class Operate(serializers.Serializer): class Operate(serializers.Serializer):
id = serializers.UUIDField(required=True, label=_("model id")) id = serializers.UUIDField(required=True, label=_("model id"))
user_id = serializers.UUIDField(required=True, label=_("user id")) user_id = serializers.UUIDField(required=False, label=_("user id"))
def is_valid(self, *, raise_exception=False): def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True) super().is_valid(raise_exception=True)
@ -114,6 +114,8 @@ class ModelSerializer(serializers.Serializer):
).first() ).first()
if model is None: if model is None:
raise AppApiException(500, _('Model does not exist')) raise AppApiException(500, _('Model does not exist'))
if model.workspace_id == 'None':
raise AppApiException(500, _('Shared models cannot be deleted or modified'))
def one(self, with_valid=False): def one(self, with_valid=False):
if with_valid: if with_valid:
@ -147,8 +149,6 @@ class ModelSerializer(serializers.Serializer):
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
model_id = self.data.get('id') model_id = self.data.get('id')
model = Model.objects.filter(id=model_id).first() model = Model.objects.filter(id=model_id).first()
if not model:
raise AppApiException(500, _("Model does not exist"))
# TODO : 这里可以添加模型删除的逻辑,需要注意删除模型时的权限和关联关系 # TODO : 这里可以添加模型删除的逻辑,需要注意删除模型时的权限和关联关系
# if model.model_type == 'LLM': # if model.model_type == 'LLM':
# application_count = Application.objects.filter(model_id=model_id).count() # application_count = Application.objects.filter(model_id=model_id).count()
@ -174,9 +174,6 @@ class ModelSerializer(serializers.Serializer):
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
model = QuerySet(Model).filter(id=self.data.get('id')).first() model = QuerySet(Model).filter(id=self.data.get('id')).first()
if model is None:
raise AppApiException(500, _('Model does not exist'))
else:
credential, model_credential, provider_handler = ModelSerializer.Edit( credential, model_credential, provider_handler = ModelSerializer.Edit(
data={**instance}).is_valid( data={**instance}).is_valid(
model=model) model=model)

View File

@ -61,7 +61,7 @@ class ModelSetting(APIView):
description=_("Create model"), description=_("Create model"),
operation_id=_("Create model"), # type: ignore operation_id=_("Create model"), # type: ignore
tags=[_("Model")], # type: ignore tags=[_("Model")], # type: ignore
parameters=ModelCreateAPI.get_query_params_api(), parameters=ModelCreateAPI.get_parameters(),
request=ModelCreateAPI.get_request(), request=ModelCreateAPI.get_request(),
responses=ModelCreateAPI.get_response()) responses=ModelCreateAPI.get_response())
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission()) @has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission())
@ -90,7 +90,7 @@ class ModelSetting(APIView):
summary=_('Query model list'), summary=_('Query model list'),
description=_('Query model list'), description=_('Query model list'),
operation_id=_('Query model list'), # type: ignore operation_id=_('Query model list'), # type: ignore
parameters=ModelCreateAPI.get_query_params_api(), parameters=ModelListResponse.get_parameters(),
responses=ModelListResponse.get_response(), responses=ModelListResponse.get_response(),
tags=[_('Model')]) # type: ignore tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission()) @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
@ -108,7 +108,7 @@ class ModelSetting(APIView):
description=_('Update model'), description=_('Update model'),
operation_id=_('Update model'), # type: ignore operation_id=_('Update model'), # type: ignore
request=ModelEditApi.get_request(), request=ModelEditApi.get_request(),
parameters=GetModelApi.get_query_params_api(), parameters=GetModelApi.get_parameters(),
responses=ModelEditApi.get_response(), responses=ModelEditApi.get_response(),
tags=[_('Model')]) # type: ignore tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_permission()) @has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_permission())
@ -125,7 +125,7 @@ class ModelSetting(APIView):
summary=_('Delete model'), summary=_('Delete model'),
description=_('Delete model'), description=_('Delete model'),
operation_id=_('Delete model'), # type: ignore operation_id=_('Delete model'), # type: ignore
parameters=GetModelApi.get_query_params_api(), parameters=GetModelApi.get_parameters(),
responses=DefaultModelResponse.get_response(), responses=DefaultModelResponse.get_response(),
tags=[_('Model')]) # type: ignore tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_permission()) @has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_permission())
@ -139,7 +139,7 @@ class ModelSetting(APIView):
summary=_('Query model details'), summary=_('Query model details'),
description=_('Query model details'), description=_('Query model details'),
operation_id=_('Query model details'), # type: ignore operation_id=_('Query model details'), # type: ignore
parameters=GetModelApi.get_query_params_api(), parameters=GetModelApi.get_parameters(),
responses=GetModelApi.get_response(), responses=GetModelApi.get_response(),
tags=[_('Model')]) # type: ignore tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission()) @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
@ -154,7 +154,7 @@ class ModelSetting(APIView):
summary=_('Get model parameter form'), summary=_('Get model parameter form'),
description=_('Get model parameter form'), description=_('Get model parameter form'),
operation_id=_('Get model parameter form'), # type: ignore operation_id=_('Get model parameter form'), # type: ignore
parameters=GetModelApi.get_query_params_api(), parameters=GetModelApi.get_parameters(),
responses=ProvideApi.ModelParamsForm.get_response(), responses=ProvideApi.ModelParamsForm.get_response(),
tags=[_('Model')]) # type: ignore tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission()) @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
@ -166,7 +166,7 @@ class ModelSetting(APIView):
summary=_('Save model parameter form'), summary=_('Save model parameter form'),
description=_('Save model parameter form'), description=_('Save model parameter form'),
operation_id=_('Save model parameter form'), # type: ignore operation_id=_('Save model parameter form'), # type: ignore
parameters=GetModelApi.get_query_params_api(), parameters=GetModelApi.get_parameters(),
request=GetModelApi.get_request(), request=GetModelApi.get_request(),
responses=ProvideApi.ModelParamsForm.get_response(), responses=ProvideApi.ModelParamsForm.get_response(),
tags=[_('Model')]) # type: ignore tags=[_('Model')]) # type: ignore
@ -187,7 +187,7 @@ class ModelSetting(APIView):
'Query model meta information, this interface does not carry authentication information'), 'Query model meta information, this interface does not carry authentication information'),
operation_id=_( operation_id=_(
'Query model meta information, this interface does not carry authentication information'), 'Query model meta information, this interface does not carry authentication information'),
parameters=GetModelApi.get_query_params_api(), parameters=GetModelApi.get_parameters(),
responses=GetModelApi.get_response(), responses=GetModelApi.get_response(),
tags=[_('Model')]) # type: ignore tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission()) @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
@ -202,7 +202,7 @@ class ModelSetting(APIView):
summary=_('Pause model download'), summary=_('Pause model download'),
description=_('Pause model download'), description=_('Pause model download'),
operation_id=_('Pause model download'), # type: ignore operation_id=_('Pause model download'), # type: ignore
parameters=GetModelApi.get_query_params_api(), parameters=GetModelApi.get_parameters(),
request=GetModelApi.get_request(), request=GetModelApi.get_request(),
responses=DefaultModelResponse.get_response(), responses=DefaultModelResponse.get_response(),
tags=[_('Model')]) # type: ignore tags=[_('Model')]) # type: ignore
@ -218,9 +218,8 @@ class ModelSetting(APIView):
summary=_('Get Share model'), summary=_('Get Share model'),
description=_('Get Share model'), description=_('Get Share model'),
operation_id=_('Get Share model'), # type: ignore operation_id=_('Get Share model'), # type: ignore
parameters=GetModelApi.get_query_params_api(), parameters=ModelListResponse.get_parameters(),
request=GetModelApi.get_request(), responses=ModelListResponse.get_response(),
responses=DefaultModelResponse.get_response(),
tags=[_('Model')]) # type: ignore tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission()) @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
def get(self, request: Request, workspace_id: str): def get(self, request: Request, workspace_id: str):