feat: 支持复杂sql查询,修改数据集查询
This commit is contained in:
parent
1e3097fa3f
commit
64e93679f8
@ -6,8 +6,9 @@
|
|||||||
@date:2023/10/7 18:20
|
@date:2023/10/7 18:20
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
from django.db import DEFAULT_DB_ALIAS, models
|
from django.db import DEFAULT_DB_ALIAS, models, connections
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
|
|
||||||
from common.db.compiler import AppSQLCompiler
|
from common.db.compiler import AppSQLCompiler
|
||||||
@ -30,8 +31,61 @@ def get_dynamics_model(attr: dict, table_name='dynamics'):
|
|||||||
return type('Dynamics', (models.Model,), attributes)
|
return type('Dynamics', (models.Model,), attributes)
|
||||||
|
|
||||||
|
|
||||||
def native_search(queryset: QuerySet, select_string: str,
|
def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string: str,
|
||||||
field_replace_dict=None,
|
field_replace_dict: None | Dict[str, Dict[str, str]] = None):
|
||||||
|
"""
|
||||||
|
生成 查询sql
|
||||||
|
:param queryset_dict: 多条件 查询条件
|
||||||
|
:param select_string: 查询sql
|
||||||
|
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
|
||||||
|
:return: sql:需要查询的sql params: sql 参数
|
||||||
|
"""
|
||||||
|
|
||||||
|
params_dict: Dict[int, Any] = {}
|
||||||
|
result_params = []
|
||||||
|
for key in queryset_dict.keys():
|
||||||
|
value = queryset_dict.get(key)
|
||||||
|
sql, params = compiler_queryset(value, None if field_replace_dict is None else field_replace_dict.get(key))
|
||||||
|
params_dict = {**params_dict, select_string.index("${" + key + "}"): params}
|
||||||
|
select_string = select_string.replace("${" + key + "}", sql)
|
||||||
|
|
||||||
|
for key in sorted(list(params_dict.keys())):
|
||||||
|
result_params = [*result_params, *params_dict.get(key)]
|
||||||
|
return select_string, result_params
|
||||||
|
|
||||||
|
|
||||||
|
def generate_sql_by_query(queryset: QuerySet, select_string: str,
|
||||||
|
field_replace_dict: None | Dict[str, str] = None):
|
||||||
|
"""
|
||||||
|
生成 查询sql
|
||||||
|
:param queryset: 查询条件
|
||||||
|
:param select_string: 原始sql
|
||||||
|
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
|
||||||
|
:return: sql:需要查询的sql params: sql 参数
|
||||||
|
"""
|
||||||
|
sql, params = compiler_queryset(queryset, field_replace_dict)
|
||||||
|
return select_string + " " + sql, params
|
||||||
|
|
||||||
|
|
||||||
|
def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, str] = None):
|
||||||
|
"""
|
||||||
|
解析 queryset查询对象
|
||||||
|
:param queryset: 查询对象
|
||||||
|
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
|
||||||
|
:return: sql:需要查询的sql params: sql 参数
|
||||||
|
"""
|
||||||
|
q = queryset.query
|
||||||
|
compiler = q.get_compiler(DEFAULT_DB_ALIAS)
|
||||||
|
if field_replace_dict is None:
|
||||||
|
field_replace_dict = get_field_replace_dict(queryset)
|
||||||
|
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
|
||||||
|
field_replace_dict=field_replace_dict)
|
||||||
|
sql, params = app_sql_compiler.get_query_str(with_table_name=False)
|
||||||
|
return sql, params
|
||||||
|
|
||||||
|
|
||||||
|
def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
|
||||||
|
field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None,
|
||||||
with_search_one=False):
|
with_search_one=False):
|
||||||
"""
|
"""
|
||||||
复杂查询
|
复杂查询
|
||||||
@ -41,19 +95,14 @@ def native_search(queryset: QuerySet, select_string: str,
|
|||||||
:param with_search_one: 查询
|
:param with_search_one: 查询
|
||||||
:return: 查询结果
|
:return: 查询结果
|
||||||
"""
|
"""
|
||||||
if field_replace_dict is None:
|
if isinstance(queryset, Dict):
|
||||||
field_replace_dict = get_field_replace_dict(queryset)
|
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict)
|
||||||
q = queryset.query
|
|
||||||
compiler = q.get_compiler(DEFAULT_DB_ALIAS)
|
|
||||||
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
|
|
||||||
field_replace_dict=field_replace_dict)
|
|
||||||
sql, params = app_sql_compiler.get_query_str(with_table_name=False)
|
|
||||||
if with_search_one:
|
|
||||||
return select_one(select_string + " " +
|
|
||||||
sql, params)
|
|
||||||
else:
|
else:
|
||||||
return select_list(select_string + " " +
|
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict)
|
||||||
sql, params)
|
if with_search_one:
|
||||||
|
return select_one(exec_sql, exec_params)
|
||||||
|
else:
|
||||||
|
return select_list(exec_sql, exec_params)
|
||||||
|
|
||||||
|
|
||||||
def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler):
|
def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler):
|
||||||
@ -70,7 +119,7 @@ def page_search(current_page: int, page_size: int, queryset: QuerySet, post_reco
|
|||||||
return Page(total, list(map(post_records_handler, result)), current_page, page_size)
|
return Page(total, list(map(post_records_handler, result)), current_page, page_size)
|
||||||
|
|
||||||
|
|
||||||
def native_page_search(current_page: int, page_size: int, queryset: QuerySet, select_string: str,
|
def native_page_search(current_page: int, page_size: int, queryset: QuerySet | Dict[str, QuerySet], select_string: str,
|
||||||
field_replace_dict=None,
|
field_replace_dict=None,
|
||||||
post_records_handler=lambda r: r):
|
post_records_handler=lambda r: r):
|
||||||
"""
|
"""
|
||||||
@ -83,20 +132,17 @@ def native_page_search(current_page: int, page_size: int, queryset: QuerySet, se
|
|||||||
:param post_records_handler: 数据row处理器
|
:param post_records_handler: 数据row处理器
|
||||||
:return: 分页结果
|
:return: 分页结果
|
||||||
"""
|
"""
|
||||||
if field_replace_dict is None:
|
if isinstance(queryset, Dict):
|
||||||
field_replace_dict = get_field_replace_dict(queryset)
|
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict)
|
||||||
q = queryset.query
|
else:
|
||||||
compiler = q.get_compiler(DEFAULT_DB_ALIAS)
|
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict)
|
||||||
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
|
total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql
|
||||||
field_replace_dict=field_replace_dict)
|
total = select_one(total_sql, exec_params)
|
||||||
page_sql, params = app_sql_compiler.get_query_str(with_table_name=False)
|
limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql(
|
||||||
total_sql = "SELECT \"count\"(*) FROM (%s) temp" % (select_string + " " + page_sql)
|
((current_page - 1) * page_size), (current_page * page_size)
|
||||||
total = select_one(total_sql, params)
|
)
|
||||||
q.set_limits(((current_page - 1) * page_size), (current_page * page_size))
|
page_sql = exec_sql + " " + limit_sql
|
||||||
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
|
result = select_list(page_sql, exec_params)
|
||||||
field_replace_dict=field_replace_dict)
|
|
||||||
page_sql, params = app_sql_compiler.get_query_str(with_table_name=False)
|
|
||||||
result = select_list(select_string + " " + page_sql, params)
|
|
||||||
return Page(total.get("count"), list(map(post_records_handler, result)), current_page, page_size)
|
return Page(total.get("count"), list(map(post_records_handler, result)), current_page, page_size)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ class Page(dict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, total: int, records: List, current_page: int, page_size: int, **kwargs):
|
def __init__(self, total: int, records: List, current_page: int, page_size: int, **kwargs):
|
||||||
super().__init__(**{'total': total, 'records': records, 'current_page': current_page, 'page_size': page_size})
|
super().__init__(**{'total': total, 'records': records, 'current': current_page, 'size': page_size})
|
||||||
|
|
||||||
|
|
||||||
class Result(JsonResponse):
|
class Result(JsonResponse):
|
||||||
@ -71,12 +71,12 @@ def get_page_api_response(response_data_schema: openapi.Schema):
|
|||||||
default=1,
|
default=1,
|
||||||
description="数据总条数"),
|
description="数据总条数"),
|
||||||
"records": response_data_schema,
|
"records": response_data_schema,
|
||||||
"current_page": openapi.Schema(
|
"current": openapi.Schema(
|
||||||
type=openapi.TYPE_INTEGER,
|
type=openapi.TYPE_INTEGER,
|
||||||
title="当前页",
|
title="当前页",
|
||||||
default=1,
|
default=1,
|
||||||
description="当前页"),
|
description="当前页"),
|
||||||
"page_size": openapi.Schema(
|
"size": openapi.Schema(
|
||||||
type=openapi.TYPE_INTEGER,
|
type=openapi.TYPE_INTEGER,
|
||||||
title="每页大小",
|
title="每页大小",
|
||||||
default=10,
|
default=10,
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import uuid
|
|||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
from django.contrib.postgres.fields import ArrayField
|
||||||
from django.core import validators
|
from django.core import validators
|
||||||
from django.db import transaction, models
|
from django.db import transaction, models
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
@ -23,6 +24,7 @@ from common.mixins.api_mixin import ApiMixin
|
|||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from dataset.models.data_set import DataSet, Document, Paragraph
|
from dataset.models.data_set import DataSet, Document, Paragraph
|
||||||
from dataset.serializers.document_serializers import CreateDocumentSerializers
|
from dataset.serializers.document_serializers import CreateDocumentSerializers
|
||||||
|
from setting.models import AuthOperate
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
from users.models import User
|
from users.models import User
|
||||||
|
|
||||||
@ -73,17 +75,38 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
message="数据集名称在1-256个字符之间")
|
message="数据集名称在1-256个字符之间")
|
||||||
])
|
])
|
||||||
|
|
||||||
|
user_id = serializers.CharField(required=True)
|
||||||
|
|
||||||
def get_query_set(self):
|
def get_query_set(self):
|
||||||
|
user_id = self.data.get("user_id")
|
||||||
|
query_set_dict = {}
|
||||||
query_set = QuerySet(model=get_dynamics_model(
|
query_set = QuerySet(model=get_dynamics_model(
|
||||||
{'dataset.name': models.CharField(), 'dataset.desc': models.CharField(),
|
{'dataset.name': models.CharField(), 'dataset.desc': models.CharField(),
|
||||||
"document_temp.char_length": models.IntegerField()}))
|
"document_temp.char_length": models.IntegerField()}))
|
||||||
if "desc" in self.data:
|
if "desc" in self.data:
|
||||||
query_string = {'dataset.desc__contains', self.data.get("desc")}
|
query_set = query_set.filter(**{'dataset.desc__contains': self.data.get("desc")})
|
||||||
query_set = query_set.filter(query_string)
|
|
||||||
if "name" in self.data:
|
if "name" in self.data:
|
||||||
query_string = {'dataset.name__contains', self.data.get("name")}
|
query_set = query_set.filter(**{'dataset.name__contains': self.data.get("name")})
|
||||||
query_set = query_set.filter(query_string)
|
|
||||||
return query_set
|
query_set_dict['default_sql'] = query_set
|
||||||
|
|
||||||
|
query_set_dict['dataset_custom_sql'] = QuerySet(model=get_dynamics_model(
|
||||||
|
{'dataset.user_id': models.CharField(),
|
||||||
|
})).filter(
|
||||||
|
**{'dataset.user_id': user_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
query_set_dict['team_member_permission_custom_sql'] = QuerySet(model=get_dynamics_model(
|
||||||
|
{'user_id': models.CharField(),
|
||||||
|
'team_member_permission.operate': ArrayField(verbose_name="权限操作列表",
|
||||||
|
base_field=models.CharField(max_length=256,
|
||||||
|
blank=True,
|
||||||
|
choices=AuthOperate.choices,
|
||||||
|
default=AuthOperate.USE)
|
||||||
|
)})).filter(
|
||||||
|
**{'user_id': user_id, 'team_member_permission.operate__contains': ['USE']})
|
||||||
|
|
||||||
|
return query_set_dict
|
||||||
|
|
||||||
def page(self, current_page: int, page_size: int):
|
def page(self, current_page: int, page_size: int):
|
||||||
return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content(
|
return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content(
|
||||||
@ -200,18 +223,32 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
dataset.delete()
|
dataset.delete()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def one(self, with_valid=True):
|
def one(self, user_id, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid()
|
self.is_valid()
|
||||||
query_string = {'dataset.id', self.data.get("id")}
|
query_set_dict = {'default_sql': QuerySet(model=get_dynamics_model(
|
||||||
query_set = QuerySet(model=get_dynamics_model(
|
{'temp.id': models.UUIDField()})).filter(**{'temp.id': self.data.get("id")}),
|
||||||
{'dataset.id': models.UUIDField()})).filter(query_string)
|
'dataset_custom_sql': QuerySet(model=get_dynamics_model(
|
||||||
return native_search(query_set, select_string=get_file_content(
|
{'dataset.user_id': models.CharField()})).filter(
|
||||||
|
**{'dataset.user_id': user_id}
|
||||||
|
), 'team_member_permission_custom_sql': QuerySet(
|
||||||
|
model=get_dynamics_model({'user_id': models.CharField(),
|
||||||
|
'team_member_permission.operate': ArrayField(
|
||||||
|
verbose_name="权限操作列表",
|
||||||
|
base_field=models.CharField(max_length=256,
|
||||||
|
blank=True,
|
||||||
|
choices=AuthOperate.choices,
|
||||||
|
default=AuthOperate.USE)
|
||||||
|
)})).filter(
|
||||||
|
**{'user_id': user_id, 'team_member_permission.operate__contains': ['USE']})}
|
||||||
|
|
||||||
|
return native_search(query_set_dict, select_string=get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql')), with_search_one=True)
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql')), with_search_one=True)
|
||||||
|
|
||||||
def edit(self, dataset: Dict):
|
def edit(self, dataset: Dict, user_id: str):
|
||||||
"""
|
"""
|
||||||
修改数据集
|
修改数据集
|
||||||
|
:param user_id: 用户id
|
||||||
:param dataset: Dict name desc
|
:param dataset: Dict name desc
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
@ -222,7 +259,7 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
if 'desc' in dataset:
|
if 'desc' in dataset:
|
||||||
_dataset.desc = dataset.get("desc")
|
_dataset.desc = dataset.get("desc")
|
||||||
_dataset.save()
|
_dataset.save()
|
||||||
return self.one(with_valid=False)
|
return self.one(with_valid=False, user_id=user_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_request_body_api():
|
def get_request_body_api():
|
||||||
|
|||||||
@ -1,7 +1,30 @@
|
|||||||
SELECT
|
SELECT
|
||||||
dataset.*,
|
*
|
||||||
document_temp."char_length",
|
|
||||||
"document_temp".document_count
|
|
||||||
FROM
|
FROM
|
||||||
dataset dataset
|
(
|
||||||
LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", dataset_id FROM "document" GROUP BY dataset_id ) "document_temp" ON dataset."id" = "document_temp".dataset_id
|
SELECT
|
||||||
|
"temp_dataset".*,
|
||||||
|
"document_temp"."char_length",
|
||||||
|
"document_temp".document_count FROM (
|
||||||
|
SELECT dataset.*
|
||||||
|
FROM
|
||||||
|
dataset dataset
|
||||||
|
${dataset_custom_sql}
|
||||||
|
UNION
|
||||||
|
SELECT
|
||||||
|
*
|
||||||
|
FROM
|
||||||
|
dataset
|
||||||
|
WHERE
|
||||||
|
dataset."id" IN (
|
||||||
|
SELECT
|
||||||
|
team_member_permission.target
|
||||||
|
FROM
|
||||||
|
team_member team_member
|
||||||
|
LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id"
|
||||||
|
${team_member_permission_custom_sql}
|
||||||
|
)
|
||||||
|
) temp_dataset
|
||||||
|
LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", dataset_id FROM "document" GROUP BY dataset_id ) "document_temp" ON temp_dataset."id" = "document_temp".dataset_id
|
||||||
|
) temp
|
||||||
|
${default_sql}
|
||||||
@ -29,7 +29,7 @@ class Dataset(APIView):
|
|||||||
responses=get_api_response(DataSetSerializers.Query.get_response_body_api()))
|
responses=get_api_response(DataSetSerializers.Query.get_response_body_api()))
|
||||||
@has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
|
@has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
|
||||||
def get(self, request: Request):
|
def get(self, request: Request):
|
||||||
d = DataSetSerializers.Query(data=request.query_params)
|
d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)})
|
||||||
d.is_valid()
|
d.is_valid()
|
||||||
return result.success(d.list())
|
return result.success(d.list())
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ class Dataset(APIView):
|
|||||||
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE,
|
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||||
dynamic_tag=keywords.get('dataset_id')))
|
dynamic_tag=keywords.get('dataset_id')))
|
||||||
def get(self, request: Request, dataset_id: str):
|
def get(self, request: Request, dataset_id: str):
|
||||||
return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).one())
|
return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).one(user_id=request.user.id))
|
||||||
|
|
||||||
@action(methods="PUT", detail=False)
|
@action(methods="PUT", detail=False)
|
||||||
@swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息",
|
@swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息",
|
||||||
@ -72,7 +72,8 @@ class Dataset(APIView):
|
|||||||
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||||
dynamic_tag=keywords.get('dataset_id')))
|
dynamic_tag=keywords.get('dataset_id')))
|
||||||
def put(self, request: Request, dataset_id: str):
|
def put(self, request: Request, dataset_id: str):
|
||||||
return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).edit(request.data))
|
return result.success(
|
||||||
|
DataSetSerializers.Operate(data={'id': dataset_id}).edit(request.data, user_id=request.user.id))
|
||||||
|
|
||||||
class Page(APIView):
|
class Page(APIView):
|
||||||
authentication_classes = [TokenAuth]
|
authentication_classes = [TokenAuth]
|
||||||
@ -85,6 +86,6 @@ class Dataset(APIView):
|
|||||||
responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api()))
|
responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api()))
|
||||||
@has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND)
|
@has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND)
|
||||||
def get(self, request: Request, current_page, page_size):
|
def get(self, request: Request, current_page, page_size):
|
||||||
d = DataSetSerializers.Query(data=request.query_params)
|
d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)})
|
||||||
d.is_valid()
|
d.is_valid()
|
||||||
return result.success(d.page(current_page, page_size))
|
return result.success(d.page(current_page, page_size))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user