perf: 应用的AI模型修改为不必填 (#297)

This commit is contained in:
shaohuzhang1 2024-04-28 17:09:12 +08:00 committed by GitHub
parent 5705f3c4a8
commit 7b5ccd9089
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 91 additions and 79 deletions

View File

@ -54,7 +54,7 @@ class IChatStep(IBaseChatPipelineStep):
message_list = serializers.ListField(required=True, child=MessageField(required=True), message_list = serializers.ListField(required=True, child=MessageField(required=True),
error_messages=ErrMessage.list("对话列表")) error_messages=ErrMessage.list("对话列表"))
# 大语言模型 # 大语言模型
chat_model = ModelField(error_messages=ErrMessage.list("大语言模型")) chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.list("大语言模型"))
# 段落列表 # 段落列表
paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表")) paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表"))
# 对话id # 对话id

View File

@ -126,6 +126,26 @@ class BaseChatStep(IChatStep):
result.append({'role': 'ai', 'content': answer_text}) result.append({'role': 'ai', 'content': answer_text})
return result return result
@staticmethod
def get_stream_result(message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
no_references_setting=None):
if paragraph_list is None:
paragraph_list = []
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
return iter(directly_return_chunk_list), False
elif no_references_setting.get(
'status') == 'designated_answer':
return iter([AIMessageChunk(content=no_references_setting.get('value'))]), False
if chat_model is None:
return iter([AIMessageChunk('抱歉,没有在知识库中查询到相关信息。')]), False
else:
return chat_model.stream(message_list), True
def execute_stream(self, message_list: List[BaseMessage], def execute_stream(self, message_list: List[BaseMessage],
chat_id, chat_id,
problem_text, problem_text,
@ -136,29 +156,8 @@ class BaseChatStep(IChatStep):
padding_problem_text: str = None, padding_problem_text: str = None,
client_id=None, client_type=None, client_id=None, client_type=None,
no_references_setting=None): no_references_setting=None):
is_ai_chat = False chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
# 调用模型 no_references_setting)
if chat_model is None:
chat_result = iter(
[AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
else:
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
'status') == 'designated_answer':
chat_result = iter([AIMessageChunk(content=no_references_setting.get('value'))])
else:
if paragraph_list is not None and len(paragraph_list) > 0:
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
chat_result = iter(directly_return_chunk_list)
else:
chat_result = chat_model.stream(message_list)
is_ai_chat = True
else:
chat_result = chat_model.stream(message_list)
is_ai_chat = True
chat_record_id = uuid.uuid1() chat_record_id = uuid.uuid1()
r = StreamingHttpResponse( r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
@ -169,6 +168,27 @@ class BaseChatStep(IChatStep):
r['Cache-Control'] = 'no-cache' r['Cache-Control'] = 'no-cache'
return r return r
@staticmethod
def get_block_result(message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
no_references_setting=None):
if paragraph_list is None:
paragraph_list = []
directly_return_chunk_list = [AIMessage(content=paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
return directly_return_chunk_list[0], False
elif no_references_setting.get(
'status') == 'designated_answer':
return AIMessage(no_references_setting.get('value')), False
if chat_model is None:
return AIMessage('抱歉,没有在知识库中查询到相关信息。'), False
else:
return chat_model.invoke(message_list), True
def execute_block(self, message_list: List[BaseMessage], def execute_block(self, message_list: List[BaseMessage],
chat_id, chat_id,
problem_text, problem_text,
@ -178,28 +198,8 @@ class BaseChatStep(IChatStep):
manage: PipelineManage = None, manage: PipelineManage = None,
padding_problem_text: str = None, padding_problem_text: str = None,
client_id=None, client_type=None, no_references_setting=None): client_id=None, client_type=None, no_references_setting=None):
is_ai_chat = False
# 调用模型 # 调用模型
if chat_model is None: chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, no_references_setting)
chat_result = AIMessage(
content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list]))
else:
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
'status') == 'designated_answer':
chat_result = AIMessage(content=no_references_setting.get('value'))
else:
if paragraph_list is not None and len(paragraph_list) > 0:
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
chat_result = iter(directly_return_chunk_list)
else:
chat_result = chat_model.invoke(message_list)
is_ai_chat = True
else:
chat_result = chat_model.invoke(message_list)
is_ai_chat = True
chat_record_id = uuid.uuid1() chat_record_id = uuid.uuid1()
if is_ai_chat: if is_ai_chat:
request_token = chat_model.get_num_tokens_from_messages(message_list) request_token = chat_model.get_num_tokens_from_messages(message_list)

