feat: 支持节点参数设置直接输出 #846

This commit is contained in:
shaohuzhang1 2024-08-02 14:21:29 +08:00 committed by GitHub
parent 76c1acbabb
commit 35f0c18dd3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 219 additions and 240 deletions

View File

@ -10,6 +10,7 @@ import time
from abc import abstractmethod from abc import abstractmethod
from typing import Type, Dict, List from typing import Type, Dict, List
from django.core import cache
from django.db.models import QuerySet from django.db.models import QuerySet
from rest_framework import serializers from rest_framework import serializers
@ -18,7 +19,6 @@ from application.models.api_key_model import ApplicationPublicAccessClient
from common.constants.authentication_type import AuthenticationType from common.constants.authentication_type import AuthenticationType
from common.field.common import InstanceField from common.field.common import InstanceField
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from django.core import cache
chat_cache = cache.caches['chat_cache'] chat_cache = cache.caches['chat_cache']
@ -27,9 +27,13 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
if step_variable is not None: if step_variable is not None:
for key in step_variable: for key in step_variable:
node.context[key] = step_variable[key] node.context[key] = step_variable[key]
if workflow.is_result() and 'answer' in step_variable:
yield step_variable['answer']
workflow.answer += step_variable['answer']
if global_variable is not None: if global_variable is not None:
for key in global_variable: for key in global_variable:
workflow.context[key] = global_variable[key] workflow.context[key] = global_variable[key]
node.context['run_time'] = time.time() - node.context['start_time']
class WorkFlowPostHandler: class WorkFlowPostHandler:
@ -70,18 +74,14 @@ class WorkFlowPostHandler:
class NodeResult: class NodeResult:
def __init__(self, node_variable: Dict, workflow_variable: Dict, _to_response=None, _write_context=write_context): def __init__(self, node_variable: Dict, workflow_variable: Dict,
_write_context=write_context):
self._write_context = _write_context self._write_context = _write_context
self.node_variable = node_variable self.node_variable = node_variable
self.workflow_variable = workflow_variable self.workflow_variable = workflow_variable
self._to_response = _to_response
def write_context(self, node, workflow): def write_context(self, node, workflow):
self._write_context(self.node_variable, self.workflow_variable, node, workflow) return self._write_context(self.node_variable, self.workflow_variable, node, workflow)
def to_response(self, chat_id, chat_record_id, node, workflow, post_handler: WorkFlowPostHandler):
return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow,
post_handler)
def is_assertion_result(self): def is_assertion_result(self):
return 'branch_id' in self.node_variable return 'branch_id' in self.node_variable

View File

@ -22,6 +22,8 @@ class ChatNodeSerializer(serializers.Serializer):
# 多轮对话数量 # 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
class IChatNode(INode): class IChatNode(INode):
type = 'ai-chat-node' type = 'ai-chat-node'

View File

@ -13,12 +13,25 @@ from typing import List, Dict
from langchain.schema import HumanMessage, SystemMessage from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from application.flow import tools
from application.flow.i_step_node import NodeResult, INode from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from setting.models_provider.tools import get_model_instance_by_model_user_id from setting.models_provider.tools import get_model_instance_by_model_user_id
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
chat_model = node_variable.get('chat_model')
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time']
if workflow.is_result():
workflow.answer += answer
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
""" """
写入上下文数据 (流式) 写入上下文数据 (流式)
@ -31,15 +44,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
answer = '' answer = ''
for chunk in response: for chunk in response:
answer += chunk.content answer += chunk.content
chat_model = node_variable.get('chat_model') yield answer
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) _write_context(node_variable, workflow_variable, node, workflow, answer)
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time']
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
@ -51,71 +57,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
@param workflow: 工作流管理器 @param workflow: 工作流管理器
""" """
response = node_variable.get('result') response = node_variable.get('result')
chat_model = node_variable.get('chat_model')
answer = response.content answer = response.content
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) _write_context(node_variable, workflow_variable, node, workflow, answer)
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
def get_to_response_write_context(node_variable: Dict, node: INode):
def _write_context(answer, status=200):
chat_model = node_variable.get('chat_model')
if status == 200:
answer_tokens = chat_model.get_num_tokens(answer)
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
else:
answer_tokens = 0
message_tokens = 0
node.err_message = answer
node.status = status
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['run_time'] = time.time() - node.context['start_time']
return _write_context
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将流式数据 转换为 流式响应
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器 输出结果后执行
@return: 流式响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将结果转换
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器
@return: 响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
class BaseChatNode(IChatNode): class BaseChatNode(IChatNode):
@ -132,13 +75,12 @@ class BaseChatNode(IChatNode):
r = chat_model.stream(message_list) r = chat_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {}, 'history_message': history_message, 'question': question.content}, {},
_write_context=write_context_stream, _write_context=write_context_stream)
_to_response=to_stream_response)
else: else:
r = chat_model.invoke(message_list) r = chat_model.invoke(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {}, 'history_message': history_message, 'question': question.content}, {},
_write_context=write_context, _to_response=to_response) _write_context=write_context)
@staticmethod @staticmethod
def get_history_message(history_chat_record, dialogue_number): def get_history_message(history_chat_record, dialogue_number):

