refactor: model api

This commit is contained in:
wxg0103 2025-06-09 16:32:16 +08:00
parent 66c868e71f
commit 6e0e0d2366
3 changed files with 85 additions and 43 deletions

View File

@ -23,6 +23,52 @@ class ModelListResponse(APIMixin):
return ModelListResult
@staticmethod
def get_parameters():
return [OpenApiParameter(
name="workspace_id",
description=_("workspace id"),
type=OpenApiTypes.STR,
location=OpenApiParameter.PATH,
required=True,
),
OpenApiParameter(
name="name",
description=_("model name"),
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
required=False,
),
OpenApiParameter(
name="model_type",
description=_("model type"),
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
required=False,
),
OpenApiParameter(
name="model_name",
description=_("base model"),
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
required=False,
),
OpenApiParameter(
name="provider",
description=_("provider"),
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
required=False,
),
OpenApiParameter(
name="create_user",
description=_("create user"),
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
required=False,
)
]
class ModelCreateAPI(APIMixin):
@staticmethod
@ -34,7 +80,7 @@ class ModelCreateAPI(APIMixin):
return ModelCreateResponse
@classmethod
def get_query_params_api(cls):
def get_parameters(cls):
return [OpenApiParameter(
name="workspace_id",
description=_("workspace id"),

View File

@ -105,7 +105,7 @@ class ModelSerializer(serializers.Serializer):
class Operate(serializers.Serializer):
id = serializers.UUIDField(required=True, label=_("model id"))
user_id = serializers.UUIDField(required=True, label=_("user id"))
user_id = serializers.UUIDField(required=False, label=_("user id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -114,6 +114,8 @@ class ModelSerializer(serializers.Serializer):
).first()
if model is None:
raise AppApiException(500, _('Model does not exist'))
if model.workspace_id == 'None':
raise AppApiException(500, _('Shared models cannot be deleted or modified'))
def one(self, with_valid=False):
if with_valid:
@ -147,8 +149,6 @@ class ModelSerializer(serializers.Serializer):
self.is_valid(raise_exception=True)
model_id = self.data.get('id')
model = Model.objects.filter(id=model_id).first()
if not model:
raise AppApiException(500, _("Model does not exist"))
# TODO : 这里可以添加模型删除的逻辑,需要注意删除模型时的权限和关联关系
# if model.model_type == 'LLM':
# application_count = Application.objects.filter(model_id=model_id).count()
@ -174,35 +174,32 @@ class ModelSerializer(serializers.Serializer):
self.is_valid(raise_exception=True)
model = QuerySet(Model).filter(id=self.data.get('id')).first()
if model is None:
raise AppApiException(500, _('Model does not exist'))
else:
credential, model_credential, provider_handler = ModelSerializer.Edit(
data={**instance}).is_valid(
model=model)
try:
model.status = Status.SUCCESS
default_params = {item['field']: item['default_value'] for item in model.model_params_form}
# 校验模型认证数据
provider_handler.is_valid_credential(model.model_type,
instance.get("model_name"),
credential,
default_params,
raise_exception=True)
credential, model_credential, provider_handler = ModelSerializer.Edit(
data={**instance}).is_valid(
model=model)
try:
model.status = Status.SUCCESS
default_params = {item['field']: item['default_value'] for item in model.model_params_form}
# 校验模型认证数据
provider_handler.is_valid_credential(model.model_type,
instance.get("model_name"),
credential,
default_params,
raise_exception=True)
except AppApiException as e:
if e.code == ValidCode.model_not_fount:
model.status = Status.DOWNLOAD
except AppApiException as e:
if e.code == ValidCode.model_not_fount:
model.status = Status.DOWNLOAD
else:
raise e
update_keys = ['credential', 'name', 'model_type', 'model_name']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
if update_key == 'credential':
model_credential_str = json.dumps(credential)
model.__setattr__(update_key, rsa_long_encrypt(model_credential_str))
else:
raise e
update_keys = ['credential', 'name', 'model_type', 'model_name']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
if update_key == 'credential':
model_credential_str = json.dumps(credential)
model.__setattr__(update_key, rsa_long_encrypt(model_credential_str))
else:
model.__setattr__(update_key, instance.get(update_key))
model.__setattr__(update_key, instance.get(update_key))
ModelManage.delete_key(str(model.id))
model.save()

View File

@ -61,7 +61,7 @@ class ModelSetting(APIView):
description=_("Create model"),
operation_id=_("Create model"), # type: ignore
tags=[_("Model")], # type: ignore
parameters=ModelCreateAPI.get_query_params_api(),
parameters=ModelCreateAPI.get_parameters(),
request=ModelCreateAPI.get_request(),
responses=ModelCreateAPI.get_response())
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission())
@ -90,7 +90,7 @@ class ModelSetting(APIView):
summary=_('Query model list'),
description=_('Query model list'),
operation_id=_('Query model list'), # type: ignore
parameters=ModelCreateAPI.get_query_params_api(),
parameters=ModelListResponse.get_parameters(),
responses=ModelListResponse.get_response(),
tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
@ -108,7 +108,7 @@ class ModelSetting(APIView):
description=_('Update model'),
operation_id=_('Update model'), # type: ignore
request=ModelEditApi.get_request(),
parameters=GetModelApi.get_query_params_api(),
parameters=GetModelApi.get_parameters(),
responses=ModelEditApi.get_response(),
tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_permission())
@ -125,7 +125,7 @@ class ModelSetting(APIView):
summary=_('Delete model'),
description=_('Delete model'),
operation_id=_('Delete model'), # type: ignore
parameters=GetModelApi.get_query_params_api(),
parameters=GetModelApi.get_parameters(),
responses=DefaultModelResponse.get_response(),
tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_permission())
@ -139,7 +139,7 @@ class ModelSetting(APIView):
summary=_('Query model details'),
description=_('Query model details'),
operation_id=_('Query model details'), # type: ignore
parameters=GetModelApi.get_query_params_api(),
parameters=GetModelApi.get_parameters(),
responses=GetModelApi.get_response(),
tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
@ -154,7 +154,7 @@ class ModelSetting(APIView):
summary=_('Get model parameter form'),
description=_('Get model parameter form'),
operation_id=_('Get model parameter form'), # type: ignore
parameters=GetModelApi.get_query_params_api(),
parameters=GetModelApi.get_parameters(),
responses=ProvideApi.ModelParamsForm.get_response(),
tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
@ -166,7 +166,7 @@ class ModelSetting(APIView):
summary=_('Save model parameter form'),
description=_('Save model parameter form'),
operation_id=_('Save model parameter form'), # type: ignore
parameters=GetModelApi.get_query_params_api(),
parameters=GetModelApi.get_parameters(),
request=GetModelApi.get_request(),
responses=ProvideApi.ModelParamsForm.get_response(),
tags=[_('Model')]) # type: ignore
@ -187,7 +187,7 @@ class ModelSetting(APIView):
'Query model meta information, this interface does not carry authentication information'),
operation_id=_(
'Query model meta information, this interface does not carry authentication information'),
parameters=GetModelApi.get_query_params_api(),
parameters=GetModelApi.get_parameters(),
responses=GetModelApi.get_response(),
tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
@ -202,7 +202,7 @@ class ModelSetting(APIView):
summary=_('Pause model download'),
description=_('Pause model download'),
operation_id=_('Pause model download'), # type: ignore
parameters=GetModelApi.get_query_params_api(),
parameters=GetModelApi.get_parameters(),
request=GetModelApi.get_request(),
responses=DefaultModelResponse.get_response(),
tags=[_('Model')]) # type: ignore
@ -218,9 +218,8 @@ class ModelSetting(APIView):
summary=_('Get Share model'),
description=_('Get Share model'),
operation_id=_('Get Share model'), # type: ignore
parameters=GetModelApi.get_query_params_api(),
request=GetModelApi.get_request(),
responses=DefaultModelResponse.get_response(),
parameters=ModelListResponse.get_parameters(),
responses=ModelListResponse.get_response(),
tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
def get(self, request: Request, workspace_id: str):