feat: 支持向量模型

This commit is contained in:
shaohuzhang1 2024-07-17 17:01:57 +08:00
parent 9b81b89975
commit bd4303aee7
14 changed files with 223 additions and 98 deletions

View File

@ -13,22 +13,45 @@ from django.db.models import QuerySet
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
from common.config.embedding_config import VectorStore, EmbeddingModel from common.config.embedding_config import VectorStore, EmbeddingModelManage
from common.db.search import native_search from common.db.search import native_search
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from dataset.models import Paragraph from dataset.models import Paragraph, DataSet
from embedding.models import SearchMode from embedding.models import SearchMode
from setting.models import Model
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
def get_model_by_id(_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception("模型不存在")
return model
def get_embedding_id(dataset_id_list):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception("知识库未向量模型不一致")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return dataset_list[0].embedding_mode_id
class BaseSearchDatasetStep(ISearchDatasetStep): class BaseSearchDatasetStep(ISearchDatasetStep):
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None, search_mode: str = None,
**kwargs) -> List[ParagraphPipelineModel]: **kwargs) -> List[ParagraphPipelineModel]:
if len(dataset_id_list) == 0:
return []
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
embedding_model = EmbeddingModel.get_embedding_model() model_id = get_embedding_id(dataset_id_list)
model = get_model_by_id(model_id)
self.context['model_name'] = model.name
embedding_model = EmbeddingModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(exec_problem_text) embedding_value = embedding_model.embed_query(exec_problem_text)
vector = VectorStore.get_embedding_vector() vector = VectorStore.get_embedding_vector()
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list, embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
@ -101,7 +124,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
'run_time': self.context['run_time'], 'run_time': self.context['run_time'],
'problem_text': step_args.get( 'problem_text': step_args.get(
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'), 'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
'model_name': EmbeddingModel.get_embedding_model().model_name, 'model_name': self.context.get('model_name'),
'message_tokens': 0, 'message_tokens': 0,
'answer_tokens': 0, 'answer_tokens': 0,
'cost': 0 'cost': 0

View File

@ -13,20 +13,47 @@ from django.db.models import QuerySet
from application.flow.i_step_node import NodeResult from application.flow.i_step_node import NodeResult
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
from common.config.embedding_config import EmbeddingModel, VectorStore from common.config.embedding_config import VectorStore, EmbeddingModelManage
from common.db.search import native_search from common.db.search import native_search
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from dataset.models import Document, Paragraph from dataset.models import Document, Paragraph, DataSet
from embedding.models import SearchMode from embedding.models import SearchMode
from setting.models import Model
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
def get_model_by_id(_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception("模型不存在")
return get_model(model)
def get_embedding_id(dataset_id_list):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception("知识库未向量模型不一致")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return dataset_list[0].embedding_mode_id
def get_none_result(question):
return NodeResult(
{'paragraph_list': [], 'is_hit_handling_method': [], 'question': question, 'data': '',
'directly_return': ''}, {})
class BaseSearchDatasetNode(ISearchDatasetStepNode): class BaseSearchDatasetNode(ISearchDatasetStepNode):
def execute(self, dataset_id_list, dataset_setting, question, def execute(self, dataset_id_list, dataset_setting, question,
exclude_paragraph_id_list=None, exclude_paragraph_id_list=None,
**kwargs) -> NodeResult: **kwargs) -> NodeResult:
self.context['question'] = question self.context['question'] = question
embedding_model = EmbeddingModel.get_embedding_model() if len(dataset_id_list) == 0:
return get_none_result(question)
model_id = get_embedding_id(dataset_id_list)
embedding_model = EmbeddingModelManage.get_model(model_id, get_model_by_id)
embedding_value = embedding_model.embed_query(question) embedding_value = embedding_model.embed_query(question)
vector = VectorStore.get_embedding_vector() vector = VectorStore.get_embedding_vector()
exclude_document_id_list = [str(document.id) for document in exclude_document_id_list = [str(document.id) for document in
@ -37,7 +64,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
exclude_paragraph_id_list, True, dataset_setting.get('top_n'), exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode'))) dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
if embedding_list is None: if embedding_list is None:
return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {}) return get_none_result(question)
paragraph_list = self.list_paragraph(embedding_list, vector) paragraph_list = self.list_paragraph(embedding_list, vector)
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list] result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
return NodeResult({'paragraph_list': result, return NodeResult({'paragraph_list': result,

View File

@ -26,7 +26,7 @@ from rest_framework import serializers
from application.flow.workflow_manage import Flow from application.flow.workflow_manage import Flow
from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
from common.config.embedding_config import VectorStore, EmbeddingModel from common.config.embedding_config import VectorStore
from common.constants.authentication_type import AuthenticationType from common.constants.authentication_type import AuthenticationType
from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list from common.db.sql_execute import select_list
@ -36,7 +36,7 @@ from common.util.common import valid_license
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from dataset.models import DataSet, Document, Image from dataset.models import DataSet, Document, Image
from dataset.serializers.common_serializers import list_paragraph from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list
from embedding.models import SearchMode from embedding.models import SearchMode
from setting.models import AuthOperate from setting.models import AuthOperate
from setting.models.model_management import Model from setting.models.model_management import Model
@ -415,12 +415,13 @@ class ApplicationSerializer(serializers.Serializer):
QuerySet(Document).filter( QuerySet(Document).filter(
dataset_id__in=dataset_id_list, dataset_id__in=dataset_id_list,
is_active=False)] is_active=False)]
model = get_embedding_model_by_dataset_id_list(dataset_id_list)
# 向量库检索 # 向量库检索
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list, hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
self.data.get('top_number'), self.data.get('top_number'),
self.data.get('similarity'), self.data.get('similarity'),
SearchMode(self.data.get('search_mode')), SearchMode(self.data.get('search_mode')),
EmbeddingModel.get_embedding_model()) model)
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'), return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),

View File

@ -41,3 +41,7 @@ class MemCache(LocMemCache):
delete_keys.append(key) delete_keys.append(key)
for key in delete_keys: for key in delete_keys:
self._delete(key) self._delete(key)
def clear_timeout_data(self):
for key in self._cache.keys():
self.get(key)

View File

@ -6,33 +6,31 @@
@date2023/10/23 16:03 @date2023/10/23 16:03
@desc: @desc:
""" """
from langchain_huggingface.embeddings import HuggingFaceEmbeddings import time
from smartdoc.const import CONFIG from common.cache.mem_cache import MemCache
class EmbeddingModel: class EmbeddingModelManage:
instance = None cache = MemCache('model', {})
up_clear_time = time.time()
@staticmethod @staticmethod
def get_embedding_model(): def get_model(_id, get_model):
""" model_instance = EmbeddingModelManage.cache.get(_id)
获取向量化模型 if model_instance is None:
:return: model_instance = get_model(_id)
""" EmbeddingModelManage.cache.set(_id, model_instance, timeout=60 * 30)
if EmbeddingModel.instance is None: return model_instance
model_name = CONFIG.get('EMBEDDING_MODEL_NAME') # 续期
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH') EmbeddingModelManage.cache.touch(_id, timeout=60 * 30)
device = CONFIG.get('EMBEDDING_DEVICE') EmbeddingModelManage.clear_timeout_cache()
encode_kwargs = {'normalize_embeddings': True} return model_instance
e = HuggingFaceEmbeddings(
model_name=model_name, @staticmethod
cache_folder=cache_folder, def clear_timeout_cache():
model_kwargs={'device': device}, if time.time() - EmbeddingModelManage.up_clear_time > 60:
encode_kwargs=encode_kwargs, EmbeddingModelManage.cache.clear_timeout_data()
)
EmbeddingModel.instance = e
return EmbeddingModel.instance
class VectorStore: class VectorStore:

View File

@ -15,8 +15,9 @@ from typing import List
import django.db.models import django.db.models
from blinker import signal from blinker import signal
from django.db.models import QuerySet from django.db.models import QuerySet
from langchain_core.embeddings import Embeddings
from common.config.embedding_config import VectorStore, EmbeddingModel from common.config.embedding_config import VectorStore
from common.db.search import native_search, get_dynamics_model from common.db.search import native_search, get_dynamics_model
from common.event.common import poxy, embedding_poxy from common.event.common import poxy, embedding_poxy
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
@ -89,11 +90,11 @@ class ListenerManagement:
@staticmethod @staticmethod
@embedding_poxy @embedding_poxy
def embedding_by_paragraph(paragraph_id): def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
""" """
向量化段落 根据段落id 向量化段落 根据段落id
:param paragraph_id: 段落id @param paragraph_id: 段落id
:return: None @param embedding_model: 向量模型
""" """
max_kb.info(f"开始--->向量化段落:{paragraph_id}") max_kb.info(f"开始--->向量化段落:{paragraph_id}")
status = Status.success status = Status.success
@ -107,7 +108,7 @@ class ListenerManagement:
# 删除段落 # 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id) VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
# 批量向量化 # 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list) VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
except Exception as e: except Exception as e:
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}') max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error status = Status.error
@ -117,10 +118,11 @@ class ListenerManagement:
@staticmethod @staticmethod
@embedding_poxy @embedding_poxy
def embedding_by_document(document_id): def embedding_by_document(document_id, embedding_model: Embeddings):
""" """
向量化文档 向量化文档
:param document_id: 文档id @param document_id: 文档id
@param embedding_model 向量模型
:return: None :return: None
""" """
max_kb.info(f"开始--->向量化文档:{document_id}") max_kb.info(f"开始--->向量化文档:{document_id}")
@ -138,7 +140,7 @@ class ListenerManagement:
# 删除文档向量数据 # 删除文档向量数据
VectorStore.get_embedding_vector().delete_by_document_id(document_id) VectorStore.get_embedding_vector().delete_by_document_id(document_id)
# 批量向量化 # 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list) VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
except Exception as e: except Exception as e:
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}') max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error status = Status.error
@ -151,10 +153,11 @@ class ListenerManagement:
@staticmethod @staticmethod
@embedding_poxy @embedding_poxy
def embedding_by_dataset(dataset_id): def embedding_by_dataset(dataset_id, embedding_model: Embeddings):
""" """
向量化知识库 向量化知识库
:param dataset_id: 知识库id @param dataset_id: 知识库id
@param embedding_model 向量模型
:return: None :return: None
""" """
max_kb.info(f"开始--->向量化数据集:{dataset_id}") max_kb.info(f"开始--->向量化数据集:{dataset_id}")
@ -162,7 +165,7 @@ class ListenerManagement:
document_list = QuerySet(Document).filter(dataset_id=dataset_id) document_list = QuerySet(Document).filter(dataset_id=dataset_id)
max_kb.info(f"数据集文档:{[d.name for d in document_list]}") max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
for document in document_list: for document in document_list:
ListenerManagement.embedding_by_document(document.id) ListenerManagement.embedding_by_document(document.id, embedding_model)
except Exception as e: except Exception as e:
max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}') max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
finally: finally:
@ -245,11 +248,6 @@ class ListenerManagement:
def delete_embedding_by_dataset_id_list(source_ids: List[str]): def delete_embedding_by_dataset_id_list(source_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids) VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids)
@staticmethod
@poxy
def init_embedding_model(ages):
EmbeddingModel.get_embedding_model()
def run(self): def run(self):
# 添加向量 根据问题id # 添加向量 根据问题id
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem) ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
@ -276,8 +274,7 @@ class ListenerManagement:
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph) ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
# 启动段落向量 # 启动段落向量
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph) ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
# 初始化向量化模型
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)
# 同步web站点知识库 # 同步web站点知识库
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset) ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
# 同步web站点 文档 # 同步web站点 文档

