feat: Application import and export (#1836)

This commit is contained in:
shaohuzhang1 2024-12-16 14:19:57 +08:00 committed by GitHub
parent 390014fa1b
commit 64443ee136
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 260 additions and 24 deletions

View File

@ -10,6 +10,7 @@ import datetime
import hashlib import hashlib
import json import json
import os import os
import pickle
import re import re
import uuid import uuid
from functools import reduce from functools import reduce
@ -19,10 +20,10 @@ from django.contrib.postgres.fields import ArrayField
from django.core import cache, validators from django.core import cache, validators
from django.core import signing from django.core import signing
from django.db import transaction, models from django.db import transaction, models
from django.db.models import QuerySet, Q from django.db.models import QuerySet
from django.http import HttpResponse from django.http import HttpResponse
from django.template import Template, Context from django.template import Template, Context
from rest_framework import serializers from rest_framework import serializers, status
from application.flow.workflow_manage import Flow from application.flow.workflow_manage import Flow
from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion
@ -34,15 +35,17 @@ from common.constants.authentication_type import AuthenticationType
from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed
from common.field.common import UploadedImageField from common.field.common import UploadedImageField, UploadedFileField
from common.models.db_model_manage import DBModelManage from common.models.db_model_manage import DBModelManage
from common.response import result
from common.util.common import valid_license, password_encrypt from common.util.common import valid_license, password_encrypt
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from dataset.models import DataSet, Document, Image from dataset.models import DataSet, Document, Image
from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list
from embedding.models import SearchMode from embedding.models import SearchMode
from function_lib.serializers.function_lib_serializer import FunctionLibSerializer from function_lib.models.function import FunctionLib, PermissionType
from function_lib.serializers.function_lib_serializer import FunctionLibSerializer, FunctionLibModelSerializer
from setting.models import AuthOperate from setting.models import AuthOperate
from setting.models.model_management import Model from setting.models.model_management import Model
from setting.models_provider import get_model_credential from setting.models_provider import get_model_credential
@ -54,6 +57,13 @@ from users.models import User
chat_cache = cache.caches['chat_cache'] chat_cache = cache.caches['chat_cache']
class MKInstance:
def __init__(self, application: dict, function_lib_list: List[dict], version: str):
self.application = application
self.function_lib_list = function_lib_list
self.version = version
class ModelDatasetAssociation(serializers.Serializer): class ModelDatasetAssociation(serializers.Serializer):
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
@ -662,6 +672,72 @@ class ApplicationSerializer(serializers.Serializer):
get_application_access_token(application_access_token.access_token, False) get_application_access_token(application_access_token.access_token, False)
return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data)} return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data)}
class Import(serializers.Serializer):
file = UploadedFileField(required=True, error_messages=ErrMessage.image("文件"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
@valid_license(model=Application, count=5,
message='社区版最多支持 5 个应用如需拥有更多应用请联系我们https://fit2cloud.com/)。')
@transaction.atomic
def import_(self, with_valid=True):
if with_valid:
self.is_valid()
user_id = self.data.get('user_id')
mk_instance_bytes = self.data.get('file').read()
mk_instance = pickle.loads(mk_instance_bytes)
application = mk_instance.application
function_lib_list = mk_instance.function_lib_list
if len(function_lib_list) > 0:
function_lib_id_list = [function_lib.get('id') for function_lib in function_lib_list]
exits_function_lib_id_list = [str(function_lib.id) for function_lib in
QuerySet(FunctionLib).filter(id__in=function_lib_id_list)]
# 获取到需要插入的函数
function_lib_list = [function_lib for function_lib in function_lib_list if
not exits_function_lib_id_list.__contains__(function_lib.get('id'))]
application_model = self.to_application(application, user_id)
function_lib_model_list = [self.to_function_lib(f, user_id) for f in function_lib_list]
application_model.save()
QuerySet(FunctionLib).bulk_create(function_lib_model_list) if len(function_lib_model_list) > 0 else None
return True
@staticmethod
def to_application(application, user_id):
work_flow = application.get('work_flow')
for node in work_flow.get('nodes', []):
if node.get('type') == 'search-dataset-node':
node.get('properties', {}).get('node_data', {})['dataset_id_list'] = []
return Application(id=uuid.uuid1(), user_id=user_id, name=application.get('name'),
desc=application.get('desc'),
prologue=application.get('prologue'), dialogue_number=application.get('dialogue_number'),
dataset_setting=application.get('dataset_setting'),
model_params_setting=application.get('model_params_setting'),
tts_model_params_setting=application.get('tts_model_params_setting'),
problem_optimization=application.get('problem_optimization'),
icon=application.get('icon'),
work_flow=work_flow,
type=application.get('type'),
problem_optimization_prompt=application.get('problem_optimization_prompt'),
tts_model_enable=application.get('tts_model_enable'),
stt_model_enable=application.get('stt_model_enable'),
tts_type=application.get('tts_type'),
clean_time=application.get('clean_time'),
file_upload_enable=application.get('file_upload_enable'),
file_upload_setting=application.get('file_upload_setting'),
)
@staticmethod
def to_function_lib(function_lib, user_id):
"""
@param user_id: 用户id
@param function_lib: 函数库
@return:
"""
return FunctionLib(id=function_lib.get('id'), user_id=user_id, name=function_lib.get('name'),
code=function_lib.get('code'), input_field_list=function_lib.get('input_field_list'),
is_active=function_lib.get('is_active'),
permission_type=PermissionType.PRIVATE)
class Operate(serializers.Serializer): class Operate(serializers.Serializer):
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
@ -708,6 +784,31 @@ class ApplicationSerializer(serializers.Serializer):
QuerySet(Application).filter(id=self.data.get('application_id')).delete() QuerySet(Application).filter(id=self.data.get('application_id')).delete()
return True return True
def export(self, with_valid=True):
try:
if with_valid:
self.is_valid()
application_id = self.data.get('application_id')
application = QuerySet(Application).filter(id=application_id).first()
function_lib_id_list = [node.get('properties', {}).get('node_data', {}).get('function_lib_id') for node
in
application.work_flow.get('nodes', []) if
node.get('type') == 'function-lib-node']
function_lib_list = []
if len(function_lib_id_list) > 0:
function_lib_list = QuerySet(FunctionLib).filter(id__in=function_lib_id_list)
application_dict = ApplicationSerializerModel(application).data
mk_instance = MKInstance(application_dict,
[FunctionLibModelSerializer(function_lib).data for function_lib in
function_lib_list], 'v1')
application_pickle = pickle.dumps(mk_instance)
response = HttpResponse(content_type='text/plain', content=application_pickle)
response['Content-Disposition'] = f'attachment; filename="{application.name}.mk"'
return response
except Exception as e:
return result.error(str(e), response_status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@transaction.atomic @transaction.atomic
def publish(self, instance, with_valid=True): def publish(self, instance, with_valid=True):
if with_valid: if with_valid:

View File

@ -336,6 +336,27 @@ class ApplicationApi(ApiMixin):
description='应用描述') description='应用描述')
] ]
class Export(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),
]
class Import(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='file',
in_=openapi.IN_FORM,
type=openapi.TYPE_FILE,
required=True,
description='上传图片文件')
]
class Operate(ApiMixin): class Operate(ApiMixin):
@staticmethod @staticmethod
def get_request_params_api(): def get_request_params_api():