View File

@ -20,6 +20,7 @@ class ReplyNodeParamsSerializer(serializers.Serializer):
fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段")) fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段"))
content = serializers.CharField(required=False, allow_blank=True, allow_null=True, content = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char("直接回答内容")) error_messages=ErrMessage.char("直接回答内容"))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
def is_valid(self, *, raise_exception=False): def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True) super().is_valid(raise_exception=True)

View File

@ -6,69 +6,19 @@
@date2024/6/11 17:25 @date2024/6/11 17:25
@desc: @desc:
""" """
from typing import List, Dict from typing import List
from langchain_core.messages import AIMessage, AIMessageChunk from application.flow.i_step_node import NodeResult
from application.flow import tools
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode
def get_to_response_write_context(node_variable: Dict, node: INode):
def _write_context(answer, status=200):
node.context['answer'] = answer
return _write_context
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将流式数据 转换为 流式响应
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器 输出结果后执行
@return: 流式响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将结果转换
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器
@return: 响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
class BaseReplyNode(IReplyNode): class BaseReplyNode(IReplyNode):
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult: def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
if reply_type == 'referencing': if reply_type == 'referencing':
result = self.get_reference_content(fields) result = self.get_reference_content(fields)
else: else:
result = self.generate_reply_content(content) result = self.generate_reply_content(content)
if stream: return NodeResult({'answer': result}, {})
return NodeResult({'result': iter([AIMessageChunk(content=result)]), 'answer': result}, {},
_to_response=to_stream_response)
else:
return NodeResult({'result': AIMessage(content=result), 'answer': result}, {}, _to_response=to_response)
def generate_reply_content(self, prompt): def generate_reply_content(self, prompt):
return self.workflow_manage.generate_prompt(prompt) return self.workflow_manage.generate_prompt(prompt)

View File

@ -22,6 +22,8 @@ class QuestionNodeSerializer(serializers.Serializer):
# 多轮对话数量 # 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
class IQuestionNode(INode): class IQuestionNode(INode):
type = 'question-node' type = 'question-node'

View File

