refactor: shared model
This commit is contained in:
parent
2c20733957
commit
37a2041d8d
@ -10,6 +10,7 @@ from django.utils.translation import gettext_lazy as _
|
|||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from common.config.embedding_config import ModelManage
|
from common.config.embedding_config import ModelManage
|
||||||
|
from common.database_model_manage.database_model_manage import DatabaseModelManage
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt
|
from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt
|
||||||
from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
|
from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
|
||||||
@ -394,7 +395,22 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class SharedModelSerializer(serializers.Serializer):
|
def get_authorized_tool(tool_query_set, workspace_id, model_workspace_authorization):
|
||||||
|
white_authorized_tool_ids = QuerySet(model_workspace_authorization).filter(
|
||||||
|
workspace_id=workspace_id, authentication_type='WHITE_LIST'
|
||||||
|
).values_list('model_id', flat=True)
|
||||||
|
black_authorized_tool_ids = QuerySet(model_workspace_authorization).filter(
|
||||||
|
workspace_id=workspace_id, authentication_type='BLACK_LIST'
|
||||||
|
).values_list('model_id', flat=True)
|
||||||
|
tool_query_set = tool_query_set.filter(
|
||||||
|
id__in=white_authorized_tool_ids
|
||||||
|
).exclude(
|
||||||
|
id__in=black_authorized_tool_ids
|
||||||
|
)
|
||||||
|
return tool_query_set
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceSharedModelSerializer(serializers.Serializer):
|
||||||
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
|
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
|
||||||
name = serializers.CharField(required=False, max_length=64, label=_('model name'))
|
name = serializers.CharField(required=False, max_length=64, label=_('model name'))
|
||||||
model_type = serializers.CharField(required=False, label=_('model type'))
|
model_type = serializers.CharField(required=False, label=_('model type'))
|
||||||
@ -404,7 +420,10 @@ class SharedModelSerializer(serializers.Serializer):
|
|||||||
|
|
||||||
def get_share_model_list(self):
|
def get_share_model_list(self):
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
queryset = QuerySet(Model).filter(workspace_id='None')
|
workspace_id = self.data.get('workspace_id')
|
||||||
|
|
||||||
|
queryset = self._build_queryset(workspace_id)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
'id': str(model.id),
|
'id': str(model.id),
|
||||||
@ -419,3 +438,23 @@ class SharedModelSerializer(serializers.Serializer):
|
|||||||
}
|
}
|
||||||
for model in queryset.order_by("-create_time")
|
for model in queryset.order_by("-create_time")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _build_queryset(self, workspace_id):
|
||||||
|
queryset = QuerySet(Model)
|
||||||
|
if workspace_id:
|
||||||
|
model_workspace_authorization = DatabaseModelManage.get_model("model_workspace_authorization")
|
||||||
|
if model_workspace_authorization is not None:
|
||||||
|
queryset = get_authorized_tool(queryset, workspace_id,
|
||||||
|
model_workspace_authorization=model_workspace_authorization)
|
||||||
|
|
||||||
|
for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']:
|
||||||
|
value = self.data.get(field)
|
||||||
|
if value is not None:
|
||||||
|
if field == 'name':
|
||||||
|
queryset = queryset.filter(**{f'{field}__icontains': value})
|
||||||
|
elif field == 'create_user':
|
||||||
|
queryset = queryset.filter(user_id=value)
|
||||||
|
else:
|
||||||
|
queryset = queryset.filter(**{field: value})
|
||||||
|
|
||||||
|
return queryset
|
||||||
|
|||||||
@ -18,7 +18,7 @@ urlpatterns = [
|
|||||||
path('workspace/<str:workspace_id>/model/<str:model_id>/pause_download',
|
path('workspace/<str:workspace_id>/model/<str:model_id>/pause_download',
|
||||||
views.ModelSetting.PauseDownload.as_view()),
|
views.ModelSetting.PauseDownload.as_view()),
|
||||||
path('workspace/<str:workspace_id>/model/<str:model_id>/meta', views.ModelSetting.ModelMeta.as_view()),
|
path('workspace/<str:workspace_id>/model/<str:model_id>/meta', views.ModelSetting.ModelMeta.as_view()),
|
||||||
path('workspace/<str:workspace_id>/shared/model', views.SharedModel.as_view()),
|
path('system/shared/workspace/<str:workspace_id>/model', views.WorkspaceSharedModelSetting.as_view()),
|
||||||
]
|
]
|
||||||
|
|
||||||
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||||
|
|||||||
@ -21,7 +21,8 @@ from common.utils.common import query_params_to_single_dict
|
|||||||
from models_provider.api.model import ModelCreateAPI, GetModelApi, ModelEditApi, ModelListResponse, DefaultModelResponse
|
from models_provider.api.model import ModelCreateAPI, GetModelApi, ModelEditApi, ModelListResponse, DefaultModelResponse
|
||||||
from models_provider.api.provide import ProvideApi
|
from models_provider.api.provide import ProvideApi
|
||||||
from models_provider.models import Model
|
from models_provider.models import Model
|
||||||
from models_provider.serializers.model_serializer import ModelSerializer, SharedModelSerializer
|
from models_provider.serializers.model_serializer import ModelSerializer, SharedModelSerializer, \
|
||||||
|
WorkspaceSharedModelSerializer
|
||||||
from system_manage.views import encryption_str
|
from system_manage.views import encryption_str
|
||||||
|
|
||||||
|
|
||||||
@ -65,7 +66,7 @@ class ModelSetting(APIView):
|
|||||||
request=ModelCreateAPI.get_request(),
|
request=ModelCreateAPI.get_request(),
|
||||||
responses=ModelCreateAPI.get_response())
|
responses=ModelCreateAPI.get_response())
|
||||||
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission(),
|
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission(),
|
||||||
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
||||||
@log(menu='model', operate='Create model',
|
@log(menu='model', operate='Create model',
|
||||||
get_operation_object=lambda r, k: {'name': r.date.get('name')},
|
get_operation_object=lambda r, k: {'name': r.date.get('name')},
|
||||||
get_details=get_edit_model_details,
|
get_details=get_edit_model_details,
|
||||||
@ -95,7 +96,7 @@ class ModelSetting(APIView):
|
|||||||
responses=ModelListResponse.get_response(),
|
responses=ModelListResponse.get_response(),
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
|
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
|
||||||
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
||||||
def get(self, request: Request, workspace_id: str):
|
def get(self, request: Request, workspace_id: str):
|
||||||
return result.success(
|
return result.success(
|
||||||
ModelSerializer.Query(
|
ModelSerializer.Query(
|
||||||
@ -114,7 +115,7 @@ class ModelSetting(APIView):
|
|||||||
responses=ModelEditApi.get_response(),
|
responses=ModelEditApi.get_response(),
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_permission(),
|
@has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_permission(),
|
||||||
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
||||||
@log(menu='model', operate='Update model',
|
@log(menu='model', operate='Update model',
|
||||||
get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')),
|
get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')),
|
||||||
get_details=get_edit_model_details,
|
get_details=get_edit_model_details,
|
||||||
@ -133,7 +134,7 @@ class ModelSetting(APIView):
|
|||||||
responses=DefaultModelResponse.get_response(),
|
responses=DefaultModelResponse.get_response(),
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_permission(),
|
@has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_permission(),
|
||||||
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
||||||
@log(menu='model', operate='Delete model',
|
@log(menu='model', operate='Delete model',
|
||||||
get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')),
|
get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')),
|
||||||
)
|
)
|
||||||
@ -150,7 +151,7 @@ class ModelSetting(APIView):
|
|||||||
responses=GetModelApi.get_response(),
|
responses=GetModelApi.get_response(),
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
|
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
|
||||||
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
||||||
def get(self, request: Request, workspace_id: str, model_id: str):
|
def get(self, request: Request, workspace_id: str, model_id: str):
|
||||||
return result.success(
|
return result.success(
|
||||||
ModelSerializer.Operate(
|
ModelSerializer.Operate(
|
||||||
@ -168,7 +169,7 @@ class ModelSetting(APIView):
|
|||||||
responses=ProvideApi.ModelParamsForm.get_response(),
|
responses=ProvideApi.ModelParamsForm.get_response(),
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
|
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
|
||||||
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
||||||
def get(self, request: Request, workspace_id: str, model_id: str):
|
def get(self, request: Request, workspace_id: str, model_id: str):
|
||||||
return result.success(
|
return result.success(
|
||||||
ModelSerializer.ModelParams(data={'id': model_id}).get_model_params())
|
ModelSerializer.ModelParams(data={'id': model_id}).get_model_params())
|
||||||
@ -182,7 +183,7 @@ class ModelSetting(APIView):
|
|||||||
responses=ProvideApi.ModelParamsForm.get_response(),
|
responses=ProvideApi.ModelParamsForm.get_response(),
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
|
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
|
||||||
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
||||||
@log(menu='model', operate='Save model parameter form',
|
@log(menu='model', operate='Save model parameter form',
|
||||||
get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')),
|
get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')),
|
||||||
)
|
)
|
||||||
@ -204,7 +205,7 @@ class ModelSetting(APIView):
|
|||||||
responses=GetModelApi.get_response(),
|
responses=GetModelApi.get_response(),
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
|
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
|
||||||
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
||||||
def get(self, request: Request, workspace_id: str, model_id: str):
|
def get(self, request: Request, workspace_id: str, model_id: str):
|
||||||
return result.success(
|
return result.success(
|
||||||
ModelSerializer.Operate(data={'id': model_id, 'workspace_id': workspace_id}).one_meta(with_valid=True))
|
ModelSerializer.Operate(data={'id': model_id, 'workspace_id': workspace_id}).one_meta(with_valid=True))
|
||||||
@ -221,25 +222,29 @@ class ModelSetting(APIView):
|
|||||||
responses=DefaultModelResponse.get_response(),
|
responses=DefaultModelResponse.get_response(),
|
||||||
tags=[_('Model')]) # type: ignore
|
tags=[_('Model')]) # type: ignore
|
||||||
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission(),
|
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission(),
|
||||||
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
||||||
def put(self, request: Request, workspace_id: str, model_id: str):
|
def put(self, request: Request, workspace_id: str, model_id: str):
|
||||||
return result.success(
|
return result.success(
|
||||||
ModelSerializer.Operate(data={'id': model_id, 'workspace_id': workspace_id}).pause_download())
|
ModelSerializer.Operate(data={'id': model_id, 'workspace_id': workspace_id}).pause_download())
|
||||||
|
|
||||||
|
|
||||||
class SharedModel(APIView):
|
class WorkspaceSharedModelSetting(APIView):
|
||||||
authentication_classes = [TokenAuth]
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
@extend_schema(
|
@extend_schema(
|
||||||
methods=['Get'],
|
methods=['Get'],
|
||||||
summary=_('Get Share model'),
|
summary=_('Get Share model by workspace id'),
|
||||||
description=_('Get Share model'),
|
description=_('Get Share model by workspace id'),
|
||||||
operation_id=_('Get Share model'), # type: ignore
|
operation_id=_('Get Share model by workspace id'), # type: ignore
|
||||||
parameters=ModelCreateAPI.get_parameters(),
|
parameters=ModelListResponse.get_parameters(),
|
||||||
responses=ModelListResponse.get_response(),
|
responses=DefaultModelResponse.get_response(),
|
||||||
tags=[_('Shared Model')]
|
tags=[_('Shared Model')]
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
@has_permissions(PermissionConstants.MODEL_READ, RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
@has_permissions(
|
||||||
|
PermissionConstants.MODEL_READ.get_workspace_permission(),
|
||||||
|
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(),
|
||||||
|
RoleConstants.USER.get_workspace_role(),
|
||||||
|
)
|
||||||
def get(self, request: Request, workspace_id: str):
|
def get(self, request: Request, workspace_id: str):
|
||||||
return result.success(
|
return result.success(
|
||||||
SharedModelSerializer(data={'workspace_id': workspace_id}).get_share_model_list())
|
WorkspaceSharedModelSerializer(data={'workspace_id': workspace_id}).get_share_model_list())
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user