View File

@ -1,4 +1,4 @@
# Generated by Django 4.2.13 on 2024-07-15 15:56 # Generated by Django 4.2.13 on 2024-07-17 13:56
import dataset.models.data_set import dataset.models.data_set
from django.db import migrations, models from django.db import migrations, models
@ -15,7 +15,7 @@ class Migration(migrations.Migration):
operations = [ operations = [
migrations.AddField( migrations.AddField(
model_name='dataset', model_name='dataset',
name='embedding_mode_id', name='embedding_mode',
field=models.ForeignKey(default=dataset.models.data_set.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model', verbose_name='向量模型'), field=models.ForeignKey(default=dataset.models.data_set.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model', verbose_name='向量模型'),
), ),
] ]

View File

@ -49,8 +49,8 @@ class DataSet(AppModelMixin):
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户") user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices, type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
default=Type.base) default=Type.base)
embedding_mode_id = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型", embedding_mode = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
default=default_model) default=default_model)
meta = models.JSONField(verbose_name="元数据", default=dict) meta = models.JSONField(verbose_name="元数据", default=dict)
class Meta: class Meta:

View File

@ -14,6 +14,7 @@ from django.db.models import QuerySet
from drf_yasg import openapi from drf_yasg import openapi
from rest_framework import serializers from rest_framework import serializers
from common.config.embedding_config import EmbeddingModelManage
from common.db.search import native_search from common.db.search import native_search
from common.db.sql_execute import update_execute from common.db.sql_execute import update_execute
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
@ -21,7 +22,8 @@ from common.mixins.api_mixin import ApiMixin
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from common.util.fork import Fork from common.util.fork import Fork
from dataset.models import Paragraph, Problem, ProblemParagraphMapping from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
@ -130,3 +132,18 @@ class ProblemParagraphManage:
result = [problem_model for problem_model, is_create in problem_content_dict.values() if result = [problem_model for problem_model, is_create in problem_content_dict.values() if
is_create], problem_paragraph_mapping_list is_create], problem_paragraph_mapping_list
return result return result
def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception("知识库未向量模型不一致")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return EmbeddingModelManage.get_model(str(dataset_list[0].id),
lambda _id: get_model(dataset_list[0].embedding_mode))
def get_embedding_model_by_dataset_id(dataset_id: str):
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id)
return EmbeddingModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))

