perf: 应用的AI模型修改为不必填 (#297)
This commit is contained in:
parent
5705f3c4a8
commit
7b5ccd9089
@ -54,7 +54,7 @@ class IChatStep(IBaseChatPipelineStep):
|
|||||||
message_list = serializers.ListField(required=True, child=MessageField(required=True),
|
message_list = serializers.ListField(required=True, child=MessageField(required=True),
|
||||||
error_messages=ErrMessage.list("对话列表"))
|
error_messages=ErrMessage.list("对话列表"))
|
||||||
# 大语言模型
|
# 大语言模型
|
||||||
chat_model = ModelField(error_messages=ErrMessage.list("大语言模型"))
|
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.list("大语言模型"))
|
||||||
# 段落列表
|
# 段落列表
|
||||||
paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表"))
|
paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表"))
|
||||||
# 对话id
|
# 对话id
|
||||||
|
|||||||
@ -126,6 +126,26 @@ class BaseChatStep(IChatStep):
|
|||||||
result.append({'role': 'ai', 'content': answer_text})
|
result.append({'role': 'ai', 'content': answer_text})
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_stream_result(message_list: List[BaseMessage],
|
||||||
|
chat_model: BaseChatModel = None,
|
||||||
|
paragraph_list=None,
|
||||||
|
no_references_setting=None):
|
||||||
|
if paragraph_list is None:
|
||||||
|
paragraph_list = []
|
||||||
|
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
|
||||||
|
for paragraph in paragraph_list if
|
||||||
|
paragraph.hit_handling_method == 'directly_return']
|
||||||
|
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
|
||||||
|
return iter(directly_return_chunk_list), False
|
||||||
|
elif no_references_setting.get(
|
||||||
|
'status') == 'designated_answer':
|
||||||
|
return iter([AIMessageChunk(content=no_references_setting.get('value'))]), False
|
||||||
|
if chat_model is None:
|
||||||
|
return iter([AIMessageChunk('抱歉,没有在知识库中查询到相关信息。')]), False
|
||||||
|
else:
|
||||||
|
return chat_model.stream(message_list), True
|
||||||
|
|
||||||
def execute_stream(self, message_list: List[BaseMessage],
|
def execute_stream(self, message_list: List[BaseMessage],
|
||||||
chat_id,
|
chat_id,
|
||||||
problem_text,
|
problem_text,
|
||||||
@ -136,29 +156,8 @@ class BaseChatStep(IChatStep):
|
|||||||
padding_problem_text: str = None,
|
padding_problem_text: str = None,
|
||||||
client_id=None, client_type=None,
|
client_id=None, client_type=None,
|
||||||
no_references_setting=None):
|
no_references_setting=None):
|
||||||
is_ai_chat = False
|
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
|
||||||
# 调用模型
|
no_references_setting)
|
||||||
if chat_model is None:
|
|
||||||
chat_result = iter(
|
|
||||||
[AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
|
|
||||||
else:
|
|
||||||
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
|
|
||||||
'status') == 'designated_answer':
|
|
||||||
chat_result = iter([AIMessageChunk(content=no_references_setting.get('value'))])
|
|
||||||
else:
|
|
||||||
if paragraph_list is not None and len(paragraph_list) > 0:
|
|
||||||
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
|
|
||||||
for paragraph in paragraph_list if
|
|
||||||
paragraph.hit_handling_method == 'directly_return']
|
|
||||||
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
|
|
||||||
chat_result = iter(directly_return_chunk_list)
|
|
||||||
else:
|
|
||||||
chat_result = chat_model.stream(message_list)
|
|
||||||
is_ai_chat = True
|
|
||||||
else:
|
|
||||||
chat_result = chat_model.stream(message_list)
|
|
||||||
is_ai_chat = True
|
|
||||||
|
|
||||||
chat_record_id = uuid.uuid1()
|
chat_record_id = uuid.uuid1()
|
||||||
r = StreamingHttpResponse(
|
r = StreamingHttpResponse(
|
||||||
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
||||||
@ -169,6 +168,27 @@ class BaseChatStep(IChatStep):
|
|||||||
r['Cache-Control'] = 'no-cache'
|
r['Cache-Control'] = 'no-cache'
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_block_result(message_list: List[BaseMessage],
|
||||||
|
chat_model: BaseChatModel = None,
|
||||||
|
paragraph_list=None,
|
||||||
|
no_references_setting=None):
|
||||||
|
if paragraph_list is None:
|
||||||
|
paragraph_list = []
|
||||||
|
|
||||||
|
directly_return_chunk_list = [AIMessage(content=paragraph.content)
|
||||||
|
for paragraph in paragraph_list if
|
||||||
|
paragraph.hit_handling_method == 'directly_return']
|
||||||
|
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
|
||||||
|
return directly_return_chunk_list[0], False
|
||||||
|
elif no_references_setting.get(
|
||||||
|
'status') == 'designated_answer':
|
||||||
|
return AIMessage(no_references_setting.get('value')), False
|
||||||
|
if chat_model is None:
|
||||||
|
return AIMessage('抱歉,没有在知识库中查询到相关信息。'), False
|
||||||
|
else:
|
||||||
|
return chat_model.invoke(message_list), True
|
||||||
|
|
||||||
def execute_block(self, message_list: List[BaseMessage],
|
def execute_block(self, message_list: List[BaseMessage],
|
||||||
chat_id,
|
chat_id,
|
||||||
problem_text,
|
problem_text,
|
||||||
@ -178,28 +198,8 @@ class BaseChatStep(IChatStep):
|
|||||||
manage: PipelineManage = None,
|
manage: PipelineManage = None,
|
||||||
padding_problem_text: str = None,
|
padding_problem_text: str = None,
|
||||||
client_id=None, client_type=None, no_references_setting=None):
|
client_id=None, client_type=None, no_references_setting=None):
|
||||||
is_ai_chat = False
|
|
||||||
# 调用模型
|
# 调用模型
|
||||||
if chat_model is None:
|
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, no_references_setting)
|
||||||
chat_result = AIMessage(
|
|
||||||
content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list]))
|
|
||||||
else:
|
|
||||||
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
|
|
||||||
'status') == 'designated_answer':
|
|
||||||
chat_result = AIMessage(content=no_references_setting.get('value'))
|
|
||||||
else:
|
|
||||||
if paragraph_list is not None and len(paragraph_list) > 0:
|
|
||||||
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
|
|
||||||
for paragraph in paragraph_list if
|
|
||||||
paragraph.hit_handling_method == 'directly_return']
|
|
||||||
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
|
|
||||||
chat_result = iter(directly_return_chunk_list)
|
|
||||||
else:
|
|
||||||
chat_result = chat_model.invoke(message_list)
|
|
||||||
is_ai_chat = True
|
|
||||||
else:
|
|
||||||
chat_result = chat_model.invoke(message_list)
|
|
||||||
is_ai_chat = True
|
|
||||||
chat_record_id = uuid.uuid1()
|
chat_record_id = uuid.uuid1()
|
||||||
if is_ai_chat:
|
if is_ai_chat:
|
||||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||||
|
|||||||
@ -28,7 +28,7 @@ class IResetProblemStep(IBaseChatPipelineStep):
|
|||||||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
||||||
error_messages=ErrMessage.list("历史对答"))
|
error_messages=ErrMessage.list("历史对答"))
|
||||||
# 大语言模型
|
# 大语言模型
|
||||||
chat_model = ModelField(error_messages=ErrMessage.base("大语言模型"))
|
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.base("大语言模型"))
|
||||||
|
|
||||||
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
|
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
|
||||||
return self.InstanceSerializer
|
return self.InstanceSerializer
|
||||||
|
|||||||
@ -22,6 +22,8 @@ prompt = (
|
|||||||
class BaseResetProblemStep(IResetProblemStep):
|
class BaseResetProblemStep(IResetProblemStep):
|
||||||
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
|
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
|
||||||
**kwargs) -> str:
|
**kwargs) -> str:
|
||||||
|
if chat_model is None:
|
||||||
|
return problem_text
|
||||||
start_index = len(history_chat_record) - 3
|
start_index = len(history_chat_record) - 3
|
||||||
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||||
for index in
|
for index in
|
||||||
|
|||||||
@ -47,7 +47,8 @@ chat_cache = cache.caches['chat_cache']
|
|||||||
|
|
||||||
class ModelDatasetAssociation(serializers.Serializer):
|
class ModelDatasetAssociation(serializers.Serializer):
|
||||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
|
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||||
|
error_messages=ErrMessage.char("模型id"))
|
||||||
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True,
|
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True,
|
||||||
error_messages=ErrMessage.uuid(
|
error_messages=ErrMessage.uuid(
|
||||||
"知识库id")),
|
"知识库id")),
|
||||||
@ -57,6 +58,7 @@ class ModelDatasetAssociation(serializers.Serializer):
|
|||||||
super().is_valid(raise_exception=True)
|
super().is_valid(raise_exception=True)
|
||||||
model_id = self.data.get('model_id')
|
model_id = self.data.get('model_id')
|
||||||
user_id = self.data.get('user_id')
|
user_id = self.data.get('user_id')
|
||||||
|
if model_id is not None and len(model_id) > 0:
|
||||||
if not QuerySet(Model).filter(id=model_id).exists():
|
if not QuerySet(Model).filter(id=model_id).exists():
|
||||||
raise AppApiException(500, f'模型不存在【{model_id}】')
|
raise AppApiException(500, f'模型不存在【{model_id}】')
|
||||||
dataset_id_list = list(set(self.data.get('dataset_id_list')))
|
dataset_id_list = list(set(self.data.get('dataset_id_list')))
|
||||||
@ -109,7 +111,8 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||||
max_length=256, min_length=1,
|
max_length=256, min_length=1,
|
||||||
error_messages=ErrMessage.char("应用描述"))
|
error_messages=ErrMessage.char("应用描述"))
|
||||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型"))
|
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||||
|
error_messages=ErrMessage.char("模型"))
|
||||||
multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话"))
|
multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话"))
|
||||||
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024,
|
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024,
|
||||||
error_messages=ErrMessage.char("开场白"))
|
error_messages=ErrMessage.char("开场白"))
|
||||||
@ -254,7 +257,8 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
error_messages=ErrMessage.char("应用名称"))
|
error_messages=ErrMessage.char("应用名称"))
|
||||||
desc = serializers.CharField(required=False, max_length=256, min_length=1, allow_null=True, allow_blank=True,
|
desc = serializers.CharField(required=False, max_length=256, min_length=1, allow_null=True, allow_blank=True,
|
||||||
error_messages=ErrMessage.char("应用描述"))
|
error_messages=ErrMessage.char("应用描述"))
|
||||||
model_id = serializers.CharField(required=False, error_messages=ErrMessage.char("模型"))
|
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||||
|
error_messages=ErrMessage.char("模型"))
|
||||||
multiple_rounds_dialogue = serializers.BooleanField(required=False,
|
multiple_rounds_dialogue = serializers.BooleanField(required=False,
|
||||||
error_messages=ErrMessage.boolean("多轮会话"))
|
error_messages=ErrMessage.boolean("多轮会话"))
|
||||||
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024,
|
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024,
|
||||||
@ -494,13 +498,14 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
application_id = self.data.get("application_id")
|
application_id = self.data.get("application_id")
|
||||||
|
|
||||||
application = QuerySet(Application).get(id=application_id)
|
application = QuerySet(Application).get(id=application_id)
|
||||||
|
if instance.get('model_id') is None or len(instance.get('model_id')) == 0:
|
||||||
|
application.model_id = None
|
||||||
|
else:
|
||||||
model = QuerySet(Model).filter(
|
model = QuerySet(Model).filter(
|
||||||
id=instance.get('model_id') if 'model_id' in instance else application.model_id,
|
id=instance.get('model_id'),
|
||||||
user_id=application.user_id).first()
|
user_id=application.user_id).first()
|
||||||
if model is None:
|
if model is None:
|
||||||
raise AppApiException(500, "模型不存在")
|
raise AppApiException(500, "模型不存在")
|
||||||
|
|
||||||
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
|
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
|
||||||
'dataset_setting', 'model_setting', 'problem_optimization',
|
'dataset_setting', 'model_setting', 'problem_optimization',
|
||||||
'api_key_is_active', 'icon']
|
'api_key_is_active', 'icon']
|
||||||
|
|||||||
@ -167,9 +167,11 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||||||
chat_cache.set(chat_id,
|
chat_cache.set(chat_id,
|
||||||
chat_info, timeout=60 * 30)
|
chat_info, timeout=60 * 30)
|
||||||
model = chat_info.application.model
|
model = chat_info.application.model
|
||||||
|
if model is None:
|
||||||
|
return chat_info
|
||||||
model = QuerySet(Model).filter(id=model.id).first()
|
model = QuerySet(Model).filter(id=model.id).first()
|
||||||
if model is None:
|
if model is None:
|
||||||
raise AppApiException(500, "模型不存在")
|
return chat_info
|
||||||
if model.status == Status.ERROR:
|
if model.status == Status.ERROR:
|
||||||
raise AppApiException(500, "当前模型不可用")
|
raise AppApiException(500, "当前模型不可用")
|
||||||
if model.status == Status.DOWNLOAD:
|
if model.status == Status.DOWNLOAD:
|
||||||
|
|||||||
@ -213,7 +213,8 @@ class ChatSerializers(serializers.Serializer):
|
|||||||
|
|
||||||
id = serializers.UUIDField(required=False, allow_null=True,
|
id = serializers.UUIDField(required=False, allow_null=True,
|
||||||
error_messages=ErrMessage.uuid("应用id"))
|
error_messages=ErrMessage.uuid("应用id"))
|
||||||
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
|
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||||
|
error_messages=ErrMessage.uuid("模型id"))
|
||||||
|
|
||||||
multiple_rounds_dialogue = serializers.BooleanField(required=True,
|
multiple_rounds_dialogue = serializers.BooleanField(required=True,
|
||||||
error_messages=ErrMessage.boolean("多轮会话"))
|
error_messages=ErrMessage.boolean("多轮会话"))
|
||||||
@ -246,14 +247,17 @@ class ChatSerializers(serializers.Serializer):
|
|||||||
def open(self):
|
def open(self):
|
||||||
user_id = self.is_valid(raise_exception=True)
|
user_id = self.is_valid(raise_exception=True)
|
||||||
chat_id = str(uuid.uuid1())
|
chat_id = str(uuid.uuid1())
|
||||||
|
model_id = self.data.get('model_id')
|
||||||
|
if model_id is not None and len(model_id) > 0:
|
||||||
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
|
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
|
||||||
if model is None:
|
|
||||||
raise AppApiException(500, "模型不存在")
|
|
||||||
dataset_id_list = self.data.get('dataset_id_list')
|
|
||||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||||
json.loads(
|
json.loads(
|
||||||
decrypt(model.credential)),
|
decrypt(model.credential)),
|
||||||
streaming=True)
|
streaming=True)
|
||||||
|
else:
|
||||||
|
model = None
|
||||||
|
chat_model = None
|
||||||
|
dataset_id_list = self.data.get('dataset_id_list')
|
||||||
application = Application(id=None, dialogue_number=3, model=model,
|
application = Application(id=None, dialogue_number=3, model=model,
|
||||||
dataset_setting=self.data.get('dataset_setting'),
|
dataset_setting=self.data.get('dataset_setting'),
|
||||||
model_setting=self.data.get('model_setting'),
|
model_setting=self.data.get('model_setting'),
|
||||||
|
|||||||
@ -224,7 +224,7 @@ const chartOpenId = ref('')
|
|||||||
const chatList = ref<any[]>([])
|
const chatList = ref<any[]>([])
|
||||||
|
|
||||||
const isDisabledChart = computed(
|
const isDisabledChart = computed(
|
||||||
() => !(inputValue.value.trim() && (props.appId || (props.data?.name && props.data?.model_id)))
|
() => !(inputValue.value.trim() && (props.appId || props.data?.name))
|
||||||
)
|
)
|
||||||
const isMdArray = (val: string) => val.match(/^-\s.*/m)
|
const isMdArray = (val: string) => val.match(/^-\s.*/m)
|
||||||
const prologueList = computed(() => {
|
const prologueList = computed(() => {
|
||||||
@ -509,9 +509,7 @@ function regenerationChart(item: chatType) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function getSourceDetail(row: any) {
|
function getSourceDetail(row: any) {
|
||||||
logApi
|
logApi.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading).then((res) => {
|
||||||
.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading)
|
|
||||||
.then((res) => {
|
|
||||||
const exclude_keys = ['answer_text', 'id']
|
const exclude_keys = ['answer_text', 'id']
|
||||||
Object.keys(res.data).forEach((key) => {
|
Object.keys(res.data).forEach((key) => {
|
||||||
if (!exclude_keys.includes(key)) {
|
if (!exclude_keys.includes(key)) {
|
||||||
|
|||||||
@ -48,7 +48,7 @@
|
|||||||
<el-form-item label="AI 模型" prop="model_id">
|
<el-form-item label="AI 模型" prop="model_id">
|
||||||
<template #label>
|
<template #label>
|
||||||
<div class="flex-between">
|
<div class="flex-between">
|
||||||
<span>AI 模型 <span class="danger">*</span></span>
|
<span>AI 模型 </span>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
<el-select
|
<el-select
|
||||||
@ -56,6 +56,7 @@
|
|||||||
placeholder="请选择 AI 模型"
|
placeholder="请选择 AI 模型"
|
||||||
class="w-full"
|
class="w-full"
|
||||||
popper-class="select-model"
|
popper-class="select-model"
|
||||||
|
:clearable="true"
|
||||||
>
|
>
|
||||||
<el-option-group
|
<el-option-group
|
||||||
v-for="(value, label) in modelOptions"
|
v-for="(value, label) in modelOptions"
|
||||||
@ -338,7 +339,7 @@ const rules = reactive<FormRules<ApplicationFormType>>({
|
|||||||
name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }],
|
name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }],
|
||||||
model_id: [
|
model_id: [
|
||||||
{
|
{
|
||||||
required: true,
|
required: false,
|
||||||
message: '请选择模型',
|
message: '请选择模型',
|
||||||
trigger: 'change'
|
trigger: 'change'
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user