@ -13,12 +13,25 @@ from typing import List, Dict
from langchain.schema import HumanMessage, SystemMessage from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from application.flow import tools
from application.flow.i_step_node import NodeResult, INode from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.question_node.i_question_node import IQuestionNode from application.flow.step_node.question_node.i_question_node import IQuestionNode
from setting.models_provider.tools import get_model_instance_by_model_user_id from setting.models_provider.tools import get_model_instance_by_model_user_id
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
chat_model = node_variable.get('chat_model')
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time']
if workflow.is_result():
workflow.answer += answer
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
""" """
写入上下文数据 (流式) 写入上下文数据 (流式)
@ -31,15 +44,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
answer = '' answer = ''
for chunk in response: for chunk in response:
answer += chunk.content answer += chunk.content
chat_model = node_variable.get('chat_model') yield answer
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) _write_context(node_variable, workflow_variable, node, workflow, answer)
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time']
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
@ -51,71 +57,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
@param workflow: 工作流管理器 @param workflow: 工作流管理器
""" """
response = node_variable.get('result') response = node_variable.get('result')
chat_model = node_variable.get('chat_model')
answer = response.content answer = response.content
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) _write_context(node_variable, workflow_variable, node, workflow, answer)
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
def get_to_response_write_context(node_variable: Dict, node: INode):
def _write_context(answer, status=200):
chat_model = node_variable.get('chat_model')
if status == 200:
answer_tokens = chat_model.get_num_tokens(answer)
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
else:
answer_tokens = 0
message_tokens = 0
node.err_message = answer
node.status = status
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['run_time'] = time.time() - node.context['start_time']
return _write_context
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将流式数据 转换为 流式响应
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器 输出结果后执行
@return: 流式响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将结果转换
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器
@return: 响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
class BaseQuestionNode(IQuestionNode): class BaseQuestionNode(IQuestionNode):
@ -131,15 +74,13 @@ class BaseQuestionNode(IQuestionNode):
if stream: if stream:
r = chat_model.stream(message_list) r = chat_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'get_to_response_write_context': get_to_response_write_context,
'history_message': history_message, 'question': question.content}, {}, 'history_message': history_message, 'question': question.content}, {},
_write_context=write_context_stream, _write_context=write_context_stream)
_to_response=to_stream_response)
else: else:
r = chat_model.invoke(message_list) r = chat_model.invoke(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {}, 'history_message': history_message, 'question': question.content}, {},
_write_context=write_context, _to_response=to_response) _write_context=write_context)
@staticmethod @staticmethod
def get_history_message(history_chat_record, dialogue_number): def get_history_message(history_chat_record, dialogue_number):

View File

@ -85,3 +85,21 @@ def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_
post_handler.handler(chat_id, chat_record_id, answer, workflow) post_handler.handler(chat_id, chat_record_id, answer, workflow)
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': answer, 'is_end': True}) 'content': answer, 'is_end': True})
def to_response_simple(chat_id, chat_record_id, response: BaseMessage, workflow,
post_handler: WorkFlowPostHandler):
answer = response.content
post_handler.handler(chat_id, chat_record_id, answer, workflow)
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': answer, 'is_end': True})
def to_stream_response_simple(stream_event):
r = StreamingHttpResponse(
streaming_content=stream_event,
content_type='text/event-stream;charset=utf-8',
charset='utf-8')
r['Cache-Control'] = 'no-cache'
return r

View File