View File

@ -15,7 +15,6 @@ from functools import reduce
from typing import Dict, List from typing import Dict, List
from urllib.parse import urlparse from urllib.parse import urlparse
from django.conf import settings
from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields import ArrayField
from django.core import validators from django.core import validators
from django.db import transaction, models from django.db import transaction, models
@ -25,7 +24,7 @@ from drf_yasg import openapi
from rest_framework import serializers from rest_framework import serializers
from application.models import ApplicationDatasetMapping from application.models import ApplicationDatasetMapping
from common.config.embedding_config import VectorStore, EmbeddingModel from common.config.embedding_config import VectorStore
from common.db.search import get_dynamics_model, native_page_search, native_search from common.db.search import get_dynamics_model, native_page_search, native_search
from common.db.sql_execute import select_list from common.db.sql_execute import select_list
from common.event import ListenerManagement, SyncWebDatasetArgs from common.event import ListenerManagement, SyncWebDatasetArgs
@ -37,7 +36,8 @@ from common.util.file_util import get_file_content
from common.util.fork import ChildLink, Fork from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_by_dataset_id
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from embedding.models import SearchMode from embedding.models import SearchMode
from setting.models import AuthOperate from setting.models import AuthOperate
@ -359,8 +359,9 @@ class DataSetSerializers(serializers.ModelSerializer):
@staticmethod @staticmethod
def post_embedding_dataset(document_list, dataset_id): def post_embedding_dataset(document_list, dataset_id):
model = get_embedding_model_by_dataset_id(dataset_id)
# 发送向量化事件 # 发送向量化事件
ListenerManagement.embedding_by_dataset_signal.send(dataset_id) ListenerManagement.embedding_by_dataset_signal.send(dataset_id, embedding_model=model)
return document_list return document_list
def save_qa(self, instance: Dict, with_valid=True): def save_qa(self, instance: Dict, with_valid=True):
@ -565,12 +566,13 @@ class DataSetSerializers(serializers.ModelSerializer):
QuerySet(Document).filter( QuerySet(Document).filter(
dataset_id=self.data.get('id'), dataset_id=self.data.get('id'),
is_active=False)] is_active=False)]
model = get_embedding_model_by_dataset_id(self.data.get('id'))
# 向量库检索 # 向量库检索
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list, hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list,
self.data.get('top_number'), self.data.get('top_number'),
self.data.get('similarity'), self.data.get('similarity'),
SearchMode(self.data.get('search_mode')), SearchMode(self.data.get('search_mode')),
EmbeddingModel.get_embedding_model()) model)
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'), return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),