View File

@ -5,11 +5,13 @@ from . import views
app_name = "application" app_name = "application"
urlpatterns = [ urlpatterns = [
path('application', views.Application.as_view(), name="application"), path('application', views.Application.as_view(), name="application"),
path('application/import', views.Application.Import.as_view()),
path('application/profile', views.Application.Profile.as_view(), name='application/profile'), path('application/profile', views.Application.Profile.as_view(), name='application/profile'),
path('application/embed', views.Application.Embed.as_view()), path('application/embed', views.Application.Embed.as_view()),
path('application/authentication', views.Application.Authentication.as_view()), path('application/authentication', views.Application.Authentication.as_view()),
path('application/<str:application_id>/publish', views.Application.Publish.as_view()), path('application/<str:application_id>/publish', views.Application.Publish.as_view()),
path('application/<str:application_id>/edit_icon', views.Application.EditIcon.as_view()), path('application/<str:application_id>/edit_icon', views.Application.EditIcon.as_view()),
path('application/<str:application_id>/export', views.Application.Export.as_view()),
path('application/<str:application_id>/statistics/customer_count', path('application/<str:application_id>/statistics/customer_count',
views.ApplicationStatistics.CustomerCount.as_view()), views.ApplicationStatistics.CustomerCount.as_view()),
path('application/<str:application_id>/statistics/customer_count_trend', path('application/<str:application_id>/statistics/customer_count_trend',

View File

@ -27,7 +27,6 @@ from common.response import result
from common.swagger_api.common_api import CommonApi from common.swagger_api.common_api import CommonApi
from common.util.common import query_params_to_single_dict from common.util.common import query_params_to_single_dict
from dataset.serializers.dataset_serializers import DataSetSerializers from dataset.serializers.dataset_serializers import DataSetSerializers
from setting.swagger_api.provide_api import ProvideApi
chat_cache = cache.caches['chat_cache'] chat_cache = cache.caches['chat_cache']
@ -158,6 +157,34 @@ class Application(APIView):
data={'application_id': application_id, 'user_id': request.user.id, data={'application_id': application_id, 'user_id': request.user.id,
'image': request.FILES.get('file')}).edit(request.data)) 'image': request.FILES.get('file')}).edit(request.data))
class Import(APIView):
authentication_classes = [TokenAuth]
parser_classes = [MultiPartParser]
@action(methods="GET", detail=False)
@swagger_auto_schema(operation_summary="导入应用", operation_id="导入应用",
manual_parameters=ApplicationApi.Import.get_request_params_api(),
tags=["应用"]
)
@has_permissions(RoleConstants.ADMIN, RoleConstants.USER)
def post(self, request: Request):
return result.success(ApplicationSerializer.Import(
data={'user_id': request.user.id, 'file': request.FILES.get('file')}).import_())
class Export(APIView):
authentication_classes = [TokenAuth]
@action(methods="GET", detail=False)
@swagger_auto_schema(operation_summary="导出应用", operation_id="导出应用",
manual_parameters=ApplicationApi.Export.get_request_params_api(),
tags=["应用"]
)
@has_permissions(lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE,
dynamic_tag=keywords.get('application_id')))
def get(self, request: Request, application_id: str):
return ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id}).export()
class Embed(APIView): class Embed(APIView):
@action(methods=["GET"], detail=False) @action(methods=["GET"], detail=False)
@swagger_auto_schema(operation_summary="获取嵌入js", @swagger_auto_schema(operation_summary="获取嵌入js",
@ -362,7 +389,8 @@ class Application(APIView):
compare=CompareConstants.AND)) compare=CompareConstants.AND))
def put(self, request: Request, application_id: str): def put(self, request: Request, application_id: str):
return result.success( return result.success(
ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).edit(request.data)) ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).edit(
request.data))
@action(methods=['GET'], detail=False) @action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取应用 AccessToken信息", @swagger_auto_schema(operation_summary="获取应用 AccessToken信息",
@ -382,9 +410,10 @@ class Application(APIView):
class Authentication(APIView): class Authentication(APIView):
@action(methods=['OPTIONS'], detail=False) @action(methods=['OPTIONS'], detail=False)
def options(self, request, *args, **kwargs): def options(self, request, *args, **kwargs):
return HttpResponse(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true", return HttpResponse(
"Access-Control-Allow-Methods": "POST", headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}, ) "Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}, )
@action(methods=['POST'], detail=False) @action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="应用认证", @swagger_auto_schema(operation_summary="应用认证",
@ -404,6 +433,7 @@ class Application(APIView):
) )
@action(methods=['POST'], detail=False) @action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建应用", @swagger_auto_schema(operation_summary="创建应用",
operation_id="创建应用", operation_id="创建应用",
request_body=ApplicationApi.Create.get_request_body_api(), request_body=ApplicationApi.Create.get_request_body_api(),
@ -444,7 +474,8 @@ class Application(APIView):
"query_text": request.query_params.get("query_text"), "query_text": request.query_params.get("query_text"),
"top_number": request.query_params.get("top_number"), "top_number": request.query_params.get("top_number"),
'similarity': request.query_params.get('similarity'), 'similarity': request.query_params.get('similarity'),
'search_mode': request.query_params.get('search_mode')}).hit_test( 'search_mode': request.query_params.get(
'search_mode')}).hit_test(
)) ))
class Publish(APIView): class Publish(APIView):
@ -502,7 +533,8 @@ class Application(APIView):
compare=CompareConstants.AND)) compare=CompareConstants.AND))
def put(self, request: Request, application_id: str): def put(self, request: Request, application_id: str):
return result.success( return result.success(
ApplicationSerializer.Operate(data={'application_id': application_id, 'user_id': request.user.id}).edit( ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id}).edit(
request.data)) request.data))
@action(methods=['GET'], detail=False) @action(methods=['GET'], detail=False)
@ -528,11 +560,14 @@ class Application(APIView):
@swagger_auto_schema(operation_summary="获取当前应用可使用的知识库", @swagger_auto_schema(operation_summary="获取当前应用可使用的知识库",
operation_id="获取当前应用可使用的知识库", operation_id="获取当前应用可使用的知识库",
manual_parameters=ApplicationApi.Operate.get_request_params_api(), manual_parameters=ApplicationApi.Operate.get_request_params_api(),
responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()), responses=result.get_api_array_response(
DataSetSerializers.Query.get_response_body_api()),
tags=['应用']) tags=['应用'])
@has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], @has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, [lambda r, keywords: Permission(group=Group.APPLICATION,
dynamic_tag=keywords.get('application_id'))], operate=Operate.USE,
dynamic_tag=keywords.get(
'application_id'))],
compare=CompareConstants.AND)) compare=CompareConstants.AND))
def get(self, request: Request, application_id: str): def get(self, request: Request, application_id: str):
return result.success(ApplicationSerializer.Operate( return result.success(ApplicationSerializer.Operate(

View File

@ -157,10 +157,10 @@ def success(data, **kwargs):
return Result(data=data, **kwargs) return Result(data=data, **kwargs)
def error(message): def error(message, **kwargs):
""" """
获取一个失败的响应对象 获取一个失败的响应对象
:param message: 错误提示 :param message: 错误提示
:return: 接口响应对象 :return: 接口响应对象
""" """
return Result(code=500, message=message) return Result(code=500, message=message, **kwargs)

View File

@ -1,5 +1,5 @@
import { Result } from '@/request/Result' import { Result } from '@/request/Result'
import { get, post, postStream, del, put, request, download } from '@/request/index' import { get, post, postStream, del, put, request, download, exportFile } from '@/request/index'
import type { pageRequest } from '@/api/type/common' import type { pageRequest } from '@/api/type/common'
import type { ApplicationFormType } from '@/api/type/application' import type { ApplicationFormType } from '@/api/type/application'
import { type Ref } from 'vue' import { type Ref } from 'vue'
@ -300,7 +300,6 @@ const getApplicationTTIModel: (
return get(`${prefix}/${application_id}/model`, { model_type: 'TTI' }, loading) return get(`${prefix}/${application_id}/model`, { model_type: 'TTI' }, loading)
} }
/** /**
* *
* @param * @param
@ -377,7 +376,6 @@ const uploadFile: (
return post(`${prefix}/${application_id}/chat/${chat_id}/upload_file`, data, undefined, loading) return post(`${prefix}/${application_id}/chat/${chat_id}/upload_file`, data, undefined, loading)
} }
/** /**
* *
*/ */
@ -503,6 +501,28 @@ const getUserList: (type: string, loading?: Ref<boolean>) => Promise<Result<any>
return get(`/user/list/${type}`, undefined, loading) return get(`/user/list/${type}`, undefined, loading)
} }
const exportApplication = (
application_id: string,
application_name: string,
loading?: Ref<boolean>
) => {
return exportFile(
application_name + '.mk',
`/application/${application_id}/export`,
undefined,
loading
)
}
/**
*
*/
const importApplication: (data: any, loading?: Ref<boolean>) => Promise<Result<any>> = (
data,
loading
) => {
return post(`${prefix}/import`, data, undefined, loading)
}
export default { export default {
getAllAppilcation, getAllAppilcation,
getApplication, getApplication,
@ -544,5 +564,7 @@ export default {
playDemoText, playDemoText,
getUserList, getUserList,
getApplicationList, getApplicationList,
uploadFile uploadFile,
exportApplication,
importApplication
} }

View File

@ -227,7 +227,6 @@ export const exportExcel: (
) => { ) => {
return promise(request({ url: url, method: 'get', params, responseType: 'blob' }), loading).then( return promise(request({ url: url, method: 'get', params, responseType: 'blob' }), loading).then(
(res: any) => { (res: any) => {
console.log(res)
if (res) { if (res) {
const blob = new Blob([res], { const blob = new Blob([res], {
type: 'application/vnd.ms-excel' type: 'application/vnd.ms-excel'
@ -244,6 +243,35 @@ export const exportExcel: (
) )
} }
export const exportFile: (
fileName: string,
url: string,
params: any,
loading?: NProgress | Ref<boolean>
) => Promise<any> = (
fileName: string,
url: string,
params: any,
loading?: NProgress | Ref<boolean>
) => {
return promise(request({ url: url, method: 'get', params, responseType: 'blob' }), loading).then(
(res: any) => {
if (res) {
const blob = new Blob([res], {
type: 'application/octet-stream'
})
const link = document.createElement('a')
link.href = window.URL.createObjectURL(blob)
link.download = fileName
link.click()
//释放内存
window.URL.revokeObjectURL(link.href)
}
return true
}
)
}
export const exportExcelPost: ( export const exportExcelPost: (
fileName: string, fileName: string,
url: string, url: string,

View File

@ -3,6 +3,18 @@
<div class="flex-between mb-16"> <div class="flex-between mb-16">
<h4>{{ $t('views.application.applicationList.title') }}</h4> <h4>{{ $t('views.application.applicationList.title') }}</h4>
<div class="flex-between"> <div class="flex-between">
<el-upload
:file-list="[]"
class="flex-between"
action="#"
multiple
:auto-upload="false"
:show-file-list="false"
:limit="1"
:on-change="(file: any, fileList: any) => importApplication(file)"
>
<el-button>导入应用</el-button>
</el-upload>
<el-select <el-select
v-model="selectUserId" v-model="selectUserId"
class="mr-12" class="mr-12"
@ -128,7 +140,9 @@
<AppIcon iconName="app-copy"></AppIcon> <AppIcon iconName="app-copy"></AppIcon>
复制</el-dropdown-item 复制</el-dropdown-item
> >
<el-dropdown-item icon="Delete" @click.stop="exportApplication(item)">
导出
</el-dropdown-item>
<el-dropdown-item icon="Delete" @click.stop="deleteApplication(item)">{{ <el-dropdown-item icon="Delete" @click.stop="deleteApplication(item)">{{
$t('views.application.applicationList.card.delete.tooltip') $t('views.application.applicationList.card.delete.tooltip')
}}</el-dropdown-item> }}</el-dropdown-item>
@ -152,7 +166,7 @@ import { ref, onMounted, reactive } from 'vue'
import applicationApi from '@/api/application' import applicationApi from '@/api/application'
import CreateApplicationDialog from './component/CreateApplicationDialog.vue' import CreateApplicationDialog from './component/CreateApplicationDialog.vue'
import CopyApplicationDialog from './component/CopyApplicationDialog.vue' import CopyApplicationDialog from './component/CopyApplicationDialog.vue'
import { MsgSuccess, MsgConfirm, MsgAlert } from '@/utils/message' import { MsgSuccess, MsgConfirm, MsgAlert, MsgError } from '@/utils/message'
import { isAppIcon } from '@/utils/application' import { isAppIcon } from '@/utils/application'
import { useRouter } from 'vue-router' import { useRouter } from 'vue-router'
import { isWorkFlow } from '@/utils/application' import { isWorkFlow } from '@/utils/application'
@ -203,7 +217,20 @@ function settingApplication(row: any) {
router.push({ path: `/application/${row.id}/${row.type}/setting` }) router.push({ path: `/application/${row.id}/${row.type}/setting` })
} }
} }
const exportApplication = (application: any) => {
applicationApi.exportApplication(application.id, application.name, loading).catch((e) => {
e.response.data.text().then((res: string) => {
MsgError(`导出失败:${JSON.parse(res).message}`)
})
})
}
const importApplication = (file: any) => {
const formData = new FormData()
formData.append('file', file.raw, file.name)
applicationApi.importApplication(formData, loading).then((ok) => {
searchHandle()
})
}
function openCreateDialog() { function openCreateDialog() {
if (user.isEnterprise()) { if (user.isEnterprise()) {
CreateApplicationDialogRef.value.open() CreateApplicationDialogRef.value.open()