feat: 支持向量模型
This commit is contained in:
parent
ead263da22
commit
b14a799350
@ -43,6 +43,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||||||
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||||||
message="类型只支持register|reset_password", code=500)
|
message="类型只支持register|reset_password", code=500)
|
||||||
], error_messages=ErrMessage.char("检索模式"))
|
], error_messages=ErrMessage.char("检索模式"))
|
||||||
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||||
|
|
||||||
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
|
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
|
||||||
return self.InstanceSerializer
|
return self.InstanceSerializer
|
||||||
@ -56,6 +57,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||||
search_mode: str = None,
|
search_mode: str = None,
|
||||||
|
user_id=None,
|
||||||
**kwargs) -> List[ParagraphPipelineModel]:
|
**kwargs) -> List[ParagraphPipelineModel]:
|
||||||
"""
|
"""
|
||||||
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
|
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
|
||||||
@ -67,6 +69,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||||||
:param exclude_paragraph_id_list: 需要排除段落id
|
:param exclude_paragraph_id_list: 需要排除段落id
|
||||||
:param padding_problem_text 补全问题
|
:param padding_problem_text 补全问题
|
||||||
:param search_mode 检索模式
|
:param search_mode 检索模式
|
||||||
|
:param user_id 用户id
|
||||||
:return: 段落列表
|
:return: 段落列表
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from django.db.models import QuerySet
|
|||||||
|
|
||||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||||
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
|
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
|
||||||
from common.config.embedding_config import VectorStore, EmbeddingModelManage
|
from common.config.embedding_config import VectorStore, ModelManage
|
||||||
from common.db.search import native_search
|
from common.db.search import native_search
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from dataset.models import Paragraph, DataSet
|
from dataset.models import Paragraph, DataSet
|
||||||
@ -23,10 +23,12 @@ from setting.models_provider import get_model
|
|||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
def get_model_by_id(_id):
|
def get_model_by_id(_id, user_id):
|
||||||
model = QuerySet(Model).filter(id=_id).first()
|
model = QuerySet(Model).filter(id=_id).first()
|
||||||
if model is None:
|
if model is None:
|
||||||
raise Exception("模型不存在")
|
raise Exception("模型不存在")
|
||||||
|
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
|
||||||
|
raise Exception(f"无权限使用此模型:{model.name}")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -44,14 +46,15 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
|
|||||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||||
search_mode: str = None,
|
search_mode: str = None,
|
||||||
|
user_id=None,
|
||||||
**kwargs) -> List[ParagraphPipelineModel]:
|
**kwargs) -> List[ParagraphPipelineModel]:
|
||||||
if len(dataset_id_list) == 0:
|
if len(dataset_id_list) == 0:
|
||||||
return []
|
return []
|
||||||
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||||
model_id = get_embedding_id(dataset_id_list)
|
model_id = get_embedding_id(dataset_id_list)
|
||||||
model = get_model_by_id(model_id)
|
model = get_model_by_id(model_id, user_id)
|
||||||
self.context['model_name'] = model.name
|
self.context['model_name'] = model.name
|
||||||
embedding_model = EmbeddingModelManage.get_model(model_id, lambda _id: get_model(model))
|
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||||
embedding_value = embedding_model.embed_query(exec_problem_text)
|
embedding_value = embedding_model.embed_query(exec_problem_text)
|
||||||
vector = VectorStore.get_embedding_vector()
|
vector = VectorStore.get_embedding_vector()
|
||||||
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
|
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from django.db.models import QuerySet
|
|||||||
|
|
||||||
from application.flow.i_step_node import NodeResult
|
from application.flow.i_step_node import NodeResult
|
||||||
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
|
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
|
||||||
from common.config.embedding_config import VectorStore, EmbeddingModelManage
|
from common.config.embedding_config import VectorStore, ModelManage
|
||||||
from common.db.search import native_search
|
from common.db.search import native_search
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from dataset.models import Document, Paragraph, DataSet
|
from dataset.models import Document, Paragraph, DataSet
|
||||||
@ -56,7 +56,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
|||||||
return get_none_result(question)
|
return get_none_result(question)
|
||||||
model_id = get_embedding_id(dataset_id_list)
|
model_id = get_embedding_id(dataset_id_list)
|
||||||
model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
||||||
embedding_model = EmbeddingModelManage.get_model(model_id, lambda _id: get_model(model))
|
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||||
embedding_value = embedding_model.embed_query(question)
|
embedding_value = embedding_model.embed_query(question)
|
||||||
vector = VectorStore.get_embedding_vector()
|
vector = VectorStore.get_embedding_vector()
|
||||||
exclude_document_id_list = [str(document.id) for document in
|
exclude_document_id_list = [str(document.id) for document in
|
||||||
|
|||||||
@ -88,7 +88,9 @@ class ChatInfo:
|
|||||||
'no_references_setting': self.application.dataset_setting.get(
|
'no_references_setting': self.application.dataset_setting.get(
|
||||||
'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else {
|
'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else {
|
||||||
'status': 'ai_questioning',
|
'status': 'ai_questioning',
|
||||||
'value': '{question}'}
|
'value': '{question}',
|
||||||
|
},
|
||||||
|
'user_id': self.application.user_id
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -29,6 +29,7 @@ from application.models.api_key_model import ApplicationAccessToken
|
|||||||
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
|
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
|
||||||
ModelSettingSerializer
|
ModelSettingSerializer
|
||||||
from application.serializers.chat_message_serializers import ChatInfo
|
from application.serializers.chat_message_serializers import ChatInfo
|
||||||
|
from common.config.embedding_config import ModelManage
|
||||||
from common.constants.permission_constants import RoleConstants
|
from common.constants.permission_constants import RoleConstants
|
||||||
from common.db.search import native_search, native_page_search, page_search, get_dynamics_model
|
from common.db.search import native_search, native_page_search, page_search, get_dynamics_model
|
||||||
from common.event import ListenerManagement
|
from common.event import ListenerManagement
|
||||||
@ -42,6 +43,7 @@ from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
|
|||||||
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
|
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
|
||||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
||||||
from setting.models import Model
|
from setting.models import Model
|
||||||
|
from setting.models_provider import get_model
|
||||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
@ -242,12 +244,7 @@ class ChatSerializers(serializers.Serializer):
|
|||||||
application_id=application_id)]
|
application_id=application_id)]
|
||||||
chat_model = None
|
chat_model = None
|
||||||
if model is not None:
|
if model is not None:
|
||||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
chat_model = ModelManage.get_model(str(model.id), lambda _id: get_model(model))
|
||||||
json.loads(
|
|
||||||
rsa_long_decrypt(
|
|
||||||
model.credential)),
|
|
||||||
streaming=True)
|
|
||||||
|
|
||||||
chat_id = str(uuid.uuid1())
|
chat_id = str(uuid.uuid1())
|
||||||
chat_cache.set(chat_id,
|
chat_cache.set(chat_id,
|
||||||
ChatInfo(chat_id, chat_model, dataset_id_list,
|
ChatInfo(chat_id, chat_model, dataset_id_list,
|
||||||
|
|||||||
@ -11,26 +11,31 @@ import time
|
|||||||
from common.cache.mem_cache import MemCache
|
from common.cache.mem_cache import MemCache
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingModelManage:
|
class ModelManage:
|
||||||
cache = MemCache('model', {})
|
cache = MemCache('model', {})
|
||||||
up_clear_time = time.time()
|
up_clear_time = time.time()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model(_id, get_model):
|
def get_model(_id, get_model):
|
||||||
model_instance = EmbeddingModelManage.cache.get(_id)
|
model_instance = ModelManage.cache.get(_id)
|
||||||
if model_instance is None:
|
if model_instance is None:
|
||||||
model_instance = get_model(_id)
|
model_instance = get_model(_id)
|
||||||
EmbeddingModelManage.cache.set(_id, model_instance, timeout=60 * 30)
|
ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
|
||||||
return model_instance
|
return model_instance
|
||||||
# 续期
|
# 续期
|
||||||
EmbeddingModelManage.cache.touch(_id, timeout=60 * 30)
|
ModelManage.cache.touch(_id, timeout=60 * 30)
|
||||||
EmbeddingModelManage.clear_timeout_cache()
|
ModelManage.clear_timeout_cache()
|
||||||
return model_instance
|
return model_instance
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def clear_timeout_cache():
|
def clear_timeout_cache():
|
||||||
if time.time() - EmbeddingModelManage.up_clear_time > 60:
|
if time.time() - ModelManage.up_clear_time > 60:
|
||||||
EmbeddingModelManage.cache.clear_timeout_data()
|
ModelManage.cache.clear_timeout_data()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_key(_id):
|
||||||
|
if ModelManage.cache.has_key(_id):
|
||||||
|
ModelManage.cache.delete(_id)
|
||||||
|
|
||||||
|
|
||||||
class VectorStore:
|
class VectorStore:
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from django.db.models import QuerySet
|
|||||||
from drf_yasg import openapi
|
from drf_yasg import openapi
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from common.config.embedding_config import EmbeddingModelManage
|
from common.config.embedding_config import ModelManage
|
||||||
from common.db.search import native_search
|
from common.db.search import native_search
|
||||||
from common.db.sql_execute import update_execute
|
from common.db.sql_execute import update_execute
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
@ -140,14 +140,14 @@ def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
|
|||||||
raise Exception("知识库未向量模型不一致")
|
raise Exception("知识库未向量模型不一致")
|
||||||
if len(dataset_list) == 0:
|
if len(dataset_list) == 0:
|
||||||
raise Exception("知识库设置错误,请重新设置知识库")
|
raise Exception("知识库设置错误,请重新设置知识库")
|
||||||
return EmbeddingModelManage.get_model(str(dataset_list[0].id),
|
return ModelManage.get_model(str(dataset_list[0].id),
|
||||||
lambda _id: get_model(dataset_list[0].embedding_mode))
|
lambda _id: get_model(dataset_list[0].embedding_mode))
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model_by_dataset_id(dataset_id: str):
|
def get_embedding_model_by_dataset_id(dataset_id: str):
|
||||||
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
|
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
|
||||||
return EmbeddingModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode))
|
return ModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode))
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model_by_dataset(dataset):
|
def get_embedding_model_by_dataset(dataset):
|
||||||
return EmbeddingModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))
|
return ModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user