View File

@ -10,9 +10,8 @@ import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Dict from typing import List, Dict
from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_core.embeddings import Embeddings
from common.config.embedding_config import EmbeddingModel
from common.util.common import sub_array from common.util.common import sub_array
from embedding.models import SourceType, SearchMode from embedding.models import SourceType, SearchMode
@ -51,7 +50,7 @@ class BaseVectorStore(ABC):
def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool, is_active: bool,
embedding=None): embedding: Embeddings):
""" """
插入向量数据 插入向量数据
:param source_id: 资源id :param source_id: 资源id
@ -64,13 +63,10 @@ class BaseVectorStore(ABC):
:param paragraph_id 段落id :param paragraph_id 段落id
:return: bool :return: bool
""" """
if embedding is None:
embedding = EmbeddingModel.get_embedding_model()
self.save_pre_handler() self.save_pre_handler()
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding) self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
def batch_save(self, data_list: List[Dict], embedding=None): def batch_save(self, data_list: List[Dict], embedding: Embeddings):
# 获取锁 # 获取锁
lock.acquire() lock.acquire()
try: try:
@ -80,8 +76,6 @@ class BaseVectorStore(ABC):
:param embedding: 向量化处理器 :param embedding: 向量化处理器
:return: bool :return: bool
""" """
if embedding is None:
embedding = EmbeddingModel.get_embedding_model()
self.save_pre_handler() self.save_pre_handler()
result = sub_array(data_list) result = sub_array(data_list)
for child_array in result: for child_array in result:
@ -94,17 +88,17 @@ class BaseVectorStore(ABC):
@abstractmethod @abstractmethod
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool, is_active: bool,
embedding: HuggingFaceEmbeddings): embedding: Embeddings):
pass pass
@abstractmethod @abstractmethod
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings): def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
pass pass
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], exclude_paragraph_list: list[str],
is_active: bool, is_active: bool,
embedding: HuggingFaceEmbeddings): embedding: Embeddings):
if dataset_id_list is None or len(dataset_id_list) == 0: if dataset_id_list is None or len(dataset_id_list) == 0:
return [] return []
embedding_query = embedding.embed_query(query_text) embedding_query = embedding.embed_query(query_text)
@ -123,7 +117,7 @@ class BaseVectorStore(ABC):
def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int, def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float, similarity: float,
search_mode: SearchMode, search_mode: SearchMode,
embedding: HuggingFaceEmbeddings): embedding: Embeddings):
pass pass
@abstractmethod @abstractmethod
@ -142,14 +136,6 @@ class BaseVectorStore(ABC):
def update_by_source_ids(self, source_ids: List[str], instance: Dict): def update_by_source_ids(self, source_ids: List[str], instance: Dict):
pass pass
@abstractmethod
def embed_documents(self, text_list: List[str]):
pass
@abstractmethod
def embed_query(self, text: str):
pass
@abstractmethod @abstractmethod
def delete_by_dataset_id(self, dataset_id: str): def delete_by_dataset_id(self, dataset_id: str):
pass pass

