feat: 增加全文检索和混合检索方式
This commit is contained in:
parent
8fe1a147ff
commit
c89ae29429
@ -6,15 +6,16 @@
|
|||||||
@date:2024/1/9 18:10
|
@date:2024/1/9 18:10
|
||||||
@desc: 检索知识库
|
@desc: 检索知识库
|
||||||
"""
|
"""
|
||||||
|
import re
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import List, Type
|
from typing import List, Type
|
||||||
|
|
||||||
|
from django.core import validators
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
||||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||||
from common.util.field_message import ErrMessage
|
from common.util.field_message import ErrMessage
|
||||||
from dataset.models import Paragraph
|
|
||||||
|
|
||||||
|
|
||||||
class ISearchDatasetStep(IBaseChatPipelineStep):
|
class ISearchDatasetStep(IBaseChatPipelineStep):
|
||||||
@ -38,6 +39,10 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||||||
# 相似度 0-1之间
|
# 相似度 0-1之间
|
||||||
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
|
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
|
||||||
error_messages=ErrMessage.float("引用分段数"))
|
error_messages=ErrMessage.float("引用分段数"))
|
||||||
|
search_mode = serializers.CharField(required=True, validators=[
|
||||||
|
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||||||
|
message="类型只支持register|reset_password", code=500)
|
||||||
|
], error_messages=ErrMessage.char("检索模式"))
|
||||||
|
|
||||||
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
|
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
|
||||||
return self.InstanceSerializer
|
return self.InstanceSerializer
|
||||||
@ -50,6 +55,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||||
|
search_mode: str = None,
|
||||||
**kwargs) -> List[ParagraphPipelineModel]:
|
**kwargs) -> List[ParagraphPipelineModel]:
|
||||||
"""
|
"""
|
||||||
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
|
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
|
||||||
@ -60,6 +66,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||||||
:param exclude_document_id_list: 需要排除的文档id
|
:param exclude_document_id_list: 需要排除的文档id
|
||||||
:param exclude_paragraph_id_list: 需要排除段落id
|
:param exclude_paragraph_id_list: 需要排除段落id
|
||||||
:param padding_problem_text 补全问题
|
:param padding_problem_text 补全问题
|
||||||
|
:param search_mode 检索模式
|
||||||
:return: 段落列表
|
:return: 段落列表
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from common.config.embedding_config import VectorStore, EmbeddingModel
|
|||||||
from common.db.search import native_search
|
from common.db.search import native_search
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from dataset.models import Paragraph
|
from dataset.models import Paragraph
|
||||||
|
from embedding.models import SearchMode
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
@ -24,13 +25,14 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
|
|||||||
|
|
||||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||||
|
search_mode: str = None,
|
||||||
**kwargs) -> List[ParagraphPipelineModel]:
|
**kwargs) -> List[ParagraphPipelineModel]:
|
||||||
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||||
embedding_model = EmbeddingModel.get_embedding_model()
|
embedding_model = EmbeddingModel.get_embedding_model()
|
||||||
embedding_value = embedding_model.embed_query(exec_problem_text)
|
embedding_value = embedding_model.embed_query(exec_problem_text)
|
||||||
vector = VectorStore.get_embedding_vector()
|
vector = VectorStore.get_embedding_vector()
|
||||||
embedding_list = vector.query(embedding_value, dataset_id_list, exclude_document_id_list,
|
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
|
||||||
exclude_paragraph_id_list, True, top_n, similarity)
|
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
|
||||||
if embedding_list is None:
|
if embedding_list is None:
|
||||||
return []
|
return []
|
||||||
paragraph_list = self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector)
|
paragraph_list = self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector)
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from users.models import User
|
|||||||
|
|
||||||
|
|
||||||
def get_dataset_setting_dict():
|
def get_dataset_setting_dict():
|
||||||
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000}
|
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding'}
|
||||||
|
|
||||||
|
|
||||||
def get_model_setting_dict():
|
def get_model_setting_dict():
|
||||||
|
|||||||
@ -8,12 +8,13 @@
|
|||||||
"""
|
"""
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from django.contrib.postgres.fields import ArrayField
|
from django.contrib.postgres.fields import ArrayField
|
||||||
from django.core import cache
|
from django.core import cache, validators
|
||||||
from django.core import signing
|
from django.core import signing
|
||||||
from django.db import transaction, models
|
from django.db import transaction, models
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
@ -32,6 +33,7 @@ from common.util.field_message import ErrMessage
|
|||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from dataset.models import DataSet, Document
|
from dataset.models import DataSet, Document
|
||||||
from dataset.serializers.common_serializers import list_paragraph
|
from dataset.serializers.common_serializers import list_paragraph
|
||||||
|
from embedding.models import SearchMode
|
||||||
from setting.models import AuthOperate
|
from setting.models import AuthOperate
|
||||||
from setting.models.model_management import Model
|
from setting.models.model_management import Model
|
||||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||||
@ -77,6 +79,10 @@ class DatasetSettingSerializer(serializers.Serializer):
|
|||||||
error_messages=ErrMessage.float("相识度"))
|
error_messages=ErrMessage.float("相识度"))
|
||||||
max_paragraph_char_number = serializers.IntegerField(required=True, min_value=500, max_value=10000,
|
max_paragraph_char_number = serializers.IntegerField(required=True, min_value=500, max_value=10000,
|
||||||
error_messages=ErrMessage.integer("最多引用字符数"))
|
error_messages=ErrMessage.integer("最多引用字符数"))
|
||||||
|
search_mode = serializers.CharField(required=True, validators=[
|
||||||
|
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||||||
|
message="类型只支持register|reset_password", code=500)
|
||||||
|
], error_messages=ErrMessage.char("检索模式"))
|
||||||
|
|
||||||
|
|
||||||
class ModelSettingSerializer(serializers.Serializer):
|
class ModelSettingSerializer(serializers.Serializer):
|
||||||
@ -291,6 +297,10 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
error_messages=ErrMessage.integer("topN"))
|
error_messages=ErrMessage.integer("topN"))
|
||||||
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
|
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
|
||||||
error_messages=ErrMessage.float("相关度"))
|
error_messages=ErrMessage.float("相关度"))
|
||||||
|
search_mode = serializers.CharField(required=True, validators=[
|
||||||
|
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||||||
|
message="类型只支持register|reset_password", code=500)
|
||||||
|
], error_messages=ErrMessage.char("检索模式"))
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=False):
|
def is_valid(self, *, raise_exception=False):
|
||||||
super().is_valid(raise_exception=True)
|
super().is_valid(raise_exception=True)
|
||||||
@ -312,6 +322,7 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
|
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
|
||||||
self.data.get('top_number'),
|
self.data.get('top_number'),
|
||||||
self.data.get('similarity'),
|
self.data.get('similarity'),
|
||||||
|
SearchMode(self.data.get('search_mode')),
|
||||||
EmbeddingModel.get_embedding_model())
|
EmbeddingModel.get_embedding_model())
|
||||||
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
||||||
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
||||||
|
|||||||
@ -77,6 +77,8 @@ class ChatInfo:
|
|||||||
'model_id': self.application.model.id if self.application.model is not None else None,
|
'model_id': self.application.model.id if self.application.model is not None else None,
|
||||||
'problem_optimization': self.application.problem_optimization,
|
'problem_optimization': self.application.problem_optimization,
|
||||||
'stream': True,
|
'stream': True,
|
||||||
|
'search_mode': self.application.dataset_setting.get(
|
||||||
|
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding'
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -184,9 +186,9 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||||||
pipeline_manage_builder.append_step(BaseResetProblemStep)
|
pipeline_manage_builder.append_step(BaseResetProblemStep)
|
||||||
# 构建流水线管理器
|
# 构建流水线管理器
|
||||||
pipeline_message = (pipeline_manage_builder.append_step(BaseSearchDatasetStep)
|
pipeline_message = (pipeline_manage_builder.append_step(BaseSearchDatasetStep)
|
||||||
.append_step(BaseGenerateHumanMessageStep)
|
.append_step(BaseGenerateHumanMessageStep)
|
||||||
.append_step(BaseChatStep)
|
.append_step(BaseChatStep)
|
||||||
.build())
|
.build())
|
||||||
exclude_paragraph_id_list = []
|
exclude_paragraph_id_list = []
|
||||||
# 相同问题是否需要排除已经查询到的段落
|
# 相同问题是否需要排除已经查询到的段落
|
||||||
if re_chat:
|
if re_chat:
|
||||||
|
|||||||
@ -161,6 +161,8 @@ class ApplicationApi(ApiMixin):
|
|||||||
default=0.6),
|
default=0.6),
|
||||||
'max_paragraph_char_number': openapi.Schema(type=openapi.TYPE_NUMBER, title='最多引用字符数',
|
'max_paragraph_char_number': openapi.Schema(type=openapi.TYPE_NUMBER, title='最多引用字符数',
|
||||||
description="最多引用字符数", default=3000),
|
description="最多引用字符数", default=3000),
|
||||||
|
'search_mode': openapi.Schema(type=openapi.TYPE_STRING, title='检索模式',
|
||||||
|
description="embedding|keywords|blend", default='embedding'),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -343,7 +343,8 @@ class Application(APIView):
|
|||||||
ApplicationSerializer.HitTest(data={'id': application_id, 'user_id': request.user.id,
|
ApplicationSerializer.HitTest(data={'id': application_id, 'user_id': request.user.id,
|
||||||
"query_text": request.query_params.get("query_text"),
|
"query_text": request.query_params.get("query_text"),
|
||||||
"top_number": request.query_params.get("top_number"),
|
"top_number": request.query_params.get("top_number"),
|
||||||
'similarity': request.query_params.get('similarity')}).hit_test(
|
'similarity': request.query_params.get('similarity'),
|
||||||
|
'search_mode': request.query_params.get('search_mode')}).hit_test(
|
||||||
))
|
))
|
||||||
|
|
||||||
class Operate(APIView):
|
class Operate(APIView):
|
||||||
|
|||||||
@ -33,6 +33,13 @@ class CommonApi:
|
|||||||
default=0.6,
|
default=0.6,
|
||||||
required=True,
|
required=True,
|
||||||
description='相关性'),
|
description='相关性'),
|
||||||
|
openapi.Parameter(name='search_mode',
|
||||||
|
in_=openapi.IN_QUERY,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
default="embedding",
|
||||||
|
required=True,
|
||||||
|
description='检索模式embedding|keywords|blend'
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
107
apps/common/util/ts_vecto_util.py
Normal file
107
apps/common/util/ts_vecto_util.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: ts_vecto_util.py
|
||||||
|
@date:2024/4/16 15:26
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import jieba
|
||||||
|
from jieba import analyse
|
||||||
|
|
||||||
|
from common.util.split_model import group_by
|
||||||
|
|
||||||
|
jieba_word_list_cache = [chr(item) for item in range(38, 84)]
|
||||||
|
|
||||||
|
for jieba_word in jieba_word_list_cache:
|
||||||
|
jieba.add_word('#' + jieba_word + '#')
|
||||||
|
# r"(?i)\b(?:https?|ftp|tcp|file)://[^\s]+\b",
|
||||||
|
# 某些不分词数据
|
||||||
|
# r'"([^"]*)"'
|
||||||
|
word_pattern_list = [r"v\d+.\d+.\d+",
|
||||||
|
r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}"]
|
||||||
|
|
||||||
|
remove_chars = '\n , :\'<>!@#¥%……&*()!@#$%^&*(): ;,/"./-'
|
||||||
|
|
||||||
|
|
||||||
|
def get_word_list(text: str):
|
||||||
|
result = []
|
||||||
|
for pattern in word_pattern_list:
|
||||||
|
word_list = re.findall(pattern, text)
|
||||||
|
for child_list in word_list:
|
||||||
|
for word in child_list if isinstance(child_list, tuple) else [child_list]:
|
||||||
|
# 不能有: 所以再使用: 进行分割
|
||||||
|
if word.__contains__(':'):
|
||||||
|
item_list = word.split(":")
|
||||||
|
for w in item_list:
|
||||||
|
result.append(w)
|
||||||
|
else:
|
||||||
|
result.append(word)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def replace_word(word_dict, text: str):
|
||||||
|
for key in word_dict:
|
||||||
|
text = re.sub('(?<!#)' + word_dict[key] + '(?!#)', key, text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def get_word_key(text: str, use_word_list):
|
||||||
|
for j_word in jieba_word_list_cache:
|
||||||
|
if not text.__contains__(j_word) and not use_word_list.__contains__(j_word):
|
||||||
|
return j_word
|
||||||
|
j_word = str(uuid.uuid1())
|
||||||
|
jieba.add_word(j_word)
|
||||||
|
return j_word
|
||||||
|
|
||||||
|
|
||||||
|
def to_word_dict(word_list: List, text: str):
|
||||||
|
word_dict = {}
|
||||||
|
for word in word_list:
|
||||||
|
key = get_word_key(text, set(word_dict))
|
||||||
|
word_dict['#' + key + '#'] = word
|
||||||
|
return word_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_key_by_word_dict(key, word_dict):
|
||||||
|
v = word_dict.get(key)
|
||||||
|
if v is None:
|
||||||
|
return key
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def to_ts_vector(text: str):
|
||||||
|
# 获取不分词的数据
|
||||||
|
word_list = get_word_list(text)
|
||||||
|
# 获取关键词关系
|
||||||
|
word_dict = to_word_dict(word_list, text)
|
||||||
|
# 替换字符串
|
||||||
|
text = replace_word(word_dict, text)
|
||||||
|
# 分词
|
||||||
|
result = jieba.tokenize(text, mode='search')
|
||||||
|
result_ = [{'word': get_key_by_word_dict(item[0], word_dict), 'index': item[1]} for item in result]
|
||||||
|
result_group = group_by(result_, lambda r: r['word'])
|
||||||
|
return " ".join(
|
||||||
|
[f"{key.lower()}:{','.join([str(item['index'] + 1) for item in result_group[key]][:20])}" for key in
|
||||||
|
result_group if
|
||||||
|
not remove_chars.__contains__(key) and len(key.strip()) >= 0])
|
||||||
|
|
||||||
|
|
||||||
|
def to_query(text: str):
|
||||||
|
# 获取不分词的数据
|
||||||
|
word_list = get_word_list(text)
|
||||||
|
# 获取关键词关系
|
||||||
|
word_dict = to_word_dict(word_list, text)
|
||||||
|
# 替换字符串
|
||||||
|
text = replace_word(word_dict, text)
|
||||||
|
extract_tags = analyse.extract_tags(text, topK=5, withWeight=True, allowPOS=('ns', 'n', 'vn', 'v', 'eng'))
|
||||||
|
result = " ".join([get_key_by_word_dict(word, word_dict) for word, score in extract_tags if
|
||||||
|
not remove_chars.__contains__(word)])
|
||||||
|
# 删除词库
|
||||||
|
for word in word_list:
|
||||||
|
jieba.del_word(word)
|
||||||
|
return result
|
||||||
@ -37,6 +37,7 @@ from common.util.split_model import get_split_model
|
|||||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
|
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
|
||||||
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer
|
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer
|
||||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
||||||
|
from embedding.models import SearchMode
|
||||||
from setting.models import AuthOperate
|
from setting.models import AuthOperate
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
@ -457,6 +458,10 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
error_messages=ErrMessage.char("响应Top"))
|
error_messages=ErrMessage.char("响应Top"))
|
||||||
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
|
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
|
||||||
error_messages=ErrMessage.char("相似度"))
|
error_messages=ErrMessage.char("相似度"))
|
||||||
|
search_mode = serializers.CharField(required=True, validators=[
|
||||||
|
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||||||
|
message="类型只支持register|reset_password", code=500)
|
||||||
|
], error_messages=ErrMessage.char("检索模式"))
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=True):
|
def is_valid(self, *, raise_exception=True):
|
||||||
super().is_valid(raise_exception=True)
|
super().is_valid(raise_exception=True)
|
||||||
@ -474,6 +479,7 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list,
|
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list,
|
||||||
self.data.get('top_number'),
|
self.data.get('top_number'),
|
||||||
self.data.get('similarity'),
|
self.data.get('similarity'),
|
||||||
|
SearchMode(self.data.get('search_mode')),
|
||||||
EmbeddingModel.get_embedding_model())
|
EmbeddingModel.get_embedding_model())
|
||||||
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
||||||
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
||||||
|
|||||||
@ -111,7 +111,8 @@ class Dataset(APIView):
|
|||||||
DataSetSerializers.HitTest(data={'id': dataset_id, 'user_id': request.user.id,
|
DataSetSerializers.HitTest(data={'id': dataset_id, 'user_id': request.user.id,
|
||||||
"query_text": request.query_params.get("query_text"),
|
"query_text": request.query_params.get("query_text"),
|
||||||
"top_number": request.query_params.get("top_number"),
|
"top_number": request.query_params.get("top_number"),
|
||||||
'similarity': request.query_params.get('similarity')}).hit_test(
|
'similarity': request.query_params.get('similarity'),
|
||||||
|
'search_mode': request.query_params.get('search_mode')}).hit_test(
|
||||||
))
|
))
|
||||||
|
|
||||||
class Operate(APIView):
|
class Operate(APIView):
|
||||||
|
|||||||
54
apps/embedding/migrations/0002_embedding_search_vector.py
Normal file
54
apps/embedding/migrations/0002_embedding_search_vector.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# Generated by Django 4.1.13 on 2024-04-16 11:43
|
||||||
|
|
||||||
|
import django.contrib.postgres.search
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
from common.util.common import sub_array
|
||||||
|
from common.util.ts_vecto_util import to_ts_vector
|
||||||
|
from dataset.models import Status
|
||||||
|
from embedding.models import Embedding
|
||||||
|
|
||||||
|
|
||||||
|
def update_embedding_search_vector(embedding, paragraph_list):
|
||||||
|
paragraphs = [paragraph for paragraph in paragraph_list if paragraph.id == embedding.get('paragraph')]
|
||||||
|
if len(paragraphs) > 0:
|
||||||
|
content = paragraphs[0].title + paragraphs[0].content
|
||||||
|
return Embedding(id=embedding.get('id'), search_vector=to_ts_vector(content))
|
||||||
|
return Embedding(id=embedding.get('id'), search_vector="")
|
||||||
|
|
||||||
|
|
||||||
|
def save_keywords(apps, schema_editor):
|
||||||
|
document = apps.get_model("dataset", "Document")
|
||||||
|
embedding = apps.get_model("embedding", "Embedding")
|
||||||
|
paragraph = apps.get_model('dataset', 'Paragraph')
|
||||||
|
db_alias = schema_editor.connection.alias
|
||||||
|
document_list = document.objects.using(db_alias).all()
|
||||||
|
for document in document_list:
|
||||||
|
document.status = Status.embedding
|
||||||
|
document.save()
|
||||||
|
paragraph_list = paragraph.objects.using(db_alias).filter(document=document).all()
|
||||||
|
embedding_list = embedding.objects.using(db_alias).filter(document=document).values('id', 'search_vector',
|
||||||
|
'paragraph')
|
||||||
|
embedding_update_list = [update_embedding_search_vector(embedding, paragraph_list) for embedding
|
||||||
|
in embedding_list]
|
||||||
|
child_array = sub_array(embedding_update_list, 20)
|
||||||
|
for c in child_array:
|
||||||
|
try:
|
||||||
|
embedding.objects.using(db_alias).bulk_update(c, ['search_vector'])
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
('embedding', '0001_initial'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='embedding',
|
||||||
|
name='search_vector',
|
||||||
|
field=django.contrib.postgres.search.SearchVectorField(default='', verbose_name='分词'),
|
||||||
|
),
|
||||||
|
migrations.RunPython(save_keywords)
|
||||||
|
]
|
||||||
@ -10,6 +10,7 @@ from django.db import models
|
|||||||
|
|
||||||
from common.field.vector_field import VectorField
|
from common.field.vector_field import VectorField
|
||||||
from dataset.models.data_set import Document, Paragraph, DataSet
|
from dataset.models.data_set import Document, Paragraph, DataSet
|
||||||
|
from django.contrib.postgres.search import SearchVectorField
|
||||||
|
|
||||||
|
|
||||||
class SourceType(models.TextChoices):
|
class SourceType(models.TextChoices):
|
||||||
@ -19,6 +20,12 @@ class SourceType(models.TextChoices):
|
|||||||
TITLE = 2, '标题'
|
TITLE = 2, '标题'
|
||||||
|
|
||||||
|
|
||||||
|
class SearchMode(models.TextChoices):
|
||||||
|
embedding = 'embedding'
|
||||||
|
keywords = 'keywords'
|
||||||
|
blend = 'blend'
|
||||||
|
|
||||||
|
|
||||||
class Embedding(models.Model):
|
class Embedding(models.Model):
|
||||||
id = models.CharField(max_length=128, primary_key=True, verbose_name="主键id")
|
id = models.CharField(max_length=128, primary_key=True, verbose_name="主键id")
|
||||||
|
|
||||||
@ -37,6 +44,8 @@ class Embedding(models.Model):
|
|||||||
|
|
||||||
embedding = VectorField(verbose_name="向量")
|
embedding = VectorField(verbose_name="向量")
|
||||||
|
|
||||||
|
search_vector = SearchVectorField(verbose_name="分词", default="")
|
||||||
|
|
||||||
meta = models.JSONField(verbose_name="元数据", default=dict)
|
meta = models.JSONField(verbose_name="元数据", default=dict)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|||||||
26
apps/embedding/sql/blend_search.sql
Normal file
26
apps/embedding/sql/blend_search.sql
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
SELECT
|
||||||
|
paragraph_id,
|
||||||
|
comprehensive_score,
|
||||||
|
comprehensive_score AS similarity
|
||||||
|
FROM
|
||||||
|
(
|
||||||
|
SELECT DISTINCT ON
|
||||||
|
( "paragraph_id" ) ( similarity ),* ,
|
||||||
|
similarity AS comprehensive_score
|
||||||
|
FROM
|
||||||
|
(
|
||||||
|
SELECT
|
||||||
|
*,
|
||||||
|
(( 1 - ( embedding.embedding <=> %s ) )+ts_rank_cd( embedding.search_vector, websearch_to_tsquery('simple', %s ), 32 )) AS similarity
|
||||||
|
FROM
|
||||||
|
embedding ${embedding_query}
|
||||||
|
) TEMP
|
||||||
|
ORDER BY
|
||||||
|
paragraph_id,
|
||||||
|
similarity DESC
|
||||||
|
) DISTINCT_TEMP
|
||||||
|
WHERE
|
||||||
|
comprehensive_score >%s
|
||||||
|
ORDER BY
|
||||||
|
comprehensive_score DESC
|
||||||
|
LIMIT %s
|
||||||
17
apps/embedding/sql/keywords_search.sql
Normal file
17
apps/embedding/sql/keywords_search.sql
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
SELECT
|
||||||
|
paragraph_id,
|
||||||
|
comprehensive_score,
|
||||||
|
comprehensive_score as similarity
|
||||||
|
FROM
|
||||||
|
(
|
||||||
|
SELECT DISTINCT ON
|
||||||
|
("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score
|
||||||
|
FROM
|
||||||
|
( SELECT *,ts_rank_cd(embedding.search_vector,websearch_to_tsquery('simple',%s),32) AS similarity FROM embedding ${keywords_query}) TEMP
|
||||||
|
ORDER BY
|
||||||
|
paragraph_id,
|
||||||
|
similarity DESC
|
||||||
|
) DISTINCT_TEMP
|
||||||
|
WHERE comprehensive_score>%s
|
||||||
|
ORDER BY comprehensive_score DESC
|
||||||
|
LIMIT %s
|
||||||
@ -14,7 +14,7 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
|
|||||||
|
|
||||||
from common.config.embedding_config import EmbeddingModel
|
from common.config.embedding_config import EmbeddingModel
|
||||||
from common.util.common import sub_array
|
from common.util.common import sub_array
|
||||||
from embedding.models import SourceType
|
from embedding.models import SourceType, SearchMode
|
||||||
|
|
||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
|
|
||||||
@ -113,13 +113,16 @@ class BaseVectorStore(ABC):
|
|||||||
return result[0]
|
return result[0]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str],
|
def query(self, query_text:str,query_embedding: List[float], dataset_id_list: list[str],
|
||||||
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float):
|
exclude_document_id_list: list[str],
|
||||||
|
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
|
||||||
|
search_mode: SearchMode):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
|
def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
|
||||||
similarity: float,
|
similarity: float,
|
||||||
|
search_mode: SearchMode,
|
||||||
embedding: HuggingFaceEmbeddings):
|
embedding: HuggingFaceEmbeddings):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
@ -18,7 +19,8 @@ from common.config.embedding_config import EmbeddingModel
|
|||||||
from common.db.search import generate_sql_by_query_dict
|
from common.db.search import generate_sql_by_query_dict
|
||||||
from common.db.sql_execute import select_list
|
from common.db.sql_execute import select_list
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from embedding.models import Embedding, SourceType
|
from common.util.ts_vecto_util import to_ts_vector, to_query
|
||||||
|
from embedding.models import Embedding, SourceType, SearchMode
|
||||||
from embedding.vector.base_vector import BaseVectorStore
|
from embedding.vector.base_vector import BaseVectorStore
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
@ -57,7 +59,8 @@ class PGVector(BaseVectorStore):
|
|||||||
paragraph_id=paragraph_id,
|
paragraph_id=paragraph_id,
|
||||||
source_id=source_id,
|
source_id=source_id,
|
||||||
embedding=text_embedding,
|
embedding=text_embedding,
|
||||||
source_type=source_type)
|
source_type=source_type,
|
||||||
|
search_vector=to_ts_vector(text))
|
||||||
embedding.save()
|
embedding.save()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -71,13 +74,15 @@ class PGVector(BaseVectorStore):
|
|||||||
is_active=text_list[index].get('is_active', True),
|
is_active=text_list[index].get('is_active', True),
|
||||||
source_id=text_list[index].get('source_id'),
|
source_id=text_list[index].get('source_id'),
|
||||||
source_type=text_list[index].get('source_type'),
|
source_type=text_list[index].get('source_type'),
|
||||||
embedding=embeddings[index]) for index in
|
embedding=embeddings[index],
|
||||||
|
search_vector=to_ts_vector(text_list[index]['text'])) for index in
|
||||||
range(0, len(text_list))]
|
range(0, len(text_list))]
|
||||||
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
|
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
|
def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
|
||||||
similarity: float,
|
similarity: float,
|
||||||
|
search_mode: SearchMode,
|
||||||
embedding: HuggingFaceEmbeddings):
|
embedding: HuggingFaceEmbeddings):
|
||||||
if dataset_id_list is None or len(dataset_id_list) == 0:
|
if dataset_id_list is None or len(dataset_id_list) == 0:
|
||||||
return []
|
return []
|
||||||
@ -87,17 +92,14 @@ class PGVector(BaseVectorStore):
|
|||||||
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
|
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
|
||||||
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
|
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
|
||||||
query_set = query_set.exclude(**exclude_dict)
|
query_set = query_set.exclude(**exclude_dict)
|
||||||
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
|
for search_handle in search_handle_list:
|
||||||
select_string=get_file_content(
|
if search_handle.support(search_mode):
|
||||||
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
|
return search_handle.handle(query_set, query_text, embedding_query, top_number, similarity, search_mode)
|
||||||
'hit_test.sql')),
|
|
||||||
with_table_name=True)
|
|
||||||
embedding_model = select_list(exec_sql,
|
|
||||||
[json.dumps(embedding_query), *exec_params, similarity, top_number])
|
|
||||||
return embedding_model
|
|
||||||
|
|
||||||
def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str],
|
def query(self, query_text: str, query_embedding: List[float], dataset_id_list: list[str],
|
||||||
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float):
|
exclude_document_id_list: list[str],
|
||||||
|
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
|
||||||
|
search_mode: SearchMode):
|
||||||
exclude_dict = {}
|
exclude_dict = {}
|
||||||
if dataset_id_list is None or len(dataset_id_list) == 0:
|
if dataset_id_list is None or len(dataset_id_list) == 0:
|
||||||
return []
|
return []
|
||||||
@ -107,14 +109,9 @@ class PGVector(BaseVectorStore):
|
|||||||
if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:
|
if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:
|
||||||
exclude_dict.__setitem__('paragraph_id__in', exclude_paragraph_list)
|
exclude_dict.__setitem__('paragraph_id__in', exclude_paragraph_list)
|
||||||
query_set = query_set.exclude(**exclude_dict)
|
query_set = query_set.exclude(**exclude_dict)
|
||||||
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
|
for search_handle in search_handle_list:
|
||||||
select_string=get_file_content(
|
if search_handle.support(search_mode):
|
||||||
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
|
return search_handle.handle(query_set, query_text, query_embedding, top_n, similarity, search_mode)
|
||||||
'embedding_search.sql')),
|
|
||||||
with_table_name=True)
|
|
||||||
embedding_model = select_list(exec_sql,
|
|
||||||
[json.dumps(query_embedding), *exec_params, similarity, top_n])
|
|
||||||
return embedding_model
|
|
||||||
|
|
||||||
def update_by_source_id(self, source_id: str, instance: Dict):
|
def update_by_source_id(self, source_id: str, instance: Dict):
|
||||||
QuerySet(Embedding).filter(source_id=source_id).update(**instance)
|
QuerySet(Embedding).filter(source_id=source_id).update(**instance)
|
||||||
@ -141,3 +138,81 @@ class PGVector(BaseVectorStore):
|
|||||||
|
|
||||||
def delete_by_paragraph_id(self, paragraph_id: str):
|
def delete_by_paragraph_id(self, paragraph_id: str):
|
||||||
QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete()
|
QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete()
|
||||||
|
|
||||||
|
|
||||||
|
class ISearch(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def support(self, search_mode: SearchMode):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def handle(self, query_set, query_text, query_embedding, top_number: int,
|
||||||
|
similarity: float, search_mode: SearchMode):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingSearch(ISearch):
|
||||||
|
def handle(self,
|
||||||
|
query_set,
|
||||||
|
query_text,
|
||||||
|
query_embedding,
|
||||||
|
top_number: int,
|
||||||
|
similarity: float,
|
||||||
|
search_mode: SearchMode):
|
||||||
|
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
|
||||||
|
select_string=get_file_content(
|
||||||
|
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
|
||||||
|
'embedding_search.sql')),
|
||||||
|
with_table_name=True)
|
||||||
|
embedding_model = select_list(exec_sql,
|
||||||
|
[json.dumps(query_embedding), *exec_params, similarity, top_number])
|
||||||
|
return embedding_model
|
||||||
|
|
||||||
|
def support(self, search_mode: SearchMode):
|
||||||
|
return search_mode.value == SearchMode.embedding.value
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordsSearch(ISearch):
|
||||||
|
def handle(self,
|
||||||
|
query_set,
|
||||||
|
query_text,
|
||||||
|
query_embedding,
|
||||||
|
top_number: int,
|
||||||
|
similarity: float,
|
||||||
|
search_mode: SearchMode):
|
||||||
|
exec_sql, exec_params = generate_sql_by_query_dict({'keywords_query': query_set},
|
||||||
|
select_string=get_file_content(
|
||||||
|
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
|
||||||
|
'keywords_search.sql')),
|
||||||
|
with_table_name=True)
|
||||||
|
embedding_model = select_list(exec_sql,
|
||||||
|
[to_query(query_text), *exec_params, similarity, top_number])
|
||||||
|
return embedding_model
|
||||||
|
|
||||||
|
def support(self, search_mode: SearchMode):
|
||||||
|
return search_mode.value == SearchMode.keywords.value
|
||||||
|
|
||||||
|
|
||||||
|
class BlendSearch(ISearch):
|
||||||
|
def handle(self,
|
||||||
|
query_set,
|
||||||
|
query_text,
|
||||||
|
query_embedding,
|
||||||
|
top_number: int,
|
||||||
|
similarity: float,
|
||||||
|
search_mode: SearchMode):
|
||||||
|
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
|
||||||
|
select_string=get_file_content(
|
||||||
|
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
|
||||||
|
'blend_search.sql')),
|
||||||
|
with_table_name=True)
|
||||||
|
embedding_model = select_list(exec_sql,
|
||||||
|
[json.dumps(query_embedding), to_query(query_text), *exec_params, similarity,
|
||||||
|
top_number])
|
||||||
|
return embedding_model
|
||||||
|
|
||||||
|
def support(self, search_mode: SearchMode):
|
||||||
|
return search_mode.value == SearchMode.blend.value
|
||||||
|
|
||||||
|
|
||||||
|
search_handle_list = [EmbeddingSearch(), KeywordsSearch(), BlendSearch()]
|
||||||
|
|||||||
@ -145,74 +145,12 @@
|
|||||||
<div class="flex-between">
|
<div class="flex-between">
|
||||||
<span>关联知识库</span>
|
<span>关联知识库</span>
|
||||||
<div>
|
<div>
|
||||||
<el-popover :visible="popoverVisible" :width="214" trigger="click">
|
<el-button type="primary" link @click="openParamSettingDialog">
|
||||||
<template #reference>
|
<AppIcon iconName="app-operation" class="mr-4"></AppIcon>参数设置
|
||||||
<el-button type="primary" link @click="datasetSettingChange('open')">
|
</el-button>
|
||||||
<AppIcon iconName="app-operation" class="mr-4"></AppIcon>参数设置
|
<el-button type="primary" link @click="openDatasetDialog">
|
||||||
</el-button>
|
<el-icon class="mr-4"><Plus /></el-icon>添加
|
||||||
</template>
|
</el-button>
|
||||||
<div class="dataset_setting">
|
|
||||||
<div class="form-item mb-16">
|
|
||||||
<div class="title flex align-center mb-8">
|
|
||||||
<span style="margin-right: 4px">相似度高于</span>
|
|
||||||
<el-tooltip
|
|
||||||
effect="dark"
|
|
||||||
content="相似度越高相关性越强。"
|
|
||||||
placement="right"
|
|
||||||
>
|
|
||||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
|
||||||
</el-tooltip>
|
|
||||||
</div>
|
|
||||||
<div @click.stop>
|
|
||||||
<el-input-number
|
|
||||||
v-model="dataset_setting.similarity"
|
|
||||||
:min="0"
|
|
||||||
:max="1"
|
|
||||||
:precision="3"
|
|
||||||
:step="0.1"
|
|
||||||
controls-position="right"
|
|
||||||
style="width: 180px"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="form-item mb-16">
|
|
||||||
<div class="title mb-8">引用分段数 TOP</div>
|
|
||||||
<div @click.stop>
|
|
||||||
<el-input-number
|
|
||||||
v-model="dataset_setting.top_n"
|
|
||||||
:min="1"
|
|
||||||
:max="10"
|
|
||||||
controls-position="right"
|
|
||||||
style="width: 180px"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="form-item mb-16">
|
|
||||||
<div class="title mb-8">最多引用字符数</div>
|
|
||||||
<div class="flex align-center">
|
|
||||||
<el-slider
|
|
||||||
v-model="dataset_setting.max_paragraph_char_number"
|
|
||||||
show-input
|
|
||||||
:show-input-controls="false"
|
|
||||||
:min="500"
|
|
||||||
:max="10000"
|
|
||||||
style="width: 180px"
|
|
||||||
class="custom-slider"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="text-right">
|
|
||||||
<el-button @click="popoverVisible = false">取消</el-button>
|
|
||||||
<el-button type="primary" @click="datasetSettingChange('close')"
|
|
||||||
>确认</el-button
|
|
||||||
>
|
|
||||||
</div>
|
|
||||||
</el-popover>
|
|
||||||
<el-button type="primary" link @click="openDatasetDialog"
|
|
||||||
><el-icon class="mr-4"><Plus /></el-icon>添加</el-button
|
|
||||||
>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
@ -221,13 +159,6 @@
|
|||||||
>关联的知识库展示在这里</el-text
|
>关联的知识库展示在这里</el-text
|
||||||
>
|
>
|
||||||
<el-row :gutter="12" v-else>
|
<el-row :gutter="12" v-else>
|
||||||
<!-- <el-col :xs="24" :sm="24" :md="12" :lg="12" :xl="12" class="mb-8">
|
|
||||||
<CardAdd
|
|
||||||
title="关联知识库"
|
|
||||||
@click="openDatasetDialog"
|
|
||||||
style="min-height: 50px; font-size: 14px"
|
|
||||||
/>
|
|
||||||
</el-col> -->
|
|
||||||
<el-col
|
<el-col
|
||||||
:xs="24"
|
:xs="24"
|
||||||
:sm="24"
|
:sm="24"
|
||||||
@ -311,6 +242,7 @@
|
|||||||
</el-col>
|
</el-col>
|
||||||
</el-row>
|
</el-row>
|
||||||
|
|
||||||
|
<ParamSettingDialog ref="ParamSettingDialogRef" @refresh="refreshParam" />
|
||||||
<AddDatasetDialog
|
<AddDatasetDialog
|
||||||
ref="AddDatasetDialogRef"
|
ref="AddDatasetDialogRef"
|
||||||
@addData="addDataset"
|
@addData="addDataset"
|
||||||
@ -318,6 +250,8 @@
|
|||||||
@refresh="refresh"
|
@refresh="refresh"
|
||||||
:loading="datasetLoading"
|
:loading="datasetLoading"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<!-- 添加模版 -->
|
||||||
<CreateModelDialog
|
<CreateModelDialog
|
||||||
ref="createModelRef"
|
ref="createModelRef"
|
||||||
@submit="getModel"
|
@submit="getModel"
|
||||||
@ -329,7 +263,8 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { reactive, ref, watch, onMounted } from 'vue'
|
import { reactive, ref, watch, onMounted } from 'vue'
|
||||||
import { useRouter, useRoute } from 'vue-router'
|
import { useRouter, useRoute } from 'vue-router'
|
||||||
import { groupBy, cloneDeep } from 'lodash'
|
import { groupBy } from 'lodash'
|
||||||
|
import ParamSettingDialog from './components/ParamSettingDialog.vue'
|
||||||
import AddDatasetDialog from './components/AddDatasetDialog.vue'
|
import AddDatasetDialog from './components/AddDatasetDialog.vue'
|
||||||
import CreateModelDialog from '@/views/template/component/CreateModelDialog.vue'
|
import CreateModelDialog from '@/views/template/component/CreateModelDialog.vue'
|
||||||
import SelectProviderDialog from '@/views/template/component/SelectProviderDialog.vue'
|
import SelectProviderDialog from '@/views/template/component/SelectProviderDialog.vue'
|
||||||
@ -363,6 +298,8 @@ const defaultPrompt = `已知信息:
|
|||||||
问题:
|
问题:
|
||||||
{question}
|
{question}
|
||||||
`
|
`
|
||||||
|
|
||||||
|
const ParamSettingDialogRef = ref<InstanceType<typeof ParamSettingDialog>>()
|
||||||
const createModelRef = ref<InstanceType<typeof CreateModelDialog>>()
|
const createModelRef = ref<InstanceType<typeof CreateModelDialog>>()
|
||||||
const selectProviderRef = ref<InstanceType<typeof SelectProviderDialog>>()
|
const selectProviderRef = ref<InstanceType<typeof SelectProviderDialog>>()
|
||||||
|
|
||||||
@ -384,7 +321,8 @@ const applicationForm = ref<ApplicationFormType>({
|
|||||||
dataset_setting: {
|
dataset_setting: {
|
||||||
top_n: 3,
|
top_n: 3,
|
||||||
similarity: 0.6,
|
similarity: 0.6,
|
||||||
max_paragraph_char_number: 5000
|
max_paragraph_char_number: 5000,
|
||||||
|
search_mode: 'embedding'
|
||||||
},
|
},
|
||||||
model_setting: {
|
model_setting: {
|
||||||
prompt: defaultPrompt
|
prompt: defaultPrompt
|
||||||
@ -392,8 +330,6 @@ const applicationForm = ref<ApplicationFormType>({
|
|||||||
problem_optimization: false
|
problem_optimization: false
|
||||||
})
|
})
|
||||||
|
|
||||||
const popoverVisible = ref(false)
|
|
||||||
|
|
||||||
const rules = reactive<FormRules<ApplicationFormType>>({
|
const rules = reactive<FormRules<ApplicationFormType>>({
|
||||||
name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }],
|
name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }],
|
||||||
model_id: [
|
model_id: [
|
||||||
@ -408,17 +344,6 @@ const rules = reactive<FormRules<ApplicationFormType>>({
|
|||||||
const modelOptions = ref<any>(null)
|
const modelOptions = ref<any>(null)
|
||||||
const providerOptions = ref<Array<Provider>>([])
|
const providerOptions = ref<Array<Provider>>([])
|
||||||
const datasetList = ref([])
|
const datasetList = ref([])
|
||||||
const dataset_setting = ref<any>({})
|
|
||||||
|
|
||||||
function datasetSettingChange(val: string) {
|
|
||||||
if (val === 'open') {
|
|
||||||
popoverVisible.value = true
|
|
||||||
dataset_setting.value = cloneDeep(applicationForm.value.dataset_setting)
|
|
||||||
} else if (val === 'close') {
|
|
||||||
popoverVisible.value = false
|
|
||||||
applicationForm.value.dataset_setting = cloneDeep(dataset_setting.value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const submit = async (formEl: FormInstance | undefined) => {
|
const submit = async (formEl: FormInstance | undefined) => {
|
||||||
if (!formEl) return
|
if (!formEl) return
|
||||||
@ -438,6 +363,14 @@ const submit = async (formEl: FormInstance | undefined) => {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const openParamSettingDialog = () => {
|
||||||
|
ParamSettingDialogRef.value?.open(applicationForm.value.dataset_setting)
|
||||||
|
}
|
||||||
|
|
||||||
|
function refreshParam(data: any) {
|
||||||
|
applicationForm.value.dataset_setting = data
|
||||||
|
}
|
||||||
|
|
||||||
const openCreateModel = (provider?: Provider) => {
|
const openCreateModel = (provider?: Provider) => {
|
||||||
if (provider && provider.provider) {
|
if (provider && provider.provider) {
|
||||||
createModelRef.value?.open(provider)
|
createModelRef.value?.open(provider)
|
||||||
@ -560,13 +493,4 @@ onMounted(() => {
|
|||||||
.prologue-md-editor {
|
.prologue-md-editor {
|
||||||
height: 150px;
|
height: 150px;
|
||||||
}
|
}
|
||||||
.dataset_setting {
|
|
||||||
color: var(--el-text-color-regular);
|
|
||||||
font-weight: 400;
|
|
||||||
}
|
|
||||||
.custom-slider {
|
|
||||||
:deep(.el-input-number.is-without-controls .el-input__wrapper) {
|
|
||||||
padding: 0 !important;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
</style>
|
</style>
|
||||||
|
|||||||
152
ui/src/views/application/components/ParamSettingDialog.vue
Normal file
152
ui/src/views/application/components/ParamSettingDialog.vue
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
<template>
|
||||||
|
<el-dialog title="参数设置" class="param-dialog" v-model="dialogVisible" style="width: 550px">
|
||||||
|
<div class="dialog-max-height">
|
||||||
|
<el-scrollbar>
|
||||||
|
<div class="p-16">
|
||||||
|
<el-form label-position="top" ref="paramFormRef" :model="form">
|
||||||
|
<el-form-item label="检索模式">
|
||||||
|
<el-radio-group v-model="form.search_mode" class="card__radio">
|
||||||
|
<el-card
|
||||||
|
shadow="never"
|
||||||
|
class="mb-16"
|
||||||
|
:class="form.search_mode === 'embedding' ? 'active' : ''"
|
||||||
|
>
|
||||||
|
<el-radio value="embedding" size="large">
|
||||||
|
<p class="mb-4">向量检索</p>
|
||||||
|
<el-text type="info">通过向量距离计算与用户问题最相似的文本分段</el-text>
|
||||||
|
</el-radio>
|
||||||
|
</el-card>
|
||||||
|
<el-card
|
||||||
|
shadow="never"
|
||||||
|
class="mb-16"
|
||||||
|
:class="form.search_mode === 'keywords' ? 'active' : ''"
|
||||||
|
>
|
||||||
|
<el-radio value="keywords" size="large">
|
||||||
|
<p class="mb-4">全文检索</p>
|
||||||
|
<el-text type="info">通过关键词检索,返回包含关键词最多的文本分段</el-text>
|
||||||
|
</el-radio>
|
||||||
|
</el-card>
|
||||||
|
<el-card
|
||||||
|
shadow="never"
|
||||||
|
class="mb-16"
|
||||||
|
:class="form.search_mode === 'blend' ? 'active' : ''"
|
||||||
|
>
|
||||||
|
<el-radio value="blend" size="large">
|
||||||
|
<p class="mb-4">混合检索</p>
|
||||||
|
<el-text type="info"
|
||||||
|
>同时执行全文检索和向量检索,再进行重排序,从两类查询结果中选择匹配用户问题的最佳结果</el-text
|
||||||
|
>
|
||||||
|
</el-radio>
|
||||||
|
</el-card>
|
||||||
|
</el-radio-group>
|
||||||
|
</el-form-item>
|
||||||
|
<el-form-item>
|
||||||
|
<template #label>
|
||||||
|
<div class="flex align-center">
|
||||||
|
<span class="mr-4">相似度高于</span>
|
||||||
|
<el-tooltip effect="dark" content="相似度越高相关性越强。" placement="right">
|
||||||
|
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||||
|
</el-tooltip>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<el-input-number
|
||||||
|
v-model="form.similarity"
|
||||||
|
:min="0"
|
||||||
|
:max="1"
|
||||||
|
:precision="3"
|
||||||
|
:step="0.1"
|
||||||
|
controls-position="right"
|
||||||
|
/>
|
||||||
|
</el-form-item>
|
||||||
|
<el-form-item label="引用分段数 TOP">
|
||||||
|
<el-input-number v-model="form.top_n" :min="1" :max="10" controls-position="right" />
|
||||||
|
</el-form-item>
|
||||||
|
<el-form-item label="最多引用字符数">
|
||||||
|
<el-slider
|
||||||
|
v-model="form.max_paragraph_char_number"
|
||||||
|
show-input
|
||||||
|
:show-input-controls="false"
|
||||||
|
:min="500"
|
||||||
|
:max="10000"
|
||||||
|
class="custom-slider"
|
||||||
|
/>
|
||||||
|
</el-form-item>
|
||||||
|
</el-form>
|
||||||
|
</div>
|
||||||
|
</el-scrollbar>
|
||||||
|
</div>
|
||||||
|
<template #footer>
|
||||||
|
<span class="dialog-footer">
|
||||||
|
<el-button @click.prevent="dialogVisible = false"> 取消 </el-button>
|
||||||
|
<el-button type="primary" @click="submit(paramFormRef)" :loading="loading">
|
||||||
|
保存
|
||||||
|
</el-button>
|
||||||
|
</span>
|
||||||
|
</template>
|
||||||
|
</el-dialog>
|
||||||
|
</template>
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { ref, watch } from 'vue'
|
||||||
|
|
||||||
|
import type { FormInstance, FormRules } from 'element-plus'
|
||||||
|
|
||||||
|
const emit = defineEmits(['refresh'])
|
||||||
|
|
||||||
|
const paramFormRef = ref()
|
||||||
|
const form = ref<any>({
|
||||||
|
search_mode: 'embedding',
|
||||||
|
top_n: 3,
|
||||||
|
similarity: 0.6,
|
||||||
|
max_paragraph_char_number: 5000
|
||||||
|
})
|
||||||
|
|
||||||
|
const dialogVisible = ref<boolean>(false)
|
||||||
|
const loading = ref(false)
|
||||||
|
|
||||||
|
watch(dialogVisible, (bool) => {
|
||||||
|
if (!bool) {
|
||||||
|
form.value = {
|
||||||
|
search_mode: 'embedding',
|
||||||
|
top_n: 3,
|
||||||
|
similarity: 0.6,
|
||||||
|
max_paragraph_char_number: 5000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const open = (data: any) => {
|
||||||
|
form.value = { ...form.value, ...data }
|
||||||
|
dialogVisible.value = true
|
||||||
|
}
|
||||||
|
|
||||||
|
const submit = async (formEl: FormInstance | undefined) => {
|
||||||
|
if (!formEl) return
|
||||||
|
await formEl.validate((valid, fields) => {
|
||||||
|
if (valid) {
|
||||||
|
emit('refresh', form.value)
|
||||||
|
dialogVisible.value = false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
defineExpose({ open })
|
||||||
|
</script>
|
||||||
|
<style lang="scss" scope>
|
||||||
|
.param-dialog {
|
||||||
|
padding: 8px;
|
||||||
|
.el-dialog__header {
|
||||||
|
padding: 16px 16px 0 16px;
|
||||||
|
}
|
||||||
|
.el-dialog__body {
|
||||||
|
padding: 0 !important;
|
||||||
|
}
|
||||||
|
.dialog-max-height {
|
||||||
|
height: calc(100vh - 260px);
|
||||||
|
}
|
||||||
|
.custom-slider {
|
||||||
|
.el-input-number.is-without-controls .el-input__wrapper {
|
||||||
|
padding: 0 !important;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</style>
|
||||||
@ -78,11 +78,47 @@
|
|||||||
<ParagraphDialog ref="ParagraphDialogRef" :title="title" @refresh="refresh" />
|
<ParagraphDialog ref="ParagraphDialogRef" :title="title" @refresh="refresh" />
|
||||||
</LayoutContainer>
|
</LayoutContainer>
|
||||||
<div class="hit-test__operate p-24 pt-0">
|
<div class="hit-test__operate p-24 pt-0">
|
||||||
<el-popover :visible="popoverVisible" placement="right-end" :width="180" trigger="click">
|
<el-popover :visible="popoverVisible" placement="right-end" :width="500" trigger="click">
|
||||||
<template #reference>
|
<template #reference>
|
||||||
<el-button icon="Setting" class="mb-8" @click="settingChange('open')">参数设置</el-button>
|
<el-button icon="Setting" class="mb-8" @click="settingChange('open')">参数设置</el-button>
|
||||||
</template>
|
</template>
|
||||||
|
<div class="mb-16">
|
||||||
|
<div class="title mb-8">检索模式</div>
|
||||||
|
<el-radio-group v-model="cloneForm.search_mode" class="card__radio">
|
||||||
|
<el-card
|
||||||
|
shadow="never"
|
||||||
|
class="mb-16"
|
||||||
|
:class="cloneForm.search_mode === 'embedding' ? 'active' : ''"
|
||||||
|
>
|
||||||
|
<el-radio value="embedding" size="large">
|
||||||
|
<p class="mb-4">向量检索</p>
|
||||||
|
<el-text type="info">通过向量距离计算与用户问题最相似的文本分段</el-text>
|
||||||
|
</el-radio>
|
||||||
|
</el-card>
|
||||||
|
<el-card
|
||||||
|
shadow="never"
|
||||||
|
class="mb-16"
|
||||||
|
:class="cloneForm.search_mode === 'keywords' ? 'active' : ''"
|
||||||
|
>
|
||||||
|
<el-radio value="keywords" size="large">
|
||||||
|
<p class="mb-4">全文检索</p>
|
||||||
|
<el-text type="info">通过关键词检索,返回包含关键词最多的文本分段</el-text>
|
||||||
|
</el-radio>
|
||||||
|
</el-card>
|
||||||
|
<el-card
|
||||||
|
shadow="never"
|
||||||
|
class="mb-16"
|
||||||
|
:class="cloneForm.search_mode === 'blend' ? 'active' : ''"
|
||||||
|
>
|
||||||
|
<el-radio value="blend" size="large">
|
||||||
|
<p class="mb-4">混合检索</p>
|
||||||
|
<el-text type="info"
|
||||||
|
>同时执行全文检索和向量检索,再进行重排序,从两类查询结果中选择匹配用户问题的最佳结果</el-text
|
||||||
|
>
|
||||||
|
</el-radio>
|
||||||
|
</el-card>
|
||||||
|
</el-radio-group>
|
||||||
|
</div>
|
||||||
<div class="mb-16">
|
<div class="mb-16">
|
||||||
<div class="title mb-8">相似度高于</div>
|
<div class="title mb-8">相似度高于</div>
|
||||||
<el-input-number
|
<el-input-number
|
||||||
@ -92,7 +128,6 @@
|
|||||||
:precision="3"
|
:precision="3"
|
||||||
:step="0.1"
|
:step="0.1"
|
||||||
controls-position="right"
|
controls-position="right"
|
||||||
style="width: 145px"
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div class="mb-16">
|
<div class="mb-16">
|
||||||
@ -103,7 +138,6 @@
|
|||||||
:min="1"
|
:min="1"
|
||||||
:max="10"
|
:max="10"
|
||||||
controls-position="right"
|
controls-position="right"
|
||||||
style="width: 145px"
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div class="text-right">
|
<div class="text-right">
|
||||||
@ -161,7 +195,8 @@ const title = ref('')
|
|||||||
const inputValue = ref('')
|
const inputValue = ref('')
|
||||||
const formInline = ref({
|
const formInline = ref({
|
||||||
similarity: 0.6,
|
similarity: 0.6,
|
||||||
top_number: 5
|
top_number: 5,
|
||||||
|
search_mode: 'embedding'
|
||||||
})
|
})
|
||||||
|
|
||||||
// 第一次加载
|
// 第一次加载
|
||||||
@ -213,8 +248,7 @@ function sendChatHandle(event: any) {
|
|||||||
function getHitTestList() {
|
function getHitTestList() {
|
||||||
const obj = {
|
const obj = {
|
||||||
query_text: inputValue.value,
|
query_text: inputValue.value,
|
||||||
similarity: formInline.value.similarity,
|
...formInline.value
|
||||||
top_number: formInline.value.top_number
|
|
||||||
}
|
}
|
||||||
if (isDataset.value) {
|
if (isDataset.value) {
|
||||||
datasetApi.getDatasetHitTest(id, obj, loading).then((res) => {
|
datasetApi.getDatasetHitTest(id, obj, loading).then((res) => {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user