refactor: shared model
This commit is contained in:
parent
b2465b28dc
commit
516c8ea9d2
@ -336,20 +336,53 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
query_params = self._build_query_params(workspace_id)
|
query_params = self._build_query_params(workspace_id)
|
||||||
return self._fetch_models(query_params)
|
return self._fetch_models(query_params)
|
||||||
|
|
||||||
|
def model_list(self, workspace_id, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
queryset = self._build_query_params(workspace_id)
|
||||||
|
model_workspace_authorization = DatabaseModelManage.get_model("model_workspace_authorization")
|
||||||
|
if model_workspace_authorization is not None:
|
||||||
|
queryset = get_authorized_tool(queryset, workspace_id,
|
||||||
|
model_workspace_authorization=model_workspace_authorization)
|
||||||
|
shared_model = []
|
||||||
|
normal_model = []
|
||||||
|
|
||||||
|
for model in queryset.order_by("-create_time"):
|
||||||
|
data = {
|
||||||
|
'id': str(model.id),
|
||||||
|
'provider': model.provider,
|
||||||
|
'name': model.name,
|
||||||
|
'model_type': model.model_type,
|
||||||
|
'model_name': model.model_name,
|
||||||
|
'status': model.status,
|
||||||
|
'meta': model.meta,
|
||||||
|
'user_id': model.user_id,
|
||||||
|
'username': model.user.nick_name
|
||||||
|
}
|
||||||
|
if model.workspace_id == 'None':
|
||||||
|
shared_model.append(data)
|
||||||
|
else:
|
||||||
|
normal_model.append(data)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"shared_model": shared_model,
|
||||||
|
"model": normal_model
|
||||||
|
}
|
||||||
|
|
||||||
def _build_query_params(self, workspace_id):
|
def _build_query_params(self, workspace_id):
|
||||||
query_params = {}
|
queryset = QuerySet(Model)
|
||||||
if workspace_id:
|
if workspace_id:
|
||||||
query_params['workspace_id'] = workspace_id
|
queryset = queryset.filter(workspace_id=workspace_id)
|
||||||
for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']:
|
for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']:
|
||||||
value = self.data.get(field)
|
value = self.data.get(field)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
if field == 'name':
|
if field == 'name':
|
||||||
query_params[f'{field}__icontains'] = value
|
queryset = queryset.filter(**{f'{field}__icontains': value})
|
||||||
elif field == 'create_user':
|
elif field == 'create_user':
|
||||||
query_params['user_id'] = value
|
queryset = queryset.filter(user_id=value)
|
||||||
else:
|
else:
|
||||||
query_params[field] = value
|
queryset = queryset.filter(**{field: value})
|
||||||
return query_params
|
return queryset
|
||||||
|
|
||||||
def _fetch_models(self, query_params):
|
def _fetch_models(self, query_params):
|
||||||
return [
|
return [
|
||||||
|
|||||||
@ -12,6 +12,7 @@ urlpatterns = [
|
|||||||
path('provider/model_params_form', views.Provide.ModelParamsForm.as_view()),
|
path('provider/model_params_form', views.Provide.ModelParamsForm.as_view()),
|
||||||
path('provider/model_form', views.Provide.ModelForm.as_view()),
|
path('provider/model_form', views.Provide.ModelForm.as_view()),
|
||||||
path('workspace/<str:workspace_id>/model', views.ModelSetting.as_view()),
|
path('workspace/<str:workspace_id>/model', views.ModelSetting.as_view()),
|
||||||
|
path('workspace/<str:workspace_id>/model_list', views.ModelList.as_view()),
|
||||||
path('workspace/<str:workspace_id>/model/<str:model_id>/model_params_form',
|
path('workspace/<str:workspace_id>/model/<str:model_id>/model_params_form',
|
||||||
views.ModelSetting.ModelParamsForm.as_view()),
|
views.ModelSetting.ModelParamsForm.as_view()),
|
||||||
path('workspace/<str:workspace_id>/model/<str:model_id>', views.ModelSetting.Operate.as_view()),
|
path('workspace/<str:workspace_id>/model/<str:model_id>', views.ModelSetting.Operate.as_view()),
|
||||||
|
|||||||
@ -248,3 +248,22 @@ class WorkspaceSharedModelSetting(APIView):
|
|||||||
def get(self, request: Request, workspace_id: str):
|
def get(self, request: Request, workspace_id: str):
|
||||||
return result.success(
|
return result.success(
|
||||||
WorkspaceSharedModelSerializer(data={'workspace_id': workspace_id}).get_share_model_list())
|
WorkspaceSharedModelSerializer(data={'workspace_id': workspace_id}).get_share_model_list())
|
||||||
|
|
||||||
|
|
||||||
|
class ModelList(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@extend_schema(methods=['GET'],
|
||||||
|
summary=_('Query all model list'),
|
||||||
|
description=_('Query all model list'),
|
||||||
|
operation_id=_('Query all model list'), # type: ignore
|
||||||
|
parameters=ModelListResponse.get_parameters(),
|
||||||
|
responses=ModelListResponse.get_response(),
|
||||||
|
tags=[_('Model')]) # type: ignore
|
||||||
|
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
|
||||||
|
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
|
||||||
|
def get(self, request: Request, workspace_id: str):
|
||||||
|
return result.success(
|
||||||
|
ModelSerializer.Query(
|
||||||
|
data={**query_params_to_single_dict(request.query_params)}).model_list(workspace_id=workspace_id,
|
||||||
|
with_valid=True))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user