@ -6,10 +6,11 @@
@date2024/1/9 17:40 @date2024/1/9 17:40
@desc: @desc:
""" """
import json
from functools import reduce from functools import reduce
from typing import List, Dict from typing import List, Dict
from langchain_core.messages import AIMessageChunk, AIMessage from langchain_core.messages import AIMessage
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from application.flow import tools from application.flow import tools
@ -63,7 +64,6 @@ class Flow:
def get_search_node(self): def get_search_node(self):
return [node for node in self.nodes if node.type == 'search-dataset-node'] return [node for node in self.nodes if node.type == 'search-dataset-node']
def is_valid(self): def is_valid(self):
""" """
校验工作流数据 校验工作流数据
@ -140,33 +140,71 @@ class WorkflowManage:
self.work_flow_post_handler = work_flow_post_handler self.work_flow_post_handler = work_flow_post_handler
self.current_node = None self.current_node = None
self.current_result = None self.current_result = None
self.answer = ""
def run(self): def run(self):
""" if self.params.get('stream'):
运行工作流 return self.run_stream()
""" return self.run_block()
def run_block(self):
try: try:
while self.has_next_node(self.current_result): while self.has_next_node(self.current_result):
self.current_node = self.get_next_node() self.current_node = self.get_next_node()
self.node_context.append(self.current_node) self.node_context.append(self.current_node)
self.current_result = self.current_node.run() self.current_result = self.current_node.run()
if self.has_next_node(self.current_result): result = self.current_result.write_context(self.current_node, self)
self.current_result.write_context(self.current_node, self) if result is not None:
else: list(result)
r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'], if not self.has_next_node(self.current_result):
self.current_node, self, return tools.to_response_simple(self.params['chat_id'], self.params['chat_record_id'],
self.work_flow_post_handler) AIMessage(self.answer), self,
return r self.work_flow_post_handler)
except Exception as e: except Exception as e:
if self.params.get('stream'): return tools.to_response(self.params['chat_id'], self.params['chat_record_id'],
return tools.to_stream_response(self.params['chat_id'], self.params['chat_record_id'], AIMessage(str(e)), self, self.current_node.get_write_error_context(e),
iter([AIMessageChunk(str(e))]), self, self.work_flow_post_handler)
self.current_node.get_write_error_context(e),
self.work_flow_post_handler) def run_stream(self):
else: return tools.to_stream_response_simple(self.stream_event())
return tools.to_response(self.params['chat_id'], self.params['chat_record_id'],
AIMessage(str(e)), self, self.current_node.get_write_error_context(e), def stream_event(self):
self.work_flow_post_handler) try:
while self.has_next_node(self.current_result):
self.current_node = self.get_next_node()
self.node_context.append(self.current_node)
self.current_result = self.current_node.run()
result = self.current_result.write_context(self.current_node, self)
if result is not None:
for r in result:
if self.is_result():
yield self.get_chunk_content(r)
if not self.has_next_node(self.current_result):
yield self.get_chunk_content('', True)
break
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
self.answer,
self)
except Exception as e:
self.current_node.get_write_error_context(e)
self.answer += str(e)
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
self.answer,
self)
yield self.get_chunk_content(str(e), True)
def is_result(self):
"""
判断是否是返回节点
@return:
"""
return self.current_node.node_params.get('is_result', not self.has_next_node(
self.current_result)) if self.current_node.node_params is not None else False
def get_chunk_content(self, chunk, is_end=False):
return 'data: ' + json.dumps(
{'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True,
'content': chunk, 'is_end': is_end}, ensure_ascii=False) + "\n\n"
def has_next_node(self, node_result: NodeResult | None): def has_next_node(self, node_result: NodeResult | None):
""" """

View File

