feat: 模型接口添加权限参数

This commit is contained in:
shaohuzhang1 2024-07-16 16:08:31 +08:00
parent 2f3d282c0d
commit 9b81b89975
2 changed files with 22 additions and 5 deletions

View File

@ -7,12 +7,14 @@
@desc: @desc:
""" """
import json import json
import re
import threading import threading
import time import time
import uuid import uuid
from typing import Dict from typing import Dict
from django.db.models import QuerySet from django.core import validators
from django.db.models import QuerySet, Q
from rest_framework import serializers from rest_framework import serializers
from application.models import Application from application.models import Application
@ -72,7 +74,7 @@ class ModelSerializer(serializers.Serializer):
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
user_id = self.data.get('user_id') user_id = self.data.get('user_id')
name = self.data.get('name') name = self.data.get('name')
model_query_set = QuerySet(Model).filter(user_id=user_id) model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC')))
query_params = {} query_params = {}
if name is not None: if name is not None:
query_params['name__contains'] = name query_params['name__contains'] = name
@ -96,6 +98,11 @@ class ModelSerializer(serializers.Serializer):
model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型")) model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
permission_type = serializers.CharField(required=False, error_messages=ErrMessage.char("权限"), validators=[
validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"),
message="权限只支持PUBLIC|PRIVATE", code=500)
])
model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型")) model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
credential = serializers.DictField(required=False, error_messages=ErrMessage.dict("认证信息")) credential = serializers.DictField(required=False, error_messages=ErrMessage.dict("认证信息"))
@ -135,6 +142,11 @@ class ModelSerializer(serializers.Serializer):
model_type = serializers.CharField(required=True, error_messages=ErrMessage.char("模型类型")) model_type = serializers.CharField(required=True, error_messages=ErrMessage.char("模型类型"))
permission_type = serializers.CharField(required=True, error_messages=ErrMessage.char("权限"), validators=[
validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"),
message="权限只支持PUBLIC|PRIVATE", code=500)
])
model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型")) model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型"))
credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息")) credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息"))
@ -165,10 +177,12 @@ class ModelSerializer(serializers.Serializer):
provider = self.data.get('provider') provider = self.data.get('provider')
model_type = self.data.get('model_type') model_type = self.data.get('model_type')
model_name = self.data.get('model_name') model_name = self.data.get('model_name')
permission_type = self.data.get('permission_type')
model_credential_str = json.dumps(credential) model_credential_str = json.dumps(credential)
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name, model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
credential=rsa_long_encrypt(model_credential_str), credential=rsa_long_encrypt(model_credential_str),
provider=provider, model_type=model_type, model_name=model_name) provider=provider, model_type=model_type, model_name=model_name,
permission_type=permission_type)
model.save() model.save()
if status == Status.DOWNLOAD: if status == Status.DOWNLOAD:
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential)) thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
@ -245,7 +259,7 @@ class ModelSerializer(serializers.Serializer):
model.status = Status.DOWNLOAD model.status = Status.DOWNLOAD
else: else:
raise e raise e
update_keys = ['credential', 'name', 'model_type', 'model_name'] update_keys = ['credential', 'name', 'model_type', 'model_name', 'permission_type']
for update_key in update_keys: for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None: if update_key in instance and instance.get(update_key) is not None:
if update_key == 'credential': if update_key == 'credential':

View File

@ -74,6 +74,8 @@ class ModelCreateApi(ApiMixin):
'provider': openapi.Schema(type=openapi.TYPE_STRING, 'provider': openapi.Schema(type=openapi.TYPE_STRING,
title="供应商", title="供应商",
description="供应商"), description="供应商"),
'permission_type': openapi.Schema(type=openapi.TYPE_STRING, title="权限",
description="PUBLIC|PRIVATE"),
'model_type': openapi.Schema(type=openapi.TYPE_STRING, 'model_type': openapi.Schema(type=openapi.TYPE_STRING,
title="供应商", title="供应商",
description="供应商"), description="供应商"),
@ -82,7 +84,8 @@ class ModelCreateApi(ApiMixin):
description="供应商"), description="供应商"),
'credential': openapi.Schema(type=openapi.TYPE_OBJECT, 'credential': openapi.Schema(type=openapi.TYPE_OBJECT,
title="模型证书信息", title="模型证书信息",
description="模型证书信息") description="模型证书信息"),
} }
) )