View File

@ -28,7 +28,7 @@ class IResetProblemStep(IBaseChatPipelineStep):
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True), history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
error_messages=ErrMessage.list("历史对答")) error_messages=ErrMessage.list("历史对答"))
# 大语言模型 # 大语言模型
chat_model = ModelField(error_messages=ErrMessage.base("大语言模型")) chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.base("大语言模型"))
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer return self.InstanceSerializer

View File

@ -22,6 +22,8 @@ prompt = (
class BaseResetProblemStep(IResetProblemStep): class BaseResetProblemStep(IResetProblemStep):
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None, def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
**kwargs) -> str: **kwargs) -> str:
if chat_model is None:
return problem_text
start_index = len(history_chat_record) - 3 start_index = len(history_chat_record) - 3
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in for index in

View File

@ -47,7 +47,8 @@ chat_cache = cache.caches['chat_cache']
class ModelDatasetAssociation(serializers.Serializer): class ModelDatasetAssociation(serializers.Serializer):
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("模型id"))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True, dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True,
error_messages=ErrMessage.uuid( error_messages=ErrMessage.uuid(
"知识库id")), "知识库id")),
@ -57,8 +58,9 @@ class ModelDatasetAssociation(serializers.Serializer):
super().is_valid(raise_exception=True) super().is_valid(raise_exception=True)
model_id = self.data.get('model_id') model_id = self.data.get('model_id')
user_id = self.data.get('user_id') user_id = self.data.get('user_id')
if not QuerySet(Model).filter(id=model_id).exists(): if model_id is not None and len(model_id) > 0:
raise AppApiException(500, f'模型不存在【{model_id}') if not QuerySet(Model).filter(id=model_id).exists():
raise AppApiException(500, f'模型不存在【{model_id}')
dataset_id_list = list(set(self.data.get('dataset_id_list'))) dataset_id_list = list(set(self.data.get('dataset_id_list')))
exist_dataset_id_list = [str(dataset.id) for dataset in exist_dataset_id_list = [str(dataset.id) for dataset in
QuerySet(DataSet).filter(id__in=dataset_id_list, user_id=user_id)] QuerySet(DataSet).filter(id__in=dataset_id_list, user_id=user_id)]
@ -109,7 +111,8 @@ class ApplicationSerializer(serializers.Serializer):
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, desc = serializers.CharField(required=False, allow_null=True, allow_blank=True,
max_length=256, min_length=1, max_length=256, min_length=1,
error_messages=ErrMessage.char("应用描述")) error_messages=ErrMessage.char("应用描述"))
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型")) model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("模型"))
multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话")) multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话"))
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024, prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024,
error_messages=ErrMessage.char("开场白")) error_messages=ErrMessage.char("开场白"))
@ -254,7 +257,8 @@ class ApplicationSerializer(serializers.Serializer):
error_messages=ErrMessage.char("应用名称")) error_messages=ErrMessage.char("应用名称"))
desc = serializers.CharField(required=False, max_length=256, min_length=1, allow_null=True, allow_blank=True, desc = serializers.CharField(required=False, max_length=256, min_length=1, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("应用描述")) error_messages=ErrMessage.char("应用描述"))
model_id = serializers.CharField(required=False, error_messages=ErrMessage.char("模型")) model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char("模型"))
multiple_rounds_dialogue = serializers.BooleanField(required=False, multiple_rounds_dialogue = serializers.BooleanField(required=False,
error_messages=ErrMessage.boolean("多轮会话")) error_messages=ErrMessage.boolean("多轮会话"))
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024, prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024,
@ -494,13 +498,14 @@ class ApplicationSerializer(serializers.Serializer):
application_id = self.data.get("application_id") application_id = self.data.get("application_id")
application = QuerySet(Application).get(id=application_id) application = QuerySet(Application).get(id=application_id)
if instance.get('model_id') is None or len(instance.get('model_id')) == 0:
model = QuerySet(Model).filter( application.model_id = None
id=instance.get('model_id') if 'model_id' in instance else application.model_id, else:
user_id=application.user_id).first() model = QuerySet(Model).filter(
if model is None: id=instance.get('model_id'),
raise AppApiException(500, "模型不存在") user_id=application.user_id).first()
if model is None:
raise AppApiException(500, "模型不存在")
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status', update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
'dataset_setting', 'model_setting', 'problem_optimization', 'dataset_setting', 'model_setting', 'problem_optimization',
'api_key_is_active', 'icon'] 'api_key_is_active', 'icon']

View File