@ -170,3 +170,13 @@ export const nodeDict: any = {
export function isWorkFlow(type: string | undefined) { export function isWorkFlow(type: string | undefined) {
return type === 'WORK_FLOW' return type === 'WORK_FLOW'
} }
export function isLastNode(nodeModel: any) {
const incoming = nodeModel.graphModel.getNodeIncomingNode(nodeModel.id)
const outcomming = nodeModel.graphModel.getNodeOutgoingNode(nodeModel.id)
if (incoming.length > 0 && outcomming.length === 0) {
return true
} else {
return false
}
}

View File

@ -132,6 +132,23 @@
class="w-full" class="w-full"
/> />
</el-form-item> </el-form-item>
<el-form-item label="返回内容" @click.prevent>
<template #label>
<div class="flex align-center">
<div class="mr-4">
<span>返回内容<span class="danger">*</span></span>
</div>
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
<template #content>
关闭后该节点的内容则不输出给用户
如果你想让用户看到该节点的输出内容请打开开关
</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<el-switch size="small" v-model="chat_data.is_result" />
</el-form-item>
</el-form> </el-form>
</el-card> </el-card>
@ -156,6 +173,7 @@ import applicationApi from '@/api/application'
import useStore from '@/stores' import useStore from '@/stores'
import { relatedObject } from '@/utils/utils' import { relatedObject } from '@/utils/utils'
import type { Provider } from '@/api/type/model' import type { Provider } from '@/api/type/model'
import { isLastNode } from '@/workflow/common/data'
const { model } = useStore() const { model } = useStore()
const isKeyDown = ref(false) const isKeyDown = ref(false)
@ -180,7 +198,8 @@ const form = {
model_id: '', model_id: '',
system: '', system: '',
prompt: defaultPrompt, prompt: defaultPrompt,
dialogue_number: 1 dialogue_number: 1,
is_result: false
} }
const chat_data = computed({ const chat_data = computed({
@ -240,6 +259,12 @@ const openCreateModel = (provider?: Provider) => {
onMounted(() => { onMounted(() => {
getProvider() getProvider()
getModel() getModel()
if (typeof props.nodeModel.properties.node_data?.is_result === 'undefined') {
if (isLastNode(props.nodeModel)) {
set(props.nodeModel.properties.node_data, 'is_result', true)
}
}
set(props.nodeModel, 'validate', validate) set(props.nodeModel, 'validate', validate)
}) })
</script> </script>

View File

@ -133,6 +133,23 @@
class="w-full" class="w-full"
/> />
</el-form-item> </el-form-item>
<el-form-item label="返回内容" @click.prevent>
<template #label>
<div class="flex align-center">
<div class="mr-4">
<span>返回内容<span class="danger">*</span></span>
</div>
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
<template #content>
关闭后该节点的内容则不输出给用户
如果你想让用户看到该节点的输出内容请打开开关
</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<el-switch size="small" v-model="form_data.is_result" />
</el-form-item>
</el-form> </el-form>
</el-card> </el-card>
<!-- 添加模版 --> <!-- 添加模版 -->
@ -156,6 +173,8 @@ import applicationApi from '@/api/application'
import useStore from '@/stores' import useStore from '@/stores'
import { relatedObject } from '@/utils/utils' import { relatedObject } from '@/utils/utils'
import type { Provider } from '@/api/type/model' import type { Provider } from '@/api/type/model'
import { isLastNode } from '@/workflow/common/data'
const { model } = useStore() const { model } = useStore()
const isKeyDown = ref(false) const isKeyDown = ref(false)
const wheel = (e: any) => { const wheel = (e: any) => {
@ -177,7 +196,8 @@ const form = {
model_id: '', model_id: '',
system: '你是一个问题优化大师', system: '你是一个问题优化大师',
prompt: defaultPrompt, prompt: defaultPrompt,
dialogue_number: 1 dialogue_number: 1,
is_result: false
} }
const form_data = computed({ const form_data = computed({
@ -237,6 +257,11 @@ const openCreateModel = (provider?: Provider) => {
onMounted(() => { onMounted(() => {
getProvider() getProvider()
getModel() getModel()
if (typeof props.nodeModel.properties.node_data?.is_result === 'undefined') {
if (isLastNode(props.nodeModel)) {
set(props.nodeModel.properties.node_data, 'is_result', true)
}
}
set(props.nodeModel, 'validate', validate) set(props.nodeModel, 'validate', validate)
}) })
</script> </script>

View File

@ -46,6 +46,23 @@
v-model="form_data.fields" v-model="form_data.fields"
/> />
</el-form-item> </el-form-item>
<el-form-item label="返回内容" @click.prevent>
<template #label>
<div class="flex align-center">
<div class="mr-4">
<span>返回内容<span class="danger">*</span></span>
</div>
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
<template #content>
关闭后该节点的内容则不输出给用户
如果你想让用户看到该节点的输出内容请打开开关
</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<el-switch size="small" v-model="form_data.is_result" />
</el-form-item>
</el-form> </el-form>
</el-card> </el-card>
<!-- 回复内容弹出层 --> <!-- 回复内容弹出层 -->
@ -64,12 +81,14 @@ import { set } from 'lodash'
import NodeContainer from '@/workflow/common/NodeContainer.vue' import NodeContainer from '@/workflow/common/NodeContainer.vue'
import NodeCascader from '@/workflow/common/NodeCascader.vue' import NodeCascader from '@/workflow/common/NodeCascader.vue'
import { ref, computed, onMounted } from 'vue' import { ref, computed, onMounted } from 'vue'
import { isLastNode } from '@/workflow/common/data'
const props = defineProps<{ nodeModel: any }>() const props = defineProps<{ nodeModel: any }>()
const form = { const form = {
reply_type: 'content', reply_type: 'content',
content: '', content: '',
fields: [] fields: [],
is_result: false
} }
const footers: any = [null, '=', 0] const footers: any = [null, '=', 0]
@ -111,6 +130,12 @@ const validate = () => {
} }
onMounted(() => { onMounted(() => {
if (typeof props.nodeModel.properties.node_data?.is_result === 'undefined') {
if (isLastNode(props.nodeModel)) {
set(props.nodeModel.properties.node_data, 'is_result', true)
}
}
set(props.nodeModel, 'validate', validate) set(props.nodeModel, 'validate', validate)
}) })
</script> </script>