feat: 支持复杂sql查询,修改数据集查询

This commit is contained in:
shaohuzhang1 2023-10-11 15:07:10 +08:00
parent 1e3097fa3f
commit 64e93679f8
5 changed files with 161 additions and 54 deletions

View File

@ -6,8 +6,9 @@
@date2023/10/7 18:20 @date2023/10/7 18:20
@desc: @desc:
""" """
from typing import Dict, Any
from django.db import DEFAULT_DB_ALIAS, models from django.db import DEFAULT_DB_ALIAS, models, connections
from django.db.models import QuerySet from django.db.models import QuerySet
from common.db.compiler import AppSQLCompiler from common.db.compiler import AppSQLCompiler
@ -30,8 +31,61 @@ def get_dynamics_model(attr: dict, table_name='dynamics'):
return type('Dynamics', (models.Model,), attributes) return type('Dynamics', (models.Model,), attributes)
def native_search(queryset: QuerySet, select_string: str, def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string: str,
field_replace_dict=None, field_replace_dict: None | Dict[str, Dict[str, str]] = None):
"""
生成 查询sql
:param queryset_dict: 多条件 查询条件
:param select_string: 查询sql
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
:return: sql:需要查询的sql params: sql 参数
"""
params_dict: Dict[int, Any] = {}
result_params = []
for key in queryset_dict.keys():
value = queryset_dict.get(key)
sql, params = compiler_queryset(value, None if field_replace_dict is None else field_replace_dict.get(key))
params_dict = {**params_dict, select_string.index("${" + key + "}"): params}
select_string = select_string.replace("${" + key + "}", sql)
for key in sorted(list(params_dict.keys())):
result_params = [*result_params, *params_dict.get(key)]
return select_string, result_params
def generate_sql_by_query(queryset: QuerySet, select_string: str,
field_replace_dict: None | Dict[str, str] = None):
"""
生成 查询sql
:param queryset: 查询条件
:param select_string: 原始sql
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
:return: sql:需要查询的sql params: sql 参数
"""
sql, params = compiler_queryset(queryset, field_replace_dict)
return select_string + " " + sql, params
def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, str] = None):
"""
解析 queryset查询对象
:param queryset: 查询对象
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
:return: sql:需要查询的sql params: sql 参数
"""
q = queryset.query
compiler = q.get_compiler(DEFAULT_DB_ALIAS)
if field_replace_dict is None:
field_replace_dict = get_field_replace_dict(queryset)
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
field_replace_dict=field_replace_dict)
sql, params = app_sql_compiler.get_query_str(with_table_name=False)
return sql, params
def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None,
with_search_one=False): with_search_one=False):
""" """
复杂查询 复杂查询
@ -41,19 +95,14 @@ def native_search(queryset: QuerySet, select_string: str,
:param with_search_one: 查询 :param with_search_one: 查询
:return: 查询结果 :return: 查询结果
""" """
if field_replace_dict is None: if isinstance(queryset, Dict):
field_replace_dict = get_field_replace_dict(queryset) exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict)
q = queryset.query
compiler = q.get_compiler(DEFAULT_DB_ALIAS)
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
field_replace_dict=field_replace_dict)
sql, params = app_sql_compiler.get_query_str(with_table_name=False)
if with_search_one:
return select_one(select_string + " " +
sql, params)
else: else:
return select_list(select_string + " " + exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict)
sql, params) if with_search_one:
return select_one(exec_sql, exec_params)
else:
return select_list(exec_sql, exec_params)
def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler): def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler):
@ -70,7 +119,7 @@ def page_search(current_page: int, page_size: int, queryset: QuerySet, post_reco
return Page(total, list(map(post_records_handler, result)), current_page, page_size) return Page(total, list(map(post_records_handler, result)), current_page, page_size)
def native_page_search(current_page: int, page_size: int, queryset: QuerySet, select_string: str, def native_page_search(current_page: int, page_size: int, queryset: QuerySet | Dict[str, QuerySet], select_string: str,
field_replace_dict=None, field_replace_dict=None,
post_records_handler=lambda r: r): post_records_handler=lambda r: r):
""" """
@ -83,20 +132,17 @@ def native_page_search(current_page: int, page_size: int, queryset: QuerySet, se
:param post_records_handler: 数据row处理器 :param post_records_handler: 数据row处理器
:return: 分页结果 :return: 分页结果
""" """
if field_replace_dict is None: if isinstance(queryset, Dict):
field_replace_dict = get_field_replace_dict(queryset) exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict)
q = queryset.query else:
compiler = q.get_compiler(DEFAULT_DB_ALIAS) exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict)
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection, total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql
field_replace_dict=field_replace_dict) total = select_one(total_sql, exec_params)
page_sql, params = app_sql_compiler.get_query_str(with_table_name=False) limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql(
total_sql = "SELECT \"count\"(*) FROM (%s) temp" % (select_string + " " + page_sql) ((current_page - 1) * page_size), (current_page * page_size)
total = select_one(total_sql, params) )
q.set_limits(((current_page - 1) * page_size), (current_page * page_size)) page_sql = exec_sql + " " + limit_sql
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection, result = select_list(page_sql, exec_params)
field_replace_dict=field_replace_dict)
page_sql, params = app_sql_compiler.get_query_str(with_table_name=False)
result = select_list(select_string + " " + page_sql, params)
return Page(total.get("count"), list(map(post_records_handler, result)), current_page, page_size) return Page(total.get("count"), list(map(post_records_handler, result)), current_page, page_size)

