feat: 高级编排支持文件上传(WIP)

This commit is contained in:
CaptainB 2024-11-13 14:48:01 +08:00 committed by 刘瑞斌
parent 60097b4903
commit 88b6eebf35
12 changed files with 90 additions and 30 deletions

View File

@ -9,12 +9,7 @@ from common.util.field_message import ErrMessage
class DocumentExtractNodeSerializer(serializers.Serializer): class DocumentExtractNodeSerializer(serializers.Serializer):
# 需要查询的数据集id列表 document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档"))
file_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
error_messages=ErrMessage.list("数据集id列表"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
class IDocumentExtractNode(INode): class IDocumentExtractNode(INode):
@ -24,7 +19,9 @@ class IDocumentExtractNode(INode):
return DocumentExtractNodeSerializer return DocumentExtractNodeSerializer
def _run(self): def _run(self):
return self.execute(**self.flow_params_serializer.data) res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('document_list')[0],
self.node_params_serializer.data.get('document_list')[1:])
return self.execute(document=res, **self.flow_params_serializer.data)
def execute(self, file_list, **kwargs) -> NodeResult: def execute(self, document, **kwargs) -> NodeResult:
pass pass

View File

@ -1,11 +1,26 @@
# coding=utf-8 # coding=utf-8
from application.flow.i_step_node import NodeResult
from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode
class BaseDocumentExtractNode(IDocumentExtractNode): class BaseDocumentExtractNode(IDocumentExtractNode):
def execute(self, file_list, **kwargs): def execute(self, document, **kwargs):
pass self.context['document_list'] = document
content = ''
if len(document) > 0:
for doc in document:
content += doc['name']
content += '\n-----------------------------------\n'
return NodeResult({'content': content}, {})
def get_details(self, index: int, **kwargs): def get_details(self, index: int, **kwargs):
pass return {
'name': self.node.properties.get('stepName'),
"index": index,
'run_time': self.context.get('run_time'),
'type': self.node.type,
'content': self.context.get('content'),
'status': self.status,
'err_message': self.err_message,
'document_list': self.context.get('document_list')
}

View File

@ -18,7 +18,7 @@ class ImageUnderstandNodeSerializer(serializers.Serializer):
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片仅1张")) image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
class IImageUnderstandNode(INode): class IImageUnderstandNode(INode):

View File

@ -25,7 +25,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo
node.context['question'] = node_variable['question'] node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time'] node.context['run_time'] = time.time() - node.context['start_time']
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
workflow.answer += answer node.answer_text = answer
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):

View File

@ -52,8 +52,12 @@ class BaseStartStepNode(IStarNode):
""" """
开始节点 初始化全局变量 开始节点 初始化全局变量
""" """
return NodeResult({'question': question, 'image': self.workflow_manage.image_list}, node_variable = {
workflow_variable) 'question': question,
'image': self.workflow_manage.image_list,
'document': self.workflow_manage.document_list
}
return NodeResult(node_variable, workflow_variable)
def get_details(self, index: int, **kwargs): def get_details(self, index: int, **kwargs):
global_fields = [] global_fields = []

View File

@ -240,16 +240,20 @@ class NodeChunk:
class WorkflowManage: class WorkflowManage:
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler, def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None, base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None,
document_list=None,
start_node_id=None, start_node_id=None,
start_node_data=None, chat_record=None): start_node_data=None, chat_record=None):
if form_data is None: if form_data is None:
form_data = {} form_data = {}
if image_list is None: if image_list is None:
image_list = [] image_list = []
if document_list is None:
document_list = []
self.start_node = None self.start_node = None
self.start_node_result_future = None self.start_node_result_future = None
self.form_data = form_data self.form_data = form_data
self.image_list = image_list self.image_list = image_list
self.document_list = document_list
self.params = params self.params = params
self.flow = flow self.flow = flow
self.lock = threading.Lock() self.lock = threading.Lock()

View File

