feat: 支持节点参数设置直接输出 #846
This commit is contained in:
parent
76c1acbabb
commit
35f0c18dd3
@ -10,6 +10,7 @@ import time
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Type, Dict, List
|
from typing import Type, Dict, List
|
||||||
|
|
||||||
|
from django.core import cache
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
@ -18,7 +19,6 @@ from application.models.api_key_model import ApplicationPublicAccessClient
|
|||||||
from common.constants.authentication_type import AuthenticationType
|
from common.constants.authentication_type import AuthenticationType
|
||||||
from common.field.common import InstanceField
|
from common.field.common import InstanceField
|
||||||
from common.util.field_message import ErrMessage
|
from common.util.field_message import ErrMessage
|
||||||
from django.core import cache
|
|
||||||
|
|
||||||
chat_cache = cache.caches['chat_cache']
|
chat_cache = cache.caches['chat_cache']
|
||||||
|
|
||||||
@ -27,9 +27,13 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
|||||||
if step_variable is not None:
|
if step_variable is not None:
|
||||||
for key in step_variable:
|
for key in step_variable:
|
||||||
node.context[key] = step_variable[key]
|
node.context[key] = step_variable[key]
|
||||||
|
if workflow.is_result() and 'answer' in step_variable:
|
||||||
|
yield step_variable['answer']
|
||||||
|
workflow.answer += step_variable['answer']
|
||||||
if global_variable is not None:
|
if global_variable is not None:
|
||||||
for key in global_variable:
|
for key in global_variable:
|
||||||
workflow.context[key] = global_variable[key]
|
workflow.context[key] = global_variable[key]
|
||||||
|
node.context['run_time'] = time.time() - node.context['start_time']
|
||||||
|
|
||||||
|
|
||||||
class WorkFlowPostHandler:
|
class WorkFlowPostHandler:
|
||||||
@ -70,18 +74,14 @@ class WorkFlowPostHandler:
|
|||||||
|
|
||||||
|
|
||||||
class NodeResult:
|
class NodeResult:
|
||||||
def __init__(self, node_variable: Dict, workflow_variable: Dict, _to_response=None, _write_context=write_context):
|
def __init__(self, node_variable: Dict, workflow_variable: Dict,
|
||||||
|
_write_context=write_context):
|
||||||
self._write_context = _write_context
|
self._write_context = _write_context
|
||||||
self.node_variable = node_variable
|
self.node_variable = node_variable
|
||||||
self.workflow_variable = workflow_variable
|
self.workflow_variable = workflow_variable
|
||||||
self._to_response = _to_response
|
|
||||||
|
|
||||||
def write_context(self, node, workflow):
|
def write_context(self, node, workflow):
|
||||||
self._write_context(self.node_variable, self.workflow_variable, node, workflow)
|
return self._write_context(self.node_variable, self.workflow_variable, node, workflow)
|
||||||
|
|
||||||
def to_response(self, chat_id, chat_record_id, node, workflow, post_handler: WorkFlowPostHandler):
|
|
||||||
return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow,
|
|
||||||
post_handler)
|
|
||||||
|
|
||||||
def is_assertion_result(self):
|
def is_assertion_result(self):
|
||||||
return 'branch_id' in self.node_variable
|
return 'branch_id' in self.node_variable
|
||||||
|
|||||||
@ -22,6 +22,8 @@ class ChatNodeSerializer(serializers.Serializer):
|
|||||||
# 多轮对话数量
|
# 多轮对话数量
|
||||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
||||||
|
|
||||||
|
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
|
||||||
|
|
||||||
|
|
||||||
class IChatNode(INode):
|
class IChatNode(INode):
|
||||||
type = 'ai-chat-node'
|
type = 'ai-chat-node'
|
||||||
|
|||||||
@ -13,12 +13,25 @@ from typing import List, Dict
|
|||||||
from langchain.schema import HumanMessage, SystemMessage
|
from langchain.schema import HumanMessage, SystemMessage
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
from application.flow import tools
|
|
||||||
from application.flow.i_step_node import NodeResult, INode
|
from application.flow.i_step_node import NodeResult, INode
|
||||||
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
|
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
|
||||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||||
|
|
||||||
|
|
||||||
|
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
||||||
|
chat_model = node_variable.get('chat_model')
|
||||||
|
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||||
|
answer_tokens = chat_model.get_num_tokens(answer)
|
||||||
|
node.context['message_tokens'] = message_tokens
|
||||||
|
node.context['answer_tokens'] = answer_tokens
|
||||||
|
node.context['answer'] = answer
|
||||||
|
node.context['history_message'] = node_variable['history_message']
|
||||||
|
node.context['question'] = node_variable['question']
|
||||||
|
node.context['run_time'] = time.time() - node.context['start_time']
|
||||||
|
if workflow.is_result():
|
||||||
|
workflow.answer += answer
|
||||||
|
|
||||||
|
|
||||||
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||||
"""
|
"""
|
||||||
写入上下文数据 (流式)
|
写入上下文数据 (流式)
|
||||||
@ -31,15 +44,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
|
|||||||
answer = ''
|
answer = ''
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
answer += chunk.content
|
answer += chunk.content
|
||||||
chat_model = node_variable.get('chat_model')
|
yield answer
|
||||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||||
answer_tokens = chat_model.get_num_tokens(answer)
|
|
||||||
node.context['message_tokens'] = message_tokens
|
|
||||||
node.context['answer_tokens'] = answer_tokens
|
|
||||||
node.context['answer'] = answer
|
|
||||||
node.context['history_message'] = node_variable['history_message']
|
|
||||||
node.context['question'] = node_variable['question']
|
|
||||||
node.context['run_time'] = time.time() - node.context['start_time']
|
|
||||||
|
|
||||||
|
|
||||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||||
@ -51,71 +57,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
|
|||||||
@param workflow: 工作流管理器
|
@param workflow: 工作流管理器
|
||||||
"""
|
"""
|
||||||
response = node_variable.get('result')
|
response = node_variable.get('result')
|
||||||
chat_model = node_variable.get('chat_model')
|
|
||||||
answer = response.content
|
answer = response.content
|
||||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||||
answer_tokens = chat_model.get_num_tokens(answer)
|
|
||||||
node.context['message_tokens'] = message_tokens
|
|
||||||
node.context['answer_tokens'] = answer_tokens
|
|
||||||
node.context['answer'] = answer
|
|
||||||
node.context['history_message'] = node_variable['history_message']
|
|
||||||
node.context['question'] = node_variable['question']
|
|
||||||
|
|
||||||
|
|
||||||
def get_to_response_write_context(node_variable: Dict, node: INode):
|
|
||||||
def _write_context(answer, status=200):
|
|
||||||
chat_model = node_variable.get('chat_model')
|
|
||||||
|
|
||||||
if status == 200:
|
|
||||||
answer_tokens = chat_model.get_num_tokens(answer)
|
|
||||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
|
||||||
else:
|
|
||||||
answer_tokens = 0
|
|
||||||
message_tokens = 0
|
|
||||||
node.err_message = answer
|
|
||||||
node.status = status
|
|
||||||
node.context['message_tokens'] = message_tokens
|
|
||||||
node.context['answer_tokens'] = answer_tokens
|
|
||||||
node.context['answer'] = answer
|
|
||||||
node.context['run_time'] = time.time() - node.context['start_time']
|
|
||||||
|
|
||||||
return _write_context
|
|
||||||
|
|
||||||
|
|
||||||
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
|
||||||
post_handler):
|
|
||||||
"""
|
|
||||||
将流式数据 转换为 流式响应
|
|
||||||
@param chat_id: 会话id
|
|
||||||
@param chat_record_id: 对话记录id
|
|
||||||
@param node_variable: 节点数据
|
|
||||||
@param workflow_variable: 工作流数据
|
|
||||||
@param node: 节点
|
|
||||||
@param workflow: 工作流管理器
|
|
||||||
@param post_handler: 后置处理器 输出结果后执行
|
|
||||||
@return: 流式响应
|
|
||||||
"""
|
|
||||||
response = node_variable.get('result')
|
|
||||||
_write_context = get_to_response_write_context(node_variable, node)
|
|
||||||
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
|
||||||
|
|
||||||
|
|
||||||
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
|
||||||
post_handler):
|
|
||||||
"""
|
|
||||||
将结果转换
|
|
||||||
@param chat_id: 会话id
|
|
||||||
@param chat_record_id: 对话记录id
|
|
||||||
@param node_variable: 节点数据
|
|
||||||
@param workflow_variable: 工作流数据
|
|
||||||
@param node: 节点
|
|
||||||
@param workflow: 工作流管理器
|
|
||||||
@param post_handler: 后置处理器
|
|
||||||
@return: 响应
|
|
||||||
"""
|
|
||||||
response = node_variable.get('result')
|
|
||||||
_write_context = get_to_response_write_context(node_variable, node)
|
|
||||||
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseChatNode(IChatNode):
|
class BaseChatNode(IChatNode):
|
||||||
@ -132,13 +75,12 @@ class BaseChatNode(IChatNode):
|
|||||||
r = chat_model.stream(message_list)
|
r = chat_model.stream(message_list)
|
||||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||||
'history_message': history_message, 'question': question.content}, {},
|
'history_message': history_message, 'question': question.content}, {},
|
||||||
_write_context=write_context_stream,
|
_write_context=write_context_stream)
|
||||||
_to_response=to_stream_response)
|
|
||||||
else:
|
else:
|
||||||
r = chat_model.invoke(message_list)
|
r = chat_model.invoke(message_list)
|
||||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||||
'history_message': history_message, 'question': question.content}, {},
|
'history_message': history_message, 'question': question.content}, {},
|
||||||
_write_context=write_context, _to_response=to_response)
|
_write_context=write_context)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_history_message(history_chat_record, dialogue_number):
|
def get_history_message(history_chat_record, dialogue_number):
|
||||||
|
|||||||
@ -20,6 +20,7 @@ class ReplyNodeParamsSerializer(serializers.Serializer):
|
|||||||
fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段"))
|
fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段"))
|
||||||
content = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
content = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||||
error_messages=ErrMessage.char("直接回答内容"))
|
error_messages=ErrMessage.char("直接回答内容"))
|
||||||
|
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=False):
|
def is_valid(self, *, raise_exception=False):
|
||||||
super().is_valid(raise_exception=True)
|
super().is_valid(raise_exception=True)
|
||||||
|
|||||||
@ -6,69 +6,19 @@
|
|||||||
@date:2024/6/11 17:25
|
@date:2024/6/11 17:25
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
from typing import List, Dict
|
from typing import List
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
from application.flow.i_step_node import NodeResult
|
||||||
|
|
||||||
from application.flow import tools
|
|
||||||
from application.flow.i_step_node import NodeResult, INode
|
|
||||||
from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode
|
from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode
|
||||||
|
|
||||||
|
|
||||||
def get_to_response_write_context(node_variable: Dict, node: INode):
|
|
||||||
def _write_context(answer, status=200):
|
|
||||||
node.context['answer'] = answer
|
|
||||||
|
|
||||||
return _write_context
|
|
||||||
|
|
||||||
|
|
||||||
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
|
||||||
post_handler):
|
|
||||||
"""
|
|
||||||
将流式数据 转换为 流式响应
|
|
||||||
@param chat_id: 会话id
|
|
||||||
@param chat_record_id: 对话记录id
|
|
||||||
@param node_variable: 节点数据
|
|
||||||
@param workflow_variable: 工作流数据
|
|
||||||
@param node: 节点
|
|
||||||
@param workflow: 工作流管理器
|
|
||||||
@param post_handler: 后置处理器 输出结果后执行
|
|
||||||
@return: 流式响应
|
|
||||||
"""
|
|
||||||
response = node_variable.get('result')
|
|
||||||
_write_context = get_to_response_write_context(node_variable, node)
|
|
||||||
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
|
||||||
|
|
||||||
|
|
||||||
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
|
||||||
post_handler):
|
|
||||||
"""
|
|
||||||
将结果转换
|
|
||||||
@param chat_id: 会话id
|
|
||||||
@param chat_record_id: 对话记录id
|
|
||||||
@param node_variable: 节点数据
|
|
||||||
@param workflow_variable: 工作流数据
|
|
||||||
@param node: 节点
|
|
||||||
@param workflow: 工作流管理器
|
|
||||||
@param post_handler: 后置处理器
|
|
||||||
@return: 响应
|
|
||||||
"""
|
|
||||||
response = node_variable.get('result')
|
|
||||||
_write_context = get_to_response_write_context(node_variable, node)
|
|
||||||
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseReplyNode(IReplyNode):
|
class BaseReplyNode(IReplyNode):
|
||||||
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
|
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
|
||||||
if reply_type == 'referencing':
|
if reply_type == 'referencing':
|
||||||
result = self.get_reference_content(fields)
|
result = self.get_reference_content(fields)
|
||||||
else:
|
else:
|
||||||
result = self.generate_reply_content(content)
|
result = self.generate_reply_content(content)
|
||||||
if stream:
|
return NodeResult({'answer': result}, {})
|
||||||
return NodeResult({'result': iter([AIMessageChunk(content=result)]), 'answer': result}, {},
|
|
||||||
_to_response=to_stream_response)
|
|
||||||
else:
|
|
||||||
return NodeResult({'result': AIMessage(content=result), 'answer': result}, {}, _to_response=to_response)
|
|
||||||
|
|
||||||
def generate_reply_content(self, prompt):
|
def generate_reply_content(self, prompt):
|
||||||
return self.workflow_manage.generate_prompt(prompt)
|
return self.workflow_manage.generate_prompt(prompt)
|
||||||
|
|||||||
@ -22,6 +22,8 @@ class QuestionNodeSerializer(serializers.Serializer):
|
|||||||
# 多轮对话数量
|
# 多轮对话数量
|
||||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
||||||
|
|
||||||
|
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
|
||||||
|
|
||||||
|
|
||||||
class IQuestionNode(INode):
|
class IQuestionNode(INode):
|
||||||
type = 'question-node'
|
type = 'question-node'
|
||||||
|
|||||||
@ -13,12 +13,25 @@ from typing import List, Dict
|
|||||||
from langchain.schema import HumanMessage, SystemMessage
|
from langchain.schema import HumanMessage, SystemMessage
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
from application.flow import tools
|
|
||||||
from application.flow.i_step_node import NodeResult, INode
|
from application.flow.i_step_node import NodeResult, INode
|
||||||
from application.flow.step_node.question_node.i_question_node import IQuestionNode
|
from application.flow.step_node.question_node.i_question_node import IQuestionNode
|
||||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||||
|
|
||||||
|
|
||||||
|
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
||||||
|
chat_model = node_variable.get('chat_model')
|
||||||
|
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||||
|
answer_tokens = chat_model.get_num_tokens(answer)
|
||||||
|
node.context['message_tokens'] = message_tokens
|
||||||
|
node.context['answer_tokens'] = answer_tokens
|
||||||
|
node.context['answer'] = answer
|
||||||
|
node.context['history_message'] = node_variable['history_message']
|
||||||
|
node.context['question'] = node_variable['question']
|
||||||
|
node.context['run_time'] = time.time() - node.context['start_time']
|
||||||
|
if workflow.is_result():
|
||||||
|
workflow.answer += answer
|
||||||
|
|
||||||
|
|
||||||
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||||
"""
|
"""
|
||||||
写入上下文数据 (流式)
|
写入上下文数据 (流式)
|
||||||
@ -31,15 +44,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
|
|||||||
answer = ''
|
answer = ''
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
answer += chunk.content
|
answer += chunk.content
|
||||||
chat_model = node_variable.get('chat_model')
|
yield answer
|
||||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||||
answer_tokens = chat_model.get_num_tokens(answer)
|
|
||||||
node.context['message_tokens'] = message_tokens
|
|
||||||
node.context['answer_tokens'] = answer_tokens
|
|
||||||
node.context['answer'] = answer
|
|
||||||
node.context['history_message'] = node_variable['history_message']
|
|
||||||
node.context['question'] = node_variable['question']
|
|
||||||
node.context['run_time'] = time.time() - node.context['start_time']
|
|
||||||
|
|
||||||
|
|
||||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||||
@ -51,71 +57,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
|
|||||||
@param workflow: 工作流管理器
|
@param workflow: 工作流管理器
|
||||||
"""
|
"""
|
||||||
response = node_variable.get('result')
|
response = node_variable.get('result')
|
||||||
chat_model = node_variable.get('chat_model')
|
|
||||||
answer = response.content
|
answer = response.content
|
||||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||||
answer_tokens = chat_model.get_num_tokens(answer)
|
|
||||||
node.context['message_tokens'] = message_tokens
|
|
||||||
node.context['answer_tokens'] = answer_tokens
|
|
||||||
node.context['answer'] = answer
|
|
||||||
node.context['history_message'] = node_variable['history_message']
|
|
||||||
node.context['question'] = node_variable['question']
|
|
||||||
|
|
||||||
|
|
||||||
def get_to_response_write_context(node_variable: Dict, node: INode):
|
|
||||||
def _write_context(answer, status=200):
|
|
||||||
chat_model = node_variable.get('chat_model')
|
|
||||||
|
|
||||||
if status == 200:
|
|
||||||
answer_tokens = chat_model.get_num_tokens(answer)
|
|
||||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
|
||||||
else:
|
|
||||||
answer_tokens = 0
|
|
||||||
message_tokens = 0
|
|
||||||
node.err_message = answer
|
|
||||||
node.status = status
|
|
||||||
node.context['message_tokens'] = message_tokens
|
|
||||||
node.context['answer_tokens'] = answer_tokens
|
|
||||||
node.context['answer'] = answer
|
|
||||||
node.context['run_time'] = time.time() - node.context['start_time']
|
|
||||||
|
|
||||||
return _write_context
|
|
||||||
|
|
||||||
|
|
||||||
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
|
||||||
post_handler):
|
|
||||||
"""
|
|
||||||
将流式数据 转换为 流式响应
|
|
||||||
@param chat_id: 会话id
|
|
||||||
@param chat_record_id: 对话记录id
|
|
||||||
@param node_variable: 节点数据
|
|
||||||
@param workflow_variable: 工作流数据
|
|
||||||
@param node: 节点
|
|
||||||
@param workflow: 工作流管理器
|
|
||||||
@param post_handler: 后置处理器 输出结果后执行
|
|
||||||
@return: 流式响应
|
|
||||||
"""
|
|
||||||
response = node_variable.get('result')
|
|
||||||
_write_context = get_to_response_write_context(node_variable, node)
|
|
||||||
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
|
||||||
|
|
||||||
|
|
||||||
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
|
||||||
post_handler):
|
|
||||||
"""
|
|
||||||
将结果转换
|
|
||||||
@param chat_id: 会话id
|
|
||||||
@param chat_record_id: 对话记录id
|
|
||||||
@param node_variable: 节点数据
|
|
||||||
@param workflow_variable: 工作流数据
|
|
||||||
@param node: 节点
|
|
||||||
@param workflow: 工作流管理器
|
|
||||||
@param post_handler: 后置处理器
|
|
||||||
@return: 响应
|
|
||||||
"""
|
|
||||||
response = node_variable.get('result')
|
|
||||||
_write_context = get_to_response_write_context(node_variable, node)
|
|
||||||
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseQuestionNode(IQuestionNode):
|
class BaseQuestionNode(IQuestionNode):
|
||||||
@ -131,15 +74,13 @@ class BaseQuestionNode(IQuestionNode):
|
|||||||
if stream:
|
if stream:
|
||||||
r = chat_model.stream(message_list)
|
r = chat_model.stream(message_list)
|
||||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||||
'get_to_response_write_context': get_to_response_write_context,
|
|
||||||
'history_message': history_message, 'question': question.content}, {},
|
'history_message': history_message, 'question': question.content}, {},
|
||||||
_write_context=write_context_stream,
|
_write_context=write_context_stream)
|
||||||
_to_response=to_stream_response)
|
|
||||||
else:
|
else:
|
||||||
r = chat_model.invoke(message_list)
|
r = chat_model.invoke(message_list)
|
||||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||||
'history_message': history_message, 'question': question.content}, {},
|
'history_message': history_message, 'question': question.content}, {},
|
||||||
_write_context=write_context, _to_response=to_response)
|
_write_context=write_context)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_history_message(history_chat_record, dialogue_number):
|
def get_history_message(history_chat_record, dialogue_number):
|
||||||
|
|||||||
@ -85,3 +85,21 @@ def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_
|
|||||||
post_handler.handler(chat_id, chat_record_id, answer, workflow)
|
post_handler.handler(chat_id, chat_record_id, answer, workflow)
|
||||||
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||||
'content': answer, 'is_end': True})
|
'content': answer, 'is_end': True})
|
||||||
|
|
||||||
|
|
||||||
|
def to_response_simple(chat_id, chat_record_id, response: BaseMessage, workflow,
|
||||||
|
post_handler: WorkFlowPostHandler):
|
||||||
|
answer = response.content
|
||||||
|
post_handler.handler(chat_id, chat_record_id, answer, workflow)
|
||||||
|
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||||
|
'content': answer, 'is_end': True})
|
||||||
|
|
||||||
|
|
||||||
|
def to_stream_response_simple(stream_event):
|
||||||
|
r = StreamingHttpResponse(
|
||||||
|
streaming_content=stream_event,
|
||||||
|
content_type='text/event-stream;charset=utf-8',
|
||||||
|
charset='utf-8')
|
||||||
|
|
||||||
|
r['Cache-Control'] = 'no-cache'
|
||||||
|
return r
|
||||||
|
|||||||
@ -6,10 +6,11 @@
|
|||||||
@date:2024/1/9 17:40
|
@date:2024/1/9 17:40
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
|
import json
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
from langchain_core.messages import AIMessageChunk, AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
from application.flow import tools
|
from application.flow import tools
|
||||||
@ -63,7 +64,6 @@ class Flow:
|
|||||||
def get_search_node(self):
|
def get_search_node(self):
|
||||||
return [node for node in self.nodes if node.type == 'search-dataset-node']
|
return [node for node in self.nodes if node.type == 'search-dataset-node']
|
||||||
|
|
||||||
|
|
||||||
def is_valid(self):
|
def is_valid(self):
|
||||||
"""
|
"""
|
||||||
校验工作流数据
|
校验工作流数据
|
||||||
@ -140,34 +140,72 @@ class WorkflowManage:
|
|||||||
self.work_flow_post_handler = work_flow_post_handler
|
self.work_flow_post_handler = work_flow_post_handler
|
||||||
self.current_node = None
|
self.current_node = None
|
||||||
self.current_result = None
|
self.current_result = None
|
||||||
|
self.answer = ""
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""
|
if self.params.get('stream'):
|
||||||
运行工作流
|
return self.run_stream()
|
||||||
"""
|
return self.run_block()
|
||||||
|
|
||||||
|
def run_block(self):
|
||||||
try:
|
try:
|
||||||
while self.has_next_node(self.current_result):
|
while self.has_next_node(self.current_result):
|
||||||
self.current_node = self.get_next_node()
|
self.current_node = self.get_next_node()
|
||||||
self.node_context.append(self.current_node)
|
self.node_context.append(self.current_node)
|
||||||
self.current_result = self.current_node.run()
|
self.current_result = self.current_node.run()
|
||||||
if self.has_next_node(self.current_result):
|
result = self.current_result.write_context(self.current_node, self)
|
||||||
self.current_result.write_context(self.current_node, self)
|
if result is not None:
|
||||||
else:
|
list(result)
|
||||||
r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'],
|
if not self.has_next_node(self.current_result):
|
||||||
self.current_node, self,
|
return tools.to_response_simple(self.params['chat_id'], self.params['chat_record_id'],
|
||||||
|
AIMessage(self.answer), self,
|
||||||
self.work_flow_post_handler)
|
self.work_flow_post_handler)
|
||||||
return r
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.params.get('stream'):
|
|
||||||
return tools.to_stream_response(self.params['chat_id'], self.params['chat_record_id'],
|
|
||||||
iter([AIMessageChunk(str(e))]), self,
|
|
||||||
self.current_node.get_write_error_context(e),
|
|
||||||
self.work_flow_post_handler)
|
|
||||||
else:
|
|
||||||
return tools.to_response(self.params['chat_id'], self.params['chat_record_id'],
|
return tools.to_response(self.params['chat_id'], self.params['chat_record_id'],
|
||||||
AIMessage(str(e)), self, self.current_node.get_write_error_context(e),
|
AIMessage(str(e)), self, self.current_node.get_write_error_context(e),
|
||||||
self.work_flow_post_handler)
|
self.work_flow_post_handler)
|
||||||
|
|
||||||
|
def run_stream(self):
|
||||||
|
return tools.to_stream_response_simple(self.stream_event())
|
||||||
|
|
||||||
|
def stream_event(self):
|
||||||
|
try:
|
||||||
|
while self.has_next_node(self.current_result):
|
||||||
|
self.current_node = self.get_next_node()
|
||||||
|
self.node_context.append(self.current_node)
|
||||||
|
self.current_result = self.current_node.run()
|
||||||
|
result = self.current_result.write_context(self.current_node, self)
|
||||||
|
if result is not None:
|
||||||
|
for r in result:
|
||||||
|
if self.is_result():
|
||||||
|
yield self.get_chunk_content(r)
|
||||||
|
if not self.has_next_node(self.current_result):
|
||||||
|
yield self.get_chunk_content('', True)
|
||||||
|
break
|
||||||
|
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
|
||||||
|
self.answer,
|
||||||
|
self)
|
||||||
|
except Exception as e:
|
||||||
|
self.current_node.get_write_error_context(e)
|
||||||
|
self.answer += str(e)
|
||||||
|
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
|
||||||
|
self.answer,
|
||||||
|
self)
|
||||||
|
yield self.get_chunk_content(str(e), True)
|
||||||
|
|
||||||
|
def is_result(self):
|
||||||
|
"""
|
||||||
|
判断是否是返回节点
|
||||||
|
@return:
|
||||||
|
"""
|
||||||
|
return self.current_node.node_params.get('is_result', not self.has_next_node(
|
||||||
|
self.current_result)) if self.current_node.node_params is not None else False
|
||||||
|
|
||||||
|
def get_chunk_content(self, chunk, is_end=False):
|
||||||
|
return 'data: ' + json.dumps(
|
||||||
|
{'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True,
|
||||||
|
'content': chunk, 'is_end': is_end}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
def has_next_node(self, node_result: NodeResult | None):
|
def has_next_node(self, node_result: NodeResult | None):
|
||||||
"""
|
"""
|
||||||
是否有下一个可运行的节点
|
是否有下一个可运行的节点
|
||||||
|
|||||||
@ -170,3 +170,13 @@ export const nodeDict: any = {
|
|||||||
export function isWorkFlow(type: string | undefined) {
|
export function isWorkFlow(type: string | undefined) {
|
||||||
return type === 'WORK_FLOW'
|
return type === 'WORK_FLOW'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function isLastNode(nodeModel: any) {
|
||||||
|
const incoming = nodeModel.graphModel.getNodeIncomingNode(nodeModel.id)
|
||||||
|
const outcomming = nodeModel.graphModel.getNodeOutgoingNode(nodeModel.id)
|
||||||
|
if (incoming.length > 0 && outcomming.length === 0) {
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -132,6 +132,23 @@
|
|||||||
class="w-full"
|
class="w-full"
|
||||||
/>
|
/>
|
||||||
</el-form-item>
|
</el-form-item>
|
||||||
|
<el-form-item label="返回内容" @click.prevent>
|
||||||
|
<template #label>
|
||||||
|
<div class="flex align-center">
|
||||||
|
<div class="mr-4">
|
||||||
|
<span>返回内容<span class="danger">*</span></span>
|
||||||
|
</div>
|
||||||
|
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
|
||||||
|
<template #content>
|
||||||
|
关闭后该节点的内容则不输出给用户。
|
||||||
|
如果你想让用户看到该节点的输出内容,请打开开关。
|
||||||
|
</template>
|
||||||
|
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||||
|
</el-tooltip>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<el-switch size="small" v-model="chat_data.is_result" />
|
||||||
|
</el-form-item>
|
||||||
</el-form>
|
</el-form>
|
||||||
</el-card>
|
</el-card>
|
||||||
|
|
||||||
@ -156,6 +173,7 @@ import applicationApi from '@/api/application'
|
|||||||
import useStore from '@/stores'
|
import useStore from '@/stores'
|
||||||
import { relatedObject } from '@/utils/utils'
|
import { relatedObject } from '@/utils/utils'
|
||||||
import type { Provider } from '@/api/type/model'
|
import type { Provider } from '@/api/type/model'
|
||||||
|
import { isLastNode } from '@/workflow/common/data'
|
||||||
|
|
||||||
const { model } = useStore()
|
const { model } = useStore()
|
||||||
const isKeyDown = ref(false)
|
const isKeyDown = ref(false)
|
||||||
@ -180,7 +198,8 @@ const form = {
|
|||||||
model_id: '',
|
model_id: '',
|
||||||
system: '',
|
system: '',
|
||||||
prompt: defaultPrompt,
|
prompt: defaultPrompt,
|
||||||
dialogue_number: 1
|
dialogue_number: 1,
|
||||||
|
is_result: false
|
||||||
}
|
}
|
||||||
|
|
||||||
const chat_data = computed({
|
const chat_data = computed({
|
||||||
@ -240,6 +259,12 @@ const openCreateModel = (provider?: Provider) => {
|
|||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
getProvider()
|
getProvider()
|
||||||
getModel()
|
getModel()
|
||||||
|
if (typeof props.nodeModel.properties.node_data?.is_result === 'undefined') {
|
||||||
|
if (isLastNode(props.nodeModel)) {
|
||||||
|
set(props.nodeModel.properties.node_data, 'is_result', true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
set(props.nodeModel, 'validate', validate)
|
set(props.nodeModel, 'validate', validate)
|
||||||
})
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@ -133,6 +133,23 @@
|
|||||||
class="w-full"
|
class="w-full"
|
||||||
/>
|
/>
|
||||||
</el-form-item>
|
</el-form-item>
|
||||||
|
<el-form-item label="返回内容" @click.prevent>
|
||||||
|
<template #label>
|
||||||
|
<div class="flex align-center">
|
||||||
|
<div class="mr-4">
|
||||||
|
<span>返回内容<span class="danger">*</span></span>
|
||||||
|
</div>
|
||||||
|
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
|
||||||
|
<template #content>
|
||||||
|
关闭后该节点的内容则不输出给用户。
|
||||||
|
如果你想让用户看到该节点的输出内容,请打开开关。
|
||||||
|
</template>
|
||||||
|
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||||
|
</el-tooltip>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<el-switch size="small" v-model="form_data.is_result" />
|
||||||
|
</el-form-item>
|
||||||
</el-form>
|
</el-form>
|
||||||
</el-card>
|
</el-card>
|
||||||
<!-- 添加模版 -->
|
<!-- 添加模版 -->
|
||||||
@ -156,6 +173,8 @@ import applicationApi from '@/api/application'
|
|||||||
import useStore from '@/stores'
|
import useStore from '@/stores'
|
||||||
import { relatedObject } from '@/utils/utils'
|
import { relatedObject } from '@/utils/utils'
|
||||||
import type { Provider } from '@/api/type/model'
|
import type { Provider } from '@/api/type/model'
|
||||||
|
import { isLastNode } from '@/workflow/common/data'
|
||||||
|
|
||||||
const { model } = useStore()
|
const { model } = useStore()
|
||||||
const isKeyDown = ref(false)
|
const isKeyDown = ref(false)
|
||||||
const wheel = (e: any) => {
|
const wheel = (e: any) => {
|
||||||
@ -177,7 +196,8 @@ const form = {
|
|||||||
model_id: '',
|
model_id: '',
|
||||||
system: '你是一个问题优化大师',
|
system: '你是一个问题优化大师',
|
||||||
prompt: defaultPrompt,
|
prompt: defaultPrompt,
|
||||||
dialogue_number: 1
|
dialogue_number: 1,
|
||||||
|
is_result: false
|
||||||
}
|
}
|
||||||
|
|
||||||
const form_data = computed({
|
const form_data = computed({
|
||||||
@ -237,6 +257,11 @@ const openCreateModel = (provider?: Provider) => {
|
|||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
getProvider()
|
getProvider()
|
||||||
getModel()
|
getModel()
|
||||||
|
if (typeof props.nodeModel.properties.node_data?.is_result === 'undefined') {
|
||||||
|
if (isLastNode(props.nodeModel)) {
|
||||||
|
set(props.nodeModel.properties.node_data, 'is_result', true)
|
||||||
|
}
|
||||||
|
}
|
||||||
set(props.nodeModel, 'validate', validate)
|
set(props.nodeModel, 'validate', validate)
|
||||||
})
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@ -46,6 +46,23 @@
|
|||||||
v-model="form_data.fields"
|
v-model="form_data.fields"
|
||||||
/>
|
/>
|
||||||
</el-form-item>
|
</el-form-item>
|
||||||
|
<el-form-item label="返回内容" @click.prevent>
|
||||||
|
<template #label>
|
||||||
|
<div class="flex align-center">
|
||||||
|
<div class="mr-4">
|
||||||
|
<span>返回内容<span class="danger">*</span></span>
|
||||||
|
</div>
|
||||||
|
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
|
||||||
|
<template #content>
|
||||||
|
关闭后该节点的内容则不输出给用户。
|
||||||
|
如果你想让用户看到该节点的输出内容,请打开开关。
|
||||||
|
</template>
|
||||||
|
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||||
|
</el-tooltip>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<el-switch size="small" v-model="form_data.is_result" />
|
||||||
|
</el-form-item>
|
||||||
</el-form>
|
</el-form>
|
||||||
</el-card>
|
</el-card>
|
||||||
<!-- 回复内容弹出层 -->
|
<!-- 回复内容弹出层 -->
|
||||||
@ -64,12 +81,14 @@ import { set } from 'lodash'
|
|||||||
import NodeContainer from '@/workflow/common/NodeContainer.vue'
|
import NodeContainer from '@/workflow/common/NodeContainer.vue'
|
||||||
import NodeCascader from '@/workflow/common/NodeCascader.vue'
|
import NodeCascader from '@/workflow/common/NodeCascader.vue'
|
||||||
import { ref, computed, onMounted } from 'vue'
|
import { ref, computed, onMounted } from 'vue'
|
||||||
|
import { isLastNode } from '@/workflow/common/data'
|
||||||
|
|
||||||
const props = defineProps<{ nodeModel: any }>()
|
const props = defineProps<{ nodeModel: any }>()
|
||||||
const form = {
|
const form = {
|
||||||
reply_type: 'content',
|
reply_type: 'content',
|
||||||
content: '',
|
content: '',
|
||||||
fields: []
|
fields: [],
|
||||||
|
is_result: false
|
||||||
}
|
}
|
||||||
const footers: any = [null, '=', 0]
|
const footers: any = [null, '=', 0]
|
||||||
|
|
||||||
@ -111,6 +130,12 @@ const validate = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
|
if (typeof props.nodeModel.properties.node_data?.is_result === 'undefined') {
|
||||||
|
if (isLastNode(props.nodeModel)) {
|
||||||
|
set(props.nodeModel.properties.node_data, 'is_result', true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
set(props.nodeModel, 'validate', validate)
|
set(props.nodeModel, 'validate', validate)
|
||||||
})
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user