feat: 支持向量模型
This commit is contained in:
parent
9b81b89975
commit
bd4303aee7
@ -13,22 +13,45 @@ from django.db.models import QuerySet
|
|||||||
|
|
||||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||||
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
|
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
|
||||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
from common.config.embedding_config import VectorStore, EmbeddingModelManage
|
||||||
from common.db.search import native_search
|
from common.db.search import native_search
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from dataset.models import Paragraph
|
from dataset.models import Paragraph, DataSet
|
||||||
from embedding.models import SearchMode
|
from embedding.models import SearchMode
|
||||||
|
from setting.models import Model
|
||||||
|
from setting.models_provider import get_model
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_by_id(_id):
|
||||||
|
model = QuerySet(Model).filter(id=_id).first()
|
||||||
|
if model is None:
|
||||||
|
raise Exception("模型不存在")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_id(dataset_id_list):
|
||||||
|
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||||
|
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||||
|
raise Exception("知识库未向量模型不一致")
|
||||||
|
if len(dataset_list) == 0:
|
||||||
|
raise Exception("知识库设置错误,请重新设置知识库")
|
||||||
|
return dataset_list[0].embedding_mode_id
|
||||||
|
|
||||||
|
|
||||||
class BaseSearchDatasetStep(ISearchDatasetStep):
|
class BaseSearchDatasetStep(ISearchDatasetStep):
|
||||||
|
|
||||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||||
search_mode: str = None,
|
search_mode: str = None,
|
||||||
**kwargs) -> List[ParagraphPipelineModel]:
|
**kwargs) -> List[ParagraphPipelineModel]:
|
||||||
|
if len(dataset_id_list) == 0:
|
||||||
|
return []
|
||||||
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||||
embedding_model = EmbeddingModel.get_embedding_model()
|
model_id = get_embedding_id(dataset_id_list)
|
||||||
|
model = get_model_by_id(model_id)
|
||||||
|
self.context['model_name'] = model.name
|
||||||
|
embedding_model = EmbeddingModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||||
embedding_value = embedding_model.embed_query(exec_problem_text)
|
embedding_value = embedding_model.embed_query(exec_problem_text)
|
||||||
vector = VectorStore.get_embedding_vector()
|
vector = VectorStore.get_embedding_vector()
|
||||||
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
|
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
|
||||||
@ -101,7 +124,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
|
|||||||
'run_time': self.context['run_time'],
|
'run_time': self.context['run_time'],
|
||||||
'problem_text': step_args.get(
|
'problem_text': step_args.get(
|
||||||
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
|
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
|
||||||
'model_name': EmbeddingModel.get_embedding_model().model_name,
|
'model_name': self.context.get('model_name'),
|
||||||
'message_tokens': 0,
|
'message_tokens': 0,
|
||||||
'answer_tokens': 0,
|
'answer_tokens': 0,
|
||||||
'cost': 0
|
'cost': 0
|
||||||
|
|||||||
@ -13,20 +13,47 @@ from django.db.models import QuerySet
|
|||||||
|
|
||||||
from application.flow.i_step_node import NodeResult
|
from application.flow.i_step_node import NodeResult
|
||||||
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
|
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
|
||||||
from common.config.embedding_config import EmbeddingModel, VectorStore
|
from common.config.embedding_config import VectorStore, EmbeddingModelManage
|
||||||
from common.db.search import native_search
|
from common.db.search import native_search
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from dataset.models import Document, Paragraph
|
from dataset.models import Document, Paragraph, DataSet
|
||||||
from embedding.models import SearchMode
|
from embedding.models import SearchMode
|
||||||
|
from setting.models import Model
|
||||||
|
from setting.models_provider import get_model
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_by_id(_id):
|
||||||
|
model = QuerySet(Model).filter(id=_id).first()
|
||||||
|
if model is None:
|
||||||
|
raise Exception("模型不存在")
|
||||||
|
return get_model(model)
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_id(dataset_id_list):
|
||||||
|
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||||
|
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||||
|
raise Exception("知识库未向量模型不一致")
|
||||||
|
if len(dataset_list) == 0:
|
||||||
|
raise Exception("知识库设置错误,请重新设置知识库")
|
||||||
|
return dataset_list[0].embedding_mode_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_none_result(question):
|
||||||
|
return NodeResult(
|
||||||
|
{'paragraph_list': [], 'is_hit_handling_method': [], 'question': question, 'data': '',
|
||||||
|
'directly_return': ''}, {})
|
||||||
|
|
||||||
|
|
||||||
class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
||||||
def execute(self, dataset_id_list, dataset_setting, question,
|
def execute(self, dataset_id_list, dataset_setting, question,
|
||||||
exclude_paragraph_id_list=None,
|
exclude_paragraph_id_list=None,
|
||||||
**kwargs) -> NodeResult:
|
**kwargs) -> NodeResult:
|
||||||
self.context['question'] = question
|
self.context['question'] = question
|
||||||
embedding_model = EmbeddingModel.get_embedding_model()
|
if len(dataset_id_list) == 0:
|
||||||
|
return get_none_result(question)
|
||||||
|
model_id = get_embedding_id(dataset_id_list)
|
||||||
|
embedding_model = EmbeddingModelManage.get_model(model_id, get_model_by_id)
|
||||||
embedding_value = embedding_model.embed_query(question)
|
embedding_value = embedding_model.embed_query(question)
|
||||||
vector = VectorStore.get_embedding_vector()
|
vector = VectorStore.get_embedding_vector()
|
||||||
exclude_document_id_list = [str(document.id) for document in
|
exclude_document_id_list = [str(document.id) for document in
|
||||||
@ -37,7 +64,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
|||||||
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
|
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
|
||||||
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
|
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
|
||||||
if embedding_list is None:
|
if embedding_list is None:
|
||||||
return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {})
|
return get_none_result(question)
|
||||||
paragraph_list = self.list_paragraph(embedding_list, vector)
|
paragraph_list = self.list_paragraph(embedding_list, vector)
|
||||||
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
|
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
|
||||||
return NodeResult({'paragraph_list': result,
|
return NodeResult({'paragraph_list': result,
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from rest_framework import serializers
|
|||||||
from application.flow.workflow_manage import Flow
|
from application.flow.workflow_manage import Flow
|
||||||
from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion
|
from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion
|
||||||
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
|
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
|
||||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
from common.config.embedding_config import VectorStore
|
||||||
from common.constants.authentication_type import AuthenticationType
|
from common.constants.authentication_type import AuthenticationType
|
||||||
from common.db.search import get_dynamics_model, native_search, native_page_search
|
from common.db.search import get_dynamics_model, native_search, native_page_search
|
||||||
from common.db.sql_execute import select_list
|
from common.db.sql_execute import select_list
|
||||||
@ -36,7 +36,7 @@ from common.util.common import valid_license
|
|||||||
from common.util.field_message import ErrMessage
|
from common.util.field_message import ErrMessage
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from dataset.models import DataSet, Document, Image
|
from dataset.models import DataSet, Document, Image
|
||||||
from dataset.serializers.common_serializers import list_paragraph
|
from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list
|
||||||
from embedding.models import SearchMode
|
from embedding.models import SearchMode
|
||||||
from setting.models import AuthOperate
|
from setting.models import AuthOperate
|
||||||
from setting.models.model_management import Model
|
from setting.models.model_management import Model
|
||||||
@ -415,12 +415,13 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
QuerySet(Document).filter(
|
QuerySet(Document).filter(
|
||||||
dataset_id__in=dataset_id_list,
|
dataset_id__in=dataset_id_list,
|
||||||
is_active=False)]
|
is_active=False)]
|
||||||
|
model = get_embedding_model_by_dataset_id_list(dataset_id_list)
|
||||||
# 向量库检索
|
# 向量库检索
|
||||||
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
|
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
|
||||||
self.data.get('top_number'),
|
self.data.get('top_number'),
|
||||||
self.data.get('similarity'),
|
self.data.get('similarity'),
|
||||||
SearchMode(self.data.get('search_mode')),
|
SearchMode(self.data.get('search_mode')),
|
||||||
EmbeddingModel.get_embedding_model())
|
model)
|
||||||
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
||||||
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
||||||
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
|
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
|
||||||
|
|||||||
4
apps/common/cache/mem_cache.py
vendored
4
apps/common/cache/mem_cache.py
vendored
@ -41,3 +41,7 @@ class MemCache(LocMemCache):
|
|||||||
delete_keys.append(key)
|
delete_keys.append(key)
|
||||||
for key in delete_keys:
|
for key in delete_keys:
|
||||||
self._delete(key)
|
self._delete(key)
|
||||||
|
|
||||||
|
def clear_timeout_data(self):
|
||||||
|
for key in self._cache.keys():
|
||||||
|
self.get(key)
|
||||||
|
|||||||
@ -6,33 +6,31 @@
|
|||||||
@date:2023/10/23 16:03
|
@date:2023/10/23 16:03
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
|
import time
|
||||||
|
|
||||||
from smartdoc.const import CONFIG
|
from common.cache.mem_cache import MemCache
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingModel:
|
class EmbeddingModelManage:
|
||||||
instance = None
|
cache = MemCache('model', {})
|
||||||
|
up_clear_time = time.time()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_embedding_model():
|
def get_model(_id, get_model):
|
||||||
"""
|
model_instance = EmbeddingModelManage.cache.get(_id)
|
||||||
获取向量化模型
|
if model_instance is None:
|
||||||
:return:
|
model_instance = get_model(_id)
|
||||||
"""
|
EmbeddingModelManage.cache.set(_id, model_instance, timeout=60 * 30)
|
||||||
if EmbeddingModel.instance is None:
|
return model_instance
|
||||||
model_name = CONFIG.get('EMBEDDING_MODEL_NAME')
|
# 续期
|
||||||
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH')
|
EmbeddingModelManage.cache.touch(_id, timeout=60 * 30)
|
||||||
device = CONFIG.get('EMBEDDING_DEVICE')
|
EmbeddingModelManage.clear_timeout_cache()
|
||||||
encode_kwargs = {'normalize_embeddings': True}
|
return model_instance
|
||||||
e = HuggingFaceEmbeddings(
|
|
||||||
model_name=model_name,
|
@staticmethod
|
||||||
cache_folder=cache_folder,
|
def clear_timeout_cache():
|
||||||
model_kwargs={'device': device},
|
if time.time() - EmbeddingModelManage.up_clear_time > 60:
|
||||||
encode_kwargs=encode_kwargs,
|
EmbeddingModelManage.cache.clear_timeout_data()
|
||||||
)
|
|
||||||
EmbeddingModel.instance = e
|
|
||||||
return EmbeddingModel.instance
|
|
||||||
|
|
||||||
|
|
||||||
class VectorStore:
|
class VectorStore:
|
||||||
|
|||||||
@ -15,8 +15,9 @@ from typing import List
|
|||||||
import django.db.models
|
import django.db.models
|
||||||
from blinker import signal
|
from blinker import signal
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
from common.config.embedding_config import VectorStore
|
||||||
from common.db.search import native_search, get_dynamics_model
|
from common.db.search import native_search, get_dynamics_model
|
||||||
from common.event.common import poxy, embedding_poxy
|
from common.event.common import poxy, embedding_poxy
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
@ -89,11 +90,11 @@ class ListenerManagement:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@embedding_poxy
|
@embedding_poxy
|
||||||
def embedding_by_paragraph(paragraph_id):
|
def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
|
||||||
"""
|
"""
|
||||||
向量化段落 根据段落id
|
向量化段落 根据段落id
|
||||||
:param paragraph_id: 段落id
|
@param paragraph_id: 段落id
|
||||||
:return: None
|
@param embedding_model: 向量模型
|
||||||
"""
|
"""
|
||||||
max_kb.info(f"开始--->向量化段落:{paragraph_id}")
|
max_kb.info(f"开始--->向量化段落:{paragraph_id}")
|
||||||
status = Status.success
|
status = Status.success
|
||||||
@ -107,7 +108,7 @@ class ListenerManagement:
|
|||||||
# 删除段落
|
# 删除段落
|
||||||
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
|
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
|
||||||
# 批量向量化
|
# 批量向量化
|
||||||
VectorStore.get_embedding_vector().batch_save(data_list)
|
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
|
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||||
status = Status.error
|
status = Status.error
|
||||||
@ -117,10 +118,11 @@ class ListenerManagement:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@embedding_poxy
|
@embedding_poxy
|
||||||
def embedding_by_document(document_id):
|
def embedding_by_document(document_id, embedding_model: Embeddings):
|
||||||
"""
|
"""
|
||||||
向量化文档
|
向量化文档
|
||||||
:param document_id: 文档id
|
@param document_id: 文档id
|
||||||
|
@param embedding_model 向量模型
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
max_kb.info(f"开始--->向量化文档:{document_id}")
|
max_kb.info(f"开始--->向量化文档:{document_id}")
|
||||||
@ -138,7 +140,7 @@ class ListenerManagement:
|
|||||||
# 删除文档向量数据
|
# 删除文档向量数据
|
||||||
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
|
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
|
||||||
# 批量向量化
|
# 批量向量化
|
||||||
VectorStore.get_embedding_vector().batch_save(data_list)
|
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
|
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||||
status = Status.error
|
status = Status.error
|
||||||
@ -151,10 +153,11 @@ class ListenerManagement:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@embedding_poxy
|
@embedding_poxy
|
||||||
def embedding_by_dataset(dataset_id):
|
def embedding_by_dataset(dataset_id, embedding_model: Embeddings):
|
||||||
"""
|
"""
|
||||||
向量化知识库
|
向量化知识库
|
||||||
:param dataset_id: 知识库id
|
@param dataset_id: 知识库id
|
||||||
|
@param embedding_model 向量模型
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
max_kb.info(f"开始--->向量化数据集:{dataset_id}")
|
max_kb.info(f"开始--->向量化数据集:{dataset_id}")
|
||||||
@ -162,7 +165,7 @@ class ListenerManagement:
|
|||||||
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
|
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
|
||||||
max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
|
max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
|
||||||
for document in document_list:
|
for document in document_list:
|
||||||
ListenerManagement.embedding_by_document(document.id)
|
ListenerManagement.embedding_by_document(document.id, embedding_model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
|
max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||||
finally:
|
finally:
|
||||||
@ -245,11 +248,6 @@ class ListenerManagement:
|
|||||||
def delete_embedding_by_dataset_id_list(source_ids: List[str]):
|
def delete_embedding_by_dataset_id_list(source_ids: List[str]):
|
||||||
VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids)
|
VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@poxy
|
|
||||||
def init_embedding_model(ages):
|
|
||||||
EmbeddingModel.get_embedding_model()
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
# 添加向量 根据问题id
|
# 添加向量 根据问题id
|
||||||
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
|
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
|
||||||
@ -276,8 +274,7 @@ class ListenerManagement:
|
|||||||
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
|
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
|
||||||
# 启动段落向量
|
# 启动段落向量
|
||||||
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
|
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
|
||||||
# 初始化向量化模型
|
|
||||||
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)
|
|
||||||
# 同步web站点知识库
|
# 同步web站点知识库
|
||||||
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
|
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
|
||||||
# 同步web站点 文档
|
# 同步web站点 文档
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
# Generated by Django 4.2.13 on 2024-07-15 15:56
|
# Generated by Django 4.2.13 on 2024-07-17 13:56
|
||||||
|
|
||||||
import dataset.models.data_set
|
import dataset.models.data_set
|
||||||
from django.db import migrations, models
|
from django.db import migrations, models
|
||||||
@ -15,7 +15,7 @@ class Migration(migrations.Migration):
|
|||||||
operations = [
|
operations = [
|
||||||
migrations.AddField(
|
migrations.AddField(
|
||||||
model_name='dataset',
|
model_name='dataset',
|
||||||
name='embedding_mode_id',
|
name='embedding_mode',
|
||||||
field=models.ForeignKey(default=dataset.models.data_set.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model', verbose_name='向量模型'),
|
field=models.ForeignKey(default=dataset.models.data_set.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model', verbose_name='向量模型'),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@ -49,8 +49,8 @@ class DataSet(AppModelMixin):
|
|||||||
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
|
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
|
||||||
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
|
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
|
||||||
default=Type.base)
|
default=Type.base)
|
||||||
embedding_mode_id = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
|
embedding_mode = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
|
||||||
default=default_model)
|
default=default_model)
|
||||||
meta = models.JSONField(verbose_name="元数据", default=dict)
|
meta = models.JSONField(verbose_name="元数据", default=dict)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from django.db.models import QuerySet
|
|||||||
from drf_yasg import openapi
|
from drf_yasg import openapi
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
from common.config.embedding_config import EmbeddingModelManage
|
||||||
from common.db.search import native_search
|
from common.db.search import native_search
|
||||||
from common.db.sql_execute import update_execute
|
from common.db.sql_execute import update_execute
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
@ -21,7 +22,8 @@ from common.mixins.api_mixin import ApiMixin
|
|||||||
from common.util.field_message import ErrMessage
|
from common.util.field_message import ErrMessage
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from common.util.fork import Fork
|
from common.util.fork import Fork
|
||||||
from dataset.models import Paragraph, Problem, ProblemParagraphMapping
|
from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet
|
||||||
|
from setting.models_provider import get_model
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
@ -130,3 +132,18 @@ class ProblemParagraphManage:
|
|||||||
result = [problem_model for problem_model, is_create in problem_content_dict.values() if
|
result = [problem_model for problem_model, is_create in problem_content_dict.values() if
|
||||||
is_create], problem_paragraph_mapping_list
|
is_create], problem_paragraph_mapping_list
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
|
||||||
|
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||||
|
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||||
|
raise Exception("知识库未向量模型不一致")
|
||||||
|
if len(dataset_list) == 0:
|
||||||
|
raise Exception("知识库设置错误,请重新设置知识库")
|
||||||
|
return EmbeddingModelManage.get_model(str(dataset_list[0].id),
|
||||||
|
lambda _id: get_model(dataset_list[0].embedding_mode))
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_model_by_dataset_id(dataset_id: str):
|
||||||
|
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id)
|
||||||
|
return EmbeddingModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))
|
||||||
|
|||||||
@ -15,7 +15,6 @@ from functools import reduce
|
|||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from django.conf import settings
|
|
||||||
from django.contrib.postgres.fields import ArrayField
|
from django.contrib.postgres.fields import ArrayField
|
||||||
from django.core import validators
|
from django.core import validators
|
||||||
from django.db import transaction, models
|
from django.db import transaction, models
|
||||||
@ -25,7 +24,7 @@ from drf_yasg import openapi
|
|||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from application.models import ApplicationDatasetMapping
|
from application.models import ApplicationDatasetMapping
|
||||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
from common.config.embedding_config import VectorStore
|
||||||
from common.db.search import get_dynamics_model, native_page_search, native_search
|
from common.db.search import get_dynamics_model, native_page_search, native_search
|
||||||
from common.db.sql_execute import select_list
|
from common.db.sql_execute import select_list
|
||||||
from common.event import ListenerManagement, SyncWebDatasetArgs
|
from common.event import ListenerManagement, SyncWebDatasetArgs
|
||||||
@ -37,7 +36,8 @@ from common.util.file_util import get_file_content
|
|||||||
from common.util.fork import ChildLink, Fork
|
from common.util.fork import ChildLink, Fork
|
||||||
from common.util.split_model import get_split_model
|
from common.util.split_model import get_split_model
|
||||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
|
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
|
||||||
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage
|
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
|
||||||
|
get_embedding_model_by_dataset_id
|
||||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
||||||
from embedding.models import SearchMode
|
from embedding.models import SearchMode
|
||||||
from setting.models import AuthOperate
|
from setting.models import AuthOperate
|
||||||
@ -359,8 +359,9 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def post_embedding_dataset(document_list, dataset_id):
|
def post_embedding_dataset(document_list, dataset_id):
|
||||||
|
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||||
# 发送向量化事件
|
# 发送向量化事件
|
||||||
ListenerManagement.embedding_by_dataset_signal.send(dataset_id)
|
ListenerManagement.embedding_by_dataset_signal.send(dataset_id, embedding_model=model)
|
||||||
return document_list
|
return document_list
|
||||||
|
|
||||||
def save_qa(self, instance: Dict, with_valid=True):
|
def save_qa(self, instance: Dict, with_valid=True):
|
||||||
@ -565,12 +566,13 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
QuerySet(Document).filter(
|
QuerySet(Document).filter(
|
||||||
dataset_id=self.data.get('id'),
|
dataset_id=self.data.get('id'),
|
||||||
is_active=False)]
|
is_active=False)]
|
||||||
|
model = get_embedding_model_by_dataset_id(self.data.get('id'))
|
||||||
# 向量库检索
|
# 向量库检索
|
||||||
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list,
|
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list,
|
||||||
self.data.get('top_number'),
|
self.data.get('top_number'),
|
||||||
self.data.get('similarity'),
|
self.data.get('similarity'),
|
||||||
SearchMode(self.data.get('search_mode')),
|
SearchMode(self.data.get('search_mode')),
|
||||||
EmbeddingModel.get_embedding_model())
|
model)
|
||||||
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
||||||
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
||||||
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
|
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
|
||||||
|
|||||||
@ -10,9 +10,8 @@ import threading
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
from common.config.embedding_config import EmbeddingModel
|
|
||||||
from common.util.common import sub_array
|
from common.util.common import sub_array
|
||||||
from embedding.models import SourceType, SearchMode
|
from embedding.models import SourceType, SearchMode
|
||||||
|
|
||||||
@ -51,7 +50,7 @@ class BaseVectorStore(ABC):
|
|||||||
|
|
||||||
def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||||
is_active: bool,
|
is_active: bool,
|
||||||
embedding=None):
|
embedding: Embeddings):
|
||||||
"""
|
"""
|
||||||
插入向量数据
|
插入向量数据
|
||||||
:param source_id: 资源id
|
:param source_id: 资源id
|
||||||
@ -64,13 +63,10 @@ class BaseVectorStore(ABC):
|
|||||||
:param paragraph_id 段落id
|
:param paragraph_id 段落id
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if embedding is None:
|
|
||||||
embedding = EmbeddingModel.get_embedding_model()
|
|
||||||
self.save_pre_handler()
|
self.save_pre_handler()
|
||||||
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
|
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
|
||||||
|
|
||||||
def batch_save(self, data_list: List[Dict], embedding=None):
|
def batch_save(self, data_list: List[Dict], embedding: Embeddings):
|
||||||
# 获取锁
|
# 获取锁
|
||||||
lock.acquire()
|
lock.acquire()
|
||||||
try:
|
try:
|
||||||
@ -80,8 +76,6 @@ class BaseVectorStore(ABC):
|
|||||||
:param embedding: 向量化处理器
|
:param embedding: 向量化处理器
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
if embedding is None:
|
|
||||||
embedding = EmbeddingModel.get_embedding_model()
|
|
||||||
self.save_pre_handler()
|
self.save_pre_handler()
|
||||||
result = sub_array(data_list)
|
result = sub_array(data_list)
|
||||||
for child_array in result:
|
for child_array in result:
|
||||||
@ -94,17 +88,17 @@ class BaseVectorStore(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||||
is_active: bool,
|
is_active: bool,
|
||||||
embedding: HuggingFaceEmbeddings):
|
embedding: Embeddings):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
|
def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||||
exclude_paragraph_list: list[str],
|
exclude_paragraph_list: list[str],
|
||||||
is_active: bool,
|
is_active: bool,
|
||||||
embedding: HuggingFaceEmbeddings):
|
embedding: Embeddings):
|
||||||
if dataset_id_list is None or len(dataset_id_list) == 0:
|
if dataset_id_list is None or len(dataset_id_list) == 0:
|
||||||
return []
|
return []
|
||||||
embedding_query = embedding.embed_query(query_text)
|
embedding_query = embedding.embed_query(query_text)
|
||||||
@ -123,7 +117,7 @@ class BaseVectorStore(ABC):
|
|||||||
def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
|
def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
|
||||||
similarity: float,
|
similarity: float,
|
||||||
search_mode: SearchMode,
|
search_mode: SearchMode,
|
||||||
embedding: HuggingFaceEmbeddings):
|
embedding: Embeddings):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -142,14 +136,6 @@ class BaseVectorStore(ABC):
|
|||||||
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def embed_documents(self, text_list: List[str]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def embed_query(self, text: str):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete_by_dataset_id(self, dataset_id: str):
|
def delete_by_dataset_id(self, dataset_id: str):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -13,9 +13,8 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
from common.config.embedding_config import EmbeddingModel
|
|
||||||
from common.db.search import generate_sql_by_query_dict
|
from common.db.search import generate_sql_by_query_dict
|
||||||
from common.db.sql_execute import select_list
|
from common.db.sql_execute import select_list
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
@ -33,14 +32,6 @@ class PGVector(BaseVectorStore):
|
|||||||
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
||||||
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
|
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
|
||||||
|
|
||||||
def embed_documents(self, text_list: List[str]):
|
|
||||||
embedding = EmbeddingModel.get_embedding_model()
|
|
||||||
return embedding.embed_documents(text_list)
|
|
||||||
|
|
||||||
def embed_query(self, text: str):
|
|
||||||
embedding = EmbeddingModel.get_embedding_model()
|
|
||||||
return embedding.embed_query(text)
|
|
||||||
|
|
||||||
def vector_is_create(self) -> bool:
|
def vector_is_create(self) -> bool:
|
||||||
# 项目启动默认是创建好的 不需要再创建
|
# 项目启动默认是创建好的 不需要再创建
|
||||||
return True
|
return True
|
||||||
@ -50,7 +41,7 @@ class PGVector(BaseVectorStore):
|
|||||||
|
|
||||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||||
is_active: bool,
|
is_active: bool,
|
||||||
embedding: HuggingFaceEmbeddings):
|
embedding: Embeddings):
|
||||||
text_embedding = embedding.embed_query(text)
|
text_embedding = embedding.embed_query(text)
|
||||||
embedding = Embedding(id=uuid.uuid1(),
|
embedding = Embedding(id=uuid.uuid1(),
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
@ -64,7 +55,7 @@ class PGVector(BaseVectorStore):
|
|||||||
embedding.save()
|
embedding.save()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
|
def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
|
||||||
texts = [row.get('text') for row in text_list]
|
texts = [row.get('text') for row in text_list]
|
||||||
embeddings = embedding.embed_documents(texts)
|
embeddings = embedding.embed_documents(texts)
|
||||||
embedding_list = [Embedding(id=uuid.uuid1(),
|
embedding_list = [Embedding(id=uuid.uuid1(),
|
||||||
@ -83,7 +74,7 @@ class PGVector(BaseVectorStore):
|
|||||||
def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
|
def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
|
||||||
similarity: float,
|
similarity: float,
|
||||||
search_mode: SearchMode,
|
search_mode: SearchMode,
|
||||||
embedding: HuggingFaceEmbeddings):
|
embedding: Embeddings):
|
||||||
if dataset_id_list is None or len(dataset_id_list) == 0:
|
if dataset_id_list is None or len(dataset_id_list) == 0:
|
||||||
return []
|
return []
|
||||||
exclude_dict = {}
|
exclude_dict = {}
|
||||||
|
|||||||
@ -6,3 +6,85 @@
|
|||||||
@date:2023/10/31 17:16
|
@date:2023/10/31 17:16
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
|
import json
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from common.util.rsa_util import rsa_long_decrypt
|
||||||
|
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_(provider, model_type, model_name, credential):
|
||||||
|
"""
|
||||||
|
获取模型实例
|
||||||
|
@param provider: 供应商
|
||||||
|
@param model_type: 模型类型
|
||||||
|
@param model_name: 模型名称
|
||||||
|
@param credential: 认证信息
|
||||||
|
@return: 模型实例
|
||||||
|
"""
|
||||||
|
model = get_provider(provider).get_model(model_type, model_name,
|
||||||
|
json.loads(
|
||||||
|
rsa_long_decrypt(credential)),
|
||||||
|
streaming=True)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(model):
|
||||||
|
"""
|
||||||
|
获取模型实例
|
||||||
|
@param model: model 数据库Model实例对象
|
||||||
|
@return: 模型实例
|
||||||
|
"""
|
||||||
|
return get_model_(model.provider, model.model_type, model.model_name, model.credential)
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider(provider):
|
||||||
|
"""
|
||||||
|
获取供应商实例
|
||||||
|
@param provider: 供应商字符串
|
||||||
|
@return: 供应商实例
|
||||||
|
"""
|
||||||
|
return ModelProvideConstants[provider].value
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_list(provider, model_type):
|
||||||
|
"""
|
||||||
|
获取模型列表
|
||||||
|
@param provider: 供应商字符串
|
||||||
|
@param model_type: 模型类型
|
||||||
|
@return: 模型列表
|
||||||
|
"""
|
||||||
|
return get_provider(provider).get_model_list(model_type)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_credential(provider, model_type, model_name):
|
||||||
|
"""
|
||||||
|
获取模型认证实例
|
||||||
|
@param provider: 供应商字符串
|
||||||
|
@param model_type: 模型类型
|
||||||
|
@param model_name: 模型名称
|
||||||
|
@return: 认证实例对象
|
||||||
|
"""
|
||||||
|
return get_provider(provider).get_model_credential(model_type, model_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_type_list(provider):
|
||||||
|
"""
|
||||||
|
获取模型类型列表
|
||||||
|
@param provider: 供应商字符串
|
||||||
|
@return: 模型类型列表
|
||||||
|
"""
|
||||||
|
return get_provider(provider).get_model_type_list()
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||||
|
"""
|
||||||
|
校验模型认证参数
|
||||||
|
@param provider: 供应商字符串
|
||||||
|
@param model_type: 模型类型
|
||||||
|
@param model_name: 模型名称
|
||||||
|
@param model_credential: 模型认证数据
|
||||||
|
@param raise_exception: 是否抛出错误
|
||||||
|
@return: True|False
|
||||||
|
"""
|
||||||
|
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception)
|
||||||
|
|||||||
@ -104,9 +104,6 @@ CACHES = {
|
|||||||
"token_cache": {
|
"token_cache": {
|
||||||
'BACKEND': 'common.cache.file_cache.FileCache',
|
'BACKEND': 'common.cache.file_cache.FileCache',
|
||||||
'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径
|
'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径
|
||||||
},
|
|
||||||
"chat_cache": {
|
|
||||||
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user