View File

@ -13,9 +13,8 @@ from abc import ABC, abstractmethod
from typing import Dict, List from typing import Dict, List
from django.db.models import QuerySet from django.db.models import QuerySet
from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_core.embeddings import Embeddings
from common.config.embedding_config import EmbeddingModel
from common.db.search import generate_sql_by_query_dict from common.db.search import generate_sql_by_query_dict
from common.db.sql_execute import select_list from common.db.sql_execute import select_list
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
@ -33,14 +32,6 @@ class PGVector(BaseVectorStore):
def update_by_source_ids(self, source_ids: List[str], instance: Dict): def update_by_source_ids(self, source_ids: List[str], instance: Dict):
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance) QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
def embed_documents(self, text_list: List[str]):
embedding = EmbeddingModel.get_embedding_model()
return embedding.embed_documents(text_list)
def embed_query(self, text: str):
embedding = EmbeddingModel.get_embedding_model()
return embedding.embed_query(text)
def vector_is_create(self) -> bool: def vector_is_create(self) -> bool:
# 项目启动默认是创建好的 不需要再创建 # 项目启动默认是创建好的 不需要再创建
return True return True
@ -50,7 +41,7 @@ class PGVector(BaseVectorStore):
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool, is_active: bool,
embedding: HuggingFaceEmbeddings): embedding: Embeddings):
text_embedding = embedding.embed_query(text) text_embedding = embedding.embed_query(text)
embedding = Embedding(id=uuid.uuid1(), embedding = Embedding(id=uuid.uuid1(),
dataset_id=dataset_id, dataset_id=dataset_id,
@ -64,7 +55,7 @@ class PGVector(BaseVectorStore):
embedding.save() embedding.save()
return True return True
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings): def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
texts = [row.get('text') for row in text_list] texts = [row.get('text') for row in text_list]
embeddings = embedding.embed_documents(texts) embeddings = embedding.embed_documents(texts)
embedding_list = [Embedding(id=uuid.uuid1(), embedding_list = [Embedding(id=uuid.uuid1(),
@ -83,7 +74,7 @@ class PGVector(BaseVectorStore):
def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int, def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float, similarity: float,
search_mode: SearchMode, search_mode: SearchMode,
embedding: HuggingFaceEmbeddings): embedding: Embeddings):
if dataset_id_list is None or len(dataset_id_list) == 0: if dataset_id_list is None or len(dataset_id_list) == 0:
return [] return []
exclude_dict = {} exclude_dict = {}

View File

@ -6,3 +6,85 @@
@date2023/10/31 17:16 @date2023/10/31 17:16
@desc: @desc:
""" """
import json
from typing import Dict
from common.util.rsa_util import rsa_long_decrypt
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
def get_model_(provider, model_type, model_name, credential):
"""
获取模型实例
@param provider: 供应商
@param model_type: 模型类型
@param model_name: 模型名称
@param credential: 认证信息
@return: 模型实例
"""
model = get_provider(provider).get_model(model_type, model_name,
json.loads(
rsa_long_decrypt(credential)),
streaming=True)
return model
def get_model(model):
"""
获取模型实例
@param model: model 数据库Model实例对象
@return: 模型实例
"""
return get_model_(model.provider, model.model_type, model.model_name, model.credential)
def get_provider(provider):
"""
获取供应商实例
@param provider: 供应商字符串
@return: 供应商实例
"""
return ModelProvideConstants[provider].value
def get_model_list(provider, model_type):
"""
获取模型列表
@param provider: 供应商字符串
@param model_type: 模型类型
@return: 模型列表
"""
return get_provider(provider).get_model_list(model_type)
def get_model_credential(provider, model_type, model_name):
"""
获取模型认证实例
@param provider: 供应商字符串
@param model_type: 模型类型
@param model_name: 模型名称
@return: 认证实例对象
"""
return get_provider(provider).get_model_credential(model_type, model_name)
def get_model_type_list(provider):
"""
获取模型类型列表
@param provider: 供应商字符串
@return: 模型类型列表
"""
return get_provider(provider).get_model_type_list()
def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
"""
校验模型认证参数
@param provider: 供应商字符串
@param model_type: 模型类型
@param model_name: 模型名称
@param model_credential: 模型认证数据
@param raise_exception: 是否抛出错误
@return: True|False
"""
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception)

View File

@ -104,9 +104,6 @@ CACHES = {
"token_cache": { "token_cache": {
'BACKEND': 'common.cache.file_cache.FileCache', 'BACKEND': 'common.cache.file_cache.FileCache',
'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径 'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径
},
"chat_cache": {
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
} }
} }