@ -230,7 +230,8 @@ class ChatMessageSerializer(serializers.Serializer):
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量")) form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量"))
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片仅1张")) image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档"))
def is_valid_application_workflow(self, *, raise_exception=False): def is_valid_application_workflow(self, *, raise_exception=False):
self.is_valid_intraday_access_num() self.is_valid_intraday_access_num()
@ -322,6 +323,7 @@ class ChatMessageSerializer(serializers.Serializer):
client_type = self.data.get('client_type') client_type = self.data.get('client_type')
form_data = self.data.get('form_data') form_data = self.data.get('form_data')
image_list = self.data.get('image_list') image_list = self.data.get('image_list')
document_list = self.data.get('document_list')
user_id = chat_info.application.user_id user_id = chat_info.application.user_id
chat_record_id = self.data.get('chat_record_id') chat_record_id = self.data.get('chat_record_id')
chat_record = None chat_record = None
@ -336,7 +338,7 @@ class ChatMessageSerializer(serializers.Serializer):
'client_id': client_id, 'client_id': client_id,
'client_type': client_type, 'client_type': client_type,
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type), 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
base_to_response, form_data, image_list, self.data.get('runtime_node_id'), base_to_response, form_data, image_list, document_list, self.data.get('runtime_node_id'),
self.data.get('node_data'), chat_record) self.data.get('node_data'), chat_record)
r = work_flow_manage.run() r = work_flow_manage.run()
return r return r

View File

@ -132,6 +132,8 @@ class ChatView(APIView):
'image_list': request.data.get( 'image_list': request.data.get(
'image_list') if 'image_list' in request.data else [], 'image_list') if 'image_list' in request.data else [],
'document_list': request.data.get(
'document_list') if 'document_list' in request.data else [],
'client_type': request.auth.client_type, 'client_type': request.auth.client_type,
'runtime_node_id': request.data.get('runtime_node_id', None), 'runtime_node_id': request.data.get('runtime_node_id', None),
'node_data': request.data.get('node_data', {}), 'node_data': request.data.get('node_data', {}),

View File