View File

@ -11,7 +11,7 @@ class Page(dict):
""" """
def __init__(self, total: int, records: List, current_page: int, page_size: int, **kwargs): def __init__(self, total: int, records: List, current_page: int, page_size: int, **kwargs):
super().__init__(**{'total': total, 'records': records, 'current_page': current_page, 'page_size': page_size}) super().__init__(**{'total': total, 'records': records, 'current': current_page, 'size': page_size})
class Result(JsonResponse): class Result(JsonResponse):
@ -71,12 +71,12 @@ def get_page_api_response(response_data_schema: openapi.Schema):
default=1, default=1,
description="数据总条数"), description="数据总条数"),
"records": response_data_schema, "records": response_data_schema,
"current_page": openapi.Schema( "current": openapi.Schema(
type=openapi.TYPE_INTEGER, type=openapi.TYPE_INTEGER,
title="当前页", title="当前页",
default=1, default=1,
description="当前页"), description="当前页"),
"page_size": openapi.Schema( "size": openapi.Schema(
type=openapi.TYPE_INTEGER, type=openapi.TYPE_INTEGER,
title="每页大小", title="每页大小",
default=10, default=10,

View File

@ -11,6 +11,7 @@ import uuid
from functools import reduce from functools import reduce
from typing import Dict from typing import Dict
from django.contrib.postgres.fields import ArrayField
from django.core import validators from django.core import validators
from django.db import transaction, models from django.db import transaction, models
from django.db.models import QuerySet from django.db.models import QuerySet
@ -23,6 +24,7 @@ from common.mixins.api_mixin import ApiMixin
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from dataset.models.data_set import DataSet, Document, Paragraph from dataset.models.data_set import DataSet, Document, Paragraph
from dataset.serializers.document_serializers import CreateDocumentSerializers from dataset.serializers.document_serializers import CreateDocumentSerializers
from setting.models import AuthOperate
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
from users.models import User from users.models import User
@ -73,17 +75,38 @@ class DataSetSerializers(serializers.ModelSerializer):
message="数据集名称在1-256个字符之间") message="数据集名称在1-256个字符之间")
]) ])
user_id = serializers.CharField(required=True)
def get_query_set(self): def get_query_set(self):
user_id = self.data.get("user_id")
query_set_dict = {}
query_set = QuerySet(model=get_dynamics_model( query_set = QuerySet(model=get_dynamics_model(
{'dataset.name': models.CharField(), 'dataset.desc': models.CharField(), {'dataset.name': models.CharField(), 'dataset.desc': models.CharField(),
"document_temp.char_length": models.IntegerField()})) "document_temp.char_length": models.IntegerField()}))
if "desc" in self.data: if "desc" in self.data:
query_string = {'dataset.desc__contains', self.data.get("desc")} query_set = query_set.filter(**{'dataset.desc__contains': self.data.get("desc")})
query_set = query_set.filter(query_string)
if "name" in self.data: if "name" in self.data:
query_string = {'dataset.name__contains', self.data.get("name")} query_set = query_set.filter(**{'dataset.name__contains': self.data.get("name")})
query_set = query_set.filter(query_string)
return query_set query_set_dict['default_sql'] = query_set
query_set_dict['dataset_custom_sql'] = QuerySet(model=get_dynamics_model(
{'dataset.user_id': models.CharField(),
})).filter(
**{'dataset.user_id': user_id}
)
query_set_dict['team_member_permission_custom_sql'] = QuerySet(model=get_dynamics_model(
{'user_id': models.CharField(),
'team_member_permission.operate': ArrayField(verbose_name="权限操作列表",
base_field=models.CharField(max_length=256,
blank=True,
choices=AuthOperate.choices,
default=AuthOperate.USE)
)})).filter(
**{'user_id': user_id, 'team_member_permission.operate__contains': ['USE']})
return query_set_dict
def page(self, current_page: int, page_size: int): def page(self, current_page: int, page_size: int):
return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content( return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content(
@ -200,18 +223,32 @@ class DataSetSerializers(serializers.ModelSerializer):
dataset.delete() dataset.delete()
return True return True
def one(self, with_valid=True): def one(self, user_id, with_valid=True):
if with_valid: if with_valid:
self.is_valid() self.is_valid()
query_string = {'dataset.id', self.data.get("id")} query_set_dict = {'default_sql': QuerySet(model=get_dynamics_model(
query_set = QuerySet(model=get_dynamics_model( {'temp.id': models.UUIDField()})).filter(**{'temp.id': self.data.get("id")}),
{'dataset.id': models.UUIDField()})).filter(query_string) 'dataset_custom_sql': QuerySet(model=get_dynamics_model(
return native_search(query_set, select_string=get_file_content( {'dataset.user_id': models.CharField()})).filter(
**{'dataset.user_id': user_id}
), 'team_member_permission_custom_sql': QuerySet(
model=get_dynamics_model({'user_id': models.CharField(),
'team_member_permission.operate': ArrayField(
verbose_name="权限操作列表",
base_field=models.CharField(max_length=256,
blank=True,
choices=AuthOperate.choices,
default=AuthOperate.USE)
)})).filter(
**{'user_id': user_id, 'team_member_permission.operate__contains': ['USE']})}
return native_search(query_set_dict, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql')), with_search_one=True) os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql')), with_search_one=True)
def edit(self, dataset: Dict): def edit(self, dataset: Dict, user_id: str):
""" """
修改数据集 修改数据集
:param user_id: 用户id
:param dataset: Dict name desc :param dataset: Dict name desc
:return: :return:
""" """
@ -222,7 +259,7 @@ class DataSetSerializers(serializers.ModelSerializer):
if 'desc' in dataset: if 'desc' in dataset:
_dataset.desc = dataset.get("desc") _dataset.desc = dataset.get("desc")
_dataset.save() _dataset.save()
return self.one(with_valid=False) return self.one(with_valid=False, user_id=user_id)
@staticmethod @staticmethod
def get_request_body_api(): def get_request_body_api():

View File

@ -1,7 +1,30 @@
SELECT SELECT
dataset.*, *
document_temp."char_length",
"document_temp".document_count
FROM FROM
dataset dataset (
LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", dataset_id FROM "document" GROUP BY dataset_id ) "document_temp" ON dataset."id" = "document_temp".dataset_id SELECT
"temp_dataset".*,
"document_temp"."char_length",
"document_temp".document_count FROM (
SELECT dataset.*
FROM
dataset dataset
${dataset_custom_sql}
UNION
SELECT
*
FROM
dataset
WHERE
dataset."id" IN (
SELECT
team_member_permission.target
FROM
team_member team_member
LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id"
${team_member_permission_custom_sql}
)
) temp_dataset
LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", dataset_id FROM "document" GROUP BY dataset_id ) "document_temp" ON temp_dataset."id" = "document_temp".dataset_id
) temp
${default_sql}

View File

@ -29,7 +29,7 @@ class Dataset(APIView):
responses=get_api_response(DataSetSerializers.Query.get_response_body_api())) responses=get_api_response(DataSetSerializers.Query.get_response_body_api()))
@has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND) @has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
def get(self, request: Request): def get(self, request: Request):
d = DataSetSerializers.Query(data=request.query_params) d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)})
d.is_valid() d.is_valid()
return result.success(d.list()) return result.success(d.list())
@ -63,7 +63,7 @@ class Dataset(APIView):
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE, @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=keywords.get('dataset_id'))) dynamic_tag=keywords.get('dataset_id')))
def get(self, request: Request, dataset_id: str): def get(self, request: Request, dataset_id: str):
return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).one()) return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).one(user_id=request.user.id))
@action(methods="PUT", detail=False) @action(methods="PUT", detail=False)
@swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息", @swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息",
@ -72,7 +72,8 @@ class Dataset(APIView):
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=keywords.get('dataset_id'))) dynamic_tag=keywords.get('dataset_id')))
def put(self, request: Request, dataset_id: str): def put(self, request: Request, dataset_id: str):
return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).edit(request.data)) return result.success(
DataSetSerializers.Operate(data={'id': dataset_id}).edit(request.data, user_id=request.user.id))
class Page(APIView): class Page(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]
@ -85,6 +86,6 @@ class Dataset(APIView):
responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api())) responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api()))
@has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND) @has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND)
def get(self, request: Request, current_page, page_size): def get(self, request: Request, current_page, page_size):
d = DataSetSerializers.Query(data=request.query_params) d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)})
d.is_valid() d.is_valid()
return result.success(d.page(current_page, page_size)) return result.success(d.page(current_page, page_size))