refactor: shared model

This commit is contained in:
wxg0103 2025-06-23 16:43:08 +08:00
parent 2c20733957
commit 37a2041d8d
3 changed files with 65 additions and 21 deletions

View File

@ -10,6 +10,7 @@ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.config.embedding_config import ModelManage
from common.database_model_manage.database_model_manage import DatabaseModelManage
from common.exception.app_exception import AppApiException
from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt
from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
@ -394,7 +395,22 @@ class ModelSerializer(serializers.Serializer):
return True
class SharedModelSerializer(serializers.Serializer):
def get_authorized_tool(tool_query_set, workspace_id, model_workspace_authorization):
white_authorized_tool_ids = QuerySet(model_workspace_authorization).filter(
workspace_id=workspace_id, authentication_type='WHITE_LIST'
).values_list('model_id', flat=True)
black_authorized_tool_ids = QuerySet(model_workspace_authorization).filter(
workspace_id=workspace_id, authentication_type='BLACK_LIST'
).values_list('model_id', flat=True)
tool_query_set = tool_query_set.filter(
id__in=white_authorized_tool_ids
).exclude(
id__in=black_authorized_tool_ids
)
return tool_query_set
class WorkspaceSharedModelSerializer(serializers.Serializer):
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
name = serializers.CharField(required=False, max_length=64, label=_('model name'))
model_type = serializers.CharField(required=False, label=_('model type'))
@ -404,7 +420,10 @@ class SharedModelSerializer(serializers.Serializer):
def get_share_model_list(self):
self.is_valid(raise_exception=True)
queryset = QuerySet(Model).filter(workspace_id='None')
workspace_id = self.data.get('workspace_id')
queryset = self._build_queryset(workspace_id)
return [
{
'id': str(model.id),
@ -419,3 +438,23 @@ class SharedModelSerializer(serializers.Serializer):
}
for model in queryset.order_by("-create_time")
]
def _build_queryset(self, workspace_id):
queryset = QuerySet(Model)
if workspace_id:
model_workspace_authorization = DatabaseModelManage.get_model("model_workspace_authorization")
if model_workspace_authorization is not None:
queryset = get_authorized_tool(queryset, workspace_id,
model_workspace_authorization=model_workspace_authorization)
for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']:
value = self.data.get(field)
if value is not None:
if field == 'name':
queryset = queryset.filter(**{f'{field}__icontains': value})
elif field == 'create_user':
queryset = queryset.filter(user_id=value)
else:
queryset = queryset.filter(**{field: value})
return queryset

View File

@ -18,7 +18,7 @@ urlpatterns = [
path('workspace/<str:workspace_id>/model/<str:model_id>/pause_download',
views.ModelSetting.PauseDownload.as_view()),
path('workspace/<str:workspace_id>/model/<str:model_id>/meta', views.ModelSetting.ModelMeta.as_view()),
path('workspace/<str:workspace_id>/shared/model', views.SharedModel.as_view()),
path('system/shared/workspace/<str:workspace_id>/model', views.WorkspaceSharedModelSetting.as_view()),
]
if os.environ.get('SERVER_NAME', 'web') == 'local_model':

View File

@ -21,7 +21,8 @@ from common.utils.common import query_params_to_single_dict
from models_provider.api.model import ModelCreateAPI, GetModelApi, ModelEditApi, ModelListResponse, DefaultModelResponse
from models_provider.api.provide import ProvideApi
from models_provider.models import Model
from models_provider.serializers.model_serializer import ModelSerializer, SharedModelSerializer
from models_provider.serializers.model_serializer import ModelSerializer, SharedModelSerializer, \
WorkspaceSharedModelSerializer
from system_manage.views import encryption_str
@ -65,7 +66,7 @@ class ModelSetting(APIView):
request=ModelCreateAPI.get_request(),
responses=ModelCreateAPI.get_response())
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission(),
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
@log(menu='model', operate='Create model',
get_operation_object=lambda r, k: {'name': r.date.get('name')},
get_details=get_edit_model_details,
@ -95,7 +96,7 @@ class ModelSetting(APIView):
responses=ModelListResponse.get_response(),
tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
def get(self, request: Request, workspace_id: str):
return result.success(
ModelSerializer.Query(
@ -114,7 +115,7 @@ class ModelSetting(APIView):
responses=ModelEditApi.get_response(),
tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_permission(),
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
@log(menu='model', operate='Update model',
get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')),
get_details=get_edit_model_details,
@ -133,7 +134,7 @@ class ModelSetting(APIView):
responses=DefaultModelResponse.get_response(),
tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_permission(),
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
@log(menu='model', operate='Delete model',
get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')),
)
@ -150,7 +151,7 @@ class ModelSetting(APIView):
responses=GetModelApi.get_response(),
tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
def get(self, request: Request, workspace_id: str, model_id: str):
return result.success(
ModelSerializer.Operate(
@ -168,7 +169,7 @@ class ModelSetting(APIView):
responses=ProvideApi.ModelParamsForm.get_response(),
tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
def get(self, request: Request, workspace_id: str, model_id: str):
return result.success(
ModelSerializer.ModelParams(data={'id': model_id}).get_model_params())
@ -182,7 +183,7 @@ class ModelSetting(APIView):
responses=ProvideApi.ModelParamsForm.get_response(),
tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
@log(menu='model', operate='Save model parameter form',
get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')),
)
@ -204,7 +205,7 @@ class ModelSetting(APIView):
responses=GetModelApi.get_response(),
tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
def get(self, request: Request, workspace_id: str, model_id: str):
return result.success(
ModelSerializer.Operate(data={'id': model_id, 'workspace_id': workspace_id}).one_meta(with_valid=True))
@ -221,25 +222,29 @@ class ModelSetting(APIView):
responses=DefaultModelResponse.get_response(),
tags=[_('Model')]) # type: ignore
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission(),
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
def put(self, request: Request, workspace_id: str, model_id: str):
return result.success(
ModelSerializer.Operate(data={'id': model_id, 'workspace_id': workspace_id}).pause_download())
class SharedModel(APIView):
class WorkspaceSharedModelSetting(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
methods=['Get'],
summary=_('Get Share model'),
description=_('Get Share model'),
operation_id=_('Get Share model'), # type: ignore
parameters=ModelCreateAPI.get_parameters(),
responses=ModelListResponse.get_response(),
summary=_('Get Share model by workspace id'),
description=_('Get Share model by workspace id'),
operation_id=_('Get Share model by workspace id'), # type: ignore
parameters=ModelListResponse.get_parameters(),
responses=DefaultModelResponse.get_response(),
tags=[_('Shared Model')]
) # type: ignore
@has_permissions(PermissionConstants.MODEL_READ, RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
@has_permissions(
PermissionConstants.MODEL_READ.get_workspace_permission(),
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(),
RoleConstants.USER.get_workspace_role(),
)
def get(self, request: Request, workspace_id: str):
return result.success(
SharedModelSerializer(data={'workspace_id': workspace_id}).get_share_model_list())
WorkspaceSharedModelSerializer(data={'workspace_id': workspace_id}).get_share_model_list())