@ -39,7 +39,8 @@ interface chatType {
record_id: string record_id: string
chat_id: string chat_id: string
vote_status: string vote_status: string
status?: number status?: number,
execution_details: any[]
} }
export class ChatRecordManage { export class ChatRecordManage {

View File

@ -20,10 +20,12 @@
<div class="operate flex align-center"> <div class="operate flex align-center">
<span v-if="props.applicationDetails.file_upload_enable" class="flex align-center"> <span v-if="props.applicationDetails.file_upload_enable" class="flex align-center">
<!-- accept="image/jpeg, image/png, image/gif"-->
<el-upload <el-upload
action="#" action="#"
:auto-upload="false" :auto-upload="false"
:show-file-list="false" :show-file-list="false"
:accept="[...imageExtensions, ...documentExtensions].map((ext) => '.' + ext).join(',')"
:on-change="(file: any, fileList: any) => uploadFile(file, fileList)" :on-change="(file: any, fileList: any) => uploadFile(file, fileList)"
> >
<el-button text> <el-button text>
@ -126,6 +128,13 @@ const localLoading = computed({
emit('update:loading', v) emit('update:loading', v)
} }
}) })
const imageExtensions = ['jpg', 'jpeg', 'png', 'gif', 'bmp']
const documentExtensions = ['pdf', 'docx', 'txt', 'xls', 'xlsx', 'md', 'html']
const videoExtensions = ['mp4', 'avi', 'mov', 'mkv', 'flv']
const audioExtensions = ['mp3', 'wav', 'aac', 'flac']
const uploadFile = async (file: any, fileList: any) => { const uploadFile = async (file: any, fileList: any) => {
const { maxFiles, fileLimit } = props.applicationDetails.file_upload_setting const { maxFiles, fileLimit } = props.applicationDetails.file_upload_setting
if (fileList.length > maxFiles) { if (fileList.length > maxFiles) {
@ -141,7 +150,18 @@ const uploadFile = async (file: any, fileList: any) => {
const formData = new FormData() const formData = new FormData()
for (const file of fileList) { for (const file of fileList) {
formData.append('file', file.raw, file.name) formData.append('file', file.raw, file.name)
uploadFileList.value.push(file) //
const extension = file.name.split('.').pop().toLowerCase() //
if (imageExtensions.includes(extension)) {
uploadImageList.value.push(file)
} else if (documentExtensions.includes(extension)) {
uploadDocumentList.value.push(file)
} else if (videoExtensions.includes(extension)) {
// videos.push(file)
} else if (audioExtensions.includes(extension)) {
// audios.push(file)
}
} }
if (!chatId_context.value) { if (!chatId_context.value) {
@ -158,21 +178,22 @@ const uploadFile = async (file: any, fileList: any) => {
) )
.then((response) => { .then((response) => {
fileList.splice(0, fileList.length) fileList.splice(0, fileList.length)
uploadFileList.value.forEach((file: any) => { uploadImageList.value.forEach((file: any) => {
const f = response.data.filter((f: any) => f.name === file.name) const f = response.data.filter((f: any) => f.name === file.name)
if (f.length > 0) { if (f.length > 0) {
file.url = f[0].url file.url = f[0].url
file.file_id = f[0].file_id file.file_id = f[0].file_id
} }
}) })
console.log(uploadFileList.value) console.log(uploadDocumentList.value, uploadImageList.value)
}) })
} }
const recorderTime = ref(0) const recorderTime = ref(0)
const startRecorderTime = ref(false) const startRecorderTime = ref(false)
const recorderLoading = ref(false) const recorderLoading = ref(false)
const inputValue = ref<string>('') const inputValue = ref<string>('')
const uploadFileList = ref<Array<any>>([]) const uploadImageList = ref<Array<any>>([])
const uploadDocumentList = ref<Array<any>>([])
const mediaRecorderStatus = ref(true) const mediaRecorderStatus = ref(true)
// //
const mediaRecorder = ref<any>(null) const mediaRecorder = ref<any>(null)
@ -289,15 +310,20 @@ const handleTimeChange = () => {
handleTimeChange() handleTimeChange()
}, 1000) }, 1000)
} }
function sendChatHandle(event: any) { function sendChatHandle(event: any) {
if (!event.ctrlKey) { if (!event.ctrlKey) {
// ctrl // ctrl
event.preventDefault() event.preventDefault()
if (!isDisabledChart.value && !props.loading && !event.isComposing) { if (!isDisabledChart.value && !props.loading && !event.isComposing) {
if (inputValue.value.trim()) { if (inputValue.value.trim()) {
props.sendMessage(inputValue.value, { image_list: uploadFileList.value }) props.sendMessage(inputValue.value, {
image_list: uploadImageList.value,
document_list: uploadDocumentList.value
})
inputValue.value = '' inputValue.value = ''
uploadFileList.value = [] uploadImageList.value = []
uploadDocumentList.value = []
quickInputRef.value.textareaStyle.height = '45px' quickInputRef.value.textareaStyle.height = '45px'
} }
} }

View File

@ -22,10 +22,19 @@
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { type chatType } from '@/api/type/application' import { type chatType } from '@/api/type/application'
defineProps<{ import { onMounted } from 'vue'
const props = defineProps<{
application: any application: any
chatRecord: chatType chatRecord: chatType
}>() }>()
onMounted(() => {
if (props.chatRecord.execution_details?.length > 0) {
props.chatRecord.execution_details[0].image_list.forEach((image: any) => {
console.log('image', image.name, image.url)
})
}
})
</script> </script>
<style lang="scss" scoped> <style lang="scss" scoped>
</style> </style>

View File

@ -10,7 +10,7 @@
label-width="auto" label-width="auto"
ref="DatasetNodeFormRef" ref="DatasetNodeFormRef"
> >
<el-form-item label="选择文" :rules="{ <el-form-item label="选择文" :rules="{
type: 'array', type: 'array',
required: true, required: true,
message: '请选择文件', message: '请选择文件',
@ -21,8 +21,8 @@
ref="nodeCascaderRef" ref="nodeCascaderRef"
:nodeModel="nodeModel" :nodeModel="nodeModel"
class="w-full" class="w-full"
placeholder="请选择文" placeholder="请选择文"
v-model="form.file_list" v-model="form.document_list"
/> />
</el-form-item> </el-form-item>
</el-form> </el-form>
@ -39,7 +39,7 @@ import NodeCascader from '@/workflow/common/NodeCascader.vue'
const props = defineProps<{ nodeModel: any }>() const props = defineProps<{ nodeModel: any }>()
const form = { const form = {
file_list: [] document_list: ["start-node", "document"]
} }