fix: 修复工作流编排历史数据模型参数必填

This commit is contained in:
shaohuzhang1 2024-08-29 14:06:36 +08:00 committed by shaohuzhang1
parent c2622e4a5d
commit b627b6638c
4 changed files with 28 additions and 3 deletions

View File

@ -24,7 +24,7 @@ class ChatNodeSerializer(serializers.Serializer):
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.integer("模型参数相关设置")) model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置"))
class IChatNode(INode): class IChatNode(INode):

View File

@ -10,11 +10,14 @@ import time
from functools import reduce from functools import reduce
from typing import List, Dict from typing import List, Dict
from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from application.flow.i_step_node import NodeResult, INode from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from setting.models import Model
from setting.models_provider import get_model_credential
from setting.models_provider.tools import get_model_instance_by_model_user_id from setting.models_provider.tools import get_model_instance_by_model_user_id
@ -61,11 +64,20 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
_write_context(node_variable, workflow_variable, node, workflow, answer) _write_context(node_variable, workflow_variable, node, workflow, answer)
def get_default_model_params_setting(model_id):
model = QuerySet(Model).filter(id=model_id).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
model_params_setting = credential.get_model_params_setting_form(
model.model_name).get_default_form_data()
return model_params_setting
class BaseChatNode(IChatNode): class BaseChatNode(IChatNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting, model_params_setting,
**kwargs) -> NodeResult: **kwargs) -> NodeResult:
if model_params_setting is None:
model_params_setting = get_default_model_params_setting(model_id)
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting) **model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number) history_message = self.get_history_message(history_chat_record, dialogue_number)

View File

@ -23,7 +23,7 @@ class QuestionNodeSerializer(serializers.Serializer):
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.integer("模型参数相关设置")) model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置"))
class IQuestionNode(INode): class IQuestionNode(INode):

View File

@ -10,11 +10,14 @@ import time
from functools import reduce from functools import reduce
from typing import List, Dict from typing import List, Dict
from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from application.flow.i_step_node import NodeResult, INode from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.question_node.i_question_node import IQuestionNode from application.flow.step_node.question_node.i_question_node import IQuestionNode
from setting.models import Model
from setting.models_provider import get_model_credential
from setting.models_provider.tools import get_model_instance_by_model_user_id from setting.models_provider.tools import get_model_instance_by_model_user_id
@ -61,10 +64,20 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
_write_context(node_variable, workflow_variable, node, workflow, answer) _write_context(node_variable, workflow_variable, node, workflow, answer)
def get_default_model_params_setting(model_id):
model = QuerySet(Model).filter(id=model_id).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
model_params_setting = credential.get_model_params_setting_form(
model.model_name).get_default_form_data()
return model_params_setting
class BaseQuestionNode(IQuestionNode): class BaseQuestionNode(IQuestionNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting, model_params_setting,
**kwargs) -> NodeResult: **kwargs) -> NodeResult:
if model_params_setting is None:
model_params_setting = get_default_model_params_setting(model_id)
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting) **model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number) history_message = self.get_history_message(history_chat_record, dialogue_number)