refactor: model list
This commit is contained in:
parent
8fc074fecb
commit
75df321783
@ -12,6 +12,7 @@ from rest_framework import serializers
|
|||||||
from django.db.models.query_utils import Q
|
from django.db.models.query_utils import Q
|
||||||
from common.config.embedding_config import ModelManage
|
from common.config.embedding_config import ModelManage
|
||||||
from common.database_model_manage.database_model_manage import DatabaseModelManage
|
from common.database_model_manage.database_model_manage import DatabaseModelManage
|
||||||
|
from common.db.search import native_search
|
||||||
from common.db.sql_execute import select_list
|
from common.db.sql_execute import select_list
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
from common.utils.common import get_file_content
|
from common.utils.common import get_file_content
|
||||||
@ -21,6 +22,8 @@ from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
|
|||||||
from models_provider.constants.model_provider_constants import ModelProvideConstants
|
from models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||||
from models_provider.models import Model, Status
|
from models_provider.models import Model, Status
|
||||||
from models_provider.tools import get_model_credential
|
from models_provider.tools import get_model_credential
|
||||||
|
from system_manage.models import WorkspaceUserResourcePermission
|
||||||
|
from users.serializers.user import is_workspace_manage
|
||||||
|
|
||||||
|
|
||||||
def get_default_model_params_setting(provider, model_type, model_name):
|
def get_default_model_params_setting(provider, model_type, model_name):
|
||||||
@ -326,6 +329,7 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
return ModelModelSerializer(model).data
|
return ModelModelSerializer(model).data
|
||||||
|
|
||||||
class Query(serializers.Serializer):
|
class Query(serializers.Serializer):
|
||||||
|
user_id = serializers.CharField(required=True, label=_("User ID"))
|
||||||
name = serializers.CharField(required=False, max_length=64, label=_('model name'))
|
name = serializers.CharField(required=False, max_length=64, label=_('model name'))
|
||||||
model_type = serializers.CharField(required=False, label=_('model type'))
|
model_type = serializers.CharField(required=False, label=_('model type'))
|
||||||
model_name = serializers.CharField(required=False, label=_('base model'))
|
model_name = serializers.CharField(required=False, label=_('base model'))
|
||||||
@ -333,17 +337,40 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
create_user = serializers.CharField(required=False, label=_('create user'))
|
create_user = serializers.CharField(required=False, label=_('create user'))
|
||||||
workspace_id = serializers.CharField(required=False, label=_('workspace id'))
|
workspace_id = serializers.CharField(required=False, label=_('workspace id'))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_x_pack_ee():
|
||||||
|
workspace_user_role_mapping_model = DatabaseModelManage.get_model("workspace_user_role_mapping")
|
||||||
|
role_permission_mapping_model = DatabaseModelManage.get_model("role_permission_mapping_model")
|
||||||
|
return workspace_user_role_mapping_model is not None and role_permission_mapping_model is not None
|
||||||
|
|
||||||
def list(self, workspace_id, with_valid):
|
def list(self, workspace_id, with_valid):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
|
user_id = self.data.get("user_id")
|
||||||
|
workspace_manage = is_workspace_manage(user_id, workspace_id)
|
||||||
|
query_params = self._build_query_params(workspace_id, workspace_manage, user_id)
|
||||||
|
is_x_pack_ee = self.is_x_pack_ee()
|
||||||
|
return native_search(query_params,
|
||||||
|
select_string=get_file_content(
|
||||||
|
os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql',
|
||||||
|
'list_model.sql' if workspace_manage else (
|
||||||
|
'list_model_user_ee.sql' if is_x_pack_ee else 'list_model_user.sql')
|
||||||
|
)))
|
||||||
|
|
||||||
query_params = self._build_query_params(workspace_id)
|
def share_list(self, workspace_id, with_valid=True):
|
||||||
return [self._build_model_data(model) for model in query_params.order_by("-create_time")]
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
user_id = self.data.get("user_id")
|
||||||
|
query_params = self._build_query_params(workspace_id, False, user_id)
|
||||||
|
return [self._build_model_data(model) for model in
|
||||||
|
query_params.get('model_query_set').order_by("-create_time")]
|
||||||
|
|
||||||
def model_list(self, workspace_id, with_valid=True):
|
def model_list(self, workspace_id, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
queryset = self._build_query_params(workspace_id)
|
user_id = self.data.get("user_id")
|
||||||
|
workspace_manage = is_workspace_manage(user_id, workspace_id)
|
||||||
|
queryset = self._build_query_params(workspace_id, workspace_manage, user_id)
|
||||||
get_authorized_model = DatabaseModelManage.get_model("get_authorized_model")
|
get_authorized_model = DatabaseModelManage.get_model("get_authorized_model")
|
||||||
|
|
||||||
shared_queryset = QuerySet(Model).filter(workspace_id='None')
|
shared_queryset = QuerySet(Model).filter(workspace_id='None')
|
||||||
@ -352,14 +379,20 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
|
|
||||||
# 构建共享模型和普通模型列表
|
# 构建共享模型和普通模型列表
|
||||||
shared_model = [self._build_model_data(model) for model in shared_queryset.order_by("-create_time")]
|
shared_model = [self._build_model_data(model) for model in shared_queryset.order_by("-create_time")]
|
||||||
normal_model = [self._build_model_data(model) for model in queryset.order_by("-create_time")]
|
|
||||||
|
|
||||||
|
is_x_pack_ee = self.is_x_pack_ee()
|
||||||
|
normal_model = native_search(queryset,
|
||||||
|
select_string=get_file_content(
|
||||||
|
os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql',
|
||||||
|
'list_model.sql' if workspace_manage else (
|
||||||
|
'list_model_user_ee.sql' if is_x_pack_ee else 'list_model_user.sql')
|
||||||
|
)))
|
||||||
return {
|
return {
|
||||||
"shared_model": shared_model,
|
"shared_model": shared_model,
|
||||||
"model": normal_model
|
"model": normal_model
|
||||||
}
|
}
|
||||||
|
|
||||||
def _build_query_params(self, workspace_id):
|
def _build_query_params(self, workspace_id, workspace_manage: bool, user_id):
|
||||||
queryset = QuerySet(Model)
|
queryset = QuerySet(Model)
|
||||||
if workspace_id:
|
if workspace_id:
|
||||||
queryset = queryset.filter(workspace_id=workspace_id)
|
queryset = queryset.filter(workspace_id=workspace_id)
|
||||||
@ -372,7 +405,15 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
queryset = queryset.filter(user_id=value)
|
queryset = queryset.filter(user_id=value)
|
||||||
else:
|
else:
|
||||||
queryset = queryset.filter(**{field: value})
|
queryset = queryset.filter(**{field: value})
|
||||||
return queryset
|
return {
|
||||||
|
'model_query_set': queryset,
|
||||||
|
'workspace_user_resource_permission_query_set': QuerySet(WorkspaceUserResourcePermission).filter(
|
||||||
|
auth_target_type="MODEL",
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user_id=user_id)} if (
|
||||||
|
not workspace_manage) else {
|
||||||
|
'model_query_set': queryset,
|
||||||
|
}
|
||||||
|
|
||||||
def _build_model_data(self, model):
|
def _build_model_data(self, model):
|
||||||
return {
|
return {
|
||||||
|
|||||||
16
apps/models_provider/sql/list_model.sql
Normal file
16
apps/models_provider/sql/list_model.sql
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
SELECT model."id"::text, model."name",
|
||||||
|
model.model_name,
|
||||||
|
model.meta,
|
||||||
|
model.credential,
|
||||||
|
model.model_params_form,
|
||||||
|
model.model_type,
|
||||||
|
model.provider,
|
||||||
|
model.status,
|
||||||
|
model.create_time,
|
||||||
|
model.update_time,
|
||||||
|
model.user_id,
|
||||||
|
"user"."nick_name" as "nick_name",
|
||||||
|
model.workspace_id
|
||||||
|
from model
|
||||||
|
left join "user" on user_id = "user".id
|
||||||
|
${model_query_set}
|
||||||
19
apps/models_provider/sql/list_model_user.sql
Normal file
19
apps/models_provider/sql/list_model_user.sql
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
SELECT *
|
||||||
|
FROM (SELECT model."id"::text, model."name",
|
||||||
|
model.model_name,
|
||||||
|
model.meta,
|
||||||
|
model.credential,
|
||||||
|
model.model_params_form,
|
||||||
|
model.model_type,
|
||||||
|
model.provider,
|
||||||
|
model.status,
|
||||||
|
model.create_time,
|
||||||
|
model.update_time,
|
||||||
|
model.user_id,
|
||||||
|
"user"."nick_name" as "nick_name",
|
||||||
|
model.workspace_id
|
||||||
|
from model
|
||||||
|
left join "user" on user_id = "user".id
|
||||||
|
where model."id" in (select target
|
||||||
|
from workspace_user_resource_permission ${workspace_user_resource_permission_query_set}
|
||||||
|
and 'VIEW' = any (permission_list)) ) temp ${model_query_set}
|
||||||
38
apps/models_provider/sql/list_model_user_ee.sql
Normal file
38
apps/models_provider/sql/list_model_user_ee.sql
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
SELECT *
|
||||||
|
FROM (SELECT model."id"::text, model."name",
|
||||||
|
model.model_name,
|
||||||
|
model.meta,
|
||||||
|
model.credential,
|
||||||
|
model.model_params_form,
|
||||||
|
model.model_type,
|
||||||
|
model.provider,
|
||||||
|
model.status,
|
||||||
|
model.create_time,
|
||||||
|
model.update_time,
|
||||||
|
model.user_id,
|
||||||
|
"user"."nick_name" as "nick_name",
|
||||||
|
model.workspace_id
|
||||||
|
from model
|
||||||
|
left join "user" on user_id = "user".id
|
||||||
|
where model."id" in (select target
|
||||||
|
from workspace_user_resource_permission ${workspace_user_resource_permission_query_set}
|
||||||
|
and case
|
||||||
|
when auth_type = 'ROLE' then
|
||||||
|
'ROLE' = any (permission_list)
|
||||||
|
and
|
||||||
|
'MODEL:READ' in (select (case
|
||||||
|
when user_role_relation.role_id = any (array['USER'])
|
||||||
|
THEN 'MODEL:READ'
|
||||||
|
else role_permission.permission_id END)
|
||||||
|
from role_permission role_permission
|
||||||
|
right join user_role_relation user_role_relation
|
||||||
|
on user_role_relation.role_id = role_permission.role_id
|
||||||
|
where user_role_relation.user_id = workspace_user_resource_permission.user_id
|
||||||
|
and user_role_relation.workspace_id =
|
||||||
|
workspace_user_resource_permission.workspace_id)
|
||||||
|
|
||||||
|
else
|
||||||
|
'VIEW' = any (permission_list)
|
||||||
|
end) ) temp ${model_query_set}
|
||||||
|
|
||||||
|
|
||||||
@ -1,8 +0,0 @@
|
|||||||
select model_id
|
|
||||||
from model_workspace_authorization
|
|
||||||
where case
|
|
||||||
when authentication_type = 'WHITE_LIST' then
|
|
||||||
%s = any (workspace_id_list)
|
|
||||||
else
|
|
||||||
not %s = any(workspace_id_list)
|
|
||||||
end
|
|
||||||
@ -101,7 +101,8 @@ class ModelSetting(APIView):
|
|||||||
def get(self, request: Request, workspace_id: str):
|
def get(self, request: Request, workspace_id: str):
|
||||||
return result.success(
|
return result.success(
|
||||||
ModelSerializer.Query(
|
ModelSerializer.Query(
|
||||||
data={**query_params_to_single_dict(request.query_params)}).list(workspace_id=workspace_id,
|
data={**query_params_to_single_dict(request.query_params), 'user_id': str(request.user.id)}).list(
|
||||||
|
workspace_id=workspace_id,
|
||||||
with_valid=True))
|
with_valid=True))
|
||||||
|
|
||||||
class Operate(APIView):
|
class Operate(APIView):
|
||||||
@ -266,5 +267,6 @@ class ModelList(APIView):
|
|||||||
def get(self, request: Request, workspace_id: str):
|
def get(self, request: Request, workspace_id: str):
|
||||||
return result.success(
|
return result.success(
|
||||||
ModelSerializer.Query(
|
ModelSerializer.Query(
|
||||||
data={**query_params_to_single_dict(request.query_params)}).model_list(workspace_id=workspace_id,
|
data={**query_params_to_single_dict(request.query_params), 'user_id': str(request.user.id)}).model_list(
|
||||||
|
workspace_id=workspace_id,
|
||||||
with_valid=True))
|
with_valid=True))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user