feat: 支持工作流ai对话节点添加节点上下文 (#1791)

This commit is contained in:
shaohuzhang1 2024-12-09 11:17:58 +08:00 committed by GitHub
parent 5c64d630a0
commit f65546a619
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 49 additions and 7 deletions

View File

@ -26,6 +26,8 @@ class ChatNodeSerializer(serializers.Serializer):
model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置")) model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置"))
dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("上下文类型"))
class IChatNode(INode): class IChatNode(INode):
type = 'ai-chat-node' type = 'ai-chat-node'
@ -39,5 +41,6 @@ class IChatNode(INode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
chat_record_id, chat_record_id,
model_params_setting=None, model_params_setting=None,
dialogue_type=None,
**kwargs) -> NodeResult: **kwargs) -> NodeResult:
pass pass

View File

@ -12,7 +12,7 @@ from typing import List, Dict
from django.db.models import QuerySet from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage, AIMessage
from application.flow.i_step_node import NodeResult, INode from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
@ -72,6 +72,22 @@ def get_default_model_params_setting(model_id):
return model_params_setting return model_params_setting
def get_node_message(chat_record, runtime_node_id):
node_details = chat_record.get_node_details_runtime_node_id(runtime_node_id)
if node_details is None:
return []
return [HumanMessage(node_details.get('question')), AIMessage(node_details.get('answer'))]
def get_workflow_message(chat_record):
return [chat_record.get_human_message(), chat_record.get_ai_message()]
def get_message(chat_record, dialogue_type, runtime_node_id):
return get_node_message(chat_record, runtime_node_id) if dialogue_type == 'NODE' else get_workflow_message(
chat_record)
class BaseChatNode(IChatNode): class BaseChatNode(IChatNode):
def save_context(self, details, workflow_manage): def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer') self.context['answer'] = details.get('answer')
@ -80,12 +96,17 @@ class BaseChatNode(IChatNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting=None, model_params_setting=None,
dialogue_type=None,
**kwargs) -> NodeResult: **kwargs) -> NodeResult:
if dialogue_type is None:
dialogue_type = 'WORKFLOW'
if model_params_setting is None: if model_params_setting is None:
model_params_setting = get_default_model_params_setting(model_id) model_params_setting = get_default_model_params_setting(model_id)
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting) **model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number) history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type,
self.runtime_node_id)
self.context['history_message'] = history_message self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt) question = self.generate_prompt_question(prompt)
self.context['question'] = question.content self.context['question'] = question.content
@ -103,10 +124,10 @@ class BaseChatNode(IChatNode):
_write_context=write_context) _write_context=write_context)
@staticmethod @staticmethod
def get_history_message(history_chat_record, dialogue_number): def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id):
start_index = len(history_chat_record) - dialogue_number start_index = len(history_chat_record) - dialogue_number
history_message = reduce(lambda x, y: [*x, *y], [ history_message = reduce(lambda x, y: [*x, *y], [
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] get_message(history_chat_record[index], dialogue_type, runtime_node_id)
for index in for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))], []) range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
return history_message return history_message

View File

@ -167,5 +167,8 @@ class ChatRecord(AppModelMixin):
def get_ai_message(self): def get_ai_message(self):
return AIMessage(content=self.answer_text) return AIMessage(content=self.answer_text)
def get_node_details_runtime_node_id(self, runtime_node_id):
return self.details.get(runtime_node_id, None)
class Meta: class Meta:
db_table = "application_chat_record" db_table = "application_chat_record"

View File

@ -93,9 +93,8 @@
v-if="showAnchor" v-if="showAnchor"
@mousemove.stop @mousemove.stop
@mousedown.stop @mousedown.stop
@keydown.stop
@click.stop @click.stop
@wheel.stop @wheel="handleWheel"
:show="showAnchor" :show="showAnchor"
:id="id" :id="id"
style="left: 100%; top: 50%; transform: translate(0, -50%)" style="left: 100%; top: 50%; transform: translate(0, -50%)"
@ -142,6 +141,12 @@ const showNode = computed({
return true return true
} }
}) })
const handleWheel = (event: any) => {
const isCombinationKeyPressed = event.ctrlKey || event.metaKey
if (!isCombinationKeyPressed) {
event.stopPropagation()
}
}
const node_status = computed(() => { const node_status = computed(() => {
if (props.nodeModel.properties.status) { if (props.nodeModel.properties.status) {
return props.nodeModel.properties.status return props.nodeModel.properties.status

View File

@ -148,6 +148,15 @@
/> />
</el-form-item> </el-form-item>
<el-form-item label="历史聊天记录"> <el-form-item label="历史聊天记录">
<template #label>
<div class="flex-between">
<div>历史聊天记录</div>
<el-select v-model="chat_data.dialogue_type" type="small" style="width: 100px">
<el-option label="节点" value="NODE" />
<el-option label="工作流" value="WORKFLOW" />
</el-select>
</div>
</template>
<el-input-number <el-input-number
v-model="chat_data.dialogue_number" v-model="chat_data.dialogue_number"
:min="0" :min="0"
@ -246,7 +255,8 @@ const form = {
dialogue_number: 1, dialogue_number: 1,
is_result: false, is_result: false,
temperature: null, temperature: null,
max_tokens: null max_tokens: null,
dialogue_type: 'WORKFLOW'
} }
const chat_data = computed({ const chat_data = computed({