feat: 数据集修改关联应用

This commit is contained in:
shaohuzhang1 2023-12-04 16:32:50 +08:00
parent a4c1125fbe
commit e11a8946ac
7 changed files with 168 additions and 13 deletions

View File

@ -299,7 +299,7 @@ class ApplicationSerializer(serializers.Serializer):
if 'dataset_id_list' in instance:
dataset_id_list = instance.get('dataset_id_list')
# 当前用户可修改关联的数据集列表
application_dataset_id_list = [dataset_dict.get('id') for dataset_dict in
application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in
self.list_dataset(with_valid=False)]
for dataset_id in dataset_id_list:
if not application_dataset_id_list.__contains__(dataset_id):

View File

@ -17,7 +17,9 @@ from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
from application.models import ApplicationDatasetMapping
from common.db.search import get_dynamics_model, native_page_search, native_search
from common.db.sql_execute import select_list
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
@ -55,6 +57,48 @@ class DataSetSerializers(serializers.ModelSerializer):
model = DataSet
fields = ['id', 'name', 'desc', 'create_time', 'update_time']
class Application(ApiMixin, serializers.Serializer):
user_id = serializers.UUIDField(required=True)
dataset_id = serializers.UUIDField(required=True)
@staticmethod
def get_request_params_api():
return [
openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='数据集id')
]
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'user_id', 'status',
'create_time',
'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="", description="主键id"),
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"),
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
description="是否开启多轮对话"),
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
title="示例列表", description="示例列表"),
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户", description="所属用户"),
'status': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否发布", description='是否发布'),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间')
}
)
class Query(ApiMixin, serializers.Serializer):
"""
查询对象
@ -223,8 +267,15 @@ class DataSetSerializers(serializers.ModelSerializer):
}
)
class Edit(serializers.Serializer):
name = serializers.CharField(required=False)
desc = serializers.CharField(required=False)
application_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
class Operate(ApiMixin, serializers.Serializer):
id = serializers.CharField(required=True)
user_id = serializers.UUIDField(required=False)
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
@ -242,6 +293,14 @@ class DataSetSerializers(serializers.ModelSerializer):
ListenerManagement.delete_embedding_by_dataset_signal.send(self.data.get('id'))
return True
def list_application(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
dataset = QuerySet(DataSet).get(id=self.data.get("id"))
return select_list(get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset_application.sql')),
[self.data.get('user_id'), dataset.user_id, self.data.get('user_id')])
def one(self, user_id, with_valid=True):
if with_valid:
self.is_valid()
@ -260,9 +319,15 @@ class DataSetSerializers(serializers.ModelSerializer):
default=AuthOperate.USE)
)})).filter(
**{'user_id': user_id, 'team_member_permission.operate__contains': ['USE']})}
return native_search(query_set_dict, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql')), with_search_one=True)
all_application_list = [str(adm.get('id')) for adm in self.list_application(with_valid=False)]
return {**native_search(query_set_dict, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql')), with_search_one=True),
'application_id_list': list(
filter(lambda application_id: all_application_list.__contains__(application_id),
[str(application_dataset_mapping.application_id) for
application_dataset_mapping in
QuerySet(ApplicationDatasetMapping).filter(
dataset_id=self.data.get('id'))]))}
def edit(self, dataset: Dict, user_id: str):
"""
@ -272,11 +337,32 @@ class DataSetSerializers(serializers.ModelSerializer):
:return:
"""
self.is_valid()
DataSetSerializers.Edit(data=dataset).is_valid(raise_exception=True)
_dataset = QuerySet(DataSet).get(id=self.data.get("id"))
if "name" in dataset:
_dataset.name = dataset.get("name")
if 'desc' in dataset:
_dataset.desc = dataset.get("desc")
if 'application_id_list' in dataset and dataset.get('application_id_list') is not None:
application_id_list = dataset.get('application_id_list')
# 当前用户可修改关联的数据集列表
application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in
self.list_application(with_valid=False)]
for dataset_id in application_id_list:
if not application_dataset_id_list.__contains__(dataset_id):
raise AppApiException(500, f"未知的应用id${dataset_id},无法关联")
# 删除已经关联的id
QuerySet(ApplicationDatasetMapping).filter(application_id__in=application_dataset_id_list,
dataset_id=self.data.get("id")).delete()
# 插入
QuerySet(ApplicationDatasetMapping).bulk_create(
[ApplicationDatasetMapping(application_id=application_id, dataset_id=self.data.get('id')) for
application_id in
application_id_list]) if len(application_id_list) > 0 else None
[ApplicationDatasetMapping(application_id=application_id, dataset_id=self.data.get('id')) for
application_id in application_id_list]
_dataset.save()
return self.one(with_valid=False, user_id=user_id)
@ -287,7 +373,10 @@ class DataSetSerializers(serializers.ModelSerializer):
required=['name', 'desc'],
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="数据集名称", description="数据集名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="数据集描述", description="数据集描述")
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="数据集描述", description="数据集描述"),
'application_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="应用id列表",
description="应用id列表",
items=openapi.Schema(type=openapi.TYPE_STRING))
}
)

View File

@ -0,0 +1,20 @@
SELECT
*
FROM
application
WHERE
user_id = %s UNION
SELECT
*
FROM
application
WHERE
"id" IN (
SELECT
team_member_permission.target
FROM
team_member team_member
LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id"
WHERE
( "team_member_permission"."auth_target_type" = 'APPLICATION' AND "team_member_permission"."operate"::text[] @> ARRAY['USE'] AND team_member.team_id = %s AND team_member.user_id =%s )
)

View File

@ -6,6 +6,7 @@ app_name = "dataset"
urlpatterns = [
path('dataset', views.Dataset.as_view(), name="dataset"),
path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"),
path('dataset/<str:dataset_id>/application', views.Dataset.Application.as_view()),
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),

View File

@ -22,6 +22,20 @@ from dataset.serializers.dataset_serializers import DataSetSerializers
class Dataset(APIView):
authentication_classes = [TokenAuth]
class Application(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取数据集可用应用列表",
operation_id="获取数据集可用应用列表",
manual_parameters=DataSetSerializers.Application.get_request_params_api(),
responses=result.get_api_array_response(
DataSetSerializers.Application.get_response_body_api()),
tags=["数据集"])
def get(self, request: Request, dataset_id: str):
return result.success(DataSetSerializers.Operate(
data={'id': dataset_id, 'user_id': str(request.user.id)}).list_application())
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取数据集列表",
operation_id="获取数据集列表",
@ -71,7 +85,8 @@ class Dataset(APIView):
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=keywords.get('dataset_id')))
def get(self, request: Request, dataset_id: str):
return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).one(user_id=request.user.id))
return result.success(DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).one(
user_id=request.user.id))
@action(methods="PUT", detail=False)
@swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息",
@ -84,7 +99,8 @@ class Dataset(APIView):
dynamic_tag=keywords.get('dataset_id')))
def put(self, request: Request, dataset_id: str):
return result.success(
DataSetSerializers.Operate(data={'id': dataset_id}).edit(request.data, user_id=request.user.id))
DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).edit(request.data,
user_id=request.user.id))
class Page(APIView):
authentication_classes = [TokenAuth]

View File

@ -2,6 +2,7 @@ import { Result } from '@/request/Result'
import { get, post, del, put } from '@/request/index'
import type { datasetData } from '@/api/type/dataset'
import type { pageRequest } from '@/api/type/common'
import type { ApplicationFormType } from '@/api/type/application'
import { type Ref } from 'vue'
const prefix = '/dataset'
@ -88,12 +89,24 @@ const putDateset: (dataset_id: string, data: any) => Promise<Result<any>> = (
) => {
return put(`${prefix}/${dataset_id}`, data)
}
/**
*
* @param dataset_id
* @param loading
* @returns
*/
const listUsableApplication: (
dataset_id: string,
loading?: Ref<boolean>
) => Promise<Result<Array<ApplicationFormType>>> = (dataset_id, loading) => {
return get(`${prefix}/${dataset_id}/application`, {}, loading)
}
export default {
getDateset,
getAllDateset,
delDateset,
postDateset,
getDatesetDetail,
putDateset
putDateset,
listUsableApplication
}

View File

@ -25,25 +25,37 @@
:autosize="{ minRows: 3 }"
/>
</el-form-item>
<el-form-item v-loading="loading">
<el-row justify="space-between" style="width: 100%">
<el-col :span="11" v-for="(item, index) in application_list" :key="index" class="mb-16">
<CardCheckbox value-field="id" :data="item" v-model="form.application_id_list">
{{ item.name }}
</CardCheckbox>
</el-col>
</el-row>
</el-form-item>
</el-form>
</template>
<script setup lang="ts">
import { ref, reactive, onMounted, onUnmounted, computed, watch } from 'vue'
import useStore from '@/stores'
import DatasetApi from '@/api/dataset'
import CardCheckbox from '@/components/card-checkbox/index.vue'
import type { ApplicationFormType } from '@/api/type/application'
const props = defineProps({
data: {
type: Object,
default: () => {}
}
})
const loading = ref<boolean>(false)
const { dataset } = useStore()
const baseInfo = computed(() => dataset.baseInfo)
const application_list = ref<Array<ApplicationFormType>>([])
const form = ref<any>({
name: '',
desc: ''
desc: '',
application_id_list: []
})
const rules = reactive({
@ -58,6 +70,10 @@ watch(
if (value && JSON.stringify(value) !== '{}') {
form.value.name = value.name
form.value.desc = value.desc
form.value.application_id_list = value.application_id_list
DatasetApi.listUsableApplication(value.id, loading).then((ok) => {
application_list.value = ok.data
})
}
},
{