@ -167,9 +167,11 @@ class ChatMessageSerializer(serializers.Serializer):
chat_cache.set(chat_id, chat_cache.set(chat_id,
chat_info, timeout=60 * 30) chat_info, timeout=60 * 30)
model = chat_info.application.model model = chat_info.application.model
if model is None:
return chat_info
model = QuerySet(Model).filter(id=model.id).first() model = QuerySet(Model).filter(id=model.id).first()
if model is None: if model is None:
raise AppApiException(500, "模型不存在") return chat_info
if model.status == Status.ERROR: if model.status == Status.ERROR:
raise AppApiException(500, "当前模型不可用") raise AppApiException(500, "当前模型不可用")
if model.status == Status.DOWNLOAD: if model.status == Status.DOWNLOAD:

View File

@ -213,7 +213,8 @@ class ChatSerializers(serializers.Serializer):
id = serializers.UUIDField(required=False, allow_null=True, id = serializers.UUIDField(required=False, allow_null=True,
error_messages=ErrMessage.uuid("应用id")) error_messages=ErrMessage.uuid("应用id"))
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.uuid("模型id"))
multiple_rounds_dialogue = serializers.BooleanField(required=True, multiple_rounds_dialogue = serializers.BooleanField(required=True,
error_messages=ErrMessage.boolean("多轮会话")) error_messages=ErrMessage.boolean("多轮会话"))
@ -246,14 +247,17 @@ class ChatSerializers(serializers.Serializer):
def open(self): def open(self):
user_id = self.is_valid(raise_exception=True) user_id = self.is_valid(raise_exception=True)
chat_id = str(uuid.uuid1()) chat_id = str(uuid.uuid1())
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first() model_id = self.data.get('model_id')
if model is None: if model_id is not None and len(model_id) > 0:
raise AppApiException(500, "模型不存在") model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
streaming=True)
else:
model = None
chat_model = None
dataset_id_list = self.data.get('dataset_id_list') dataset_id_list = self.data.get('dataset_id_list')
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
streaming=True)
application = Application(id=None, dialogue_number=3, model=model, application = Application(id=None, dialogue_number=3, model=model,
dataset_setting=self.data.get('dataset_setting'), dataset_setting=self.data.get('dataset_setting'),
model_setting=self.data.get('model_setting'), model_setting=self.data.get('model_setting'),

View File

@ -224,7 +224,7 @@ const chartOpenId = ref('')
const chatList = ref<any[]>([]) const chatList = ref<any[]>([])
const isDisabledChart = computed( const isDisabledChart = computed(
() => !(inputValue.value.trim() && (props.appId || (props.data?.name && props.data?.model_id))) () => !(inputValue.value.trim() && (props.appId || props.data?.name))
) )
const isMdArray = (val: string) => val.match(/^-\s.*/m) const isMdArray = (val: string) => val.match(/^-\s.*/m)
const prologueList = computed(() => { const prologueList = computed(() => {
@ -509,16 +509,14 @@ function regenerationChart(item: chatType) {
} }
function getSourceDetail(row: any) { function getSourceDetail(row: any) {
logApi logApi.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading).then((res) => {
.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading) const exclude_keys = ['answer_text', 'id']
.then((res) => { Object.keys(res.data).forEach((key) => {
const exclude_keys = ['answer_text', 'id'] if (!exclude_keys.includes(key)) {
Object.keys(res.data).forEach((key) => { row[key] = res.data[key]
if (!exclude_keys.includes(key)) { }
row[key] = res.data[key]
}
})
}) })
})
return true return true
} }

View File

@ -48,7 +48,7 @@
<el-form-item label="AI 模型" prop="model_id"> <el-form-item label="AI 模型" prop="model_id">
<template #label> <template #label>
<div class="flex-between"> <div class="flex-between">
<span>AI 模型 <span class="danger">*</span></span> <span>AI 模型 </span>
</div> </div>
</template> </template>
<el-select <el-select
@ -56,6 +56,7 @@
placeholder="请选择 AI 模型" placeholder="请选择 AI 模型"
class="w-full" class="w-full"
popper-class="select-model" popper-class="select-model"
:clearable="true"
> >
<el-option-group <el-option-group
v-for="(value, label) in modelOptions" v-for="(value, label) in modelOptions"
@ -338,7 +339,7 @@ const rules = reactive<FormRules<ApplicationFormType>>({
name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }], name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }],
model_id: [ model_id: [
{ {
required: true, required: false,
message: '请选择模型', message: '请选择模型',
trigger: 'change' trigger: 'change'
} }