feat: 高级编排支持文件上传(WIP)
This commit is contained in:
parent
72b91bee9a
commit
8a8305e75b
@ -16,9 +16,12 @@ from .direct_reply_node import *
|
|||||||
from .function_lib_node import *
|
from .function_lib_node import *
|
||||||
from .function_node import *
|
from .function_node import *
|
||||||
from .reranker_node import *
|
from .reranker_node import *
|
||||||
|
from .document_extract_node import *
|
||||||
|
from .image_understand_step_node import *
|
||||||
|
|
||||||
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
|
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
|
||||||
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode]
|
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode, BaseDocumentExtractNode,
|
||||||
|
BaseImageUnderstandNode]
|
||||||
|
|
||||||
|
|
||||||
def get_node(node_type):
|
def get_node(node_type):
|
||||||
|
|||||||
@ -0,0 +1 @@
|
|||||||
|
from .impl import *
|
||||||
@ -0,0 +1,30 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
from application.flow.i_step_node import INode, NodeResult
|
||||||
|
from common.util.field_message import ErrMessage
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentExtractNodeSerializer(serializers.Serializer):
|
||||||
|
# 需要查询的数据集id列表
|
||||||
|
file_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
||||||
|
error_messages=ErrMessage.list("数据集id列表"))
|
||||||
|
|
||||||
|
def is_valid(self, *, raise_exception=False):
|
||||||
|
super().is_valid(raise_exception=True)
|
||||||
|
|
||||||
|
|
||||||
|
class IDocumentExtractNode(INode):
|
||||||
|
type = 'document-extract-node'
|
||||||
|
|
||||||
|
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||||
|
return DocumentExtractNodeSerializer
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
return self.execute(**self.flow_params_serializer.data)
|
||||||
|
|
||||||
|
def execute(self, file_list, **kwargs) -> NodeResult:
|
||||||
|
pass
|
||||||
@ -0,0 +1 @@
|
|||||||
|
from .base_document_extract_node import BaseDocumentExtractNode
|
||||||
@ -0,0 +1,11 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDocumentExtractNode(IDocumentExtractNode):
|
||||||
|
def execute(self, file_list, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_details(self, index: int, **kwargs):
|
||||||
|
pass
|
||||||
@ -0,0 +1,3 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from .impl import *
|
||||||
@ -0,0 +1,39 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
from application.flow.i_step_node import INode, NodeResult
|
||||||
|
from common.util.field_message import ErrMessage
|
||||||
|
|
||||||
|
|
||||||
|
class ImageUnderstandNodeSerializer(serializers.Serializer):
|
||||||
|
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
|
||||||
|
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||||
|
error_messages=ErrMessage.char("角色设定"))
|
||||||
|
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
|
||||||
|
# 多轮对话数量
|
||||||
|
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
||||||
|
|
||||||
|
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
|
||||||
|
|
||||||
|
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片仅1张"))
|
||||||
|
|
||||||
|
|
||||||
|
class IImageUnderstandNode(INode):
|
||||||
|
type = 'image-understand-node'
|
||||||
|
|
||||||
|
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||||
|
return ImageUnderstandNodeSerializer
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('image_list')[0],
|
||||||
|
self.node_params_serializer.data.get('image_list')[1:])
|
||||||
|
return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||||
|
|
||||||
|
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
|
||||||
|
chat_record_id,
|
||||||
|
image,
|
||||||
|
**kwargs) -> NodeResult:
|
||||||
|
pass
|
||||||
@ -0,0 +1,3 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from .base_image_understand_node import BaseImageUnderstandNode
|
||||||
@ -0,0 +1,147 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from functools import reduce
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from django.db.models import QuerySet
|
||||||
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from application.flow.i_step_node import NodeResult, INode
|
||||||
|
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
|
||||||
|
from dataset.models import File
|
||||||
|
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||||
|
|
||||||
|
|
||||||
|
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
||||||
|
chat_model = node_variable.get('chat_model')
|
||||||
|
message_tokens = node_variable['usage_metadata']['output_tokens'] if 'usage_metadata' in node_variable else 0
|
||||||
|
answer_tokens = chat_model.get_num_tokens(answer)
|
||||||
|
node.context['message_tokens'] = message_tokens
|
||||||
|
node.context['answer_tokens'] = answer_tokens
|
||||||
|
node.context['answer'] = answer
|
||||||
|
node.context['history_message'] = node_variable['history_message']
|
||||||
|
node.context['question'] = node_variable['question']
|
||||||
|
node.context['run_time'] = time.time() - node.context['start_time']
|
||||||
|
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
|
||||||
|
workflow.answer += answer
|
||||||
|
|
||||||
|
|
||||||
|
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||||
|
"""
|
||||||
|
写入上下文数据 (流式)
|
||||||
|
@param node_variable: 节点数据
|
||||||
|
@param workflow_variable: 全局数据
|
||||||
|
@param node: 节点
|
||||||
|
@param workflow: 工作流管理器
|
||||||
|
"""
|
||||||
|
response = node_variable.get('result')
|
||||||
|
answer = ''
|
||||||
|
for chunk in response:
|
||||||
|
answer += chunk.content
|
||||||
|
yield chunk.content
|
||||||
|
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||||
|
|
||||||
|
|
||||||
|
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||||
|
"""
|
||||||
|
写入上下文数据
|
||||||
|
@param node_variable: 节点数据
|
||||||
|
@param workflow_variable: 全局数据
|
||||||
|
@param node: 节点实例对象
|
||||||
|
@param workflow: 工作流管理器
|
||||||
|
"""
|
||||||
|
response = node_variable.get('result')
|
||||||
|
answer = response.content
|
||||||
|
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseImageUnderstandNode(IImageUnderstandNode):
|
||||||
|
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||||
|
image,
|
||||||
|
**kwargs) -> NodeResult:
|
||||||
|
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
||||||
|
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||||
|
self.context['history_message'] = history_message
|
||||||
|
question = self.generate_prompt_question(prompt)
|
||||||
|
self.context['question'] = question.content
|
||||||
|
# todo 处理上传图片
|
||||||
|
message_list = self.generate_message_list(image_model, system, prompt, history_message, image)
|
||||||
|
self.context['message_list'] = message_list
|
||||||
|
self.context['image_list'] = image
|
||||||
|
if stream:
|
||||||
|
r = image_model.stream(message_list)
|
||||||
|
return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
|
||||||
|
'history_message': history_message, 'question': question.content}, {},
|
||||||
|
_write_context=write_context_stream)
|
||||||
|
else:
|
||||||
|
r = image_model.invoke(message_list)
|
||||||
|
return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
|
||||||
|
'history_message': history_message, 'question': question.content}, {},
|
||||||
|
_write_context=write_context)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_history_message(history_chat_record, dialogue_number):
|
||||||
|
start_index = len(history_chat_record) - dialogue_number
|
||||||
|
history_message = reduce(lambda x, y: [*x, *y], [
|
||||||
|
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||||
|
for index in
|
||||||
|
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
|
||||||
|
return history_message
|
||||||
|
|
||||||
|
def generate_prompt_question(self, prompt):
|
||||||
|
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
|
||||||
|
|
||||||
|
def generate_message_list(self, image_model, system: str, prompt: str, history_message, image):
|
||||||
|
if image is not None and len(image) > 0:
|
||||||
|
file_id = image[0]['file_id']
|
||||||
|
file = QuerySet(File).filter(id=file_id).first()
|
||||||
|
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
|
||||||
|
messages = [HumanMessage(
|
||||||
|
content=[
|
||||||
|
{'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},
|
||||||
|
{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}},
|
||||||
|
])]
|
||||||
|
else:
|
||||||
|
messages = [HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||||
|
|
||||||
|
if system is not None and len(system) > 0:
|
||||||
|
return [
|
||||||
|
SystemMessage(self.workflow_manage.generate_prompt(system)),
|
||||||
|
*history_message,
|
||||||
|
*messages
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
*history_message,
|
||||||
|
*messages
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||||
|
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||||
|
message
|
||||||
|
in
|
||||||
|
message_list]
|
||||||
|
result.append({'role': 'ai', 'content': answer_text})
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_details(self, index: int, **kwargs):
|
||||||
|
return {
|
||||||
|
'name': self.node.properties.get('stepName'),
|
||||||
|
"index": index,
|
||||||
|
'run_time': self.context.get('run_time'),
|
||||||
|
'system': self.node_params.get('system'),
|
||||||
|
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||||
|
(self.context.get('history_message') if self.context.get(
|
||||||
|
'history_message') is not None else [])],
|
||||||
|
'question': self.context.get('question'),
|
||||||
|
'answer': self.context.get('answer'),
|
||||||
|
'type': self.node.type,
|
||||||
|
'message_tokens': self.context.get('message_tokens'),
|
||||||
|
'answer_tokens': self.context.get('answer_tokens'),
|
||||||
|
'status': self.status,
|
||||||
|
'err_message': self.err_message,
|
||||||
|
'image_list': self.context.get('image_list')
|
||||||
|
}
|
||||||
@ -41,7 +41,7 @@ class BaseStartStepNode(IStarNode):
|
|||||||
"""
|
"""
|
||||||
开始节点 初始化全局变量
|
开始节点 初始化全局变量
|
||||||
"""
|
"""
|
||||||
return NodeResult({'question': question},
|
return NodeResult({'question': question, 'image': self.workflow_manage.image_list},
|
||||||
workflow_variable)
|
workflow_variable)
|
||||||
|
|
||||||
def get_details(self, index: int, **kwargs):
|
def get_details(self, index: int, **kwargs):
|
||||||
@ -61,5 +61,6 @@ class BaseStartStepNode(IStarNode):
|
|||||||
'type': self.node.type,
|
'type': self.node.type,
|
||||||
'status': self.status,
|
'status': self.status,
|
||||||
'err_message': self.err_message,
|
'err_message': self.err_message,
|
||||||
|
'image_list': self.context.get('image'),
|
||||||
'global_fields': global_fields
|
'global_fields': global_fields
|
||||||
}
|
}
|
||||||
|
|||||||
@ -240,10 +240,13 @@ class NodeChunk:
|
|||||||
|
|
||||||
class WorkflowManage:
|
class WorkflowManage:
|
||||||
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
|
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
|
||||||
base_to_response: BaseToResponse = SystemToResponse(), form_data=None):
|
base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None):
|
||||||
if form_data is None:
|
if form_data is None:
|
||||||
form_data = {}
|
form_data = {}
|
||||||
|
if image_list is None:
|
||||||
|
image_list = []
|
||||||
self.form_data = form_data
|
self.form_data = form_data
|
||||||
|
self.image_list = image_list
|
||||||
self.params = params
|
self.params = params
|
||||||
self.flow = flow
|
self.flow = flow
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
|
|||||||
@ -0,0 +1,23 @@
|
|||||||
|
# Generated by Django 4.2.15 on 2024-11-07 11:22
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('application', '0018_workflowversion_name'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='application',
|
||||||
|
name='file_upload_enable',
|
||||||
|
field=models.BooleanField(default=False, verbose_name='文件上传是否启用'),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='application',
|
||||||
|
name='file_upload_setting',
|
||||||
|
field=models.JSONField(default={}, verbose_name='文件上传相关设置'),
|
||||||
|
),
|
||||||
|
]
|
||||||
@ -66,6 +66,9 @@ class Application(AppModelMixin):
|
|||||||
stt_model_enable = models.BooleanField(verbose_name="语音识别模型是否启用", default=False)
|
stt_model_enable = models.BooleanField(verbose_name="语音识别模型是否启用", default=False)
|
||||||
tts_type = models.CharField(verbose_name="语音播放类型", max_length=20, default="BROWSER")
|
tts_type = models.CharField(verbose_name="语音播放类型", max_length=20, default="BROWSER")
|
||||||
clean_time = models.IntegerField(verbose_name="清理时间", default=180)
|
clean_time = models.IntegerField(verbose_name="清理时间", default=180)
|
||||||
|
file_upload_enable = models.BooleanField(verbose_name="文件上传是否启用", default=False)
|
||||||
|
file_upload_setting = models.JSONField(verbose_name="文件上传相关设置", default={})
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_default_model_prompt():
|
def get_default_model_prompt():
|
||||||
|
|||||||
@ -823,6 +823,8 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
'stt_model_enable': application.stt_model_enable,
|
'stt_model_enable': application.stt_model_enable,
|
||||||
'tts_model_enable': application.tts_model_enable,
|
'tts_model_enable': application.tts_model_enable,
|
||||||
'tts_type': application.tts_type,
|
'tts_type': application.tts_type,
|
||||||
|
'file_upload_enable': application.file_upload_enable,
|
||||||
|
'file_upload_setting': application.file_upload_setting,
|
||||||
'work_flow': application.work_flow,
|
'work_flow': application.work_flow,
|
||||||
'show_source': application_access_token.show_source,
|
'show_source': application_access_token.show_source,
|
||||||
**application_setting_dict})
|
**application_setting_dict})
|
||||||
@ -876,6 +878,7 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
|
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
|
||||||
'dataset_setting', 'model_setting', 'problem_optimization', 'dialogue_number',
|
'dataset_setting', 'model_setting', 'problem_optimization', 'dialogue_number',
|
||||||
'stt_model_id', 'tts_model_id', 'tts_model_enable', 'stt_model_enable', 'tts_type',
|
'stt_model_id', 'tts_model_id', 'tts_model_enable', 'stt_model_enable', 'tts_type',
|
||||||
|
'file_upload_enable', 'file_upload_setting',
|
||||||
'api_key_is_active', 'icon', 'work_flow', 'model_params_setting', 'tts_model_params_setting',
|
'api_key_is_active', 'icon', 'work_flow', 'model_params_setting', 'tts_model_params_setting',
|
||||||
'problem_optimization_prompt', 'clean_time']
|
'problem_optimization_prompt', 'clean_time']
|
||||||
for update_key in update_keys:
|
for update_key in update_keys:
|
||||||
@ -941,6 +944,10 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
instance['tts_type'] = node_data['tts_type']
|
instance['tts_type'] = node_data['tts_type']
|
||||||
if 'tts_model_params_setting' in node_data:
|
if 'tts_model_params_setting' in node_data:
|
||||||
instance['tts_model_params_setting'] = node_data['tts_model_params_setting']
|
instance['tts_model_params_setting'] = node_data['tts_model_params_setting']
|
||||||
|
if 'file_upload_enable' in node_data:
|
||||||
|
instance['file_upload_enable'] = node_data['file_upload_enable']
|
||||||
|
if 'file_upload_setting' in node_data:
|
||||||
|
instance['file_upload_setting'] = node_data['file_upload_setting']
|
||||||
break
|
break
|
||||||
|
|
||||||
def speech_to_text(self, file, with_valid=True):
|
def speech_to_text(self, file, with_valid=True):
|
||||||
|
|||||||
@ -222,6 +222,7 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||||||
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
|
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
|
||||||
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
|
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
|
||||||
form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量"))
|
form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量"))
|
||||||
|
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片仅1张"))
|
||||||
|
|
||||||
def is_valid_application_workflow(self, *, raise_exception=False):
|
def is_valid_application_workflow(self, *, raise_exception=False):
|
||||||
self.is_valid_intraday_access_num()
|
self.is_valid_intraday_access_num()
|
||||||
@ -299,6 +300,7 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||||||
client_id = self.data.get('client_id')
|
client_id = self.data.get('client_id')
|
||||||
client_type = self.data.get('client_type')
|
client_type = self.data.get('client_type')
|
||||||
form_data = self.data.get('form_data')
|
form_data = self.data.get('form_data')
|
||||||
|
image_list = self.data.get('image_list')
|
||||||
user_id = chat_info.application.user_id
|
user_id = chat_info.application.user_id
|
||||||
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
|
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
|
||||||
{'history_chat_record': chat_info.chat_record_list, 'question': message,
|
{'history_chat_record': chat_info.chat_record_list, 'question': message,
|
||||||
@ -308,7 +310,7 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||||||
'client_id': client_id,
|
'client_id': client_id,
|
||||||
'client_type': client_type,
|
'client_type': client_type,
|
||||||
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
|
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
|
||||||
base_to_response, form_data)
|
base_to_response, form_data, image_list)
|
||||||
r = work_flow_manage.run()
|
r = work_flow_manage.run()
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|||||||
@ -49,6 +49,7 @@ urlpatterns = [
|
|||||||
path('application/<str:application_id>/chat/<int:current_page>/<int:page_size>', views.ChatView.Page.as_view()),
|
path('application/<str:application_id>/chat/<int:current_page>/<int:page_size>', views.ChatView.Page.as_view()),
|
||||||
path('application/<str:application_id>/chat/<chat_id>', views.ChatView.Operate.as_view()),
|
path('application/<str:application_id>/chat/<chat_id>', views.ChatView.Operate.as_view()),
|
||||||
path('application/<str:application_id>/chat/<chat_id>/chat_record/', views.ChatView.ChatRecord.as_view()),
|
path('application/<str:application_id>/chat/<chat_id>/chat_record/', views.ChatView.ChatRecord.as_view()),
|
||||||
|
path('application/<str:application_id>/chat/<chat_id>/upload_file', views.ChatView.UploadFile.as_view()),
|
||||||
path('application/<str:application_id>/chat/<chat_id>/chat_record/<int:current_page>/<int:page_size>',
|
path('application/<str:application_id>/chat/<chat_id>/chat_record/<int:current_page>/<int:page_size>',
|
||||||
views.ChatView.ChatRecord.Page.as_view()),
|
views.ChatView.ChatRecord.Page.as_view()),
|
||||||
path('application/<str:application_id>/chat/<chat_id>/chat_record/<chat_record_id>',
|
path('application/<str:application_id>/chat/<chat_id>/chat_record/<chat_record_id>',
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from common.constants.permission_constants import Permission, Group, Operate, \
|
|||||||
RoleConstants, ViewPermission, CompareConstants
|
RoleConstants, ViewPermission, CompareConstants
|
||||||
from common.response import result
|
from common.response import result
|
||||||
from common.util.common import query_params_to_single_dict
|
from common.util.common import query_params_to_single_dict
|
||||||
|
from dataset.serializers.file_serializers import FileSerializer
|
||||||
|
|
||||||
|
|
||||||
class Openai(APIView):
|
class Openai(APIView):
|
||||||
@ -128,6 +129,7 @@ class ChatView(APIView):
|
|||||||
'client_id': request.auth.client_id,
|
'client_id': request.auth.client_id,
|
||||||
'form_data': (request.data.get(
|
'form_data': (request.data.get(
|
||||||
'form_data') if 'form_data' in request.data else {}),
|
'form_data') if 'form_data' in request.data else {}),
|
||||||
|
'image_list': request.data.get('image_list') if 'image_list' in request.data else [],
|
||||||
'client_type': request.auth.client_type}).chat()
|
'client_type': request.auth.client_type}).chat()
|
||||||
|
|
||||||
@action(methods=['GET'], detail=False)
|
@action(methods=['GET'], detail=False)
|
||||||
@ -391,3 +393,28 @@ class ChatView(APIView):
|
|||||||
data={'chat_id': chat_id, 'chat_record_id': chat_record_id,
|
data={'chat_id': chat_id, 'chat_record_id': chat_record_id,
|
||||||
'dataset_id': dataset_id, 'document_id': document_id,
|
'dataset_id': dataset_id, 'document_id': document_id,
|
||||||
'paragraph_id': paragraph_id}).delete())
|
'paragraph_id': paragraph_id}).delete())
|
||||||
|
|
||||||
|
class UploadFile(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@action(methods=['POST'], detail=False)
|
||||||
|
@swagger_auto_schema(operation_summary="上传文件",
|
||||||
|
operation_id="上传文件",
|
||||||
|
manual_parameters=ChatRecordApi.get_request_params_api(),
|
||||||
|
tags=["应用/对话日志"]
|
||||||
|
)
|
||||||
|
@has_permissions(
|
||||||
|
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY,
|
||||||
|
RoleConstants.APPLICATION_ACCESS_TOKEN],
|
||||||
|
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||||
|
dynamic_tag=keywords.get('application_id'))])
|
||||||
|
)
|
||||||
|
def post(self, request: Request, application_id: str, chat_id: str):
|
||||||
|
files = request.FILES.getlist('file')
|
||||||
|
file_ids = []
|
||||||
|
meta = {'application_id': application_id, 'chat_id': chat_id}
|
||||||
|
for file in files:
|
||||||
|
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||||
|
file_ids.append({'name': file.name, 'url': file_url, 'file_id': file_url.split('/')[-1]})
|
||||||
|
return result.success(file_ids)
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
@date:2024/3/14 11:54
|
@date:2024/3/14 11:54
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
|
from .clean_orphaned_file_job import *
|
||||||
from .client_access_num_job import *
|
from .client_access_num_job import *
|
||||||
from .clean_chat_job import *
|
from .clean_chat_job import *
|
||||||
|
|
||||||
@ -13,3 +14,4 @@ from .clean_chat_job import *
|
|||||||
def run():
|
def run():
|
||||||
client_access_num_job.run()
|
client_access_num_job.run()
|
||||||
clean_chat_job.run()
|
clean_chat_job.run()
|
||||||
|
clean_orphaned_file_job.run()
|
||||||
|
|||||||
40
apps/common/job/clean_orphaned_file_job.py
Normal file
40
apps/common/job/clean_orphaned_file_job.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
|
from django.db.models import Q
|
||||||
|
from django_apscheduler.jobstores import DjangoJobStore
|
||||||
|
|
||||||
|
from application.models import Chat
|
||||||
|
from common.lock.impl.file_lock import FileLock
|
||||||
|
from dataset.models import File
|
||||||
|
|
||||||
|
scheduler = BackgroundScheduler()
|
||||||
|
scheduler.add_jobstore(DjangoJobStore(), "default")
|
||||||
|
lock = FileLock()
|
||||||
|
|
||||||
|
|
||||||
|
def clean_debug_file():
|
||||||
|
logging.getLogger("max_kb").info('开始清理没有关联会话的上传文件')
|
||||||
|
existing_chat_ids = set(Chat.objects.values_list('id', flat=True))
|
||||||
|
# UUID to str
|
||||||
|
existing_chat_ids = [str(chat_id) for chat_id in existing_chat_ids]
|
||||||
|
print(existing_chat_ids)
|
||||||
|
# 查找引用的不存在的 chat_id 并删除相关记录
|
||||||
|
deleted_count, _ = File.objects.filter(~Q(meta__chat_id__in=existing_chat_ids)).delete()
|
||||||
|
|
||||||
|
logging.getLogger("max_kb").info(f'结束清理没有关联会话的上传文件: {deleted_count}')
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
if lock.try_lock('clean_orphaned_file_job', 30 * 30):
|
||||||
|
try:
|
||||||
|
scheduler.start()
|
||||||
|
clean_orphaned_file = scheduler.get_job(job_id='clean_orphaned_file')
|
||||||
|
if clean_orphaned_file is not None:
|
||||||
|
clean_orphaned_file.remove()
|
||||||
|
scheduler.add_job(clean_debug_file, 'cron', hour='2', minute='0', second='0',
|
||||||
|
id='clean_orphaned_file')
|
||||||
|
finally:
|
||||||
|
lock.un_lock('clean_orphaned_file_job')
|
||||||
18
apps/dataset/migrations/0010_file_meta.py
Normal file
18
apps/dataset/migrations/0010_file_meta.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Generated by Django 4.2.15 on 2024-11-07 15:32
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('dataset', '0009_alter_document_status_alter_paragraph_status'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='file',
|
||||||
|
name='meta',
|
||||||
|
field=models.JSONField(default={}, verbose_name='文件关联数据'),
|
||||||
|
),
|
||||||
|
]
|
||||||
@ -141,6 +141,9 @@ class File(AppModelMixin):
|
|||||||
|
|
||||||
loid = models.IntegerField(verbose_name="loid")
|
loid = models.IntegerField(verbose_name="loid")
|
||||||
|
|
||||||
|
meta = models.JSONField(verbose_name="文件关联数据", default=dict)
|
||||||
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
db_table = "file"
|
db_table = "file"
|
||||||
|
|
||||||
@ -149,7 +152,6 @@ class File(AppModelMixin):
|
|||||||
):
|
):
|
||||||
result = select_one("SELECT lo_from_bytea(%s, %s::bytea) as loid", [0, bytea])
|
result = select_one("SELECT lo_from_bytea(%s, %s::bytea) as loid", [0, bytea])
|
||||||
self.loid = result['loid']
|
self.loid = result['loid']
|
||||||
self.file_name = 'speech.mp3'
|
|
||||||
super().save()
|
super().save()
|
||||||
|
|
||||||
def get_byte(self):
|
def get_byte(self):
|
||||||
|
|||||||
@ -56,12 +56,13 @@ mime_types = {"html": "text/html", "htm": "text/html", "shtml": "text/html", "cs
|
|||||||
|
|
||||||
class FileSerializer(serializers.Serializer):
|
class FileSerializer(serializers.Serializer):
|
||||||
file = UploadedFileField(required=True, error_messages=ErrMessage.image("文件"))
|
file = UploadedFileField(required=True, error_messages=ErrMessage.image("文件"))
|
||||||
|
meta = serializers.JSONField(required=False)
|
||||||
|
|
||||||
def upload(self, with_valid=True):
|
def upload(self, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
file_id = uuid.uuid1()
|
file_id = uuid.uuid1()
|
||||||
file = File(id=file_id, file_name=self.data.get('file').name)
|
file = File(id=file_id, file_name=self.data.get('file').name, meta=self.data.get('meta'))
|
||||||
file.save(self.data.get('file').read())
|
file.save(self.data.get('file').read())
|
||||||
return f'/api/file/{file_id}'
|
return f'/api/file/{file_id}'
|
||||||
|
|
||||||
|
|||||||
@ -1,14 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
from abc import abstractmethod
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class BaseImage(BaseModel):
|
|
||||||
@abstractmethod
|
|
||||||
def check_auth(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def image_understand(self, image_file, text):
|
|
||||||
pass
|
|
||||||
@ -1,6 +1,10 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
from common import forms
|
from common import forms
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
from common.forms import BaseForm
|
from common.forms import BaseForm
|
||||||
@ -25,7 +29,7 @@ class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
model.check_auth()
|
model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@ -1,12 +1,9 @@
|
|||||||
import base64
|
|
||||||
import os
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from openai import OpenAI
|
from langchain_openai.chat_models import ChatOpenAI
|
||||||
|
|
||||||
from common.config.tokenizer_manage_config import TokenizerManage
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
from setting.models_provider.impl.base_image import BaseImage
|
|
||||||
|
|
||||||
|
|
||||||
def custom_get_token_ids(text: str):
|
def custom_get_token_ids(text: str):
|
||||||
@ -14,66 +11,15 @@ def custom_get_token_ids(text: str):
|
|||||||
return tokenizer.encode(text)
|
return tokenizer.encode(text)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIImage(MaxKBBaseModel, BaseImage):
|
class OpenAIImage(MaxKBBaseModel, ChatOpenAI):
|
||||||
api_base: str
|
|
||||||
api_key: str
|
|
||||||
model: str
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.api_key = kwargs.get('api_key')
|
|
||||||
self.api_base = kwargs.get('api_base')
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
optional_params = {}
|
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
|
||||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
|
||||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
|
||||||
optional_params['temperature'] = model_kwargs['temperature']
|
|
||||||
return OpenAIImage(
|
return OpenAIImage(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
api_base=model_credential.get('api_base'),
|
openai_api_base=model_credential.get('api_base'),
|
||||||
api_key=model_credential.get('api_key'),
|
openai_api_key=model_credential.get('api_key'),
|
||||||
|
stream_options={"include_usage": True},
|
||||||
**optional_params,
|
**optional_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_auth(self):
|
|
||||||
client = OpenAI(
|
|
||||||
base_url=self.api_base,
|
|
||||||
api_key=self.api_key
|
|
||||||
)
|
|
||||||
response_list = client.models.with_raw_response.list()
|
|
||||||
# print(response_list)
|
|
||||||
# cwd = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
# with open(f'{cwd}/img_1.png', 'rb') as f:
|
|
||||||
# self.image_understand(f, "一句话概述这个图片")
|
|
||||||
|
|
||||||
def image_understand(self, image_file, text):
|
|
||||||
client = OpenAI(
|
|
||||||
base_url=self.api_base,
|
|
||||||
api_key=self.api_key
|
|
||||||
)
|
|
||||||
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": text,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
return response.choices[0].message.content
|
|
||||||
|
|||||||
@ -0,0 +1,69 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/7/11 18:41
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm, TooltipLabel
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class QwenModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||||
|
required=True, default_value=1.0,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.9,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
|
||||||
|
required=True, default_value=800,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenVLModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
for key in ['api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
return QwenModelParams()
|
||||||
@ -0,0 +1,22 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
|
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class QwenVLChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||||
|
chat_tong_yi = QwenVLChatModel(
|
||||||
|
model=model_name,
|
||||||
|
openai_api_key=model_credential.get('api_key'),
|
||||||
|
openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
model_kwargs=optional_params,
|
||||||
|
)
|
||||||
|
return chat_tong_yi
|
||||||
@ -11,21 +11,33 @@ import os
|
|||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||||
ModelInfoManage
|
ModelInfoManage
|
||||||
|
from setting.models_provider.impl.qwen_model_provider.credential.image import QwenVLModelCredential
|
||||||
from setting.models_provider.impl.qwen_model_provider.credential.llm import OpenAILLMModelCredential
|
from setting.models_provider.impl.qwen_model_provider.credential.llm import OpenAILLMModelCredential
|
||||||
|
from setting.models_provider.impl.qwen_model_provider.model.image import QwenVLChatModel
|
||||||
|
|
||||||
from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel
|
from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
qwen_model_credential = OpenAILLMModelCredential()
|
qwen_model_credential = OpenAILLMModelCredential()
|
||||||
|
qwenvl_model_credential = QwenVLModelCredential()
|
||||||
|
|
||||||
module_info_list = [
|
module_info_list = [
|
||||||
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
|
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
|
||||||
ModelInfo('qwen-plus', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
|
ModelInfo('qwen-plus', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
|
||||||
ModelInfo('qwen-max', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)
|
ModelInfo('qwen-max', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)
|
||||||
]
|
]
|
||||||
|
module_info_vl_list = [
|
||||||
|
ModelInfo('qwen-vl-max', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
|
||||||
|
ModelInfo('qwen-vl-max-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
|
||||||
|
ModelInfo('qwen-vl-plus-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
|
||||||
|
]
|
||||||
|
|
||||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(module_info_list).append_default_model_info(
|
model_info_manage = (ModelInfoManage.builder()
|
||||||
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)).build()
|
.append_model_info_list(module_info_list)
|
||||||
|
.append_default_model_info(
|
||||||
|
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel))
|
||||||
|
.append_model_info_list(module_info_vl_list)
|
||||||
|
.build())
|
||||||
|
|
||||||
|
|
||||||
class QwenModelProvider(IModelProvider):
|
class QwenModelProvider(IModelProvider):
|
||||||
|
|||||||
@ -0,0 +1,69 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/7/11 18:41
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm, TooltipLabel
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class QwenModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||||
|
required=True, default_value=1.0,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.9,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
|
||||||
|
required=True, default_value=800,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
|
class TencentVisionModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
for key in ['api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
return QwenModelParams()
|
||||||
@ -0,0 +1,25 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_openai.chat_models import ChatOpenAI
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
def custom_get_token_ids(text: str):
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return tokenizer.encode(text)
|
||||||
|
|
||||||
|
|
||||||
|
class TencentVision(MaxKBBaseModel, ChatOpenAI):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||||
|
return TencentVision(
|
||||||
|
model=model_name,
|
||||||
|
openai_api_base='https://api.hunyuan.cloud.tencent.com/v1',
|
||||||
|
openai_api_key=model_credential.get('api_key'),
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
**optional_params,
|
||||||
|
)
|
||||||
@ -7,8 +7,10 @@ from setting.models_provider.base_model_provider import (
|
|||||||
IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
|
IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
|
||||||
)
|
)
|
||||||
from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
|
from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
|
||||||
|
from setting.models_provider.impl.tencent_model_provider.credential.image import TencentVisionModelCredential
|
||||||
from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
|
from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
|
||||||
from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
|
from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
|
||||||
|
from setting.models_provider.impl.tencent_model_provider.model.image import TencentVision
|
||||||
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
|
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
@ -78,9 +80,18 @@ def _initialize_model_info():
|
|||||||
|
|
||||||
model_info_embedding_list = [tencent_embedding_model_info]
|
model_info_embedding_list = [tencent_embedding_model_info]
|
||||||
|
|
||||||
|
model_info_vision_list = [_create_model_info(
|
||||||
|
'hunyuan-vision',
|
||||||
|
'混元视觉模型',
|
||||||
|
ModelTypeConst.IMAGE,
|
||||||
|
TencentVisionModelCredential,
|
||||||
|
TencentVision)]
|
||||||
|
|
||||||
|
|
||||||
model_info_manage = ModelInfoManage.builder() \
|
model_info_manage = ModelInfoManage.builder() \
|
||||||
.append_model_info_list(model_info_list) \
|
.append_model_info_list(model_info_list) \
|
||||||
.append_model_info_list(model_info_embedding_list) \
|
.append_model_info_list(model_info_embedding_list) \
|
||||||
|
.append_model_info_list(model_info_vision_list) \
|
||||||
.append_default_model_info(model_info_list[0]) \
|
.append_default_model_info(model_info_list[0]) \
|
||||||
.build()
|
.build()
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,15 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
from common import forms
|
from common import forms
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
from common.forms import BaseForm
|
from common.forms import BaseForm
|
||||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
from setting.models_provider.impl.xf_model_provider.model.image import ImageMessage
|
||||||
|
|
||||||
|
|
||||||
class XunFeiImageModelCredential(BaseForm, BaseModelCredential):
|
class XunFeiImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
@ -28,7 +32,10 @@ class XunFeiImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
model.check_auth()
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
with open(f'{cwd}/img_1.png', 'rb') as f:
|
||||||
|
message_list = [ImageMessage(str(base64.b64encode(f.read()), 'utf-8')), HumanMessage('请概述这张图片')]
|
||||||
|
model.stream(message_list)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 354 KiB After Width: | Height: | Size: 354 KiB |
@ -1,50 +1,58 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
import base64
|
||||||
import datetime
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import ssl
|
from typing import Dict, Any, List, Optional, Iterator
|
||||||
from datetime import datetime, UTC
|
|
||||||
from typing import Dict
|
|
||||||
from urllib.parse import urlencode
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import websockets
|
from docutils.utils import SystemMessage
|
||||||
|
from langchain_community.chat_models.sparkllm import ChatSparkLLM, _convert_delta_to_message_chunk
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.messages import BaseMessage, ChatMessage, HumanMessage, AIMessage, AIMessageChunk
|
||||||
|
from langchain_core.outputs import ChatGenerationChunk
|
||||||
|
|
||||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
from setting.models_provider.impl.base_image import BaseImage
|
|
||||||
|
|
||||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
||||||
ssl_context.check_hostname = False
|
|
||||||
ssl_context.verify_mode = ssl.CERT_NONE
|
|
||||||
|
|
||||||
|
|
||||||
class XFSparkImage(MaxKBBaseModel, BaseImage):
|
class ImageMessage(HumanMessage):
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
message_dict: Dict[str, Any]
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
elif isinstance(message, ImageMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content, "content_type": "image"}
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
if "function_call" in message.additional_kwargs:
|
||||||
|
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||||
|
# If function call only, content is None not empty string
|
||||||
|
if message_dict["content"] == "":
|
||||||
|
message_dict["content"] = None
|
||||||
|
if "tool_calls" in message.additional_kwargs:
|
||||||
|
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||||
|
# If tool calls only, content is None not empty string
|
||||||
|
if message_dict["content"] == "":
|
||||||
|
message_dict["content"] = None
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
|
||||||
|
class XFSparkImage(MaxKBBaseModel, ChatSparkLLM):
|
||||||
spark_app_id: str
|
spark_app_id: str
|
||||||
spark_api_key: str
|
spark_api_key: str
|
||||||
spark_api_secret: str
|
spark_api_secret: str
|
||||||
spark_api_url: str
|
spark_api_url: str
|
||||||
params: dict
|
|
||||||
|
|
||||||
# 初始化
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.spark_api_url = kwargs.get('spark_api_url')
|
|
||||||
self.spark_app_id = kwargs.get('spark_app_id')
|
|
||||||
self.spark_api_key = kwargs.get('spark_api_key')
|
|
||||||
self.spark_api_secret = kwargs.get('spark_api_secret')
|
|
||||||
self.params = kwargs.get('params')
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
optional_params = {'params': {}}
|
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||||
for key, value in model_kwargs.items():
|
|
||||||
if key not in ['model_id', 'use_local', 'streaming']:
|
|
||||||
optional_params['params'][key] = value
|
|
||||||
return XFSparkImage(
|
return XFSparkImage(
|
||||||
spark_app_id=model_credential.get('spark_app_id'),
|
spark_app_id=model_credential.get('spark_app_id'),
|
||||||
spark_api_key=model_credential.get('spark_api_key'),
|
spark_api_key=model_credential.get('spark_api_key'),
|
||||||
@ -53,118 +61,36 @@ class XFSparkImage(MaxKBBaseModel, BaseImage):
|
|||||||
**optional_params
|
**optional_params
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_url(self):
|
|
||||||
url = self.spark_api_url
|
|
||||||
host = urlparse(url).hostname
|
|
||||||
# 生成RFC1123格式的时间戳
|
|
||||||
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
|
|
||||||
date = datetime.now(UTC).strftime(gmt_format)
|
|
||||||
|
|
||||||
# 拼接字符串
|
|
||||||
signature_origin = "host: " + host + "\n"
|
|
||||||
signature_origin += "date: " + date + "\n"
|
|
||||||
signature_origin += "GET " + "/v2.1/image " + "HTTP/1.1"
|
|
||||||
# 进行hmac-sha256进行加密
|
|
||||||
signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
|
||||||
digestmod=hashlib.sha256).digest()
|
|
||||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
|
||||||
|
|
||||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
|
||||||
self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
|
|
||||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
|
||||||
# 将请求的鉴权参数组合为字典
|
|
||||||
v = {
|
|
||||||
"authorization": authorization,
|
|
||||||
"date": date,
|
|
||||||
"host": host
|
|
||||||
}
|
|
||||||
# 拼接鉴权参数,生成url
|
|
||||||
url = url + '?' + urlencode(v)
|
|
||||||
# print("date: ",date)
|
|
||||||
# print("v: ",v)
|
|
||||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
|
||||||
# print('websocket url :', url)
|
|
||||||
return url
|
|
||||||
|
|
||||||
def check_auth(self):
|
|
||||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
with open(f'{cwd}/img_1.png', 'rb') as f:
|
|
||||||
self.image_understand(f,"一句话概述这个图片")
|
|
||||||
|
|
||||||
def image_understand(self, image_file, question):
|
|
||||||
async def handle():
|
|
||||||
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
|
|
||||||
# 发送 full client request
|
|
||||||
await self.send(ws, image_file, question)
|
|
||||||
return await self.handle_message(ws)
|
|
||||||
|
|
||||||
return asyncio.run(handle())
|
|
||||||
|
|
||||||
# 收到websocket消息的处理
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def handle_message(ws):
|
def generate_message(prompt: str, image) -> list[BaseMessage]:
|
||||||
# print(message)
|
if image is None:
|
||||||
answer = ''
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||||
while True:
|
with open(f'{cwd}/img_1.png', 'rb') as f:
|
||||||
res = await ws.recv()
|
base64_image = base64.b64encode(f.read()).decode("utf-8")
|
||||||
data = json.loads(res)
|
return [ImageMessage(f'data:image/jpeg;base64,{base64_image}'), HumanMessage(prompt)]
|
||||||
code = data['header']['code']
|
return [HumanMessage(prompt)]
|
||||||
if code != 0:
|
|
||||||
return f'请求错误: {code}, {data}'
|
|
||||||
else:
|
|
||||||
choices = data["payload"]["choices"]
|
|
||||||
status = choices["status"]
|
|
||||||
content = choices["text"][0]["content"]
|
|
||||||
# print(content, end="")
|
|
||||||
answer += content
|
|
||||||
# print(1)
|
|
||||||
if status == 2:
|
|
||||||
break
|
|
||||||
return answer
|
|
||||||
|
|
||||||
async def send(self, ws, image_file, question):
|
def _stream(
|
||||||
text = [
|
self,
|
||||||
{"role": "user", "content": str(base64.b64encode(image_file.read()), 'utf-8'), "content_type": "image"},
|
messages: List[BaseMessage],
|
||||||
{"role": "user", "content": question}
|
stop: Optional[List[str]] = None,
|
||||||
]
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
|
||||||
data = {
|
self.client.arun(
|
||||||
"header": {
|
[convert_message_to_dict(m) for m in messages],
|
||||||
"app_id": self.spark_app_id
|
self.spark_user_id,
|
||||||
},
|
self.model_kwargs,
|
||||||
"parameter": {
|
streaming=True,
|
||||||
"chat": {
|
)
|
||||||
"domain": "image",
|
for content in self.client.subscribe(timeout=self.request_timeout):
|
||||||
"temperature": 0.5,
|
if "data" not in content:
|
||||||
"top_k": 4,
|
continue
|
||||||
"max_tokens": 2028,
|
delta = content["data"]
|
||||||
"auditing": "default"
|
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||||
}
|
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||||
},
|
if run_manager:
|
||||||
"payload": {
|
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
|
||||||
"message": {
|
yield cg_chunk
|
||||||
"text": text
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
d = json.dumps(data)
|
|
||||||
await ws.send(d)
|
|
||||||
|
|
||||||
def is_cache_model(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_len(text):
|
|
||||||
length = 0
|
|
||||||
for content in text:
|
|
||||||
temp = content["content"]
|
|
||||||
leng = len(temp)
|
|
||||||
length += leng
|
|
||||||
return length
|
|
||||||
|
|
||||||
def check_len(self, text):
|
|
||||||
print("text-content-tokens:", self.get_len(text[1:]))
|
|
||||||
while (self.get_len(text[1:]) > 8000):
|
|
||||||
del text[1]
|
|
||||||
return text
|
|
||||||
|
|||||||
@ -37,7 +37,6 @@ model_info_list = [
|
|||||||
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||||
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
||||||
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
|
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
|
||||||
ModelInfo('image', '', ModelTypeConst.IMAGE, image_model_credential, XFSparkImage),
|
|
||||||
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
|
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -286,6 +286,14 @@ const getApplicationTTSModel: (
|
|||||||
return get(`${prefix}/${application_id}/model`, { model_type: 'TTS' }, loading)
|
return get(`${prefix}/${application_id}/model`, { model_type: 'TTS' }, loading)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const getApplicationImageModel: (
|
||||||
|
application_id: string,
|
||||||
|
loading?: Ref<boolean>
|
||||||
|
) => Promise<Result<Array<any>>> = (application_id, loading) => {
|
||||||
|
return get(`${prefix}/${application_id}/model`, { model_type: 'IMAGE' }, loading)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 发布应用
|
* 发布应用
|
||||||
* @param 参数
|
* @param 参数
|
||||||
@ -350,6 +358,19 @@ const getModelParamsForm: (
|
|||||||
return get(`${prefix}/${application_id}/model_params_form/${model_id}`, undefined, loading)
|
return get(`${prefix}/${application_id}/model_params_form/${model_id}`, undefined, loading)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 上传文档图片附件
|
||||||
|
*/
|
||||||
|
const uploadFile: (
|
||||||
|
application_id: String,
|
||||||
|
chat_id: String,
|
||||||
|
data: any,
|
||||||
|
loading?: Ref<boolean>
|
||||||
|
) => Promise<Result<any>> = (application_id, chat_id, data, loading) => {
|
||||||
|
return post(`${prefix}/${application_id}/chat/${chat_id}/upload_file`, data, undefined, loading)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 语音转文本
|
* 语音转文本
|
||||||
*/
|
*/
|
||||||
@ -501,6 +522,7 @@ export default {
|
|||||||
getApplicationRerankerModel,
|
getApplicationRerankerModel,
|
||||||
getApplicationSTTModel,
|
getApplicationSTTModel,
|
||||||
getApplicationTTSModel,
|
getApplicationTTSModel,
|
||||||
|
getApplicationImageModel,
|
||||||
postSpeechToText,
|
postSpeechToText,
|
||||||
postTextToSpeech,
|
postTextToSpeech,
|
||||||
getPlatformStatus,
|
getPlatformStatus,
|
||||||
@ -513,5 +535,6 @@ export default {
|
|||||||
putWorkFlowVersion,
|
putWorkFlowVersion,
|
||||||
playDemoText,
|
playDemoText,
|
||||||
getUserList,
|
getUserList,
|
||||||
getApplicationList
|
getApplicationList,
|
||||||
|
uploadFile
|
||||||
}
|
}
|
||||||
|
|||||||
@ -54,6 +54,18 @@
|
|||||||
<div v-for="(f, i) in item.global_fields" :key="i">
|
<div v-for="(f, i) in item.global_fields" :key="i">
|
||||||
{{ f.label }}: {{ f.value }}
|
{{ f.label }}: {{ f.value }}
|
||||||
</div>
|
</div>
|
||||||
|
<div v-if="item.document_list?.length > 0">
|
||||||
|
上传的文档:
|
||||||
|
<div v-for="(f, i) in item.document_list" :key="i">
|
||||||
|
{{ f.name }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div v-if="item.image_list?.length > 0">
|
||||||
|
上传的图片:
|
||||||
|
<div v-for="(f, i) in item.image_list" :key="i">
|
||||||
|
{{ f.name }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
@ -96,7 +108,8 @@
|
|||||||
v-if="
|
v-if="
|
||||||
item.type == WorkflowType.AiChat ||
|
item.type == WorkflowType.AiChat ||
|
||||||
item.type == WorkflowType.Question ||
|
item.type == WorkflowType.Question ||
|
||||||
item.type == WorkflowType.Application
|
item.type == WorkflowType.Application ||
|
||||||
|
item.type == WorkflowType.ImageUnderstandNode
|
||||||
"
|
"
|
||||||
>
|
>
|
||||||
<div
|
<div
|
||||||
|
|||||||
@ -192,6 +192,19 @@
|
|||||||
/>
|
/>
|
||||||
|
|
||||||
<div class="operate flex align-center">
|
<div class="operate flex align-center">
|
||||||
|
<span v-if="props.data.file_upload_enable" class="flex align-center">
|
||||||
|
<el-upload
|
||||||
|
action="#"
|
||||||
|
:auto-upload="false"
|
||||||
|
:show-file-list="false"
|
||||||
|
:on-change="(file: any, fileList: any) => uploadFile(file, fileList)"
|
||||||
|
>
|
||||||
|
<el-button text>
|
||||||
|
<el-icon><Paperclip /></el-icon>
|
||||||
|
</el-button>
|
||||||
|
</el-upload>
|
||||||
|
<el-divider direction="vertical" />
|
||||||
|
</span>
|
||||||
<span v-if="props.data.stt_model_enable" class="flex align-center">
|
<span v-if="props.data.stt_model_enable" class="flex align-center">
|
||||||
<el-button text v-if="mediaRecorderStatus" @click="startRecording">
|
<el-button text v-if="mediaRecorderStatus" @click="startRecording">
|
||||||
<el-icon>
|
<el-icon>
|
||||||
@ -790,7 +803,7 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean) {
|
|||||||
is_stop: false,
|
is_stop: false,
|
||||||
record_id: '',
|
record_id: '',
|
||||||
vote_status: '-1',
|
vote_status: '-1',
|
||||||
status: undefined
|
status: undefined,
|
||||||
})
|
})
|
||||||
chatList.value.push(chat)
|
chatList.value.push(chat)
|
||||||
ChatManagement.addChatRecord(chat, 50, loading)
|
ChatManagement.addChatRecord(chat, 50, loading)
|
||||||
@ -809,7 +822,8 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean) {
|
|||||||
const obj = {
|
const obj = {
|
||||||
message: chat.problem_text,
|
message: chat.problem_text,
|
||||||
re_chat: re_chat || false,
|
re_chat: re_chat || false,
|
||||||
form_data: { ...form_data.value, ...api_form_data.value }
|
form_data: { ...form_data.value, ...api_form_data.value },
|
||||||
|
image_list: uploadFileList.value,
|
||||||
}
|
}
|
||||||
// 对话
|
// 对话
|
||||||
applicationApi
|
applicationApi
|
||||||
@ -832,6 +846,7 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean) {
|
|||||||
nextTick(() => {
|
nextTick(() => {
|
||||||
// 将滚动条滚动到最下面
|
// 将滚动条滚动到最下面
|
||||||
scrollDiv.value.setScrollTop(getMaxHeight())
|
scrollDiv.value.setScrollTop(getMaxHeight())
|
||||||
|
uploadFileList.value = []
|
||||||
})
|
})
|
||||||
const reader = response.body.getReader()
|
const reader = response.body.getReader()
|
||||||
// 处理流数据
|
// 处理流数据
|
||||||
@ -916,6 +931,42 @@ const handleScroll = () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 保存上传文件列表
|
||||||
|
const uploadFileList = ref<any>([])
|
||||||
|
const uploadFile = async (file: any, fileList: any) => {
|
||||||
|
const { maxFiles, fileLimit } = props.data.file_upload_setting
|
||||||
|
if (fileList.length > maxFiles) {
|
||||||
|
MsgWarning('最多上传' + maxFiles + '个文件')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (fileList.filter((f: any) => f.size > fileLimit * 1024 * 1024).length > 0) { // MB
|
||||||
|
MsgWarning('单个文件大小不能超过' + fileLimit + 'MB')
|
||||||
|
fileList.splice(0, fileList.length)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
const formData = new FormData()
|
||||||
|
for (const file of fileList) {
|
||||||
|
formData.append('file', file.raw, file.name)
|
||||||
|
uploadFileList.value.push(file)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (props.chatId === 'new' || !chartOpenId.value) {
|
||||||
|
const res = await applicationApi.getChatOpen(props.data.id as string)
|
||||||
|
chartOpenId.value = res.data
|
||||||
|
}
|
||||||
|
applicationApi.uploadFile(props.data.id as string, chartOpenId.value, formData, loading).then((response) => {
|
||||||
|
fileList.splice(0, fileList.length)
|
||||||
|
uploadFileList.value.forEach((file: any) => {
|
||||||
|
const f = response.data.filter((f: any) => f.name === file.name)
|
||||||
|
if (f.length > 0) {
|
||||||
|
file.url = f[0].url
|
||||||
|
file.file_id = f[0].file_id
|
||||||
|
}
|
||||||
|
})
|
||||||
|
console.log(uploadFileList.value)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// 定义响应式引用
|
// 定义响应式引用
|
||||||
const mediaRecorder = ref<any>(null)
|
const mediaRecorder = ref<any>(null)
|
||||||
|
|
||||||
|
|||||||
@ -9,5 +9,7 @@ export enum WorkflowType {
|
|||||||
FunctionLib = 'function-lib-node',
|
FunctionLib = 'function-lib-node',
|
||||||
FunctionLibCustom = 'function-node',
|
FunctionLibCustom = 'function-node',
|
||||||
RrerankerNode = 'reranker-node',
|
RrerankerNode = 'reranker-node',
|
||||||
Application = 'application-node'
|
Application = 'application-node',
|
||||||
|
DocumentExtractNode = 'document-extract-node',
|
||||||
|
ImageUnderstandNode = 'image-understand-node',
|
||||||
}
|
}
|
||||||
|
|||||||
@ -168,13 +168,49 @@ export const rerankerNode = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
export const documentExtractNode = {
|
||||||
|
type: WorkflowType.DocumentExtractNode,
|
||||||
|
text: '提取文档中的内容',
|
||||||
|
label: '文档内容提取',
|
||||||
|
height: 252,
|
||||||
|
properties: {
|
||||||
|
stepName: '文档内容提取',
|
||||||
|
config: {
|
||||||
|
fields: [
|
||||||
|
{
|
||||||
|
label: '文件内容',
|
||||||
|
value: 'content'
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
export const imageUnderstandNode = {
|
||||||
|
type: WorkflowType.ImageUnderstandNode,
|
||||||
|
text: '识别出图片中的对象、场景等信息回答用户问题',
|
||||||
|
label: '图片理解',
|
||||||
|
height: 252,
|
||||||
|
properties: {
|
||||||
|
stepName: '图片理解',
|
||||||
|
config: {
|
||||||
|
fields: [
|
||||||
|
{
|
||||||
|
label: 'AI 回答内容',
|
||||||
|
value: 'content'
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
export const menuNodes = [
|
export const menuNodes = [
|
||||||
aiChatNode,
|
aiChatNode,
|
||||||
searchDatasetNode,
|
searchDatasetNode,
|
||||||
questionNode,
|
questionNode,
|
||||||
conditionNode,
|
conditionNode,
|
||||||
replyNode,
|
replyNode,
|
||||||
rerankerNode
|
rerankerNode,
|
||||||
|
documentExtractNode,
|
||||||
|
imageUnderstandNode
|
||||||
]
|
]
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -261,7 +297,9 @@ export const nodeDict: any = {
|
|||||||
[WorkflowType.FunctionLib]: functionLibNode,
|
[WorkflowType.FunctionLib]: functionLibNode,
|
||||||
[WorkflowType.FunctionLibCustom]: functionNode,
|
[WorkflowType.FunctionLibCustom]: functionNode,
|
||||||
[WorkflowType.RrerankerNode]: rerankerNode,
|
[WorkflowType.RrerankerNode]: rerankerNode,
|
||||||
[WorkflowType.Application]: applicationNode
|
[WorkflowType.Application]: applicationNode,
|
||||||
|
[WorkflowType.DocumentExtractNode]: documentExtractNode,
|
||||||
|
[WorkflowType.ImageUnderstandNode]: imageUnderstandNode,
|
||||||
}
|
}
|
||||||
export function isWorkFlow(type: string | undefined) {
|
export function isWorkFlow(type: string | undefined) {
|
||||||
return type === 'WORK_FLOW'
|
return type === 'WORK_FLOW'
|
||||||
|
|||||||
6
ui/src/workflow/icons/document-extract-node-icon.vue
Normal file
6
ui/src/workflow/icons/document-extract-node-icon.vue
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<template>
|
||||||
|
<AppAvatar shape="square" style="background: #7F3BF5">
|
||||||
|
<img src="@/assets/icon_document.svg" style="width: 65%" alt="" />
|
||||||
|
</AppAvatar>
|
||||||
|
</template>
|
||||||
|
<script setup lang="ts"></script>
|
||||||
6
ui/src/workflow/icons/image-understand-node-icon.vue
Normal file
6
ui/src/workflow/icons/image-understand-node-icon.vue
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<template>
|
||||||
|
<AppAvatar shape="square" style="background: #14C0FF;">
|
||||||
|
<img src="@/assets/icon_document.svg" style="width: 65%" alt="" />
|
||||||
|
</AppAvatar>
|
||||||
|
</template>
|
||||||
|
<script setup lang="ts"></script>
|
||||||
@ -0,0 +1,101 @@
|
|||||||
|
<template>
|
||||||
|
<el-dialog
|
||||||
|
title="文件上传设置"
|
||||||
|
v-model="dialogVisible"
|
||||||
|
:close-on-click-modal="false"
|
||||||
|
:close-on-press-escape="false"
|
||||||
|
:destroy-on-close="true"
|
||||||
|
:before-close="close"
|
||||||
|
append-to-body
|
||||||
|
>
|
||||||
|
<el-form
|
||||||
|
label-position="top"
|
||||||
|
ref="fieldFormRef"
|
||||||
|
:model="form_data"
|
||||||
|
require-asterisk-position="right">
|
||||||
|
<el-form-item label="单次上传最多文件数">
|
||||||
|
<el-slider v-model="form_data.maxFiles" show-input :show-input-controls="false" :min="1" :max="10" />
|
||||||
|
</el-form-item>
|
||||||
|
<el-form-item label="每个文件最大(MB)">
|
||||||
|
<el-slider v-model="form_data.fileLimit" show-input :show-input-controls="false" :min="1" :max="100" />
|
||||||
|
</el-form-item>
|
||||||
|
<el-form-item label="上传的文件类型">
|
||||||
|
<el-card style="width: 100%" class="mb-8">
|
||||||
|
<div class="flex-between">
|
||||||
|
<p>
|
||||||
|
文档(TXT、MD、DOCX、HTML、CSV、XLSX、XLS、PDF)
|
||||||
|
需要与文档内容提取节点配合使用
|
||||||
|
</p>
|
||||||
|
<el-checkbox v-model="form_data.document" />
|
||||||
|
</div>
|
||||||
|
</el-card>
|
||||||
|
<el-card style="width: 100%" class="mb-8">
|
||||||
|
<div class="flex-between">
|
||||||
|
<p>
|
||||||
|
图片(JPG、JPEG、PNG、GIF)
|
||||||
|
所选模型需要支持接收图片
|
||||||
|
</p>
|
||||||
|
<el-checkbox v-model="form_data.image" />
|
||||||
|
</div>
|
||||||
|
</el-card>
|
||||||
|
</el-form-item>
|
||||||
|
</el-form>
|
||||||
|
<template #footer>
|
||||||
|
<span class="dialog-footer">
|
||||||
|
<el-button @click.prevent="close"> 取消 </el-button>
|
||||||
|
<el-button type="primary" @click="submit()" :loading="loading">
|
||||||
|
确定
|
||||||
|
</el-button>
|
||||||
|
</span>
|
||||||
|
</template>
|
||||||
|
</el-dialog>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { nextTick, ref } from 'vue'
|
||||||
|
|
||||||
|
const emit = defineEmits(['refresh'])
|
||||||
|
const props = defineProps<{ nodeModel: any }>()
|
||||||
|
|
||||||
|
const dialogVisible = ref(false)
|
||||||
|
const loading = ref(false)
|
||||||
|
const fieldFormRef = ref()
|
||||||
|
const form_data = ref({
|
||||||
|
maxFiles: 3,
|
||||||
|
fileLimit: 50,
|
||||||
|
document: true,
|
||||||
|
image: false,
|
||||||
|
audio: false,
|
||||||
|
video: false
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
function open(data: any) {
|
||||||
|
dialogVisible.value = true
|
||||||
|
nextTick(() => {
|
||||||
|
form_data.value = { ...form_data.value, ...data }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
function close() {
|
||||||
|
dialogVisible.value = false
|
||||||
|
}
|
||||||
|
|
||||||
|
async function submit() {
|
||||||
|
const formEl = fieldFormRef.value
|
||||||
|
if (!formEl) return
|
||||||
|
await formEl.validate().then(() => {
|
||||||
|
emit('refresh', form_data.value)
|
||||||
|
props.nodeModel.graphModel.eventCenter.emit('refreshFileUploadConfig')
|
||||||
|
dialogVisible.value = false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
defineExpose({
|
||||||
|
open
|
||||||
|
})
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<style scoped lang="scss">
|
||||||
|
|
||||||
|
</style>
|
||||||
@ -45,6 +45,36 @@
|
|||||||
@submitDialog="submitDialog"
|
@submitDialog="submitDialog"
|
||||||
/>
|
/>
|
||||||
</el-form-item>
|
</el-form-item>
|
||||||
|
<el-form-item >
|
||||||
|
<template #label>
|
||||||
|
<div class="flex-between">
|
||||||
|
<div class="flex align-center">
|
||||||
|
<span class="mr-4">文件上传</span>
|
||||||
|
<el-tooltip
|
||||||
|
effect="dark"
|
||||||
|
content="开启后,问答页面会显示上传文件的按钮。"
|
||||||
|
placement="right"
|
||||||
|
>
|
||||||
|
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||||
|
</el-tooltip>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<el-button
|
||||||
|
v-if="form_data.file_upload_enable"
|
||||||
|
type="primary"
|
||||||
|
link
|
||||||
|
@click="openFileUploadSettingDialog"
|
||||||
|
class="mr-4"
|
||||||
|
>
|
||||||
|
<el-icon class="mr-4">
|
||||||
|
<Setting />
|
||||||
|
</el-icon>
|
||||||
|
</el-button>
|
||||||
|
<el-switch size="small" v-model="form_data.file_upload_enable"/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
</el-form-item>
|
||||||
<UserInputFieldTable ref="UserInputFieldTableFef" :node-model="nodeModel" />
|
<UserInputFieldTable ref="UserInputFieldTableFef" :node-model="nodeModel" />
|
||||||
<ApiInputFieldTable ref="ApiInputFieldTableFef" :node-model="nodeModel" />
|
<ApiInputFieldTable ref="ApiInputFieldTableFef" :node-model="nodeModel" />
|
||||||
<el-form-item>
|
<el-form-item>
|
||||||
@ -139,7 +169,6 @@
|
|||||||
<el-icon class="mr-4">
|
<el-icon class="mr-4">
|
||||||
<Setting />
|
<Setting />
|
||||||
</el-icon>
|
</el-icon>
|
||||||
设置
|
|
||||||
</el-button>
|
</el-button>
|
||||||
<el-switch size="small" v-model="form_data.tts_model_enable" @change="ttsModelEnableChange"/>
|
<el-switch size="small" v-model="form_data.tts_model_enable" @change="ttsModelEnableChange"/>
|
||||||
</div>
|
</div>
|
||||||
@ -212,6 +241,7 @@
|
|||||||
</el-form-item>
|
</el-form-item>
|
||||||
</el-form>
|
</el-form>
|
||||||
<TTSModeParamSettingDialog ref="TTSModeParamSettingDialogRef" @refresh="refreshTTSForm" />
|
<TTSModeParamSettingDialog ref="TTSModeParamSettingDialogRef" @refresh="refreshTTSForm" />
|
||||||
|
<FileUploadSettingDialog ref="FileUploadSettingDialogRef" :node-model="nodeModel" @refresh="refreshFileUploadForm"/>
|
||||||
</NodeContainer>
|
</NodeContainer>
|
||||||
</template>
|
</template>
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
@ -229,6 +259,7 @@ import { t } from '@/locales'
|
|||||||
import TTSModeParamSettingDialog from '@/views/application/component/TTSModeParamSettingDialog.vue'
|
import TTSModeParamSettingDialog from '@/views/application/component/TTSModeParamSettingDialog.vue'
|
||||||
import ApiInputFieldTable from './component/ApiInputFieldTable.vue'
|
import ApiInputFieldTable from './component/ApiInputFieldTable.vue'
|
||||||
import UserInputFieldTable from './component/UserInputFieldTable.vue'
|
import UserInputFieldTable from './component/UserInputFieldTable.vue'
|
||||||
|
import FileUploadSettingDialog from '@/workflow/nodes/base-node/component/FileUploadSettingDialog.vue'
|
||||||
|
|
||||||
const { model } = useStore()
|
const { model } = useStore()
|
||||||
|
|
||||||
@ -244,6 +275,7 @@ const providerOptions = ref<Array<Provider>>([])
|
|||||||
const TTSModeParamSettingDialogRef = ref<InstanceType<typeof TTSModeParamSettingDialog>>()
|
const TTSModeParamSettingDialogRef = ref<InstanceType<typeof TTSModeParamSettingDialog>>()
|
||||||
const UserInputFieldTableFef = ref()
|
const UserInputFieldTableFef = ref()
|
||||||
const ApiInputFieldTableFef = ref()
|
const ApiInputFieldTableFef = ref()
|
||||||
|
const FileUploadSettingDialogRef = ref<InstanceType<typeof FileUploadSettingDialog>>()
|
||||||
|
|
||||||
const form = {
|
const form = {
|
||||||
name: '',
|
name: '',
|
||||||
@ -350,6 +382,14 @@ const refreshTTSForm = (data: any) => {
|
|||||||
form_data.value.tts_model_params_setting = data
|
form_data.value.tts_model_params_setting = data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const openFileUploadSettingDialog = () => {
|
||||||
|
FileUploadSettingDialogRef.value?.open(form_data.value.file_upload_setting)
|
||||||
|
}
|
||||||
|
|
||||||
|
const refreshFileUploadForm = (data: any) => {
|
||||||
|
form_data.value.file_upload_setting = data
|
||||||
|
}
|
||||||
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
set(props.nodeModel, 'validate', validate)
|
set(props.nodeModel, 'validate', validate)
|
||||||
if (!props.nodeModel.properties.node_data.tts_type) {
|
if (!props.nodeModel.properties.node_data.tts_type) {
|
||||||
|
|||||||
12
ui/src/workflow/nodes/document-extract-node/index.ts
Normal file
12
ui/src/workflow/nodes/document-extract-node/index.ts
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import DocumentExtractNodeVue from './index.vue'
|
||||||
|
import { AppNode, AppNodeModel } from '@/workflow/common/app-node'
|
||||||
|
class RerankerNode extends AppNode {
|
||||||
|
constructor(props: any) {
|
||||||
|
super(props, DocumentExtractNodeVue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
export default {
|
||||||
|
type: 'document-extract-node',
|
||||||
|
model: AppNodeModel,
|
||||||
|
view: RerankerNode
|
||||||
|
}
|
||||||
64
ui/src/workflow/nodes/document-extract-node/index.vue
Normal file
64
ui/src/workflow/nodes/document-extract-node/index.vue
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
<template>
|
||||||
|
<NodeContainer :nodeModel="nodeModel">
|
||||||
|
<h5 class="title-decoration-1 mb-8">节点设置</h5>
|
||||||
|
<el-card shadow="never" class="card-never">
|
||||||
|
<el-form
|
||||||
|
@submit.prevent
|
||||||
|
:model="form_data"
|
||||||
|
label-position="top"
|
||||||
|
require-asterisk-position="right"
|
||||||
|
label-width="auto"
|
||||||
|
ref="DatasetNodeFormRef"
|
||||||
|
>
|
||||||
|
<el-form-item label="选择文件" :rules="{
|
||||||
|
type: 'array',
|
||||||
|
required: true,
|
||||||
|
message: '请选择文件',
|
||||||
|
trigger: 'change'
|
||||||
|
}"
|
||||||
|
>
|
||||||
|
<NodeCascader
|
||||||
|
ref="nodeCascaderRef"
|
||||||
|
:nodeModel="nodeModel"
|
||||||
|
class="w-full"
|
||||||
|
placeholder="请选择文件"
|
||||||
|
v-model="form.file_list"
|
||||||
|
/>
|
||||||
|
</el-form-item>
|
||||||
|
</el-form>
|
||||||
|
</el-card>
|
||||||
|
</NodeContainer>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import NodeContainer from '@/workflow/common/NodeContainer.vue'
|
||||||
|
import { computed } from 'vue'
|
||||||
|
import { set } from 'lodash'
|
||||||
|
import NodeCascader from '@/workflow/common/NodeCascader.vue'
|
||||||
|
|
||||||
|
const props = defineProps<{ nodeModel: any }>()
|
||||||
|
|
||||||
|
const form = {
|
||||||
|
file_list: []
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
const form_data = computed({
|
||||||
|
get: () => {
|
||||||
|
if (props.nodeModel.properties.node_data) {
|
||||||
|
return props.nodeModel.properties.node_data
|
||||||
|
} else {
|
||||||
|
set(props.nodeModel.properties, 'node_data', form)
|
||||||
|
}
|
||||||
|
return props.nodeModel.properties.node_data
|
||||||
|
},
|
||||||
|
set: (value) => {
|
||||||
|
set(props.nodeModel.properties, 'node_data', value)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<style lang="scss" scoped>
|
||||||
|
|
||||||
|
</style>
|
||||||
14
ui/src/workflow/nodes/image-understand/index.ts
Normal file
14
ui/src/workflow/nodes/image-understand/index.ts
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import ImageUnderstandNodeVue from './index.vue'
|
||||||
|
import { AppNode, AppNodeModel } from '@/workflow/common/app-node'
|
||||||
|
|
||||||
|
class RerankerNode extends AppNode {
|
||||||
|
constructor(props: any) {
|
||||||
|
super(props, ImageUnderstandNodeVue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default {
|
||||||
|
type: 'image-understand-node',
|
||||||
|
model: AppNodeModel,
|
||||||
|
view: RerankerNode
|
||||||
|
}
|
||||||
277
ui/src/workflow/nodes/image-understand/index.vue
Normal file
277
ui/src/workflow/nodes/image-understand/index.vue
Normal file
@ -0,0 +1,277 @@
|
|||||||
|
<template>
|
||||||
|
<NodeContainer :node-model="nodeModel">
|
||||||
|
<h5 class="title-decoration-1 mb-8">节点设置</h5>
|
||||||
|
<el-card shadow="never" class="card-never">
|
||||||
|
<el-form
|
||||||
|
@submit.prevent
|
||||||
|
:model="form_data"
|
||||||
|
label-position="top"
|
||||||
|
require-asterisk-position="right"
|
||||||
|
class="mb-24"
|
||||||
|
label-width="auto"
|
||||||
|
ref="aiChatNodeFormRef"
|
||||||
|
hide-required-asterisk
|
||||||
|
>
|
||||||
|
<el-form-item
|
||||||
|
label="图片理解模型"
|
||||||
|
prop="model_id"
|
||||||
|
:rules="{
|
||||||
|
required: true,
|
||||||
|
message: '请选择图片理解模型',
|
||||||
|
trigger: 'change'
|
||||||
|
}"
|
||||||
|
>
|
||||||
|
<template #label>
|
||||||
|
<div class="flex-between w-full">
|
||||||
|
<div>
|
||||||
|
<span>图片理解模型<span class="danger">*</span></span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<el-select
|
||||||
|
@change="model_change"
|
||||||
|
@wheel="wheel"
|
||||||
|
:teleported="false"
|
||||||
|
v-model="form_data.model_id"
|
||||||
|
placeholder="请选择图片理解模型"
|
||||||
|
class="w-full"
|
||||||
|
popper-class="select-model"
|
||||||
|
:clearable="true"
|
||||||
|
>
|
||||||
|
<el-option-group
|
||||||
|
v-for="(value, label) in modelOptions"
|
||||||
|
:key="value"
|
||||||
|
:label="relatedObject(providerOptions, label, 'provider')?.name"
|
||||||
|
>
|
||||||
|
<el-option
|
||||||
|
v-for="item in value.filter((v: any) => v.status === 'SUCCESS')"
|
||||||
|
:key="item.id"
|
||||||
|
:label="item.name"
|
||||||
|
:value="item.id"
|
||||||
|
class="flex-between"
|
||||||
|
>
|
||||||
|
<div class="flex align-center">
|
||||||
|
<span
|
||||||
|
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
|
||||||
|
class="model-icon mr-8"
|
||||||
|
></span>
|
||||||
|
<span>{{ item.name }}</span>
|
||||||
|
<el-tag v-if="item.permission_type === 'PUBLIC'" type="info" class="info-tag ml-8"
|
||||||
|
>公用
|
||||||
|
</el-tag>
|
||||||
|
</div>
|
||||||
|
<el-icon class="check-icon" v-if="item.id === form_data.model_id">
|
||||||
|
<Check />
|
||||||
|
</el-icon>
|
||||||
|
</el-option>
|
||||||
|
<!-- 不可用 -->
|
||||||
|
<el-option
|
||||||
|
v-for="item in value.filter((v: any) => v.status !== 'SUCCESS')"
|
||||||
|
:key="item.id"
|
||||||
|
:label="item.name"
|
||||||
|
:value="item.id"
|
||||||
|
class="flex-between"
|
||||||
|
disabled
|
||||||
|
>
|
||||||
|
<div class="flex">
|
||||||
|
<span
|
||||||
|
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
|
||||||
|
class="model-icon mr-8"
|
||||||
|
></span>
|
||||||
|
<span>{{ item.name }}</span>
|
||||||
|
<span class="danger">(不可用)</span>
|
||||||
|
</div>
|
||||||
|
<el-icon class="check-icon" v-if="item.id === form_data.model_id">
|
||||||
|
<Check />
|
||||||
|
</el-icon>
|
||||||
|
</el-option>
|
||||||
|
</el-option-group>
|
||||||
|
</el-select>
|
||||||
|
</el-form-item>
|
||||||
|
|
||||||
|
<el-form-item label="角色设定">
|
||||||
|
<MdEditorMagnify
|
||||||
|
title="角色设定"
|
||||||
|
v-model="form_data.system"
|
||||||
|
style="height: 100px"
|
||||||
|
@submitDialog="submitSystemDialog"
|
||||||
|
placeholder="角色设定"
|
||||||
|
/>
|
||||||
|
</el-form-item>
|
||||||
|
<el-form-item
|
||||||
|
label="提示词"
|
||||||
|
prop="prompt"
|
||||||
|
:rules="{
|
||||||
|
required: true,
|
||||||
|
message: '请输入提示词',
|
||||||
|
trigger: 'blur'
|
||||||
|
}"
|
||||||
|
>
|
||||||
|
<template #label>
|
||||||
|
<div class="flex align-center">
|
||||||
|
<div class="mr-4">
|
||||||
|
<span>提示词<span class="danger">*</span></span>
|
||||||
|
</div>
|
||||||
|
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
|
||||||
|
<template #content
|
||||||
|
>通过调整提示词内容,可以引导大模型聊天方向,该提示词会被固定在上下文的开头,可以使用变量。
|
||||||
|
</template>
|
||||||
|
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||||
|
</el-tooltip>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<MdEditorMagnify
|
||||||
|
@wheel="wheel"
|
||||||
|
title="提示词"
|
||||||
|
v-model="form_data.prompt"
|
||||||
|
style="height: 150px"
|
||||||
|
@submitDialog="submitDialog"
|
||||||
|
/>
|
||||||
|
</el-form-item>
|
||||||
|
<el-form-item label="历史聊天记录">
|
||||||
|
<el-input-number
|
||||||
|
v-model="form_data.dialogue_number"
|
||||||
|
:min="0"
|
||||||
|
:value-on-clear="0"
|
||||||
|
controls-position="right"
|
||||||
|
class="w-full"
|
||||||
|
/>
|
||||||
|
</el-form-item>
|
||||||
|
<el-form-item label="选择图片" :rules="{
|
||||||
|
type: 'array',
|
||||||
|
required: true,
|
||||||
|
message: '请选择图片',
|
||||||
|
trigger: 'change'
|
||||||
|
}"
|
||||||
|
>
|
||||||
|
<NodeCascader
|
||||||
|
ref="nodeCascaderRef"
|
||||||
|
:nodeModel="nodeModel"
|
||||||
|
class="w-full"
|
||||||
|
placeholder="请选择图片"
|
||||||
|
v-model="form_data.image_list"
|
||||||
|
/>
|
||||||
|
</el-form-item>
|
||||||
|
<el-form-item label="返回内容" @click.prevent>
|
||||||
|
<template #label>
|
||||||
|
<div class="flex align-center">
|
||||||
|
<div class="mr-4">
|
||||||
|
<span>返回内容<span class="danger">*</span></span>
|
||||||
|
</div>
|
||||||
|
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
|
||||||
|
<template #content>
|
||||||
|
关闭后该节点的内容则不输出给用户。
|
||||||
|
如果你想让用户看到该节点的输出内容,请打开开关。
|
||||||
|
</template>
|
||||||
|
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||||
|
</el-tooltip>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<el-switch size="small" v-model="form_data.is_result" />
|
||||||
|
</el-form-item>
|
||||||
|
</el-form>
|
||||||
|
</el-card>
|
||||||
|
</NodeContainer>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import NodeContainer from '@/workflow/common/NodeContainer.vue'
|
||||||
|
import { computed, onMounted, ref } from 'vue'
|
||||||
|
import { groupBy, set } from 'lodash'
|
||||||
|
import { relatedObject } from '@/utils/utils'
|
||||||
|
import type { Provider } from '@/api/type/model'
|
||||||
|
import applicationApi from '@/api/application'
|
||||||
|
import { app } from '@/main'
|
||||||
|
import useStore from '@/stores'
|
||||||
|
import NodeCascader from '@/workflow/common/NodeCascader.vue'
|
||||||
|
|
||||||
|
const { model } = useStore()
|
||||||
|
|
||||||
|
const {
|
||||||
|
params: { id }
|
||||||
|
} = app.config.globalProperties.$route as any
|
||||||
|
|
||||||
|
const props = defineProps<{ nodeModel: any }>()
|
||||||
|
const modelOptions = ref<any>(null)
|
||||||
|
const providerOptions = ref<Array<Provider>>([])
|
||||||
|
|
||||||
|
|
||||||
|
const wheel = (e: any) => {
|
||||||
|
if (e.ctrlKey === true) {
|
||||||
|
e.preventDefault()
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
e.stopPropagation()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultPrompt = `{{开始.question}}`
|
||||||
|
|
||||||
|
const form = {
|
||||||
|
model_id: '',
|
||||||
|
system: '',
|
||||||
|
prompt: defaultPrompt,
|
||||||
|
dialogue_number: 0,
|
||||||
|
is_result: true,
|
||||||
|
temperature: null,
|
||||||
|
max_tokens: null,
|
||||||
|
image_list: ["start-node", "image"]
|
||||||
|
}
|
||||||
|
|
||||||
|
const form_data = computed({
|
||||||
|
get: () => {
|
||||||
|
if (props.nodeModel.properties.node_data) {
|
||||||
|
return props.nodeModel.properties.node_data
|
||||||
|
} else {
|
||||||
|
set(props.nodeModel.properties, 'node_data', form)
|
||||||
|
}
|
||||||
|
return props.nodeModel.properties.node_data
|
||||||
|
},
|
||||||
|
set: (value) => {
|
||||||
|
set(props.nodeModel.properties, 'node_data', value)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
function getModel() {
|
||||||
|
if (id) {
|
||||||
|
applicationApi.getApplicationImageModel(id).then((res: any) => {
|
||||||
|
modelOptions.value = groupBy(res?.data, 'provider')
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
model.asyncGetModel().then((res: any) => {
|
||||||
|
modelOptions.value = groupBy(res?.data, 'provider')
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function getProvider() {
|
||||||
|
model.asyncGetProvider().then((res: any) => {
|
||||||
|
providerOptions.value = res?.data
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const model_change = (model_id?: string) => {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
function submitSystemDialog(val: string) {
|
||||||
|
set(props.nodeModel.properties.node_data, 'system', val)
|
||||||
|
}
|
||||||
|
|
||||||
|
function submitDialog(val: string) {
|
||||||
|
set(props.nodeModel.properties.node_data, 'prompt', val)
|
||||||
|
}
|
||||||
|
|
||||||
|
onMounted(() => {
|
||||||
|
getModel()
|
||||||
|
getProvider()
|
||||||
|
})
|
||||||
|
|
||||||
|
</script>
|
||||||
|
|
||||||
|
|
||||||
|
<style scoped lang="scss">
|
||||||
|
|
||||||
|
</style>
|
||||||
@ -61,8 +61,31 @@ const refreshFieldList = () => {
|
|||||||
}
|
}
|
||||||
props.nodeModel.graphModel.eventCenter.on('refreshFieldList', refreshFieldList)
|
props.nodeModel.graphModel.eventCenter.on('refreshFieldList', refreshFieldList)
|
||||||
|
|
||||||
|
const refreshFileUploadConfig = () => {
|
||||||
|
let fields = cloneDeep(props.nodeModel.properties.config.fields)
|
||||||
|
const form_data = props.nodeModel.graphModel.nodes
|
||||||
|
.filter((v: any) => v.id === 'base-node')
|
||||||
|
.map((v: any) => cloneDeep(v.properties.node_data.file_upload_setting))
|
||||||
|
.filter((v: any) => v)
|
||||||
|
if (form_data.length === 0) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fields = fields.filter((item: any) => item.value !== 'image' && item.value !== 'document')
|
||||||
|
let fileUploadFields = []
|
||||||
|
if (form_data[0].document) {
|
||||||
|
fileUploadFields.push({ label: '文档', value: 'document' })
|
||||||
|
}
|
||||||
|
if (form_data[0].image) {
|
||||||
|
fileUploadFields.push({ label: '图片', value: 'image' })
|
||||||
|
}
|
||||||
|
|
||||||
|
set(props.nodeModel.properties.config, 'fields', [...fields, ...fileUploadFields])
|
||||||
|
}
|
||||||
|
props.nodeModel.graphModel.eventCenter.on('refreshFileUploadConfig', refreshFileUploadConfig)
|
||||||
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
refreshFieldList()
|
refreshFieldList()
|
||||||
|
refreshFileUploadConfig()
|
||||||
})
|
})
|
||||||
</script>
|
</script>
|
||||||
<style lang="scss" scoped></style>
|
<style lang="scss" scoped></style>
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user