feat: 模型管理支持向量模型,知识库可以关联向量模型
feat: 模型管理支持向量模型,知识库可以关联向量模型
This commit is contained in:
commit
d3d09b10ec
@ -43,6 +43,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||||||
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||||||
message="类型只支持register|reset_password", code=500)
|
message="类型只支持register|reset_password", code=500)
|
||||||
], error_messages=ErrMessage.char("检索模式"))
|
], error_messages=ErrMessage.char("检索模式"))
|
||||||
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||||
|
|
||||||
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
|
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
|
||||||
return self.InstanceSerializer
|
return self.InstanceSerializer
|
||||||
@ -56,6 +57,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||||
search_mode: str = None,
|
search_mode: str = None,
|
||||||
|
user_id=None,
|
||||||
**kwargs) -> List[ParagraphPipelineModel]:
|
**kwargs) -> List[ParagraphPipelineModel]:
|
||||||
"""
|
"""
|
||||||
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
|
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
|
||||||
@ -67,6 +69,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||||||
:param exclude_paragraph_id_list: 需要排除段落id
|
:param exclude_paragraph_id_list: 需要排除段落id
|
||||||
:param padding_problem_text 补全问题
|
:param padding_problem_text 补全问题
|
||||||
:param search_mode 检索模式
|
:param search_mode 检索模式
|
||||||
|
:param user_id 用户id
|
||||||
:return: 段落列表
|
:return: 段落列表
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -13,22 +13,48 @@ from django.db.models import QuerySet
|
|||||||
|
|
||||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||||
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
|
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
|
||||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
from common.config.embedding_config import VectorStore, ModelManage
|
||||||
from common.db.search import native_search
|
from common.db.search import native_search
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from dataset.models import Paragraph
|
from dataset.models import Paragraph, DataSet
|
||||||
from embedding.models import SearchMode
|
from embedding.models import SearchMode
|
||||||
|
from setting.models import Model
|
||||||
|
from setting.models_provider import get_model
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_by_id(_id, user_id):
|
||||||
|
model = QuerySet(Model).filter(id=_id).first()
|
||||||
|
if model is None:
|
||||||
|
raise Exception("模型不存在")
|
||||||
|
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
|
||||||
|
raise Exception(f"无权限使用此模型:{model.name}")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_id(dataset_id_list):
|
||||||
|
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||||
|
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||||
|
raise Exception("知识库未向量模型不一致")
|
||||||
|
if len(dataset_list) == 0:
|
||||||
|
raise Exception("知识库设置错误,请重新设置知识库")
|
||||||
|
return dataset_list[0].embedding_mode_id
|
||||||
|
|
||||||
|
|
||||||
class BaseSearchDatasetStep(ISearchDatasetStep):
|
class BaseSearchDatasetStep(ISearchDatasetStep):
|
||||||
|
|
||||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||||
search_mode: str = None,
|
search_mode: str = None,
|
||||||
|
user_id=None,
|
||||||
**kwargs) -> List[ParagraphPipelineModel]:
|
**kwargs) -> List[ParagraphPipelineModel]:
|
||||||
|
if len(dataset_id_list) == 0:
|
||||||
|
return []
|
||||||
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||||
embedding_model = EmbeddingModel.get_embedding_model()
|
model_id = get_embedding_id(dataset_id_list)
|
||||||
|
model = get_model_by_id(model_id, user_id)
|
||||||
|
self.context['model_name'] = model.name
|
||||||
|
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||||
embedding_value = embedding_model.embed_query(exec_problem_text)
|
embedding_value = embedding_model.embed_query(exec_problem_text)
|
||||||
vector = VectorStore.get_embedding_vector()
|
vector = VectorStore.get_embedding_vector()
|
||||||
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
|
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
|
||||||
@ -101,7 +127,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
|
|||||||
'run_time': self.context['run_time'],
|
'run_time': self.context['run_time'],
|
||||||
'problem_text': step_args.get(
|
'problem_text': step_args.get(
|
||||||
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
|
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
|
||||||
'model_name': EmbeddingModel.get_embedding_model().model_name,
|
'model_name': self.context.get('model_name'),
|
||||||
'message_tokens': 0,
|
'message_tokens': 0,
|
||||||
'answer_tokens': 0,
|
'answer_tokens': 0,
|
||||||
'cost': 0
|
'cost': 0
|
||||||
|
|||||||
@ -111,6 +111,7 @@ class FlowParamsSerializer(serializers.Serializer):
|
|||||||
|
|
||||||
client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
|
client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
|
||||||
|
|
||||||
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id"))
|
||||||
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案"))
|
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案"))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -13,20 +13,50 @@ from django.db.models import QuerySet
|
|||||||
|
|
||||||
from application.flow.i_step_node import NodeResult
|
from application.flow.i_step_node import NodeResult
|
||||||
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
|
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
|
||||||
from common.config.embedding_config import EmbeddingModel, VectorStore
|
from common.config.embedding_config import VectorStore, ModelManage
|
||||||
from common.db.search import native_search
|
from common.db.search import native_search
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from dataset.models import Document, Paragraph
|
from dataset.models import Document, Paragraph, DataSet
|
||||||
from embedding.models import SearchMode
|
from embedding.models import SearchMode
|
||||||
|
from setting.models import Model
|
||||||
|
from setting.models_provider import get_model
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_by_id(_id, user_id):
|
||||||
|
model = QuerySet(Model).filter(id=_id).first()
|
||||||
|
if model is None:
|
||||||
|
raise Exception("模型不存在")
|
||||||
|
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
|
||||||
|
raise Exception(f"无权限使用此模型:{model.name}")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_id(dataset_id_list):
|
||||||
|
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||||
|
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||||
|
raise Exception("知识库未向量模型不一致")
|
||||||
|
if len(dataset_list) == 0:
|
||||||
|
raise Exception("知识库设置错误,请重新设置知识库")
|
||||||
|
return dataset_list[0].embedding_mode_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_none_result(question):
|
||||||
|
return NodeResult(
|
||||||
|
{'paragraph_list': [], 'is_hit_handling_method': [], 'question': question, 'data': '',
|
||||||
|
'directly_return': ''}, {})
|
||||||
|
|
||||||
|
|
||||||
class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
||||||
def execute(self, dataset_id_list, dataset_setting, question,
|
def execute(self, dataset_id_list, dataset_setting, question,
|
||||||
exclude_paragraph_id_list=None,
|
exclude_paragraph_id_list=None,
|
||||||
**kwargs) -> NodeResult:
|
**kwargs) -> NodeResult:
|
||||||
self.context['question'] = question
|
self.context['question'] = question
|
||||||
embedding_model = EmbeddingModel.get_embedding_model()
|
if len(dataset_id_list) == 0:
|
||||||
|
return get_none_result(question)
|
||||||
|
model_id = get_embedding_id(dataset_id_list)
|
||||||
|
model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
||||||
|
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||||
embedding_value = embedding_model.embed_query(question)
|
embedding_value = embedding_model.embed_query(question)
|
||||||
vector = VectorStore.get_embedding_vector()
|
vector = VectorStore.get_embedding_vector()
|
||||||
exclude_document_id_list = [str(document.id) for document in
|
exclude_document_id_list = [str(document.id) for document in
|
||||||
@ -37,7 +67,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
|||||||
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
|
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
|
||||||
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
|
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
|
||||||
if embedding_list is None:
|
if embedding_list is None:
|
||||||
return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {})
|
return get_none_result(question)
|
||||||
paragraph_list = self.list_paragraph(embedding_list, vector)
|
paragraph_list = self.list_paragraph(embedding_list, vector)
|
||||||
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
|
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
|
||||||
result = sorted(result, key=lambda p: p.get('similarity'), reverse=True)
|
result = sorted(result, key=lambda p: p.get('similarity'), reverse=True)
|
||||||
|
|||||||
19
apps/application/migrations/0010_alter_chatrecord_details.py
Normal file
19
apps/application/migrations/0010_alter_chatrecord_details.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# Generated by Django 4.2.13 on 2024-07-15 15:52
|
||||||
|
|
||||||
|
import application.models.application
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('application', '0009_application_type_application_work_flow_and_more'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name='chatrecord',
|
||||||
|
name='details',
|
||||||
|
field=models.JSONField(default=dict, encoder=application.models.application.DateEncoder, verbose_name='对话详情'),
|
||||||
|
),
|
||||||
|
]
|
||||||
@ -26,7 +26,7 @@ from rest_framework import serializers
|
|||||||
from application.flow.workflow_manage import Flow
|
from application.flow.workflow_manage import Flow
|
||||||
from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion
|
from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion
|
||||||
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
|
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
|
||||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
from common.config.embedding_config import VectorStore
|
||||||
from common.constants.authentication_type import AuthenticationType
|
from common.constants.authentication_type import AuthenticationType
|
||||||
from common.db.search import get_dynamics_model, native_search, native_page_search
|
from common.db.search import get_dynamics_model, native_search, native_page_search
|
||||||
from common.db.sql_execute import select_list
|
from common.db.sql_execute import select_list
|
||||||
@ -36,7 +36,7 @@ from common.util.common import valid_license
|
|||||||
from common.util.field_message import ErrMessage
|
from common.util.field_message import ErrMessage
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from dataset.models import DataSet, Document, Image
|
from dataset.models import DataSet, Document, Image
|
||||||
from dataset.serializers.common_serializers import list_paragraph
|
from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list
|
||||||
from embedding.models import SearchMode
|
from embedding.models import SearchMode
|
||||||
from setting.models import AuthOperate
|
from setting.models import AuthOperate
|
||||||
from setting.models.model_management import Model
|
from setting.models.model_management import Model
|
||||||
@ -415,12 +415,13 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
QuerySet(Document).filter(
|
QuerySet(Document).filter(
|
||||||
dataset_id__in=dataset_id_list,
|
dataset_id__in=dataset_id_list,
|
||||||
is_active=False)]
|
is_active=False)]
|
||||||
|
model = get_embedding_model_by_dataset_id_list(dataset_id_list)
|
||||||
# 向量库检索
|
# 向量库检索
|
||||||
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
|
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
|
||||||
self.data.get('top_number'),
|
self.data.get('top_number'),
|
||||||
self.data.get('similarity'),
|
self.data.get('similarity'),
|
||||||
SearchMode(self.data.get('search_mode')),
|
SearchMode(self.data.get('search_mode')),
|
||||||
EmbeddingModel.get_embedding_model())
|
model)
|
||||||
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
||||||
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
||||||
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
|
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
|
||||||
@ -522,12 +523,14 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
if not QuerySet(Application).filter(id=self.data.get('application_id')).exists():
|
if not QuerySet(Application).filter(id=self.data.get('application_id')).exists():
|
||||||
raise AppApiException(500, '不存在的应用id')
|
raise AppApiException(500, '不存在的应用id')
|
||||||
|
|
||||||
def list_model(self, with_valid=True):
|
def list_model(self, model_type=None, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid()
|
self.is_valid()
|
||||||
|
if model_type is None:
|
||||||
|
model_type = "LLM"
|
||||||
application = QuerySet(Application).filter(id=self.data.get("application_id")).first()
|
application = QuerySet(Application).filter(id=self.data.get("application_id")).first()
|
||||||
return ModelSerializer.Query(
|
return ModelSerializer.Query(
|
||||||
data={'user_id': application.user_id}).list(
|
data={'user_id': application.user_id, 'model_type': model_type}).list(
|
||||||
with_valid=True)
|
with_valid=True)
|
||||||
|
|
||||||
def delete(self, with_valid=True):
|
def delete(self, with_valid=True):
|
||||||
|
|||||||
@ -88,7 +88,9 @@ class ChatInfo:
|
|||||||
'no_references_setting': self.application.dataset_setting.get(
|
'no_references_setting': self.application.dataset_setting.get(
|
||||||
'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else {
|
'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else {
|
||||||
'status': 'ai_questioning',
|
'status': 'ai_questioning',
|
||||||
'value': '{question}'}
|
'value': '{question}',
|
||||||
|
},
|
||||||
|
'user_id': self.application.user_id
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -221,11 +223,13 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||||||
stream = self.data.get('stream')
|
stream = self.data.get('stream')
|
||||||
client_id = self.data.get('client_id')
|
client_id = self.data.get('client_id')
|
||||||
client_type = self.data.get('client_type')
|
client_type = self.data.get('client_type')
|
||||||
|
user_id = chat_info.application.user_id
|
||||||
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
|
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
|
||||||
{'history_chat_record': chat_info.chat_record_list, 'question': message,
|
{'history_chat_record': chat_info.chat_record_list, 'question': message,
|
||||||
'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()),
|
'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()),
|
||||||
'stream': stream,
|
'stream': stream,
|
||||||
're_chat': re_chat}, WorkFlowPostHandler(chat_info, client_id, client_type))
|
're_chat': re_chat,
|
||||||
|
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type))
|
||||||
r = work_flow_manage.run()
|
r = work_flow_manage.run()
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|||||||
@ -29,6 +29,7 @@ from application.models.api_key_model import ApplicationAccessToken
|
|||||||
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
|
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
|
||||||
ModelSettingSerializer
|
ModelSettingSerializer
|
||||||
from application.serializers.chat_message_serializers import ChatInfo
|
from application.serializers.chat_message_serializers import ChatInfo
|
||||||
|
from common.config.embedding_config import ModelManage
|
||||||
from common.constants.permission_constants import RoleConstants
|
from common.constants.permission_constants import RoleConstants
|
||||||
from common.db.search import native_search, native_page_search, page_search, get_dynamics_model
|
from common.db.search import native_search, native_page_search, page_search, get_dynamics_model
|
||||||
from common.event import ListenerManagement
|
from common.event import ListenerManagement
|
||||||
@ -39,8 +40,10 @@ from common.util.file_util import get_file_content
|
|||||||
from common.util.lock import try_lock, un_lock
|
from common.util.lock import try_lock, un_lock
|
||||||
from common.util.rsa_util import rsa_long_decrypt
|
from common.util.rsa_util import rsa_long_decrypt
|
||||||
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
|
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
|
||||||
|
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
|
||||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
||||||
from setting.models import Model
|
from setting.models import Model
|
||||||
|
from setting.models_provider import get_model
|
||||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
@ -241,12 +244,7 @@ class ChatSerializers(serializers.Serializer):
|
|||||||
application_id=application_id)]
|
application_id=application_id)]
|
||||||
chat_model = None
|
chat_model = None
|
||||||
if model is not None:
|
if model is not None:
|
||||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
chat_model = ModelManage.get_model(str(model.id), lambda _id: get_model(model))
|
||||||
json.loads(
|
|
||||||
rsa_long_decrypt(
|
|
||||||
model.credential)),
|
|
||||||
streaming=True)
|
|
||||||
|
|
||||||
chat_id = str(uuid.uuid1())
|
chat_id = str(uuid.uuid1())
|
||||||
chat_cache.set(chat_id,
|
chat_cache.set(chat_id,
|
||||||
ChatInfo(chat_id, chat_model, dataset_id_list,
|
ChatInfo(chat_id, chat_model, dataset_id_list,
|
||||||
@ -259,6 +257,7 @@ class ChatSerializers(serializers.Serializer):
|
|||||||
|
|
||||||
class OpenWorkFlowChat(serializers.Serializer):
|
class OpenWorkFlowChat(serializers.Serializer):
|
||||||
work_flow = WorkFlowSerializers(error_messages=ErrMessage.uuid("工作流"))
|
work_flow = WorkFlowSerializers(error_messages=ErrMessage.uuid("工作流"))
|
||||||
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
@ -269,7 +268,8 @@ class ChatSerializers(serializers.Serializer):
|
|||||||
dataset_setting={},
|
dataset_setting={},
|
||||||
model_setting={},
|
model_setting={},
|
||||||
problem_optimization=None,
|
problem_optimization=None,
|
||||||
type=ApplicationTypeChoices.WORK_FLOW
|
type=ApplicationTypeChoices.WORK_FLOW,
|
||||||
|
user_id=self.data.get('user_id')
|
||||||
)
|
)
|
||||||
work_flow_version = WorkFlowVersion(work_flow=work_flow)
|
work_flow_version = WorkFlowVersion(work_flow=work_flow)
|
||||||
chat_cache.set(chat_id,
|
chat_cache.set(chat_id,
|
||||||
@ -332,7 +332,8 @@ class ChatSerializers(serializers.Serializer):
|
|||||||
application = Application(id=None, dialogue_number=3, model=model,
|
application = Application(id=None, dialogue_number=3, model=model,
|
||||||
dataset_setting=self.data.get('dataset_setting'),
|
dataset_setting=self.data.get('dataset_setting'),
|
||||||
model_setting=self.data.get('model_setting'),
|
model_setting=self.data.get('model_setting'),
|
||||||
problem_optimization=self.data.get('problem_optimization'))
|
problem_optimization=self.data.get('problem_optimization'),
|
||||||
|
user_id=user_id)
|
||||||
chat_cache.set(chat_id,
|
chat_cache.set(chat_id,
|
||||||
ChatInfo(chat_id, chat_model, dataset_id_list,
|
ChatInfo(chat_id, chat_model, dataset_id_list,
|
||||||
[str(document.id) for document in
|
[str(document.id) for document in
|
||||||
@ -533,9 +534,10 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||||||
raise AppApiException(500, "文档id不正确")
|
raise AppApiException(500, "文档id不正确")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def post_embedding_paragraph(chat_record, paragraph_id):
|
def post_embedding_paragraph(chat_record, paragraph_id, dataset_id):
|
||||||
|
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||||
# 发送向量化事件
|
# 发送向量化事件
|
||||||
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id)
|
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id, embedding_model=model)
|
||||||
return chat_record
|
return chat_record
|
||||||
|
|
||||||
@post(post_function=post_embedding_paragraph)
|
@post(post_function=post_embedding_paragraph)
|
||||||
@ -573,7 +575,7 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||||||
chat_record.improve_paragraph_id_list.append(paragraph.id)
|
chat_record.improve_paragraph_id_list.append(paragraph.id)
|
||||||
# 添加标注
|
# 添加标注
|
||||||
chat_record.save()
|
chat_record.save()
|
||||||
return ChatRecordSerializerModel(chat_record).data, paragraph.id
|
return ChatRecordSerializerModel(chat_record).data, paragraph.id, dataset_id
|
||||||
|
|
||||||
class Operate(serializers.Serializer):
|
class Operate(serializers.Serializer):
|
||||||
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
|
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
|
||||||
|
|||||||
@ -67,6 +67,20 @@ class ApplicationApi(ApiMixin):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class Model(ApiMixin):
|
||||||
|
@staticmethod
|
||||||
|
def get_request_params_api():
|
||||||
|
return [openapi.Parameter(name='application_id',
|
||||||
|
in_=openapi.IN_PATH,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
required=True,
|
||||||
|
description='应用id'),
|
||||||
|
openapi.Parameter(name='model_type', in_=openapi.IN_QUERY,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
required=False,
|
||||||
|
description='模型类型'),
|
||||||
|
]
|
||||||
|
|
||||||
class ApiKey(ApiMixin):
|
class ApiKey(ApiMixin):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_request_params_api():
|
def get_request_params_api():
|
||||||
|
|||||||
@ -175,7 +175,7 @@ class Application(APIView):
|
|||||||
@swagger_auto_schema(operation_summary="获取模型列表",
|
@swagger_auto_schema(operation_summary="获取模型列表",
|
||||||
operation_id="获取模型列表",
|
operation_id="获取模型列表",
|
||||||
tags=["应用"],
|
tags=["应用"],
|
||||||
manual_parameters=ApplicationApi.ApiKey.get_request_params_api())
|
manual_parameters=ApplicationApi.Model.get_request_params_api())
|
||||||
@has_permissions(ViewPermission(
|
@has_permissions(ViewPermission(
|
||||||
[RoleConstants.ADMIN, RoleConstants.USER],
|
[RoleConstants.ADMIN, RoleConstants.USER],
|
||||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||||
@ -185,7 +185,7 @@ class Application(APIView):
|
|||||||
return result.success(
|
return result.success(
|
||||||
ApplicationSerializer.Operate(
|
ApplicationSerializer.Operate(
|
||||||
data={'application_id': application_id,
|
data={'application_id': application_id,
|
||||||
'user_id': request.user.id}).list_model())
|
'user_id': request.user.id}).list_model(request.query_params.get('model_type')))
|
||||||
|
|
||||||
class Profile(APIView):
|
class Profile(APIView):
|
||||||
authentication_classes = [TokenAuth]
|
authentication_classes = [TokenAuth]
|
||||||
|
|||||||
4
apps/common/cache/mem_cache.py
vendored
4
apps/common/cache/mem_cache.py
vendored
@ -41,3 +41,7 @@ class MemCache(LocMemCache):
|
|||||||
delete_keys.append(key)
|
delete_keys.append(key)
|
||||||
for key in delete_keys:
|
for key in delete_keys:
|
||||||
self._delete(key)
|
self._delete(key)
|
||||||
|
|
||||||
|
def clear_timeout_data(self):
|
||||||
|
for key in self._cache.keys():
|
||||||
|
self.get(key)
|
||||||
|
|||||||
@ -6,33 +6,36 @@
|
|||||||
@date:2023/10/23 16:03
|
@date:2023/10/23 16:03
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
|
import time
|
||||||
|
|
||||||
from smartdoc.const import CONFIG
|
from common.cache.mem_cache import MemCache
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingModel:
|
class ModelManage:
|
||||||
instance = None
|
cache = MemCache('model', {})
|
||||||
|
up_clear_time = time.time()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_embedding_model():
|
def get_model(_id, get_model):
|
||||||
"""
|
model_instance = ModelManage.cache.get(_id)
|
||||||
获取向量化模型
|
if model_instance is None:
|
||||||
:return:
|
model_instance = get_model(_id)
|
||||||
"""
|
ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
|
||||||
if EmbeddingModel.instance is None:
|
return model_instance
|
||||||
model_name = CONFIG.get('EMBEDDING_MODEL_NAME')
|
# 续期
|
||||||
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH')
|
ModelManage.cache.touch(_id, timeout=60 * 30)
|
||||||
device = CONFIG.get('EMBEDDING_DEVICE')
|
ModelManage.clear_timeout_cache()
|
||||||
encode_kwargs = {'normalize_embeddings': True}
|
return model_instance
|
||||||
e = HuggingFaceEmbeddings(
|
|
||||||
model_name=model_name,
|
@staticmethod
|
||||||
cache_folder=cache_folder,
|
def clear_timeout_cache():
|
||||||
model_kwargs={'device': device},
|
if time.time() - ModelManage.up_clear_time > 60:
|
||||||
encode_kwargs=encode_kwargs,
|
ModelManage.cache.clear_timeout_data()
|
||||||
)
|
|
||||||
EmbeddingModel.instance = e
|
@staticmethod
|
||||||
return EmbeddingModel.instance
|
def delete_key(_id):
|
||||||
|
if ModelManage.cache.has_key(_id):
|
||||||
|
ModelManage.cache.delete(_id)
|
||||||
|
|
||||||
|
|
||||||
class VectorStore:
|
class VectorStore:
|
||||||
|
|||||||
@ -14,14 +14,14 @@ embedding_thread_pool = ThreadPoolExecutor(3)
|
|||||||
|
|
||||||
|
|
||||||
def poxy(poxy_function):
|
def poxy(poxy_function):
|
||||||
def inner(args):
|
def inner(args, **keywords):
|
||||||
work_thread_pool.submit(poxy_function, args)
|
work_thread_pool.submit(poxy_function, args, **keywords)
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
def embedding_poxy(poxy_function):
|
def embedding_poxy(poxy_function):
|
||||||
def inner(args):
|
def inner(args, **keywords):
|
||||||
embedding_thread_pool.submit(poxy_function, args)
|
embedding_thread_pool.submit(poxy_function, args, **keywords)
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|||||||
@ -15,8 +15,9 @@ from typing import List
|
|||||||
import django.db.models
|
import django.db.models
|
||||||
from blinker import signal
|
from blinker import signal
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
from common.config.embedding_config import VectorStore
|
||||||
from common.db.search import native_search, get_dynamics_model
|
from common.db.search import native_search, get_dynamics_model
|
||||||
from common.event.common import poxy, embedding_poxy
|
from common.event.common import poxy, embedding_poxy
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
@ -46,22 +47,26 @@ class SyncWebDocumentArgs:
|
|||||||
|
|
||||||
|
|
||||||
class UpdateProblemArgs:
|
class UpdateProblemArgs:
|
||||||
def __init__(self, problem_id: str, problem_content: str):
|
def __init__(self, problem_id: str, problem_content: str, embedding_model: Embeddings):
|
||||||
self.problem_id = problem_id
|
self.problem_id = problem_id
|
||||||
self.problem_content = problem_content
|
self.problem_content = problem_content
|
||||||
|
self.embedding_model = embedding_model
|
||||||
|
|
||||||
|
|
||||||
class UpdateEmbeddingDatasetIdArgs:
|
class UpdateEmbeddingDatasetIdArgs:
|
||||||
def __init__(self, paragraph_id_list: List[str], target_dataset_id: str):
|
def __init__(self, paragraph_id_list: List[str], target_dataset_id: str, target_embedding_model: Embeddings):
|
||||||
self.paragraph_id_list = paragraph_id_list
|
self.paragraph_id_list = paragraph_id_list
|
||||||
self.target_dataset_id = target_dataset_id
|
self.target_dataset_id = target_dataset_id
|
||||||
|
self.target_embedding_model = target_embedding_model
|
||||||
|
|
||||||
|
|
||||||
class UpdateEmbeddingDocumentIdArgs:
|
class UpdateEmbeddingDocumentIdArgs:
|
||||||
def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str):
|
def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str,
|
||||||
|
target_embedding_model: Embeddings = None):
|
||||||
self.paragraph_id_list = paragraph_id_list
|
self.paragraph_id_list = paragraph_id_list
|
||||||
self.target_document_id = target_document_id
|
self.target_document_id = target_document_id
|
||||||
self.target_dataset_id = target_dataset_id
|
self.target_dataset_id = target_dataset_id
|
||||||
|
self.target_embedding_model = target_embedding_model
|
||||||
|
|
||||||
|
|
||||||
class ListenerManagement:
|
class ListenerManagement:
|
||||||
@ -84,16 +89,46 @@ class ListenerManagement:
|
|||||||
delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list")
|
delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def embedding_by_problem(args):
|
def embedding_by_problem(args, embedding_model: Embeddings):
|
||||||
VectorStore.get_embedding_vector().save(**args)
|
VectorStore.get_embedding_vector().save(**args, embedding=embedding_model)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def embedding_by_paragraph_list(paragraph_id_list, embedding_model: Embeddings):
|
||||||
|
try:
|
||||||
|
data_list = native_search(
|
||||||
|
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
|
||||||
|
**{'paragraph.id__in': paragraph_id_list}),
|
||||||
|
'paragraph': QuerySet(Paragraph).filter(id__in=paragraph_id_list)},
|
||||||
|
select_string=get_file_content(
|
||||||
|
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
|
||||||
|
ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list=paragraph_id_list,
|
||||||
|
embedding_model=embedding_model)
|
||||||
|
except Exception as e:
|
||||||
|
max_kb_error.error(f'查询向量数据:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@embedding_poxy
|
@embedding_poxy
|
||||||
def embedding_by_paragraph(paragraph_id):
|
def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings):
|
||||||
|
max_kb.info(f'开始--->向量化段落:{paragraph_id_list}')
|
||||||
|
try:
|
||||||
|
# 删除段落
|
||||||
|
VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_id_list)
|
||||||
|
# 批量向量化
|
||||||
|
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
|
||||||
|
except Exception as e:
|
||||||
|
max_kb_error.error(f'向量化段落:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}')
|
||||||
|
status = Status.error
|
||||||
|
finally:
|
||||||
|
QuerySet(Paragraph).filter(id__in=paragraph_id_list).update(**{'status': status})
|
||||||
|
max_kb.info(f'结束--->向量化段落:{paragraph_id_list}')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@embedding_poxy
|
||||||
|
def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
|
||||||
"""
|
"""
|
||||||
向量化段落 根据段落id
|
向量化段落 根据段落id
|
||||||
:param paragraph_id: 段落id
|
@param paragraph_id: 段落id
|
||||||
:return: None
|
@param embedding_model: 向量模型
|
||||||
"""
|
"""
|
||||||
max_kb.info(f"开始--->向量化段落:{paragraph_id}")
|
max_kb.info(f"开始--->向量化段落:{paragraph_id}")
|
||||||
status = Status.success
|
status = Status.success
|
||||||
@ -107,7 +142,7 @@ class ListenerManagement:
|
|||||||
# 删除段落
|
# 删除段落
|
||||||
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
|
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
|
||||||
# 批量向量化
|
# 批量向量化
|
||||||
VectorStore.get_embedding_vector().batch_save(data_list)
|
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
|
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||||
status = Status.error
|
status = Status.error
|
||||||
@ -117,12 +152,15 @@ class ListenerManagement:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@embedding_poxy
|
@embedding_poxy
|
||||||
def embedding_by_document(document_id):
|
def embedding_by_document(document_id, embedding_model: Embeddings):
|
||||||
"""
|
"""
|
||||||
向量化文档
|
向量化文档
|
||||||
:param document_id: 文档id
|
@param document_id: 文档id
|
||||||
|
@param embedding_model 向量模型
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
|
if not try_lock('embedding' + str(document_id)):
|
||||||
|
return
|
||||||
max_kb.info(f"开始--->向量化文档:{document_id}")
|
max_kb.info(f"开始--->向量化文档:{document_id}")
|
||||||
QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding})
|
QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding})
|
||||||
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.embedding})
|
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.embedding})
|
||||||
@ -138,7 +176,7 @@ class ListenerManagement:
|
|||||||
# 删除文档向量数据
|
# 删除文档向量数据
|
||||||
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
|
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
|
||||||
# 批量向量化
|
# 批量向量化
|
||||||
VectorStore.get_embedding_vector().batch_save(data_list)
|
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
|
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||||
status = Status.error
|
status = Status.error
|
||||||
@ -148,21 +186,24 @@ class ListenerManagement:
|
|||||||
**{'status': status, 'update_time': datetime.datetime.now()})
|
**{'status': status, 'update_time': datetime.datetime.now()})
|
||||||
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status})
|
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status})
|
||||||
max_kb.info(f"结束--->向量化文档:{document_id}")
|
max_kb.info(f"结束--->向量化文档:{document_id}")
|
||||||
|
un_lock('embedding' + str(document_id))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@embedding_poxy
|
@embedding_poxy
|
||||||
def embedding_by_dataset(dataset_id):
|
def embedding_by_dataset(dataset_id, embedding_model: Embeddings):
|
||||||
"""
|
"""
|
||||||
向量化知识库
|
向量化知识库
|
||||||
:param dataset_id: 知识库id
|
@param dataset_id: 知识库id
|
||||||
|
@param embedding_model 向量模型
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
max_kb.info(f"开始--->向量化数据集:{dataset_id}")
|
max_kb.info(f"开始--->向量化数据集:{dataset_id}")
|
||||||
try:
|
try:
|
||||||
|
ListenerManagement.delete_embedding_by_dataset(dataset_id)
|
||||||
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
|
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
|
||||||
max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
|
max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
|
||||||
for document in document_list:
|
for document in document_list:
|
||||||
ListenerManagement.embedding_by_document(document.id)
|
ListenerManagement.embedding_by_document(document.id, embedding_model=embedding_model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
|
max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||||
finally:
|
finally:
|
||||||
@ -224,14 +265,22 @@ class ListenerManagement:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs):
|
def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs):
|
||||||
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
|
if args.target_embedding_model is None:
|
||||||
{'dataset_id': args.target_dataset_id})
|
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
|
||||||
|
{'dataset_id': args.target_dataset_id})
|
||||||
|
else:
|
||||||
|
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
|
||||||
|
embedding_model=args.target_embedding_model)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs):
|
def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs):
|
||||||
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
|
if args.target_embedding_model is None:
|
||||||
{'document_id': args.target_document_id,
|
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
|
||||||
'dataset_id': args.target_dataset_id})
|
{'document_id': args.target_document_id,
|
||||||
|
'dataset_id': args.target_dataset_id})
|
||||||
|
else:
|
||||||
|
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
|
||||||
|
embedding_model=args.target_embedding_model)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_embedding_by_source_ids(source_ids: List[str]):
|
def delete_embedding_by_source_ids(source_ids: List[str]):
|
||||||
@ -245,11 +294,6 @@ class ListenerManagement:
|
|||||||
def delete_embedding_by_dataset_id_list(source_ids: List[str]):
|
def delete_embedding_by_dataset_id_list(source_ids: List[str]):
|
||||||
VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids)
|
VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@poxy
|
|
||||||
def init_embedding_model(ages):
|
|
||||||
EmbeddingModel.get_embedding_model()
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
# 添加向量 根据问题id
|
# 添加向量 根据问题id
|
||||||
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
|
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
|
||||||
@ -276,8 +320,7 @@ class ListenerManagement:
|
|||||||
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
|
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
|
||||||
# 启动段落向量
|
# 启动段落向量
|
||||||
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
|
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
|
||||||
# 初始化向量化模型
|
|
||||||
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)
|
|
||||||
# 同步web站点知识库
|
# 同步web站点知识库
|
||||||
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
|
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
|
||||||
# 同步web站点 文档
|
# 同步web站点 文档
|
||||||
|
|||||||
21
apps/dataset/migrations/0006_dataset_embedding_mode.py
Normal file
21
apps/dataset/migrations/0006_dataset_embedding_mode.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# Generated by Django 4.2.13 on 2024-07-17 13:56
|
||||||
|
|
||||||
|
import dataset.models.data_set
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.db.models.deletion
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('setting', '0005_model_permission_type'),
|
||||||
|
('dataset', '0005_file'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='dataset',
|
||||||
|
name='embedding_mode',
|
||||||
|
field=models.ForeignKey(default=dataset.models.data_set.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model', verbose_name='向量模型'),
|
||||||
|
),
|
||||||
|
]
|
||||||
@ -9,9 +9,11 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
from django.db.models import QuerySet
|
||||||
|
|
||||||
from common.db.sql_execute import select_one
|
from common.db.sql_execute import select_one
|
||||||
from common.mixins.app_model_mixin import AppModelMixin
|
from common.mixins.app_model_mixin import AppModelMixin
|
||||||
|
from setting.models import Model
|
||||||
from users.models import User
|
from users.models import User
|
||||||
|
|
||||||
|
|
||||||
@ -33,6 +35,10 @@ class HitHandlingMethod(models.TextChoices):
|
|||||||
directly_return = 'directly_return', '直接返回'
|
directly_return = 'directly_return', '直接返回'
|
||||||
|
|
||||||
|
|
||||||
|
def default_model():
|
||||||
|
return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab')
|
||||||
|
|
||||||
|
|
||||||
class DataSet(AppModelMixin):
|
class DataSet(AppModelMixin):
|
||||||
"""
|
"""
|
||||||
数据集表
|
数据集表
|
||||||
@ -43,7 +49,8 @@ class DataSet(AppModelMixin):
|
|||||||
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
|
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
|
||||||
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
|
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
|
||||||
default=Type.base)
|
default=Type.base)
|
||||||
|
embedding_mode = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
|
||||||
|
default=default_model)
|
||||||
meta = models.JSONField(verbose_name="元数据", default=dict)
|
meta = models.JSONField(verbose_name="元数据", default=dict)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from django.db.models import QuerySet
|
|||||||
from drf_yasg import openapi
|
from drf_yasg import openapi
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
from common.config.embedding_config import ModelManage
|
||||||
from common.db.search import native_search
|
from common.db.search import native_search
|
||||||
from common.db.sql_execute import update_execute
|
from common.db.sql_execute import update_execute
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
@ -21,7 +22,8 @@ from common.mixins.api_mixin import ApiMixin
|
|||||||
from common.util.field_message import ErrMessage
|
from common.util.field_message import ErrMessage
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from common.util.fork import Fork
|
from common.util.fork import Fork
|
||||||
from dataset.models import Paragraph, Problem, ProblemParagraphMapping
|
from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet
|
||||||
|
from setting.models_provider import get_model
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
@ -130,3 +132,22 @@ class ProblemParagraphManage:
|
|||||||
result = [problem_model for problem_model, is_create in problem_content_dict.values() if
|
result = [problem_model for problem_model, is_create in problem_content_dict.values() if
|
||||||
is_create], problem_paragraph_mapping_list
|
is_create], problem_paragraph_mapping_list
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
|
||||||
|
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||||
|
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||||
|
raise Exception("知识库未向量模型不一致")
|
||||||
|
if len(dataset_list) == 0:
|
||||||
|
raise Exception("知识库设置错误,请重新设置知识库")
|
||||||
|
return ModelManage.get_model(str(dataset_list[0].id),
|
||||||
|
lambda _id: get_model(dataset_list[0].embedding_mode))
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_model_by_dataset_id(dataset_id: str):
|
||||||
|
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
|
||||||
|
return ModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode))
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_model_by_dataset(dataset):
|
||||||
|
return ModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))
|
||||||
|
|||||||
@ -15,7 +15,6 @@ from functools import reduce
|
|||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from django.conf import settings
|
|
||||||
from django.contrib.postgres.fields import ArrayField
|
from django.contrib.postgres.fields import ArrayField
|
||||||
from django.core import validators
|
from django.core import validators
|
||||||
from django.db import transaction, models
|
from django.db import transaction, models
|
||||||
@ -25,7 +24,7 @@ from drf_yasg import openapi
|
|||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from application.models import ApplicationDatasetMapping
|
from application.models import ApplicationDatasetMapping
|
||||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
from common.config.embedding_config import VectorStore
|
||||||
from common.db.search import get_dynamics_model, native_page_search, native_search
|
from common.db.search import get_dynamics_model, native_page_search, native_search
|
||||||
from common.db.sql_execute import select_list
|
from common.db.sql_execute import select_list
|
||||||
from common.event import ListenerManagement, SyncWebDatasetArgs
|
from common.event import ListenerManagement, SyncWebDatasetArgs
|
||||||
@ -37,7 +36,8 @@ from common.util.file_util import get_file_content
|
|||||||
from common.util.fork import ChildLink, Fork
|
from common.util.fork import ChildLink, Fork
|
||||||
from common.util.split_model import get_split_model
|
from common.util.split_model import get_split_model
|
||||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
|
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
|
||||||
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage
|
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
|
||||||
|
get_embedding_model_by_dataset_id
|
||||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
||||||
from embedding.models import SearchMode
|
from embedding.models import SearchMode
|
||||||
from setting.models import AuthOperate
|
from setting.models import AuthOperate
|
||||||
@ -206,6 +206,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
max_length=256,
|
max_length=256,
|
||||||
min_length=1)
|
min_length=1)
|
||||||
|
|
||||||
|
embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型"))
|
||||||
|
|
||||||
documents = DocumentInstanceSerializer(required=False, many=True)
|
documents = DocumentInstanceSerializer(required=False, many=True)
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=False):
|
def is_valid(self, *, raise_exception=False):
|
||||||
@ -226,6 +228,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
max_length=256,
|
max_length=256,
|
||||||
min_length=1)
|
min_length=1)
|
||||||
|
|
||||||
|
embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型"))
|
||||||
|
|
||||||
file_list = serializers.ListSerializer(required=True,
|
file_list = serializers.ListSerializer(required=True,
|
||||||
error_messages=ErrMessage.list("文件列表"),
|
error_messages=ErrMessage.list("文件列表"),
|
||||||
child=serializers.FileField(required=True,
|
child=serializers.FileField(required=True,
|
||||||
@ -296,6 +300,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
min_length=1)
|
min_length=1)
|
||||||
source_url = serializers.CharField(required=True, error_messages=ErrMessage.char("Web 根地址"), )
|
source_url = serializers.CharField(required=True, error_messages=ErrMessage.char("Web 根地址"), )
|
||||||
|
|
||||||
|
embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型"))
|
||||||
|
|
||||||
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||||
error_messages=ErrMessage.char("选择器"))
|
error_messages=ErrMessage.char("选择器"))
|
||||||
|
|
||||||
@ -347,6 +353,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
properties={
|
properties={
|
||||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
|
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
|
||||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
|
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
|
||||||
|
'embedding_mode_id': openapi.Schema(type=openapi.TYPE_STRING, title="向量模型id",
|
||||||
|
description="向量模型id"),
|
||||||
'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url",
|
'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url",
|
||||||
description="web站点url"),
|
description="web站点url"),
|
||||||
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器")
|
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器")
|
||||||
@ -355,8 +363,9 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def post_embedding_dataset(document_list, dataset_id):
|
def post_embedding_dataset(document_list, dataset_id):
|
||||||
|
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||||
# 发送向量化事件
|
# 发送向量化事件
|
||||||
ListenerManagement.embedding_by_dataset_signal.send(dataset_id)
|
ListenerManagement.embedding_by_dataset_signal.send(dataset_id, embedding_model=model)
|
||||||
return document_list
|
return document_list
|
||||||
|
|
||||||
def save_qa(self, instance: Dict, with_valid=True):
|
def save_qa(self, instance: Dict, with_valid=True):
|
||||||
@ -365,7 +374,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
self.CreateQASerializers(data=instance).is_valid()
|
self.CreateQASerializers(data=instance).is_valid()
|
||||||
file_list = instance.get('file_list')
|
file_list = instance.get('file_list')
|
||||||
document_list = flat_map([DocumentSerializers.Create.parse_qa_file(file) for file in file_list])
|
document_list = flat_map([DocumentSerializers.Create.parse_qa_file(file) for file in file_list])
|
||||||
dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list}
|
dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list,
|
||||||
|
'embedding_mode_id': instance.get('embedding_mode_id')}
|
||||||
return self.save(dataset_instance, with_valid=True)
|
return self.save(dataset_instance, with_valid=True)
|
||||||
|
|
||||||
@valid_license(model=DataSet, count=50,
|
@valid_license(model=DataSet, count=50,
|
||||||
@ -381,7 +391,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
if QuerySet(DataSet).filter(user_id=user_id, name=instance.get('name')).exists():
|
if QuerySet(DataSet).filter(user_id=user_id, name=instance.get('name')).exists():
|
||||||
raise AppApiException(500, "知识库名称重复!")
|
raise AppApiException(500, "知识库名称重复!")
|
||||||
dataset = DataSet(
|
dataset = DataSet(
|
||||||
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id})
|
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
|
||||||
|
'embedding_mode_id': instance.get('embedding_mode_id')})
|
||||||
|
|
||||||
document_model_list = []
|
document_model_list = []
|
||||||
paragraph_model_list = []
|
paragraph_model_list = []
|
||||||
@ -452,7 +463,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
dataset = DataSet(
|
dataset = DataSet(
|
||||||
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
|
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
|
||||||
'type': Type.web,
|
'type': Type.web,
|
||||||
'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}})
|
'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector'),
|
||||||
|
'embedding_mode_id': instance.get('embedding_mode_id')}})
|
||||||
dataset.save()
|
dataset.save()
|
||||||
ListenerManagement.sync_web_dataset_signal.send(
|
ListenerManagement.sync_web_dataset_signal.send(
|
||||||
SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'),
|
SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'),
|
||||||
@ -500,6 +512,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
properties={
|
properties={
|
||||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
|
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
|
||||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
|
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
|
||||||
|
'embedding_mode_id': openapi.Schema(type=openapi.TYPE_STRING, title='向量模型',
|
||||||
|
description='向量模型'),
|
||||||
'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据",
|
'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据",
|
||||||
items=DocumentSerializers().Create.get_request_body_api()
|
items=DocumentSerializers().Create.get_request_body_api()
|
||||||
)
|
)
|
||||||
@ -557,12 +571,13 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
QuerySet(Document).filter(
|
QuerySet(Document).filter(
|
||||||
dataset_id=self.data.get('id'),
|
dataset_id=self.data.get('id'),
|
||||||
is_active=False)]
|
is_active=False)]
|
||||||
|
model = get_embedding_model_by_dataset_id(self.data.get('id'))
|
||||||
# 向量库检索
|
# 向量库检索
|
||||||
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list,
|
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list,
|
||||||
self.data.get('top_number'),
|
self.data.get('top_number'),
|
||||||
self.data.get('similarity'),
|
self.data.get('similarity'),
|
||||||
SearchMode(self.data.get('search_mode')),
|
SearchMode(self.data.get('search_mode')),
|
||||||
EmbeddingModel.get_embedding_model())
|
model)
|
||||||
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
||||||
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
||||||
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
|
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
|
||||||
@ -730,7 +745,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
def re_embedding(self, with_valid=True):
|
def re_embedding(self, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'))
|
model = get_embedding_model_by_dataset_id(self.data.get('id'))
|
||||||
|
ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'), embedding_model=model)
|
||||||
|
|
||||||
def list_application(self, with_valid=True):
|
def list_application(self, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
@ -769,6 +785,7 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
QuerySet(ApplicationDatasetMapping).filter(
|
QuerySet(ApplicationDatasetMapping).filter(
|
||||||
dataset_id=self.data.get('id'))]))}
|
dataset_id=self.data.get('id'))]))}
|
||||||
|
|
||||||
|
@transaction.atomic
|
||||||
def edit(self, dataset: Dict, user_id: str):
|
def edit(self, dataset: Dict, user_id: str):
|
||||||
"""
|
"""
|
||||||
修改知识库
|
修改知识库
|
||||||
@ -782,6 +799,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
raise AppApiException(500, "知识库名称重复!")
|
raise AppApiException(500, "知识库名称重复!")
|
||||||
_dataset = QuerySet(DataSet).get(id=self.data.get("id"))
|
_dataset = QuerySet(DataSet).get(id=self.data.get("id"))
|
||||||
DataSetSerializers.Edit(data=dataset).is_valid(dataset=_dataset)
|
DataSetSerializers.Edit(data=dataset).is_valid(dataset=_dataset)
|
||||||
|
if 'embedding_mode_id' in dataset:
|
||||||
|
_dataset.embedding_mode_id = dataset.get('embedding_mode_id')
|
||||||
if "name" in dataset:
|
if "name" in dataset:
|
||||||
_dataset.name = dataset.get("name")
|
_dataset.name = dataset.get("name")
|
||||||
if 'desc' in dataset:
|
if 'desc' in dataset:
|
||||||
|
|||||||
@ -41,7 +41,8 @@ from common.util.file_util import get_file_content
|
|||||||
from common.util.fork import Fork
|
from common.util.fork import Fork
|
||||||
from common.util.split_model import get_split_model
|
from common.util.split_model import get_split_model
|
||||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
|
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
|
||||||
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage
|
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \
|
||||||
|
get_embedding_model_by_dataset_id
|
||||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
@ -234,12 +235,17 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
meta={})
|
meta={})
|
||||||
else:
|
else:
|
||||||
document_list.update(dataset_id=target_dataset_id)
|
document_list.update(dataset_id=target_dataset_id)
|
||||||
# 修改向量信息
|
model = None
|
||||||
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
|
if dataset.embedding_mode_id != target_dataset.embedding_mode_id:
|
||||||
[paragraph.id for paragraph in paragraph_list],
|
model = get_embedding_model_by_dataset_id(target_dataset_id)
|
||||||
target_dataset_id))
|
|
||||||
|
pid_list = [paragraph.id for paragraph in paragraph_list]
|
||||||
# 修改段落信息
|
# 修改段落信息
|
||||||
paragraph_list.update(dataset_id=target_dataset_id)
|
paragraph_list.update(dataset_id=target_dataset_id)
|
||||||
|
# 修改向量信息
|
||||||
|
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
|
||||||
|
pid_list,
|
||||||
|
target_dataset_id, model))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_target_dataset_problem(target_dataset_id: str,
|
def get_target_dataset_problem(target_dataset_id: str,
|
||||||
@ -392,7 +398,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
problem_paragraph_mapping_list) > 0 else None
|
problem_paragraph_mapping_list) > 0 else None
|
||||||
# 向量化
|
# 向量化
|
||||||
if with_embedding:
|
if with_embedding:
|
||||||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
model = get_embedding_model_by_dataset_id(dataset_id=document.dataset_id)
|
||||||
|
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
|
||||||
else:
|
else:
|
||||||
document.status = Status.error
|
document.status = Status.error
|
||||||
document.save()
|
document.save()
|
||||||
@ -405,6 +412,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
class Operate(ApiMixin, serializers.Serializer):
|
class Operate(ApiMixin, serializers.Serializer):
|
||||||
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
||||||
"文档id"))
|
"文档id"))
|
||||||
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("数据集id"))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_request_params_api():
|
def get_request_params_api():
|
||||||
@ -530,7 +538,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
document_id = self.data.get("document_id")
|
document_id = self.data.get("document_id")
|
||||||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
model = get_embedding_model_by_dataset_id(dataset_id=self.data.get('dataset_id'))
|
||||||
|
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
|
||||||
|
|
||||||
@transaction.atomic
|
@transaction.atomic
|
||||||
def delete(self):
|
def delete(self):
|
||||||
@ -599,8 +608,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def post_embedding(result, document_id):
|
def post_embedding(result, document_id, dataset_id):
|
||||||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||||
|
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -646,7 +656,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
document_id = str(document_model.id)
|
document_id = str(document_model.id)
|
||||||
return DocumentSerializers.Operate(
|
return DocumentSerializers.Operate(
|
||||||
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
|
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||||
with_valid=True), document_id
|
with_valid=True), document_id, dataset_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_sync_handler(dataset_id):
|
def get_sync_handler(dataset_id):
|
||||||
@ -803,9 +813,10 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api())
|
return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def post_embedding(document_list):
|
def post_embedding(document_list, dataset_id):
|
||||||
for document_dict in document_list:
|
for document_dict in document_list:
|
||||||
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'))
|
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||||
|
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'), embedding_model=model)
|
||||||
return document_list
|
return document_list
|
||||||
|
|
||||||
@post(post_function=post_embedding)
|
@post(post_function=post_embedding)
|
||||||
@ -846,7 +857,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
return [],
|
return [],
|
||||||
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
|
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
|
||||||
return native_search(query_set, select_string=get_file_content(
|
return native_search(query_set, select_string=get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False),
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')),
|
||||||
|
with_search_one=False), dataset_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _batch_sync(document_id_list: List[str]):
|
def _batch_sync(document_id_list: List[str]):
|
||||||
|
|||||||
@ -20,9 +20,9 @@ from common.exception.app_exception import AppApiException
|
|||||||
from common.mixins.api_mixin import ApiMixin
|
from common.mixins.api_mixin import ApiMixin
|
||||||
from common.util.common import post
|
from common.util.common import post
|
||||||
from common.util.field_message import ErrMessage
|
from common.util.field_message import ErrMessage
|
||||||
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
|
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet
|
||||||
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
|
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
|
||||||
ProblemParagraphManage
|
ProblemParagraphManage, get_embedding_model_by_dataset_id, get_embedding_model_by_dataset
|
||||||
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
|
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
|
||||||
from embedding.models import SourceType
|
from embedding.models import SourceType
|
||||||
|
|
||||||
@ -132,6 +132,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
paragraph_id=self.data.get('paragraph_id'),
|
paragraph_id=self.data.get('paragraph_id'),
|
||||||
dataset_id=self.data.get('dataset_id'))
|
dataset_id=self.data.get('dataset_id'))
|
||||||
problem_paragraph_mapping.save()
|
problem_paragraph_mapping.save()
|
||||||
|
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
|
||||||
if with_embedding:
|
if with_embedding:
|
||||||
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
||||||
'is_active': True,
|
'is_active': True,
|
||||||
@ -140,7 +141,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
'document_id': self.data.get('document_id'),
|
'document_id': self.data.get('document_id'),
|
||||||
'paragraph_id': self.data.get('paragraph_id'),
|
'paragraph_id': self.data.get('paragraph_id'),
|
||||||
'dataset_id': self.data.get('dataset_id'),
|
'dataset_id': self.data.get('dataset_id'),
|
||||||
})
|
}, embedding_model=model)
|
||||||
|
|
||||||
return ProblemSerializers.Operate(
|
return ProblemSerializers.Operate(
|
||||||
data={'dataset_id': self.data.get('dataset_id'),
|
data={'dataset_id': self.data.get('dataset_id'),
|
||||||
@ -227,6 +228,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
problem_id=problem.id)
|
problem_id=problem.id)
|
||||||
problem_paragraph_mapping.save()
|
problem_paragraph_mapping.save()
|
||||||
if with_embedding:
|
if with_embedding:
|
||||||
|
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
|
||||||
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
||||||
'is_active': True,
|
'is_active': True,
|
||||||
'source_type': SourceType.PROBLEM,
|
'source_type': SourceType.PROBLEM,
|
||||||
@ -234,7 +236,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
'document_id': self.data.get('document_id'),
|
'document_id': self.data.get('document_id'),
|
||||||
'paragraph_id': self.data.get('paragraph_id'),
|
'paragraph_id': self.data.get('paragraph_id'),
|
||||||
'dataset_id': self.data.get('dataset_id'),
|
'dataset_id': self.data.get('dataset_id'),
|
||||||
})
|
}, embedding_model=model)
|
||||||
|
|
||||||
def un_association(self, with_valid=True):
|
def un_association(self, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
@ -336,10 +338,11 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
# 修改mapping
|
# 修改mapping
|
||||||
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
|
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
|
||||||
['document_id'])
|
['document_id'])
|
||||||
|
|
||||||
# 修改向量段落信息
|
# 修改向量段落信息
|
||||||
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
|
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
|
||||||
[paragraph.id for paragraph in paragraph_list],
|
[paragraph.id for paragraph in paragraph_list],
|
||||||
target_document_id, target_dataset_id))
|
target_document_id, target_dataset_id, target_embedding_model=None))
|
||||||
# 修改段落信息
|
# 修改段落信息
|
||||||
paragraph_list.update(document_id=target_document_id)
|
paragraph_list.update(document_id=target_document_id)
|
||||||
# 不同数据集迁移
|
# 不同数据集迁移
|
||||||
@ -366,12 +369,19 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
# 修改mapping
|
# 修改mapping
|
||||||
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
|
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
|
||||||
['problem_id', 'dataset_id', 'document_id'])
|
['problem_id', 'dataset_id', 'document_id'])
|
||||||
# 修改向量段落信息
|
target_dataset = QuerySet(DataSet).filter(id=target_dataset_id).first()
|
||||||
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
|
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
|
||||||
[paragraph.id for paragraph in paragraph_list],
|
embedding_model = None
|
||||||
target_document_id, target_dataset_id))
|
if target_dataset.embedding_mode_id != dataset.embedding_mode_id:
|
||||||
|
embedding_model = get_embedding_model_by_dataset(target_dataset)
|
||||||
|
pid_list = [paragraph.id for paragraph in paragraph_list]
|
||||||
# 修改段落信息
|
# 修改段落信息
|
||||||
paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id)
|
paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id)
|
||||||
|
# 修改向量段落信息
|
||||||
|
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
|
||||||
|
pid_list,
|
||||||
|
target_document_id, target_dataset_id, target_embedding_model=embedding_model))
|
||||||
|
|
||||||
update_document_char_length(document_id)
|
update_document_char_length(document_id)
|
||||||
update_document_char_length(target_document_id)
|
update_document_char_length(target_document_id)
|
||||||
|
|
||||||
@ -454,13 +464,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
raise AppApiException(500, "段落id不存在")
|
raise AppApiException(500, "段落id不存在")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def post_embedding(paragraph, instance):
|
def post_embedding(paragraph, instance, dataset_id):
|
||||||
if 'is_active' in instance and instance.get('is_active') is not None:
|
if 'is_active' in instance and instance.get('is_active') is not None:
|
||||||
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
|
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
|
||||||
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
|
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
|
||||||
s.send(paragraph.get('id'))
|
s.send(paragraph.get('id'))
|
||||||
else:
|
else:
|
||||||
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'))
|
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||||
|
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'), embedding_model=model)
|
||||||
return paragraph
|
return paragraph
|
||||||
|
|
||||||
@post(post_embedding)
|
@post(post_embedding)
|
||||||
@ -508,7 +519,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
|
|
||||||
_paragraph.save()
|
_paragraph.save()
|
||||||
update_document_char_length(self.data.get('document_id'))
|
update_document_char_length(self.data.get('document_id'))
|
||||||
return self.one(), instance
|
return self.one(), instance, self.data.get('dataset_id')
|
||||||
|
|
||||||
def get_problem_list(self):
|
def get_problem_list(self):
|
||||||
ProblemParagraphMapping(ProblemParagraphMapping)
|
ProblemParagraphMapping(ProblemParagraphMapping)
|
||||||
@ -582,7 +593,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
# 修改长度
|
# 修改长度
|
||||||
update_document_char_length(document_id)
|
update_document_char_length(document_id)
|
||||||
if with_embedding:
|
if with_embedding:
|
||||||
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id))
|
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||||
|
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id), embedding_model=model)
|
||||||
return ParagraphSerializers.Operate(
|
return ParagraphSerializers.Operate(
|
||||||
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
|
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||||
with_valid=True)
|
with_valid=True)
|
||||||
|
|||||||
@ -20,7 +20,8 @@ from common.event import ListenerManagement, UpdateProblemArgs
|
|||||||
from common.mixins.api_mixin import ApiMixin
|
from common.mixins.api_mixin import ApiMixin
|
||||||
from common.util.field_message import ErrMessage
|
from common.util.field_message import ErrMessage
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from dataset.models import Problem, Paragraph, ProblemParagraphMapping
|
from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet
|
||||||
|
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
@ -157,6 +158,8 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
|
|||||||
content = instance.get('content')
|
content = instance.get('content')
|
||||||
problem = QuerySet(Problem).filter(id=problem_id,
|
problem = QuerySet(Problem).filter(id=problem_id,
|
||||||
dataset_id=dataset_id).first()
|
dataset_id=dataset_id).first()
|
||||||
|
QuerySet(DataSet).filter(id=dataset_id)
|
||||||
problem.content = content
|
problem.content = content
|
||||||
problem.save()
|
problem.save()
|
||||||
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content))
|
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||||
|
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content, model))
|
||||||
|
|||||||
@ -52,7 +52,6 @@ class Dataset(APIView):
|
|||||||
@action(methods=['POST'], detail=False)
|
@action(methods=['POST'], detail=False)
|
||||||
@swagger_auto_schema(operation_summary="创建QA知识库",
|
@swagger_auto_schema(operation_summary="创建QA知识库",
|
||||||
operation_id="创建QA知识库",
|
operation_id="创建QA知识库",
|
||||||
|
|
||||||
manual_parameters=DataSetSerializers.Create.CreateQASerializers.get_request_params_api(),
|
manual_parameters=DataSetSerializers.Create.CreateQASerializers.get_request_params_api(),
|
||||||
responses=get_api_response(
|
responses=get_api_response(
|
||||||
DataSetSerializers.Create.CreateQASerializers.get_response_body_api()),
|
DataSetSerializers.Create.CreateQASerializers.get_response_body_api()),
|
||||||
|
|||||||
@ -10,9 +10,8 @@ import threading
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
from common.config.embedding_config import EmbeddingModel
|
|
||||||
from common.util.common import sub_array
|
from common.util.common import sub_array
|
||||||
from embedding.models import SourceType, SearchMode
|
from embedding.models import SourceType, SearchMode
|
||||||
|
|
||||||
@ -51,7 +50,7 @@ class BaseVectorStore(ABC):
|
|||||||
|
|
||||||
def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||||
is_active: bool,
|
is_active: bool,
|
||||||
embedding=None):
|
embedding: Embeddings):
|
||||||
"""
|
"""
|
||||||
插入向量数据
|
插入向量数据
|
||||||
:param source_id: 资源id
|
:param source_id: 资源id
|
||||||
@ -64,13 +63,10 @@ class BaseVectorStore(ABC):
|
|||||||
:param paragraph_id 段落id
|
:param paragraph_id 段落id
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if embedding is None:
|
|
||||||
embedding = EmbeddingModel.get_embedding_model()
|
|
||||||
self.save_pre_handler()
|
self.save_pre_handler()
|
||||||
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
|
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
|
||||||
|
|
||||||
def batch_save(self, data_list: List[Dict], embedding=None):
|
def batch_save(self, data_list: List[Dict], embedding: Embeddings):
|
||||||
# 获取锁
|
# 获取锁
|
||||||
lock.acquire()
|
lock.acquire()
|
||||||
try:
|
try:
|
||||||
@ -80,8 +76,6 @@ class BaseVectorStore(ABC):
|
|||||||
:param embedding: 向量化处理器
|
:param embedding: 向量化处理器
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
if embedding is None:
|
|
||||||
embedding = EmbeddingModel.get_embedding_model()
|
|
||||||
self.save_pre_handler()
|
self.save_pre_handler()
|
||||||
result = sub_array(data_list)
|
result = sub_array(data_list)
|
||||||
for child_array in result:
|
for child_array in result:
|
||||||
@ -94,17 +88,17 @@ class BaseVectorStore(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||||
is_active: bool,
|
is_active: bool,
|
||||||
embedding: HuggingFaceEmbeddings):
|
embedding: Embeddings):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
|
def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||||
exclude_paragraph_list: list[str],
|
exclude_paragraph_list: list[str],
|
||||||
is_active: bool,
|
is_active: bool,
|
||||||
embedding: HuggingFaceEmbeddings):
|
embedding: Embeddings):
|
||||||
if dataset_id_list is None or len(dataset_id_list) == 0:
|
if dataset_id_list is None or len(dataset_id_list) == 0:
|
||||||
return []
|
return []
|
||||||
embedding_query = embedding.embed_query(query_text)
|
embedding_query = embedding.embed_query(query_text)
|
||||||
@ -123,7 +117,7 @@ class BaseVectorStore(ABC):
|
|||||||
def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
|
def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
|
||||||
similarity: float,
|
similarity: float,
|
||||||
search_mode: SearchMode,
|
search_mode: SearchMode,
|
||||||
embedding: HuggingFaceEmbeddings):
|
embedding: Embeddings):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -142,14 +136,6 @@ class BaseVectorStore(ABC):
|
|||||||
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def embed_documents(self, text_list: List[str]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def embed_query(self, text: str):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete_by_dataset_id(self, dataset_id: str):
|
def delete_by_dataset_id(self, dataset_id: str):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -13,9 +13,8 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
from common.config.embedding_config import EmbeddingModel
|
|
||||||
from common.db.search import generate_sql_by_query_dict
|
from common.db.search import generate_sql_by_query_dict
|
||||||
from common.db.sql_execute import select_list
|
from common.db.sql_execute import select_list
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
@ -33,14 +32,6 @@ class PGVector(BaseVectorStore):
|
|||||||
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
||||||
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
|
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
|
||||||
|
|
||||||
def embed_documents(self, text_list: List[str]):
|
|
||||||
embedding = EmbeddingModel.get_embedding_model()
|
|
||||||
return embedding.embed_documents(text_list)
|
|
||||||
|
|
||||||
def embed_query(self, text: str):
|
|
||||||
embedding = EmbeddingModel.get_embedding_model()
|
|
||||||
return embedding.embed_query(text)
|
|
||||||
|
|
||||||
def vector_is_create(self) -> bool:
|
def vector_is_create(self) -> bool:
|
||||||
# 项目启动默认是创建好的 不需要再创建
|
# 项目启动默认是创建好的 不需要再创建
|
||||||
return True
|
return True
|
||||||
@ -50,7 +41,7 @@ class PGVector(BaseVectorStore):
|
|||||||
|
|
||||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||||
is_active: bool,
|
is_active: bool,
|
||||||
embedding: HuggingFaceEmbeddings):
|
embedding: Embeddings):
|
||||||
text_embedding = embedding.embed_query(text)
|
text_embedding = embedding.embed_query(text)
|
||||||
embedding = Embedding(id=uuid.uuid1(),
|
embedding = Embedding(id=uuid.uuid1(),
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
@ -64,7 +55,7 @@ class PGVector(BaseVectorStore):
|
|||||||
embedding.save()
|
embedding.save()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
|
def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
|
||||||
texts = [row.get('text') for row in text_list]
|
texts = [row.get('text') for row in text_list]
|
||||||
embeddings = embedding.embed_documents(texts)
|
embeddings = embedding.embed_documents(texts)
|
||||||
embedding_list = [Embedding(id=uuid.uuid1(),
|
embedding_list = [Embedding(id=uuid.uuid1(),
|
||||||
@ -83,7 +74,7 @@ class PGVector(BaseVectorStore):
|
|||||||
def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
|
def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
|
||||||
similarity: float,
|
similarity: float,
|
||||||
search_mode: SearchMode,
|
search_mode: SearchMode,
|
||||||
embedding: HuggingFaceEmbeddings):
|
embedding: Embeddings):
|
||||||
if dataset_id_list is None or len(dataset_id_list) == 0:
|
if dataset_id_list is None or len(dataset_id_list) == 0:
|
||||||
return []
|
return []
|
||||||
exclude_dict = {}
|
exclude_dict = {}
|
||||||
|
|||||||
46
apps/setting/migrations/0005_model_permission_type.py
Normal file
46
apps/setting/migrations/0005_model_permission_type.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# Generated by Django 4.2.13 on 2024-07-15 15:23
|
||||||
|
import json
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
from django.db.models import QuerySet
|
||||||
|
|
||||||
|
from common.util.rsa_util import rsa_long_encrypt
|
||||||
|
from setting.models import Status, PermissionType
|
||||||
|
from smartdoc.const import CONFIG
|
||||||
|
|
||||||
|
default_embedding_model_id = '42f63a3d-427e-11ef-b3ec-a8a1595801ab'
|
||||||
|
|
||||||
|
|
||||||
|
def save_default_embedding_model(apps, schema_editor):
|
||||||
|
ModelModel = apps.get_model('setting', 'Model')
|
||||||
|
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH')
|
||||||
|
model_name = CONFIG.get('EMBEDDING_MODEL_NAME')
|
||||||
|
credential = {'cache_folder': cache_folder}
|
||||||
|
model_credential_str = json.dumps(credential)
|
||||||
|
model = ModelModel(id=default_embedding_model_id, name='maxkb-embedding', status=Status.SUCCESS,
|
||||||
|
model_type="EMBEDDING", model_name=model_name, user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab',
|
||||||
|
provider='model_local_provider',
|
||||||
|
credential=rsa_long_encrypt(model_credential_str), meta={},
|
||||||
|
permission_type=PermissionType.PUBLIC)
|
||||||
|
model.save()
|
||||||
|
|
||||||
|
|
||||||
|
def reverse_code_embedding_model(apps, schema_editor):
|
||||||
|
ModelModel = apps.get_model('setting', 'Model')
|
||||||
|
QuerySet(ModelModel).filter(id=default_embedding_model_id).delete()
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
('setting', '0004_alter_model_credential'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='model',
|
||||||
|
name='permission_type',
|
||||||
|
field=models.CharField(choices=[('PUBLIC', '公开'), ('PRIVATE', '私有')], default='PRIVATE', max_length=20,
|
||||||
|
verbose_name='权限类型'),
|
||||||
|
),
|
||||||
|
migrations.RunPython(save_default_embedding_model, reverse_code_embedding_model)
|
||||||
|
]
|
||||||
@ -22,6 +22,13 @@ class Status(models.TextChoices):
|
|||||||
|
|
||||||
DOWNLOAD = "DOWNLOAD", '下载中'
|
DOWNLOAD = "DOWNLOAD", '下载中'
|
||||||
|
|
||||||
|
PAUSE_DOWNLOAD = "PAUSE_DOWNLOAD", '暂停下载'
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionType(models.TextChoices):
|
||||||
|
PUBLIC = "PUBLIC", '公开'
|
||||||
|
PRIVATE = "PRIVATE", "私有"
|
||||||
|
|
||||||
|
|
||||||
class Model(AppModelMixin):
|
class Model(AppModelMixin):
|
||||||
"""
|
"""
|
||||||
@ -46,6 +53,9 @@ class Model(AppModelMixin):
|
|||||||
|
|
||||||
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)
|
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)
|
||||||
|
|
||||||
|
permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices,
|
||||||
|
default=PermissionType.PRIVATE)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
db_table = "model"
|
db_table = "model"
|
||||||
unique_together = ['name', 'user_id']
|
unique_together = ['name', 'user_id']
|
||||||
|
|||||||
@ -6,3 +6,85 @@
|
|||||||
@date:2023/10/31 17:16
|
@date:2023/10/31 17:16
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
|
import json
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from common.util.rsa_util import rsa_long_decrypt
|
||||||
|
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_(provider, model_type, model_name, credential):
|
||||||
|
"""
|
||||||
|
获取模型实例
|
||||||
|
@param provider: 供应商
|
||||||
|
@param model_type: 模型类型
|
||||||
|
@param model_name: 模型名称
|
||||||
|
@param credential: 认证信息
|
||||||
|
@return: 模型实例
|
||||||
|
"""
|
||||||
|
model = get_provider(provider).get_model(model_type, model_name,
|
||||||
|
json.loads(
|
||||||
|
rsa_long_decrypt(credential)),
|
||||||
|
streaming=True)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(model):
|
||||||
|
"""
|
||||||
|
获取模型实例
|
||||||
|
@param model: model 数据库Model实例对象
|
||||||
|
@return: 模型实例
|
||||||
|
"""
|
||||||
|
return get_model_(model.provider, model.model_type, model.model_name, model.credential)
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider(provider):
|
||||||
|
"""
|
||||||
|
获取供应商实例
|
||||||
|
@param provider: 供应商字符串
|
||||||
|
@return: 供应商实例
|
||||||
|
"""
|
||||||
|
return ModelProvideConstants[provider].value
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_list(provider, model_type):
|
||||||
|
"""
|
||||||
|
获取模型列表
|
||||||
|
@param provider: 供应商字符串
|
||||||
|
@param model_type: 模型类型
|
||||||
|
@return: 模型列表
|
||||||
|
"""
|
||||||
|
return get_provider(provider).get_model_list(model_type)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_credential(provider, model_type, model_name):
|
||||||
|
"""
|
||||||
|
获取模型认证实例
|
||||||
|
@param provider: 供应商字符串
|
||||||
|
@param model_type: 模型类型
|
||||||
|
@param model_name: 模型名称
|
||||||
|
@return: 认证实例对象
|
||||||
|
"""
|
||||||
|
return get_provider(provider).get_model_credential(model_type, model_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_type_list(provider):
|
||||||
|
"""
|
||||||
|
获取模型类型列表
|
||||||
|
@param provider: 供应商字符串
|
||||||
|
@return: 模型类型列表
|
||||||
|
"""
|
||||||
|
return get_provider(provider).get_model_type_list()
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||||
|
"""
|
||||||
|
校验模型认证参数
|
||||||
|
@param provider: 供应商字符串
|
||||||
|
@param model_type: 模型类型
|
||||||
|
@param model_name: 模型名称
|
||||||
|
@param model_credential: 模型认证数据
|
||||||
|
@param raise_exception: 是否抛出错误
|
||||||
|
@return: True|False
|
||||||
|
"""
|
||||||
|
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception)
|
||||||
|
|||||||
@ -61,7 +61,7 @@ class IModelProvider(ABC):
|
|||||||
def get_model_list(self, model_type):
|
def get_model_list(self, model_type):
|
||||||
if model_type is None:
|
if model_type is None:
|
||||||
raise AppApiException(500, '模型类型不能为空')
|
raise AppApiException(500, '模型类型不能为空')
|
||||||
return self.get_model_info_manage().get_model_list()
|
return self.get_model_info_manage().get_model_list_by_model_type(model_type)
|
||||||
|
|
||||||
def get_model_credential(self, model_type, model_name):
|
def get_model_credential(self, model_type, model_name):
|
||||||
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
||||||
@ -191,6 +191,9 @@ class ModelInfoManage:
|
|||||||
def get_model_list(self):
|
def get_model_list(self):
|
||||||
return [model.to_dict() for model in self.model_list]
|
return [model.to_dict() for model in self.model_list]
|
||||||
|
|
||||||
|
def get_model_list_by_model_type(self, model_type):
|
||||||
|
return [model.to_dict() for model in self.model_list if model.model_type == model_type]
|
||||||
|
|
||||||
def get_model_type_list(self):
|
def get_model_type_list(self):
|
||||||
return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if
|
return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if
|
||||||
len([model for model in self.model_list if model.model_type == _type.name]) > 0]
|
len([model for model in self.model_list if model.model_type == _type.name]) > 0]
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import
|
|||||||
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
|
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
|
||||||
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
|
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
|
||||||
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
|
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
|
||||||
|
from setting.models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
|
||||||
|
|
||||||
|
|
||||||
class ModelProvideConstants(Enum):
|
class ModelProvideConstants(Enum):
|
||||||
@ -31,3 +32,4 @@ class ModelProvideConstants(Enum):
|
|||||||
model_xf_provider = XunFeiModelProvider()
|
model_xf_provider = XunFeiModelProvider()
|
||||||
model_deepseek_provider = DeepSeekModelProvider()
|
model_deepseek_provider = DeepSeekModelProvider()
|
||||||
model_gemini_provider = GeminiModelProvider()
|
model_gemini_provider = GeminiModelProvider()
|
||||||
|
model_local_provider = LocalModelProvider()
|
||||||
|
|||||||
@ -0,0 +1,8 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: __init__.py
|
||||||
|
@date:2024/7/10 17:48
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
@ -0,0 +1,45 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: embedding.py
|
||||||
|
@date:2024/7/11 11:06
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
|
||||||
|
|
||||||
|
|
||||||
|
class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
if not model_type == 'EMBEDDING':
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
for key in ['cache_folder']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.embed_query('你好')
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return model
|
||||||
|
|
||||||
|
cache_folder = forms.TextInputField('模型目录', required=True)
|
||||||
@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" t="1720668342208" class="icon" viewBox="0 0 1024 1024" version="1.1" p-id="9052" width="100%" height="100%"><path d="M512.2 475.7c-8.2-0.3-16.1-2.4-23.4-6.1L192.6 330.2c-24.9-11.1-36.1-40.3-25-65.2 5-11.2 13.9-20.1 25-25l281.5-133.2a89.43 89.43 0 0 1 76 0L831.7 240c24.9 11.1 36.1 40.3 25 65.2-5 11.2-13.9 20.1-25 25L535.5 469.5c-7.2 3.8-15.2 5.9-23.3 6.2z m-76.5 452.5c-7.6 0-15.1-1.9-21.8-5.5L146.3 797.2c-17-8.9-27.5-26.5-27.3-45.6v-320c0.1-18 9.7-34.5 25.1-43.7 14.3-8.1 31.8-8.1 46.1 0l267.1 125.4c16.1 8.7 26.4 25.4 27.1 43.7v320.4c-0.2 17.9-9.6 34.4-24.9 43.7-7.2 4.3-15.4 6.8-23.8 7.1z m152.9 0c-8.3 0-16.5-2.2-23.8-6.3-15.3-9.3-24.7-25.8-24.9-43.7V556.9c0.4-18.2 10.4-34.8 26.2-43.7L835 387c14.2-7.5 31.4-7.1 45.2 1.1 15.5 9.1 25 25.7 25.1 43.7v319.8c0.4 18.9-9.7 36.5-26.2 45.6L610.5 922.8c-6.8 3.6-14.3 5.5-21.9 5.4z" p-id="9053"/></svg>
|
||||||
|
After Width: | Height: | Size: 931 B |
@ -0,0 +1,38 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: zhipu_model_provider.py
|
||||||
|
@date:2024/04/19 13:5
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.util.file_util import get_file_content
|
||||||
|
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||||
|
ModelInfoManage
|
||||||
|
from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential
|
||||||
|
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
|
||||||
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
|
||||||
|
LocalEmbeddingCredential(), LocalEmbedding)
|
||||||
|
|
||||||
|
model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese)
|
||||||
|
.append_default_model_info(embedding_text2vec_base_chinese)
|
||||||
|
.build())
|
||||||
|
|
||||||
|
|
||||||
|
class LocalModelProvider(IModelProvider):
|
||||||
|
|
||||||
|
def get_model_info_manage(self):
|
||||||
|
return model_info_manage
|
||||||
|
|
||||||
|
def get_model_provide_info(self):
|
||||||
|
return ModelProvideInfo(provider='model_local_provider', name='本地模型', icon=get_file_content(
|
||||||
|
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'local_model_provider', 'icon',
|
||||||
|
'local_icon_svg')))
|
||||||
@ -0,0 +1,22 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: embedding.py
|
||||||
|
@date:2024/7/11 14:06
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
|
||||||
|
model_kwargs={'device': model_credential.get('device')},
|
||||||
|
encode_kwargs={'normalize_embeddings': True},
|
||||||
|
)
|
||||||
@ -0,0 +1,45 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: embedding.py
|
||||||
|
@date:2024/7/12 15:10
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaEmbeddingModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
try:
|
||||||
|
model_list = provider.get_base_model_list(model_credential.get('api_base'))
|
||||||
|
except Exception as e:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
|
||||||
|
exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if
|
||||||
|
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
|
||||||
|
if len(exist) == 0:
|
||||||
|
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
|
||||||
|
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.embed_query('你好')
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model_info: Dict[str, object]):
|
||||||
|
return model_info
|
||||||
|
|
||||||
|
def build_model(self, model_info: Dict[str, object]):
|
||||||
|
for key in ['model']:
|
||||||
|
if key not in model_info:
|
||||||
|
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||||
|
return self
|
||||||
|
|
||||||
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
@ -0,0 +1,48 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: embedding.py
|
||||||
|
@date:2024/7/12 15:02
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from langchain_community.embeddings import OllamaEmbeddings
|
||||||
|
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaEmbedding(MaxKBBaseModel, OllamaEmbeddings):
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
return OllamaEmbedding(
|
||||||
|
model=model_name,
|
||||||
|
base_url=model_credential.get('api_base'),
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Embed documents using an Ollama deployed embedding model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
instruction_pairs = [f"{text}" for text in texts]
|
||||||
|
embeddings = self._embed(instruction_pairs)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Embed a query using a Ollama deployed embedding model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings for the text.
|
||||||
|
"""
|
||||||
|
instruction_pair = f"{text}"
|
||||||
|
embedding = self._embed([instruction_pair])[0]
|
||||||
|
return embedding
|
||||||
@ -20,7 +20,9 @@ from common.forms import BaseForm
|
|||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
||||||
BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode, ModelInfoManage
|
BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode, ModelInfoManage
|
||||||
|
from setting.models_provider.impl.ollama_model_provider.credential.embedding import OllamaEmbeddingModelCredential
|
||||||
from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential
|
from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential
|
||||||
|
from setting.models_provider.impl.ollama_model_provider.model.embedding import OllamaEmbedding
|
||||||
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
|
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
@ -88,14 +90,25 @@ model_info_list = [
|
|||||||
ModelInfo(
|
ModelInfo(
|
||||||
'phi3',
|
'phi3',
|
||||||
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
|
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||||
|
]
|
||||||
|
ollama_embedding_model_credential = OllamaEmbeddingModelCredential()
|
||||||
|
embedding_model_info = [
|
||||||
|
ModelInfo(
|
||||||
|
'nomic-embed-text',
|
||||||
|
'一个具有大令牌上下文窗口的高性能开放嵌入模型。',
|
||||||
|
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding),
|
||||||
]
|
]
|
||||||
|
|
||||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_model_info_list(
|
||||||
|
embedding_model_info).append_default_model_info(
|
||||||
ModelInfo(
|
ModelInfo(
|
||||||
'phi3',
|
'phi3',
|
||||||
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
|
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
|
||||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).build()
|
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).append_default_model_info(ModelInfo(
|
||||||
|
'nomic-embed-text',
|
||||||
|
'一个具有大令牌上下文窗口的高性能开放嵌入模型。',
|
||||||
|
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), ).build()
|
||||||
|
|
||||||
|
|
||||||
def get_base_url(url: str):
|
def get_base_url(url: str):
|
||||||
@ -139,7 +152,6 @@ def convert(response_stream) -> Iterator[DownModelChunk]:
|
|||||||
temp = ""
|
temp = ""
|
||||||
|
|
||||||
if len(temp) > 0:
|
if len(temp) > 0:
|
||||||
print(temp)
|
|
||||||
rows = [t for t in temp.split("\n") if len(t) > 0]
|
rows = [t for t in temp.split("\n") if len(t) > 0]
|
||||||
for row in rows:
|
for row in rows:
|
||||||
yield convert_to_down_model_chunk(row, index)
|
yield convert_to_down_model_chunk(row, index)
|
||||||
@ -154,9 +166,6 @@ class OllamaModelProvider(IModelProvider):
|
|||||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon',
|
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon',
|
||||||
'ollama_icon_svg')))
|
'ollama_icon_svg')))
|
||||||
|
|
||||||
def get_dialogue_number(self):
|
|
||||||
return 2
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_base_model_list(api_base):
|
def get_base_model_list(api_base):
|
||||||
base_url = get_base_url(api_base)
|
base_url = get_base_url(api_base)
|
||||||
@ -165,7 +174,7 @@ class OllamaModelProvider(IModelProvider):
|
|||||||
return r.json()
|
return r.json()
|
||||||
|
|
||||||
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
|
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
|
||||||
api_base = model_credential.get('api_base')
|
api_base = model_credential.get('api_base', '')
|
||||||
base_url = get_base_url(api_base)
|
base_url = get_base_url(api_base)
|
||||||
r = requests.request(
|
r = requests.request(
|
||||||
method="POST",
|
method="POST",
|
||||||
|
|||||||
@ -0,0 +1,46 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: embedding.py
|
||||||
|
@date:2024/7/12 16:45
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=True):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['api_base', 'api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.embed_query('你好')
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
@ -0,0 +1,23 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: embedding.py
|
||||||
|
@date:2024/7/12 17:44
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_community.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings):
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
return OpenAIEmbeddingModel(
|
||||||
|
api_key=model_credential.get('api_key'),
|
||||||
|
model=model_name,
|
||||||
|
openai_api_base=model_credential.get('api_base'),
|
||||||
|
)
|
||||||
@ -11,7 +11,9 @@ import os
|
|||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
|
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
|
||||||
ModelTypeConst, ModelInfoManage
|
ModelTypeConst, ModelInfoManage
|
||||||
|
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
|
||||||
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||||
|
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
||||||
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
|
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
@ -58,11 +60,17 @@ model_info_list = [
|
|||||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||||
OpenAIChatModel)
|
OpenAIChatModel)
|
||||||
]
|
]
|
||||||
|
open_ai_embedding_credential = OpenAIEmbeddingCredential()
|
||||||
|
model_info_embedding_list = [
|
||||||
|
ModelInfo('text-embedding-ada-002', '',
|
||||||
|
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
|
||||||
|
OpenAIEmbeddingModel)]
|
||||||
|
|
||||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||||
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||||
openai_llm_model_credential, OpenAIChatModel
|
openai_llm_model_credential, OpenAIChatModel
|
||||||
)).build()
|
)).append_model_info_list(model_info_embedding_list).append_default_model_info(
|
||||||
|
model_info_embedding_list[0]).build()
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModelProvider(IModelProvider):
|
class OpenAIModelProvider(IModelProvider):
|
||||||
|
|||||||
@ -7,12 +7,14 @@
|
|||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from django.db.models import QuerySet
|
from django.core import validators
|
||||||
|
from django.db.models import QuerySet, Q
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from application.models import Application
|
from application.models import Application
|
||||||
@ -36,6 +38,9 @@ class ModelPullManage:
|
|||||||
for chunk in response:
|
for chunk in response:
|
||||||
down_model_chunk[chunk.digest] = chunk.to_dict()
|
down_model_chunk[chunk.digest] = chunk.to_dict()
|
||||||
if time.time() - timestamp > 5:
|
if time.time() - timestamp > 5:
|
||||||
|
model_new = QuerySet(Model).filter(id=model.id).first()
|
||||||
|
if model_new.status == Status.PAUSE_DOWNLOAD:
|
||||||
|
return
|
||||||
QuerySet(Model).filter(id=model.id).update(
|
QuerySet(Model).filter(id=model.id).update(
|
||||||
meta={"down_model_chunk": list(down_model_chunk.values())})
|
meta={"down_model_chunk": list(down_model_chunk.values())})
|
||||||
timestamp = time.time()
|
timestamp = time.time()
|
||||||
@ -72,7 +77,7 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
user_id = self.data.get('user_id')
|
user_id = self.data.get('user_id')
|
||||||
name = self.data.get('name')
|
name = self.data.get('name')
|
||||||
model_query_set = QuerySet(Model).filter(user_id=user_id)
|
model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC')))
|
||||||
query_params = {}
|
query_params = {}
|
||||||
if name is not None:
|
if name is not None:
|
||||||
query_params['name__contains'] = name
|
query_params['name__contains'] = name
|
||||||
@ -85,7 +90,8 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
|
|
||||||
return [
|
return [
|
||||||
{'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
{'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
||||||
'model_name': model.model_name, 'status': model.status, 'meta': model.meta} for model in
|
'model_name': model.model_name, 'status': model.status, 'meta': model.meta,
|
||||||
|
'permission_type': model.permission_type} for model in
|
||||||
model_query_set.filter(**query_params).order_by("-create_time")]
|
model_query_set.filter(**query_params).order_by("-create_time")]
|
||||||
|
|
||||||
class Edit(serializers.Serializer):
|
class Edit(serializers.Serializer):
|
||||||
@ -96,6 +102,11 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
|
|
||||||
model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
|
model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
|
||||||
|
|
||||||
|
permission_type = serializers.CharField(required=False, error_messages=ErrMessage.char("权限"), validators=[
|
||||||
|
validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"),
|
||||||
|
message="权限只支持PUBLIC|PRIVATE", code=500)
|
||||||
|
])
|
||||||
|
|
||||||
model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
|
model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
|
||||||
|
|
||||||
credential = serializers.DictField(required=False, error_messages=ErrMessage.dict("认证信息"))
|
credential = serializers.DictField(required=False, error_messages=ErrMessage.dict("认证信息"))
|
||||||
@ -135,6 +146,11 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
|
|
||||||
model_type = serializers.CharField(required=True, error_messages=ErrMessage.char("模型类型"))
|
model_type = serializers.CharField(required=True, error_messages=ErrMessage.char("模型类型"))
|
||||||
|
|
||||||
|
permission_type = serializers.CharField(required=True, error_messages=ErrMessage.char("权限"), validators=[
|
||||||
|
validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"),
|
||||||
|
message="权限只支持PUBLIC|PRIVATE", code=500)
|
||||||
|
])
|
||||||
|
|
||||||
model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型"))
|
model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型"))
|
||||||
|
|
||||||
credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息"))
|
credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息"))
|
||||||
@ -165,10 +181,12 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
provider = self.data.get('provider')
|
provider = self.data.get('provider')
|
||||||
model_type = self.data.get('model_type')
|
model_type = self.data.get('model_type')
|
||||||
model_name = self.data.get('model_name')
|
model_name = self.data.get('model_name')
|
||||||
|
permission_type = self.data.get('permission_type')
|
||||||
model_credential_str = json.dumps(credential)
|
model_credential_str = json.dumps(credential)
|
||||||
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
|
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
|
||||||
credential=rsa_long_encrypt(model_credential_str),
|
credential=rsa_long_encrypt(model_credential_str),
|
||||||
provider=provider, model_type=model_type, model_name=model_name)
|
provider=provider, model_type=model_type, model_name=model_name,
|
||||||
|
permission_type=permission_type)
|
||||||
model.save()
|
model.save()
|
||||||
if status == Status.DOWNLOAD:
|
if status == Status.DOWNLOAD:
|
||||||
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
|
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
|
||||||
@ -184,7 +202,8 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
'meta': model.meta,
|
'meta': model.meta,
|
||||||
'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
|
'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
|
||||||
model.model_name).encryption_dict(
|
model.model_name).encryption_dict(
|
||||||
credential)}
|
credential),
|
||||||
|
'permission_type': model.permission_type}
|
||||||
|
|
||||||
class Operate(serializers.Serializer):
|
class Operate(serializers.Serializer):
|
||||||
id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
|
id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
|
||||||
@ -210,7 +229,8 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
||||||
'model_name': model.model_name,
|
'model_name': model.model_name,
|
||||||
'status': model.status,
|
'status': model.status,
|
||||||
'meta': model.meta, }
|
'meta': model.meta
|
||||||
|
}
|
||||||
|
|
||||||
def delete(self, with_valid=True):
|
def delete(self, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
@ -221,6 +241,12 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
QuerySet(Model).filter(id=self.data.get('id')).delete()
|
QuerySet(Model).filter(id=self.data.get('id')).delete()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def pause_download(self, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
QuerySet(Model).filter(id=self.data.get('id')).update(status=Status.PAUSE_DOWNLOAD)
|
||||||
|
return True
|
||||||
|
|
||||||
def edit(self, instance: Dict, user_id: str, with_valid=True):
|
def edit(self, instance: Dict, user_id: str, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
@ -245,7 +271,7 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
model.status = Status.DOWNLOAD
|
model.status = Status.DOWNLOAD
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
update_keys = ['credential', 'name', 'model_type', 'model_name']
|
update_keys = ['credential', 'name', 'model_type', 'model_name', 'permission_type']
|
||||||
for update_key in update_keys:
|
for update_key in update_keys:
|
||||||
if update_key in instance and instance.get(update_key) is not None:
|
if update_key in instance and instance.get(update_key) is not None:
|
||||||
if update_key == 'credential':
|
if update_key == 'credential':
|
||||||
|
|||||||
@ -74,6 +74,8 @@ class ModelCreateApi(ApiMixin):
|
|||||||
'provider': openapi.Schema(type=openapi.TYPE_STRING,
|
'provider': openapi.Schema(type=openapi.TYPE_STRING,
|
||||||
title="供应商",
|
title="供应商",
|
||||||
description="供应商"),
|
description="供应商"),
|
||||||
|
'permission_type': openapi.Schema(type=openapi.TYPE_STRING, title="权限",
|
||||||
|
description="PUBLIC|PRIVATE"),
|
||||||
'model_type': openapi.Schema(type=openapi.TYPE_STRING,
|
'model_type': openapi.Schema(type=openapi.TYPE_STRING,
|
||||||
title="供应商",
|
title="供应商",
|
||||||
description="供应商"),
|
description="供应商"),
|
||||||
@ -82,7 +84,8 @@ class ModelCreateApi(ApiMixin):
|
|||||||
description="供应商"),
|
description="供应商"),
|
||||||
'credential': openapi.Schema(type=openapi.TYPE_OBJECT,
|
'credential': openapi.Schema(type=openapi.TYPE_OBJECT,
|
||||||
title="模型证书信息",
|
title="模型证书信息",
|
||||||
description="模型证书信息")
|
description="模型证书信息"),
|
||||||
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,7 @@ urlpatterns = [
|
|||||||
name="provider/model_form"),
|
name="provider/model_form"),
|
||||||
path('model', views.Model.as_view(), name='model'),
|
path('model', views.Model.as_view(), name='model'),
|
||||||
path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
|
path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
|
||||||
|
path('model/<str:model_id>/pause_download', views.Model.PauseDownload.as_view(), name='model/operate'),
|
||||||
path('model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),
|
path('model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),
|
||||||
path('email_setting', views.SystemSetting.Email.as_view(), name='email_setting'),
|
path('email_setting', views.SystemSetting.Email.as_view(), name='email_setting'),
|
||||||
path('valid/<str:valid_type>/<int:valid_count>', views.Valid.as_view())
|
path('valid/<str:valid_type>/<int:valid_count>', views.Valid.as_view())
|
||||||
|
|||||||
@ -69,6 +69,18 @@ class Model(APIView):
|
|||||||
return result.success(
|
return result.success(
|
||||||
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one_meta(with_valid=True))
|
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one_meta(with_valid=True))
|
||||||
|
|
||||||
|
class PauseDownload(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@action(methods=['PUT'], detail=False)
|
||||||
|
@swagger_auto_schema(operation_summary="暂停模型下载",
|
||||||
|
operation_id="暂停模型下载",
|
||||||
|
tags=["模型"])
|
||||||
|
@has_permissions(PermissionConstants.MODEL_CREATE)
|
||||||
|
def put(self, request: Request, model_id: str):
|
||||||
|
return result.success(
|
||||||
|
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).pause_download())
|
||||||
|
|
||||||
class Operate(APIView):
|
class Operate(APIView):
|
||||||
authentication_classes = [TokenAuth]
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
|||||||
@ -104,9 +104,6 @@ CACHES = {
|
|||||||
"token_cache": {
|
"token_cache": {
|
||||||
'BACKEND': 'common.cache.file_cache.FileCache',
|
'BACKEND': 'common.cache.file_cache.FileCache',
|
||||||
'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径
|
'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径
|
||||||
},
|
|
||||||
"chat_cache": {
|
|
||||||
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -129,11 +129,12 @@ const getDatasetDetail: (dataset_id: string, loading?: Ref<boolean>) => Promise<
|
|||||||
"desc": true
|
"desc": true
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
const putDataset: (dataset_id: string, data: any) => Promise<Result<any>> = (
|
const putDataset: (
|
||||||
dataset_id,
|
dataset_id: string,
|
||||||
data: any
|
data: any,
|
||||||
) => {
|
loading?: Ref<boolean>
|
||||||
return put(`${prefix}/${dataset_id}`, data)
|
) => Promise<Result<any>> = (dataset_id, data, loading) => {
|
||||||
|
return put(`${prefix}/${dataset_id}`, data, undefined, loading)
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* 获取知识库 可关联的应用列表
|
* 获取知识库 可关联的应用列表
|
||||||
|
|||||||
@ -130,7 +130,18 @@ const getModelMetaById: (model_id: string, loading?: Ref<boolean>) => Promise<Re
|
|||||||
) => {
|
) => {
|
||||||
return get(`${prefix}/${model_id}/meta`, {}, loading)
|
return get(`${prefix}/${model_id}/meta`, {}, loading)
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* 暂停下载
|
||||||
|
* @param model_id 模型id
|
||||||
|
* @param loading 加载器
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
const pauseDownload: (model_id: string, loading?: Ref<boolean>) => Promise<Result<boolean>> = (
|
||||||
|
model_id,
|
||||||
|
loading
|
||||||
|
) => {
|
||||||
|
return put(`${prefix}/${model_id}/pause_download`, undefined, {}, loading)
|
||||||
|
}
|
||||||
const deleteModel: (model_id: string, loading?: Ref<boolean>) => Promise<Result<boolean>> = (
|
const deleteModel: (model_id: string, loading?: Ref<boolean>) => Promise<Result<boolean>> = (
|
||||||
model_id,
|
model_id,
|
||||||
loading
|
loading
|
||||||
@ -147,5 +158,6 @@ export default {
|
|||||||
updateModel,
|
updateModel,
|
||||||
deleteModel,
|
deleteModel,
|
||||||
getModelById,
|
getModelById,
|
||||||
getModelMetaById
|
getModelMetaById,
|
||||||
|
pauseDownload
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,6 +3,7 @@ interface datasetData {
|
|||||||
desc: String
|
desc: String
|
||||||
documents?: Array<any>
|
documents?: Array<any>
|
||||||
type?: String
|
type?: String
|
||||||
|
embedding_mode_id?: String
|
||||||
}
|
}
|
||||||
|
|
||||||
export type { datasetData }
|
export type { datasetData }
|
||||||
|
|||||||
@ -53,6 +53,7 @@ interface Model {
|
|||||||
* 模型类型
|
* 模型类型
|
||||||
*/
|
*/
|
||||||
model_type: string
|
model_type: string
|
||||||
|
permission_type: 'PUBLIC' | 'PRIVATE'
|
||||||
/**
|
/**
|
||||||
* 基础模型
|
* 基础模型
|
||||||
*/
|
*/
|
||||||
@ -68,7 +69,7 @@ interface Model {
|
|||||||
/**
|
/**
|
||||||
* 状态
|
* 状态
|
||||||
*/
|
*/
|
||||||
status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR'
|
status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR' | 'PAUSE_DOWNLOAD'
|
||||||
/**
|
/**
|
||||||
* 元数据
|
* 元数据
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -18,7 +18,8 @@
|
|||||||
</slot>
|
</slot>
|
||||||
<slot></slot>
|
<slot></slot>
|
||||||
</div>
|
</div>
|
||||||
<el-checkbox v-bind:modelValue="modelValue.includes(toModelValue)"> </el-checkbox>
|
<el-checkbox v-bind:modelValue="modelValue.includes(toModelValue)" @change="checkboxChange">
|
||||||
|
</el-checkbox>
|
||||||
</div>
|
</div>
|
||||||
</el-card>
|
</el-card>
|
||||||
</template>
|
</template>
|
||||||
@ -40,7 +41,7 @@ const toModelValue = computed(() => (props.valueField ? props.data[props.valueFi
|
|||||||
// set: (val) => val
|
// set: (val) => val
|
||||||
// })
|
// })
|
||||||
|
|
||||||
const emit = defineEmits(['update:modelValue'])
|
const emit = defineEmits(['update:modelValue', 'change'])
|
||||||
|
|
||||||
const checked = () => {
|
const checked = () => {
|
||||||
const value = props.modelValue ? props.modelValue : []
|
const value = props.modelValue ? props.modelValue : []
|
||||||
@ -53,6 +54,10 @@ const checked = () => {
|
|||||||
emit('update:modelValue', [...value, toModelValue.value])
|
emit('update:modelValue', [...value, toModelValue.value])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function checkboxChange() {
|
||||||
|
emit('change')
|
||||||
|
}
|
||||||
</script>
|
</script>
|
||||||
<style lang="scss" scoped>
|
<style lang="scss" scoped>
|
||||||
.card-checkbox {
|
.card-checkbox {
|
||||||
|
|||||||
93
ui/src/components/loading/DownloadLoading.vue
Normal file
93
ui/src/components/loading/DownloadLoading.vue
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
<template>
|
||||||
|
<div class="loading-container loader">
|
||||||
|
<div class="download-loading">
|
||||||
|
<div></div>
|
||||||
|
<div></div>
|
||||||
|
<div></div>
|
||||||
|
<div></div>
|
||||||
|
<div></div>
|
||||||
|
<div></div>
|
||||||
|
<div></div>
|
||||||
|
<div></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<script setup lang="ts"></script>
|
||||||
|
<style lang="scss" scoped>
|
||||||
|
.loading-container {
|
||||||
|
display: -webkit-flex; /*safari弹性布局*/
|
||||||
|
justify-content: center;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
@-webkit-keyframes loader {
|
||||||
|
0% {
|
||||||
|
opacity: 0.3;
|
||||||
|
}
|
||||||
|
80% {
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.download-loading {
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
.download-loading div {
|
||||||
|
width: 5px;
|
||||||
|
height: 12px;
|
||||||
|
background: var(--el-color-info);
|
||||||
|
position: absolute;
|
||||||
|
border-radius: 2px;
|
||||||
|
margin: 0 auto;
|
||||||
|
}
|
||||||
|
.download-loading div:nth-child(1) {
|
||||||
|
top: -20px;
|
||||||
|
left: 0;
|
||||||
|
-webkit-animation: loader 1s -0.8s infinite ease-in-out;
|
||||||
|
}
|
||||||
|
.download-loading div:nth-child(2) {
|
||||||
|
top: -13px;
|
||||||
|
left: 13px;
|
||||||
|
-webkit-transform: rotate(45deg);
|
||||||
|
-webkit-animation: loader 1s -0.6s infinite ease-in-out;
|
||||||
|
}
|
||||||
|
.download-loading div:nth-child(3) {
|
||||||
|
top: 0px;
|
||||||
|
left: 20px;
|
||||||
|
-webkit-transform: rotate(90deg);
|
||||||
|
-webkit-animation: loader 1s -0.5s infinite ease-in-out;
|
||||||
|
}
|
||||||
|
.download-loading div:nth-child(4) {
|
||||||
|
top: 13px;
|
||||||
|
left: 13px;
|
||||||
|
-webkit-transform: rotate(-45deg);
|
||||||
|
-webkit-animation: loader 1s -0.4s infinite ease-in-out;
|
||||||
|
}
|
||||||
|
.download-loading div:nth-child(5) {
|
||||||
|
top: 20px;
|
||||||
|
left: 0px;
|
||||||
|
-webkit-transform: rotate(0deg);
|
||||||
|
-webkit-animation: loader 1s -0.3s infinite ease-in-out;
|
||||||
|
}
|
||||||
|
.download-loading div:nth-child(6) {
|
||||||
|
top: 13px;
|
||||||
|
left: -13px;
|
||||||
|
-webkit-transform: rotate(45deg);
|
||||||
|
-webkit-animation: loader 1s -0.2s infinite ease-in-out;
|
||||||
|
}
|
||||||
|
.download-loading div:nth-child(7) {
|
||||||
|
top: 0px;
|
||||||
|
left: -20px;
|
||||||
|
-webkit-transform: rotate(90deg);
|
||||||
|
-webkit-animation: loader 1s -0.1s infinite ease-in-out;
|
||||||
|
}
|
||||||
|
.download-loading div:nth-child(8) {
|
||||||
|
top: -13px;
|
||||||
|
left: -13px;
|
||||||
|
-webkit-transform: rotate(-45deg);
|
||||||
|
-webkit-animation: loader 1s 0s infinite ease-in-out;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
8
ui/src/enums/model.ts
Normal file
8
ui/src/enums/model.ts
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
export enum PermissionType {
|
||||||
|
PRIVATE = '私有',
|
||||||
|
PUBLIC = '公用'
|
||||||
|
}
|
||||||
|
export enum PermissionDesc {
|
||||||
|
PRIVATE = '仅自己使用',
|
||||||
|
PUBLIC = '所有用户都可使用,不能编辑'
|
||||||
|
}
|
||||||
@ -99,7 +99,7 @@
|
|||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
<template v-else-if="isDataset">
|
<template v-else-if="isDataset">
|
||||||
<div class="w-full text-left cursor" @click="router.push({ path: '/dataset/create' })">
|
<div class="w-full text-left cursor" @click="openCreateDialog">
|
||||||
<el-button link>
|
<el-button link>
|
||||||
<el-icon class="mr-4"><Plus /></el-icon> 创建知识库
|
<el-icon class="mr-4"><Plus /></el-icon> 创建知识库
|
||||||
</el-button>
|
</el-button>
|
||||||
@ -110,12 +110,14 @@
|
|||||||
</el-dropdown>
|
</el-dropdown>
|
||||||
</div>
|
</div>
|
||||||
<CreateApplicationDialog ref="CreateApplicationDialogRef" @refresh="refresh" />
|
<CreateApplicationDialog ref="CreateApplicationDialogRef" @refresh="refresh" />
|
||||||
|
<CreateDatasetDialog ref="CreateDatasetDialogRef" @refresh="refresh" />
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, onMounted, computed } from 'vue'
|
import { ref, onMounted, computed } from 'vue'
|
||||||
import { onBeforeRouteLeave, useRouter, useRoute } from 'vue-router'
|
import { onBeforeRouteLeave, useRouter, useRoute } from 'vue-router'
|
||||||
import CreateApplicationDialog from '@/views/application/component/CreateApplicationDialog.vue'
|
import CreateApplicationDialog from '@/views/application/component/CreateApplicationDialog.vue'
|
||||||
|
import CreateDatasetDialog from '@/views/dataset/component/CreateDatasetDialog.vue'
|
||||||
import { isAppIcon, isWorkFlow } from '@/utils/application'
|
import { isAppIcon, isWorkFlow } from '@/utils/application'
|
||||||
import useStore from '@/stores'
|
import useStore from '@/stores'
|
||||||
const { common, dataset, application } = useStore()
|
const { common, dataset, application } = useStore()
|
||||||
@ -130,6 +132,7 @@ onBeforeRouteLeave((to, from) => {
|
|||||||
common.saveBreadcrumb(null)
|
common.saveBreadcrumb(null)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const CreateDatasetDialogRef = ref()
|
||||||
const CreateApplicationDialogRef = ref()
|
const CreateApplicationDialogRef = ref()
|
||||||
const list = ref<any[]>([])
|
const list = ref<any[]>([])
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
@ -148,7 +151,11 @@ const isDataset = computed(() => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
function openCreateDialog() {
|
function openCreateDialog() {
|
||||||
CreateApplicationDialogRef.value.open()
|
if (isDataset.value) {
|
||||||
|
CreateDatasetDialogRef.value.open()
|
||||||
|
} else if (isApplication.value) {
|
||||||
|
CreateApplicationDialogRef.value.open()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function changeMenu(id: string) {
|
function changeMenu(id: string) {
|
||||||
|
|||||||
@ -13,9 +13,9 @@ const datasetRouter = {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
path: '/dataset/:type', // create 或者 upload
|
path: '/dataset/:type', // create 或者 upload
|
||||||
name: 'CreateDataset',
|
name: 'UploadDocumentDataset',
|
||||||
meta: { activeMenu: '/dataset' },
|
meta: { activeMenu: '/dataset' },
|
||||||
component: () => import('@/views/dataset/CreateDataset.vue'),
|
component: () => import('@/views/dataset/UploadDocumentDataset.vue'),
|
||||||
hidden: true
|
hidden: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
import { defineStore } from 'pinia'
|
import { defineStore } from 'pinia'
|
||||||
import modelApi from '@/api/model'
|
import modelApi from '@/api/model'
|
||||||
import type { modelRequest, Provider } from '@/api/type/model'
|
import type { ListModelRequest, Provider } from '@/api/type/model'
|
||||||
const useModelStore = defineStore({
|
const useModelStore = defineStore({
|
||||||
id: 'model',
|
id: 'model',
|
||||||
state: () => ({}),
|
state: () => ({}),
|
||||||
actions: {
|
actions: {
|
||||||
async asyncGetModel(data?: modelRequest) {
|
async asyncGetModel(data?: ListModelRequest) {
|
||||||
return new Promise((resolve, reject) => {
|
return new Promise((resolve, reject) => {
|
||||||
modelApi
|
modelApi
|
||||||
.getModel(data)
|
.getModel(data)
|
||||||
|
|||||||
@ -377,6 +377,11 @@ h5 {
|
|||||||
color: var(--el-color-primary);
|
color: var(--el-color-primary);
|
||||||
border: none;
|
border: none;
|
||||||
}
|
}
|
||||||
|
.danger-tag {
|
||||||
|
background: var(--tag-danger-bg);
|
||||||
|
color: #d03f3b;
|
||||||
|
border: none;
|
||||||
|
}
|
||||||
.success-tag {
|
.success-tag {
|
||||||
background: var(--tag-success-bg);
|
background: var(--tag-success-bg);
|
||||||
color: var(--el-color-success);
|
color: var(--el-color-success);
|
||||||
@ -388,6 +393,12 @@ h5 {
|
|||||||
border: none;
|
border: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.info-tag {
|
||||||
|
background: var(--app-text-color-light-1);
|
||||||
|
color: var(--app-text-color-secondary);
|
||||||
|
border: none;
|
||||||
|
}
|
||||||
|
|
||||||
.purple-tag {
|
.purple-tag {
|
||||||
background: #f2ebfe;
|
background: #f2ebfe;
|
||||||
color: #7f3bf5;
|
color: #7f3bf5;
|
||||||
|
|||||||
@ -30,6 +30,7 @@
|
|||||||
--tag-success-color: #2ca91f;
|
--tag-success-color: #2ca91f;
|
||||||
--tag-warning-bg: rgba(255, 136, 0, 0.2);
|
--tag-warning-bg: rgba(255, 136, 0, 0.2);
|
||||||
--tag-warning-color: #d97400;
|
--tag-warning-color: #d97400;
|
||||||
|
--tag-danger-bg: rgba(245, 74, 69, 0.2);
|
||||||
|
|
||||||
/** card */
|
/** card */
|
||||||
--card-width: 330px;
|
--card-width: 330px;
|
||||||
|
|||||||
@ -4,21 +4,28 @@
|
|||||||
v-model="dialogVisible"
|
v-model="dialogVisible"
|
||||||
width="600"
|
width="600"
|
||||||
append-to-body
|
append-to-body
|
||||||
|
class="addDataset-dialog"
|
||||||
>
|
>
|
||||||
<template #header="{ titleId, titleClass }">
|
<template #header="{ titleId, titleClass }">
|
||||||
<div class="my-header flex">
|
<div class="flex-between mb-8">
|
||||||
<h4 :id="titleId" :class="titleClass">
|
<h4 :id="titleId" :class="titleClass">
|
||||||
{{ $t('views.application.applicationForm.dialogues.addDataset') }}
|
{{ $t('views.application.applicationForm.dialogues.addDataset') }}
|
||||||
</h4>
|
</h4>
|
||||||
<el-button link class="ml-16" @click="refresh">
|
<div class="flex align-center">
|
||||||
<el-icon class="mr-4"><Refresh /></el-icon
|
<el-button link class="ml-16" @click="refresh">
|
||||||
>{{ $t('views.application.applicationForm.dialogues.refresh') }}
|
<el-icon class="mr-4"><Refresh /></el-icon
|
||||||
</el-button>
|
>{{ $t('views.application.applicationForm.dialogues.refresh') }}
|
||||||
|
</el-button>
|
||||||
|
<el-divider direction="vertical" />
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<el-text type="info" class="color-secondary">
|
||||||
|
所选知识库必须使用相同的 Embedding 模型
|
||||||
|
</el-text>
|
||||||
</template>
|
</template>
|
||||||
<el-row :gutter="12" v-loading="loading">
|
<el-row :gutter="12" v-loading="loading">
|
||||||
<el-col :span="12" v-for="(item, index) in data" :key="index" class="mb-16">
|
<el-col :span="12" v-for="(item, index) in filterData" :key="index" class="mb-16">
|
||||||
<CardCheckbox value-field="id" :data="item" v-model="checkList">
|
<CardCheckbox value-field="id" :data="item" v-model="checkList" @change="changeHandle">
|
||||||
<span class="ellipsis">
|
<span class="ellipsis">
|
||||||
{{ item.name }}
|
{{ item.name }}
|
||||||
</span>
|
</span>
|
||||||
@ -26,19 +33,29 @@
|
|||||||
</el-col>
|
</el-col>
|
||||||
</el-row>
|
</el-row>
|
||||||
<template #footer>
|
<template #footer>
|
||||||
<span class="dialog-footer">
|
<div class="flex-between">
|
||||||
<el-button @click.prevent="dialogVisible = false">
|
<div>
|
||||||
{{ $t('views.application.applicationForm.buttons.cancel') }}
|
<el-text type="info" class="color-secondary" v-if="checkList.length > 0">
|
||||||
</el-button>
|
已选 {{ checkList.length }} 个知识库
|
||||||
<el-button type="primary" @click="submitHandle">
|
</el-text>
|
||||||
{{ $t('views.application.applicationForm.buttons.confirm') }}
|
<el-button link type="primary" v-if="checkList.length > 0" @click="clearCheck">
|
||||||
</el-button>
|
清空
|
||||||
</span>
|
</el-button>
|
||||||
|
</div>
|
||||||
|
<span>
|
||||||
|
<el-button @click.prevent="dialogVisible = false">
|
||||||
|
{{ $t('views.application.applicationForm.buttons.cancel') }}
|
||||||
|
</el-button>
|
||||||
|
<el-button type="primary" @click="submitHandle">
|
||||||
|
{{ $t('views.application.applicationForm.buttons.confirm') }}
|
||||||
|
</el-button>
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
</template>
|
</template>
|
||||||
</el-dialog>
|
</el-dialog>
|
||||||
</template>
|
</template>
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, watch } from 'vue'
|
import { computed, ref, watch } from 'vue'
|
||||||
const props = defineProps({
|
const props = defineProps({
|
||||||
data: {
|
data: {
|
||||||
type: Array<any>,
|
type: Array<any>,
|
||||||
@ -51,6 +68,13 @@ const emit = defineEmits(['addData', 'refresh'])
|
|||||||
|
|
||||||
const dialogVisible = ref<boolean>(false)
|
const dialogVisible = ref<boolean>(false)
|
||||||
const checkList = ref([])
|
const checkList = ref([])
|
||||||
|
const currentEmbedding = ref('')
|
||||||
|
|
||||||
|
const filterData = computed(() => {
|
||||||
|
return currentEmbedding.value
|
||||||
|
? props.data.filter((v) => v.embedding_mode_id === currentEmbedding.value)
|
||||||
|
: props.data
|
||||||
|
})
|
||||||
|
|
||||||
watch(dialogVisible, (bool) => {
|
watch(dialogVisible, (bool) => {
|
||||||
if (!bool) {
|
if (!bool) {
|
||||||
@ -58,6 +82,18 @@ watch(dialogVisible, (bool) => {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
function changeHandle() {
|
||||||
|
if (checkList.value.length === 1) {
|
||||||
|
currentEmbedding.value = props.data.filter(
|
||||||
|
(v) => v.id === checkList.value[0]
|
||||||
|
)[0].embedding_mode_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
function clearCheck() {
|
||||||
|
checkList.value = []
|
||||||
|
currentEmbedding.value = ''
|
||||||
|
}
|
||||||
|
|
||||||
const open = (checked: any) => {
|
const open = (checked: any) => {
|
||||||
checkList.value = checked
|
checkList.value = checked
|
||||||
dialogVisible.value = true
|
dialogVisible.value = true
|
||||||
@ -73,4 +109,13 @@ const refresh = () => {
|
|||||||
|
|
||||||
defineExpose({ open })
|
defineExpose({ open })
|
||||||
</script>
|
</script>
|
||||||
<style lang="scss" scope></style>
|
<style lang="scss" scope>
|
||||||
|
.addDataset-dialog {
|
||||||
|
.el-dialog__header.show-close {
|
||||||
|
padding-right: 15px;
|
||||||
|
}
|
||||||
|
.el-dialog__headerbtn {
|
||||||
|
top: 13px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
|||||||
@ -63,10 +63,10 @@
|
|||||||
</el-form>
|
</el-form>
|
||||||
<template #footer>
|
<template #footer>
|
||||||
<span class="dialog-footer">
|
<span class="dialog-footer">
|
||||||
<el-button @click.prevent="dialogVisible = false">
|
<el-button @click.prevent="dialogVisible = false" :loading="loading">
|
||||||
{{ $t('views.application.applicationForm.buttons.cancel') }}
|
{{ $t('views.application.applicationForm.buttons.cancel') }}
|
||||||
</el-button>
|
</el-button>
|
||||||
<el-button type="primary" @click="submitValid(applicationFormRef)">
|
<el-button type="primary" @click="submitValid(applicationFormRef)" :loading="loading">
|
||||||
{{ $t('views.application.applicationForm.buttons.create') }}
|
{{ $t('views.application.applicationForm.buttons.create') }}
|
||||||
</el-button>
|
</el-button>
|
||||||
</span>
|
</span>
|
||||||
@ -183,10 +183,7 @@ const submitValid = (formEl: FormInstance | undefined) => {
|
|||||||
if (res?.data) {
|
if (res?.data) {
|
||||||
submitHandle(formEl)
|
submitHandle(formEl)
|
||||||
} else {
|
} else {
|
||||||
MsgAlert(
|
MsgAlert('提示', '社区版最多支持 5 个应用,如需拥有更多应用,请升级为专业版。')
|
||||||
'提示',
|
|
||||||
'社区版最多支持 5 个应用,如需拥有更多应用,请联系我们(https://fit2cloud.com/)。'
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
<template>
|
<template>
|
||||||
<div class="authentication-setting p-24">
|
<div class="authentication-setting p-16-24">
|
||||||
<h4>{{ $t('login.authentication') }}</h4>
|
<h4>{{ $t('login.authentication') }}</h4>
|
||||||
<el-tabs v-model="activeName" class="demo-tabs" @tab-click="handleClick">
|
<el-tabs v-model="activeName" class="demo-tabs" @tab-click="handleClick">
|
||||||
<template v-for="(item, index) in tabList" :key="index">
|
<template v-for="(item, index) in tabList" :key="index">
|
||||||
@ -38,7 +38,7 @@ onMounted(() => {})
|
|||||||
background-color: var(--app-view-bg-color);
|
background-color: var(--app-view-bg-color);
|
||||||
box-sizing: border-box;
|
box-sizing: border-box;
|
||||||
min-width: 700px;
|
min-width: 700px;
|
||||||
height: calc(100vh - var(--app-header-height) - var(--app-view-padding) * 2 - 80px);
|
height: calc(100vh - var(--app-header-height) - var(--app-view-padding) * 2 - 70px);
|
||||||
box-sizing: border-box;
|
box-sizing: border-box;
|
||||||
.form-container {
|
.form-container {
|
||||||
width: 70%;
|
width: 70%;
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
<div class="dataset-setting main-calc-height">
|
<div class="dataset-setting main-calc-height">
|
||||||
<el-scrollbar>
|
<el-scrollbar>
|
||||||
<div class="p-24" v-loading="loading">
|
<div class="p-24" v-loading="loading">
|
||||||
|
<h4 class="title-decoration-1 mb-16">基本信息</h4>
|
||||||
<BaseForm ref="BaseFormRef" :data="detail" />
|
<BaseForm ref="BaseFormRef" :data="detail" />
|
||||||
|
|
||||||
<el-form
|
<el-form
|
||||||
@ -104,7 +105,7 @@ import { useRoute } from 'vue-router'
|
|||||||
import BaseForm from '@/views/dataset/component/BaseForm.vue'
|
import BaseForm from '@/views/dataset/component/BaseForm.vue'
|
||||||
import datasetApi from '@/api/dataset'
|
import datasetApi from '@/api/dataset'
|
||||||
import type { ApplicationFormType } from '@/api/type/application'
|
import type { ApplicationFormType } from '@/api/type/application'
|
||||||
import { MsgSuccess } from '@/utils/message'
|
import { MsgSuccess, MsgConfirm } from '@/utils/message'
|
||||||
import { isAppIcon } from '@/utils/application'
|
import { isAppIcon } from '@/utils/application'
|
||||||
import useStore from '@/stores'
|
import useStore from '@/stores'
|
||||||
const route = useRoute()
|
const route = useRoute()
|
||||||
@ -119,6 +120,8 @@ const loading = ref(false)
|
|||||||
const detail = ref<any>({})
|
const detail = ref<any>({})
|
||||||
const application_list = ref<Array<ApplicationFormType>>([])
|
const application_list = ref<Array<ApplicationFormType>>([])
|
||||||
const application_id_list = ref([])
|
const application_id_list = ref([])
|
||||||
|
const cloneModelId = ref('')
|
||||||
|
|
||||||
const form = ref<any>({
|
const form = ref<any>({
|
||||||
source_url: '',
|
source_url: '',
|
||||||
selector: ''
|
selector: ''
|
||||||
@ -132,7 +135,6 @@ async function submit() {
|
|||||||
if (await BaseFormRef.value?.validate()) {
|
if (await BaseFormRef.value?.validate()) {
|
||||||
await webFormRef.value.validate((valid: any) => {
|
await webFormRef.value.validate((valid: any) => {
|
||||||
if (valid) {
|
if (valid) {
|
||||||
loading.value = true
|
|
||||||
const obj =
|
const obj =
|
||||||
detail.value.type === '1'
|
detail.value.type === '1'
|
||||||
? {
|
? {
|
||||||
@ -144,15 +146,25 @@ async function submit() {
|
|||||||
application_id_list: application_id_list.value,
|
application_id_list: application_id_list.value,
|
||||||
...BaseFormRef.value.form
|
...BaseFormRef.value.form
|
||||||
}
|
}
|
||||||
datasetApi
|
|
||||||
.putDataset(id, obj)
|
if (cloneModelId.value !== BaseFormRef.value.form.embedding_mode_id) {
|
||||||
.then((res) => {
|
MsgConfirm(`提示`, `修改知识库向量模型后,需要对知识库重新向量化,是否继续保存?`, {
|
||||||
|
confirmButtonText: '重新向量化',
|
||||||
|
confirmButtonClass: 'primary'
|
||||||
|
})
|
||||||
|
.then(() => {
|
||||||
|
datasetApi.putDataset(id, obj, loading).then((res) => {
|
||||||
|
datasetApi.putReEmbeddingDataset(id).then(() => {
|
||||||
|
MsgSuccess('保存成功')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.catch(() => {})
|
||||||
|
} else {
|
||||||
|
datasetApi.putDataset(id, obj, loading).then((res) => {
|
||||||
MsgSuccess('保存成功')
|
MsgSuccess('保存成功')
|
||||||
loading.value = false
|
|
||||||
})
|
|
||||||
.catch(() => {
|
|
||||||
loading.value = false
|
|
||||||
})
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -161,10 +173,10 @@ async function submit() {
|
|||||||
function getDetail() {
|
function getDetail() {
|
||||||
dataset.asyncGetDatasetDetail(id, loading).then((res: any) => {
|
dataset.asyncGetDatasetDetail(id, loading).then((res: any) => {
|
||||||
detail.value = res.data
|
detail.value = res.data
|
||||||
|
cloneModelId.value = res.data?.embedding_mode_id
|
||||||
if (detail.value.type === '1') {
|
if (detail.value.type === '1') {
|
||||||
form.value = res.data.meta
|
form.value = res.data.meta
|
||||||
}
|
}
|
||||||
|
|
||||||
application_id_list.value = res.data?.application_id_list
|
application_id_list.value = res.data?.application_id_list
|
||||||
datasetApi.listUsableApplication(id, loading).then((ok) => {
|
datasetApi.listUsableApplication(id, loading).then((ok) => {
|
||||||
application_list.value = ok.data
|
application_list.value = ok.data
|
||||||
|
|||||||
@ -1,15 +1,18 @@
|
|||||||
<template>
|
<template>
|
||||||
<LayoutContainer :header="isCreate ? '创建知识库' : '上传文档'" class="create-dataset">
|
<LayoutContainer header="上传文档" class="create-dataset">
|
||||||
<template #backButton>
|
<template #backButton>
|
||||||
<back-button @click="back"></back-button>
|
<back-button @click="back"></back-button>
|
||||||
</template>
|
</template>
|
||||||
<div class="create-dataset__main flex" v-loading="loading">
|
<div class="create-dataset__main flex" v-loading="loading">
|
||||||
<div class="create-dataset__component main-calc-height">
|
<div class="create-dataset__component main-calc-height">
|
||||||
<template v-if="active === 0">
|
<template v-if="active === 0">
|
||||||
<StepFirst ref="StepFirstRef" />
|
<div class="upload-document p-24">
|
||||||
|
<!-- 上传文档 -->
|
||||||
|
<UploadComponent ref="UploadComponentRef" />
|
||||||
|
</div>
|
||||||
</template>
|
</template>
|
||||||
<template v-else-if="active === 1">
|
<template v-else-if="active === 1">
|
||||||
<StepSecond ref="StepSecondRef" />
|
<SetRules ref="SetRulesRef" />
|
||||||
</template>
|
</template>
|
||||||
<template v-else-if="active === 2">
|
<template v-else-if="active === 2">
|
||||||
<ResultSuccess :data="successInfo" />
|
<ResultSuccess :data="successInfo" />
|
||||||
@ -19,12 +22,7 @@
|
|||||||
<div class="create-dataset__footer text-right border-t" v-if="active !== 2">
|
<div class="create-dataset__footer text-right border-t" v-if="active !== 2">
|
||||||
<el-button @click="router.go(-1)" :disabled="loading">取消</el-button>
|
<el-button @click="router.go(-1)" :disabled="loading">取消</el-button>
|
||||||
<el-button @click="prev" v-if="active === 1" :disabled="loading">上一步</el-button>
|
<el-button @click="prev" v-if="active === 1" :disabled="loading">上一步</el-button>
|
||||||
<el-button
|
<el-button @click="next" type="primary" v-if="active === 0" :disabled="loading">
|
||||||
@click="next"
|
|
||||||
type="primary"
|
|
||||||
v-if="active === 0"
|
|
||||||
:disabled="loading || StepFirstRef?.loading"
|
|
||||||
>
|
|
||||||
创建并导入
|
创建并导入
|
||||||
</el-button>
|
</el-button>
|
||||||
<el-button @click="submit" type="primary" v-if="active === 1" :disabled="loading">
|
<el-button @click="submit" type="primary" v-if="active === 1" :disabled="loading">
|
||||||
@ -36,9 +34,9 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed, onUnmounted } from 'vue'
|
import { ref, computed, onUnmounted } from 'vue'
|
||||||
import { useRouter, useRoute } from 'vue-router'
|
import { useRouter, useRoute } from 'vue-router'
|
||||||
import StepFirst from './step/StepFirst.vue'
|
import SetRules from './component/SetRules.vue'
|
||||||
import StepSecond from './step/StepSecond.vue'
|
import ResultSuccess from './component/ResultSuccess.vue'
|
||||||
import ResultSuccess from './step/ResultSuccess.vue'
|
import UploadComponent from './component/UploadComponent.vue'
|
||||||
import datasetApi from '@/api/dataset'
|
import datasetApi from '@/api/dataset'
|
||||||
import documentApi from '@/api/document'
|
import documentApi from '@/api/document'
|
||||||
import type { datasetData } from '@/api/type/dataset'
|
import type { datasetData } from '@/api/type/dataset'
|
||||||
@ -46,33 +44,17 @@ import { MsgConfirm, MsgSuccess } from '@/utils/message'
|
|||||||
|
|
||||||
import useStore from '@/stores'
|
import useStore from '@/stores'
|
||||||
const { dataset, document } = useStore()
|
const { dataset, document } = useStore()
|
||||||
const baseInfo = computed(() => dataset.baseInfo)
|
|
||||||
const webInfo = computed(() => dataset.webInfo)
|
|
||||||
const documentsFiles = computed(() => dataset.documentsFiles)
|
const documentsFiles = computed(() => dataset.documentsFiles)
|
||||||
const documentsType = computed(() => dataset.documentsType)
|
const documentsType = computed(() => dataset.documentsType)
|
||||||
|
|
||||||
const router = useRouter()
|
const router = useRouter()
|
||||||
const route = useRoute()
|
const route = useRoute()
|
||||||
const {
|
const {
|
||||||
params: { type },
|
|
||||||
query: { id } // id为datasetID,有id的是上传文档
|
query: { id } // id为datasetID,有id的是上传文档
|
||||||
} = route
|
} = route
|
||||||
const isCreate = type === 'create'
|
|
||||||
// const steps = [
|
|
||||||
// {
|
|
||||||
// ref: 'StepFirstRef',
|
|
||||||
// name: '上传文档',
|
|
||||||
// component: StepFirst
|
|
||||||
// },
|
|
||||||
// {
|
|
||||||
// ref: 'StepSecondRef',
|
|
||||||
// name: '设置分段规则',
|
|
||||||
// component: StepSecond
|
|
||||||
// }
|
|
||||||
// ]
|
|
||||||
|
|
||||||
const StepFirstRef = ref()
|
const SetRulesRef = ref()
|
||||||
const StepSecondRef = ref()
|
const UploadComponentRef = ref()
|
||||||
|
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
const disabled = ref(false)
|
const disabled = ref(false)
|
||||||
@ -81,7 +63,7 @@ const successInfo = ref<any>(null)
|
|||||||
|
|
||||||
async function next() {
|
async function next() {
|
||||||
disabled.value = true
|
disabled.value = true
|
||||||
if (await StepFirstRef.value?.onSubmit()) {
|
if (await UploadComponentRef.value.validate()) {
|
||||||
if (documentsType.value === 'QA') {
|
if (documentsType.value === 'QA') {
|
||||||
let fd = new FormData()
|
let fd = new FormData()
|
||||||
documentsFiles.value.forEach((item: any) => {
|
documentsFiles.value.forEach((item: any) => {
|
||||||
@ -118,16 +100,14 @@ const prev = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function clearStore() {
|
function clearStore() {
|
||||||
dataset.saveBaseInfo(null)
|
|
||||||
dataset.saveWebInfo(null)
|
|
||||||
dataset.saveDocumentsFile([])
|
dataset.saveDocumentsFile([])
|
||||||
dataset.saveDocumentsType('')
|
dataset.saveDocumentsType('')
|
||||||
}
|
}
|
||||||
function submit() {
|
function submit() {
|
||||||
loading.value = true
|
loading.value = true
|
||||||
const documents = [] as any
|
const documents = [] as any
|
||||||
StepSecondRef.value?.paragraphList.map((item: any) => {
|
SetRulesRef.value?.paragraphList.map((item: any) => {
|
||||||
if (!StepSecondRef.value?.checkedConnect) {
|
if (!SetRulesRef.value?.checkedConnect) {
|
||||||
item.content.map((v: any) => {
|
item.content.map((v: any) => {
|
||||||
delete v['problem_list']
|
delete v['problem_list']
|
||||||
})
|
})
|
||||||
@ -159,7 +139,7 @@ function submit() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
function back() {
|
function back() {
|
||||||
if (baseInfo.value || webInfo.value || documentsFiles.value?.length > 0) {
|
if (documentsFiles.value?.length > 0) {
|
||||||
MsgConfirm(`提示`, `当前的更改尚未保存,确认退出吗?`, {
|
MsgConfirm(`提示`, `当前的更改尚未保存,确认退出吗?`, {
|
||||||
confirmButtonText: '确认',
|
confirmButtonText: '确认',
|
||||||
type: 'warning'
|
type: 'warning'
|
||||||
@ -206,5 +186,10 @@ onUnmounted(() => {
|
|||||||
width: 100%;
|
width: 100%;
|
||||||
box-sizing: border-box;
|
box-sizing: border-box;
|
||||||
}
|
}
|
||||||
|
.upload-document {
|
||||||
|
width: 70%;
|
||||||
|
margin: 0 auto;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
@ -1,11 +1,11 @@
|
|||||||
<template>
|
<template>
|
||||||
<h4 class="title-decoration-1 mb-16">基本信息</h4>
|
|
||||||
<el-form
|
<el-form
|
||||||
ref="FormRef"
|
ref="FormRef"
|
||||||
:model="form"
|
:model="form"
|
||||||
:rules="rules"
|
:rules="rules"
|
||||||
label-position="top"
|
label-position="top"
|
||||||
require-asterisk-position="right"
|
require-asterisk-position="right"
|
||||||
|
v-loading="loading"
|
||||||
>
|
>
|
||||||
<el-form-item label="知识库名称" prop="name">
|
<el-form-item label="知识库名称" prop="name">
|
||||||
<el-input
|
<el-input
|
||||||
@ -27,14 +27,72 @@
|
|||||||
@blur="form.desc = form.desc.trim()"
|
@blur="form.desc = form.desc.trim()"
|
||||||
/>
|
/>
|
||||||
</el-form-item>
|
</el-form-item>
|
||||||
|
<el-form-item label="Embedding模型" prop="embedding_mode_id">
|
||||||
|
<el-select
|
||||||
|
v-model="form.embedding_mode_id"
|
||||||
|
placeholder="请选择Embedding模型"
|
||||||
|
class="w-full"
|
||||||
|
popper-class="select-model"
|
||||||
|
:clearable="true"
|
||||||
|
>
|
||||||
|
<el-option-group
|
||||||
|
v-for="(value, label) in modelOptions"
|
||||||
|
:key="value"
|
||||||
|
:label="relatedObject(providerOptions, label, 'provider')?.name"
|
||||||
|
>
|
||||||
|
<el-option
|
||||||
|
v-for="item in value.filter((v: any) => v.status === 'SUCCESS')"
|
||||||
|
:key="item.id"
|
||||||
|
:label="item.name"
|
||||||
|
:value="item.id"
|
||||||
|
class="flex-between"
|
||||||
|
>
|
||||||
|
<div class="flex">
|
||||||
|
<span
|
||||||
|
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
|
||||||
|
class="model-icon mr-8"
|
||||||
|
></span>
|
||||||
|
<span>{{ item.name }}</span>
|
||||||
|
</div>
|
||||||
|
<el-icon class="check-icon" v-if="item.id === form.embedding_mode_id"
|
||||||
|
><Check
|
||||||
|
/></el-icon>
|
||||||
|
</el-option>
|
||||||
|
<!-- 不可用 -->
|
||||||
|
<el-option
|
||||||
|
v-for="item in value.filter((v: any) => v.status !== 'SUCCESS')"
|
||||||
|
:key="item.id"
|
||||||
|
:label="item.name"
|
||||||
|
:value="item.id"
|
||||||
|
class="flex-between"
|
||||||
|
disabled
|
||||||
|
>
|
||||||
|
<div class="flex">
|
||||||
|
<span
|
||||||
|
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
|
||||||
|
class="model-icon mr-8"
|
||||||
|
></span>
|
||||||
|
<span>{{ item.name }}</span>
|
||||||
|
<span class="danger">{{
|
||||||
|
$t('views.application.applicationForm.form.aiModel.unavailable')
|
||||||
|
}}</span>
|
||||||
|
</div>
|
||||||
|
<el-icon class="check-icon" v-if="item.id === form.embedding_mode_id"
|
||||||
|
><Check
|
||||||
|
/></el-icon>
|
||||||
|
</el-option>
|
||||||
|
</el-option-group>
|
||||||
|
</el-select>
|
||||||
|
</el-form-item>
|
||||||
</el-form>
|
</el-form>
|
||||||
</template>
|
</template>
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, reactive, onMounted, onUnmounted, computed, watch } from 'vue'
|
import { ref, reactive, onMounted, onUnmounted, computed, watch } from 'vue'
|
||||||
import { useRoute } from 'vue-router'
|
import { groupBy } from 'lodash'
|
||||||
import useStore from '@/stores'
|
import useStore from '@/stores'
|
||||||
import type { datasetData } from '@/api/type/dataset'
|
import type { datasetData } from '@/api/type/dataset'
|
||||||
import { isAllPropertiesEmpty } from '@/utils/utils'
|
import { relatedObject } from '@/utils/utils'
|
||||||
|
import type { Provider } from '@/api/type/model'
|
||||||
|
|
||||||
const props = defineProps({
|
const props = defineProps({
|
||||||
data: {
|
data: {
|
||||||
@ -42,23 +100,23 @@ const props = defineProps({
|
|||||||
default: () => {}
|
default: () => {}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
const route = useRoute()
|
const { model } = useStore()
|
||||||
const {
|
|
||||||
params: { type }
|
|
||||||
} = route
|
|
||||||
const isCreate = type === 'create'
|
|
||||||
const { dataset } = useStore()
|
|
||||||
const baseInfo = computed(() => dataset.baseInfo)
|
|
||||||
const form = ref<datasetData>({
|
const form = ref<datasetData>({
|
||||||
name: '',
|
name: '',
|
||||||
desc: ''
|
desc: '',
|
||||||
|
embedding_mode_id: ''
|
||||||
})
|
})
|
||||||
|
|
||||||
const rules = reactive({
|
const rules = reactive({
|
||||||
name: [{ required: true, message: '请输入知识库名称', trigger: 'blur' }],
|
name: [{ required: true, message: '请输入知识库名称', trigger: 'blur' }],
|
||||||
desc: [{ required: true, message: '请输入知识库描述', trigger: 'blur' }]
|
desc: [{ required: true, message: '请输入知识库描述', trigger: 'blur' }],
|
||||||
|
embedding_mode_id: [{ required: true, message: '请输入Embedding模型', trigger: 'change' }]
|
||||||
})
|
})
|
||||||
|
|
||||||
const FormRef = ref()
|
const FormRef = ref()
|
||||||
|
const loading = ref(false)
|
||||||
|
const modelOptions = ref<any>([])
|
||||||
|
const providerOptions = ref<Array<Provider>>([])
|
||||||
|
|
||||||
watch(
|
watch(
|
||||||
() => props.data,
|
() => props.data,
|
||||||
@ -66,23 +124,13 @@ watch(
|
|||||||
if (value && JSON.stringify(value) !== '{}') {
|
if (value && JSON.stringify(value) !== '{}') {
|
||||||
form.value.name = value.name
|
form.value.name = value.name
|
||||||
form.value.desc = value.desc
|
form.value.desc = value.desc
|
||||||
|
form.value.embedding_mode_id = value.embedding_mode_id
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
immediate: true
|
immediate: true
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
watch(form.value, (value) => {
|
|
||||||
if (isAllPropertiesEmpty(value)) {
|
|
||||||
dataset.saveBaseInfo(null)
|
|
||||||
} else {
|
|
||||||
if (isCreate) {
|
|
||||||
dataset.saveBaseInfo(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
表单校验
|
表单校验
|
||||||
*/
|
*/
|
||||||
@ -93,16 +141,43 @@ function validate() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getModel() {
|
||||||
|
loading.value = true
|
||||||
|
model
|
||||||
|
.asyncGetModel({ model_type: 'EMBEDDING' })
|
||||||
|
.then((res: any) => {
|
||||||
|
modelOptions.value = groupBy(res?.data, 'provider')
|
||||||
|
loading.value = false
|
||||||
|
})
|
||||||
|
.catch(() => {
|
||||||
|
loading.value = false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
function getProvider() {
|
||||||
|
loading.value = true
|
||||||
|
model
|
||||||
|
.asyncGetProvider()
|
||||||
|
.then((res: any) => {
|
||||||
|
providerOptions.value = res?.data
|
||||||
|
loading.value = false
|
||||||
|
})
|
||||||
|
.catch(() => {
|
||||||
|
loading.value = false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
if (baseInfo.value) {
|
getProvider()
|
||||||
form.value = baseInfo.value
|
getModel()
|
||||||
}
|
|
||||||
})
|
})
|
||||||
onUnmounted(() => {
|
onUnmounted(() => {
|
||||||
form.value = {
|
form.value = {
|
||||||
name: '',
|
name: '',
|
||||||
desc: ''
|
desc: '',
|
||||||
|
embedding_mode_id: ''
|
||||||
}
|
}
|
||||||
|
FormRef.value?.clearValidate()
|
||||||
})
|
})
|
||||||
|
|
||||||
defineExpose({
|
defineExpose({
|
||||||
|
|||||||
177
ui/src/views/dataset/component/CreateDatasetDialog.vue
Normal file
177
ui/src/views/dataset/component/CreateDatasetDialog.vue
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
<template>
|
||||||
|
<el-dialog title="创建知识库" v-model="dialogVisible" width="650" append-to-body>
|
||||||
|
<!-- 基本信息 -->
|
||||||
|
<BaseForm ref="BaseFormRef" v-if="dialogVisible" />
|
||||||
|
<el-form
|
||||||
|
ref="DatasetFormRef"
|
||||||
|
:rules="rules"
|
||||||
|
:model="datasetForm"
|
||||||
|
label-position="top"
|
||||||
|
require-asterisk-position="right"
|
||||||
|
>
|
||||||
|
<el-form-item label="知识库类型" required>
|
||||||
|
<el-radio-group v-model="datasetForm.type" class="card__radio" @change="radioChange">
|
||||||
|
<el-row :gutter="20">
|
||||||
|
<el-col :span="12">
|
||||||
|
<el-card
|
||||||
|
shadow="never"
|
||||||
|
class="mb-16"
|
||||||
|
:class="datasetForm.type === '0' ? 'active' : ''"
|
||||||
|
>
|
||||||
|
<el-radio value="0" size="large">
|
||||||
|
<div class="flex align-center">
|
||||||
|
<AppAvatar class="mr-8 avatar-blue" shape="square" :size="32">
|
||||||
|
<img src="@/assets/icon_document.svg" style="width: 58%" alt="" />
|
||||||
|
</AppAvatar>
|
||||||
|
<div>
|
||||||
|
<p class="mb-4">通用型</p>
|
||||||
|
<el-text type="info">可以通过上传文件或手动录入方式构建知识库</el-text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</el-radio>
|
||||||
|
</el-card>
|
||||||
|
</el-col>
|
||||||
|
<el-col :span="12">
|
||||||
|
<el-card
|
||||||
|
shadow="never"
|
||||||
|
class="mb-16"
|
||||||
|
:class="datasetForm.type === '1' ? 'active' : ''"
|
||||||
|
>
|
||||||
|
<el-radio value="1" size="large">
|
||||||
|
<div class="flex align-center">
|
||||||
|
<AppAvatar class="mr-8 avatar-purple" shape="square" :size="32">
|
||||||
|
<img src="@/assets/icon_web.svg" style="width: 58%" alt="" />
|
||||||
|
</AppAvatar>
|
||||||
|
<div>
|
||||||
|
<p class="mb-4">Web 站点</p>
|
||||||
|
<el-text type="info">通过网站链接同步方式构建知识库 </el-text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</el-radio>
|
||||||
|
</el-card>
|
||||||
|
</el-col>
|
||||||
|
</el-row>
|
||||||
|
</el-radio-group>
|
||||||
|
</el-form-item>
|
||||||
|
<el-form-item label="Web 根地址" prop="source_url" v-if="datasetForm.type === '1'">
|
||||||
|
<el-input
|
||||||
|
v-model="datasetForm.source_url"
|
||||||
|
placeholder="请输入 Web 根地址"
|
||||||
|
@blur="datasetForm.source_url = datasetForm.source_url.trim()"
|
||||||
|
/>
|
||||||
|
</el-form-item>
|
||||||
|
<el-form-item label="选择器" v-if="datasetForm.type === '1'">
|
||||||
|
<el-input
|
||||||
|
v-model="datasetForm.selector"
|
||||||
|
placeholder="默认为 body,可输入 .classname/#idname/tagname"
|
||||||
|
@blur="datasetForm.selector = datasetForm.selector.trim()"
|
||||||
|
/>
|
||||||
|
</el-form-item>
|
||||||
|
</el-form>
|
||||||
|
<template #footer>
|
||||||
|
<span class="dialog-footer">
|
||||||
|
<el-button @click.prevent="dialogVisible = false" :loading="loading">
|
||||||
|
{{ $t('views.application.applicationForm.buttons.cancel') }}
|
||||||
|
</el-button>
|
||||||
|
<el-button type="primary" @click="submitValid" :loading="loading">
|
||||||
|
{{ $t('views.application.applicationForm.buttons.create') }}
|
||||||
|
</el-button>
|
||||||
|
</span>
|
||||||
|
</template>
|
||||||
|
</el-dialog>
|
||||||
|
</template>
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { ref, watch, reactive } from 'vue'
|
||||||
|
import { useRouter, useRoute } from 'vue-router'
|
||||||
|
import BaseForm from './BaseForm.vue'
|
||||||
|
import datasetApi from '@/api/dataset'
|
||||||
|
import { MsgSuccess, MsgAlert } from '@/utils/message'
|
||||||
|
import useStore from '@/stores'
|
||||||
|
import { ValidType, ValidCount } from '@/enums/common'
|
||||||
|
|
||||||
|
const emit = defineEmits(['refresh'])
|
||||||
|
|
||||||
|
const { common, user } = useStore()
|
||||||
|
const router = useRouter()
|
||||||
|
const BaseFormRef = ref()
|
||||||
|
const DatasetFormRef = ref()
|
||||||
|
|
||||||
|
const loading = ref(false)
|
||||||
|
const dialogVisible = ref<boolean>(false)
|
||||||
|
|
||||||
|
const datasetForm = ref<any>({
|
||||||
|
type: '0',
|
||||||
|
source_url: '',
|
||||||
|
selector: ''
|
||||||
|
})
|
||||||
|
|
||||||
|
const rules = reactive({
|
||||||
|
source_url: [{ required: true, message: '请输入 Web 根地址', trigger: 'blur' }]
|
||||||
|
})
|
||||||
|
|
||||||
|
watch(dialogVisible, (bool) => {
|
||||||
|
if (!bool) {
|
||||||
|
datasetForm.value = {
|
||||||
|
type: '0',
|
||||||
|
source_url: '',
|
||||||
|
selector: ''
|
||||||
|
}
|
||||||
|
DatasetFormRef.value?.clearValidate()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const open = () => {
|
||||||
|
dialogVisible.value = true
|
||||||
|
}
|
||||||
|
|
||||||
|
const submitValid = () => {
|
||||||
|
if (user.isEnterprise()) {
|
||||||
|
submitHandle()
|
||||||
|
} else {
|
||||||
|
common.asyncGetValid(ValidType.Dataset, ValidCount.Dataset, loading).then(async (res: any) => {
|
||||||
|
if (res?.data) {
|
||||||
|
submitHandle()
|
||||||
|
} else {
|
||||||
|
MsgAlert('提示', '社区版最多支持 50 个知识库,如需拥有更多知识库,请升级为专业版。')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const submitHandle = async () => {
|
||||||
|
if (await BaseFormRef.value?.validate()) {
|
||||||
|
await DatasetFormRef.value.validate((valid: any) => {
|
||||||
|
if (valid) {
|
||||||
|
if (datasetForm.value.type === '0') {
|
||||||
|
const obj = {
|
||||||
|
...BaseFormRef.value.form,
|
||||||
|
type: datasetForm.value.type
|
||||||
|
}
|
||||||
|
datasetApi.postDataset(obj, loading).then((res) => {
|
||||||
|
MsgSuccess('创建成功')
|
||||||
|
router.push({ path: `/dataset/${res.data.id}/document` })
|
||||||
|
emit('refresh')
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
const obj = { ...BaseFormRef.value.form, ...datasetForm.value }
|
||||||
|
datasetApi.postWebDataset(obj, loading).then((res) => {
|
||||||
|
MsgSuccess('创建成功')
|
||||||
|
router.push({ path: `/dataset/${res.data.id}/document` })
|
||||||
|
emit('refresh')
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
function radioChange() {
|
||||||
|
datasetForm.value.source_url = ''
|
||||||
|
datasetForm.value.selector = ''
|
||||||
|
}
|
||||||
|
|
||||||
|
defineExpose({ open })
|
||||||
|
</script>
|
||||||
|
<style lang="scss" scope></style>
|
||||||
@ -65,7 +65,7 @@
|
|||||||
show-input
|
show-input
|
||||||
:show-input-controls="false"
|
:show-input-controls="false"
|
||||||
:min="50"
|
:min="50"
|
||||||
:max="4096"
|
:max="100000"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div class="form-item mb-16">
|
<div class="form-item mb-16">
|
||||||
@ -22,7 +22,7 @@
|
|||||||
>
|
>
|
||||||
<el-row :gutter="15">
|
<el-row :gutter="15">
|
||||||
<el-col :xs="24" :sm="12" :md="8" :lg="6" :xl="4" class="mb-16">
|
<el-col :xs="24" :sm="12" :md="8" :lg="6" :xl="4" class="mb-16">
|
||||||
<CardAdd title="创建知识库" @click="router.push({ path: '/dataset/create' })" />
|
<CardAdd title="创建知识库" @click="openCreateDialog" />
|
||||||
</el-col>
|
</el-col>
|
||||||
<template v-for="(item, index) in datasetList" :key="index">
|
<template v-for="(item, index) in datasetList" :key="index">
|
||||||
<el-col :xs="24" :sm="12" :md="8" :lg="6" :xl="4" class="mb-16">
|
<el-col :xs="24" :sm="12" :md="8" :lg="6" :xl="4" class="mb-16">
|
||||||
@ -107,17 +107,20 @@
|
|||||||
</InfiniteScroll>
|
</InfiniteScroll>
|
||||||
</div>
|
</div>
|
||||||
<SyncWebDialog ref="SyncWebDialogRef" @refresh="refresh" />
|
<SyncWebDialog ref="SyncWebDialogRef" @refresh="refresh" />
|
||||||
|
<CreateDatasetDialog ref="CreateDatasetDialogRef"/>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, onMounted, reactive, computed } from 'vue'
|
import { ref, onMounted, reactive, computed } from 'vue'
|
||||||
import SyncWebDialog from '@/views/dataset/component/SyncWebDialog.vue'
|
import SyncWebDialog from '@/views/dataset/component/SyncWebDialog.vue'
|
||||||
|
import CreateDatasetDialog from './component/CreateDatasetDialog.vue'
|
||||||
import datasetApi from '@/api/dataset'
|
import datasetApi from '@/api/dataset'
|
||||||
import { MsgSuccess, MsgConfirm } from '@/utils/message'
|
import { MsgSuccess, MsgConfirm } from '@/utils/message'
|
||||||
import { useRouter } from 'vue-router'
|
import { useRouter } from 'vue-router'
|
||||||
import { numberFormat } from '@/utils/utils'
|
import { numberFormat } from '@/utils/utils'
|
||||||
const router = useRouter()
|
const router = useRouter()
|
||||||
|
|
||||||
|
const CreateDatasetDialogRef = ref()
|
||||||
const SyncWebDialogRef = ref()
|
const SyncWebDialogRef = ref()
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
const datasetList = ref<any[]>([])
|
const datasetList = ref<any[]>([])
|
||||||
@ -129,6 +132,10 @@ const paginationConfig = reactive({
|
|||||||
|
|
||||||
const searchValue = ref('')
|
const searchValue = ref('')
|
||||||
|
|
||||||
|
function openCreateDialog() {
|
||||||
|
CreateDatasetDialogRef.value.open()
|
||||||
|
}
|
||||||
|
|
||||||
function refresh() {
|
function refresh() {
|
||||||
MsgSuccess('同步任务发送成功')
|
MsgSuccess('同步任务发送成功')
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,183 +0,0 @@
|
|||||||
<template>
|
|
||||||
<el-scrollbar>
|
|
||||||
<div class="upload-document p-24">
|
|
||||||
<!-- 基本信息 -->
|
|
||||||
<BaseForm ref="BaseFormRef" v-if="isCreate" />
|
|
||||||
<el-form
|
|
||||||
v-if="isCreate"
|
|
||||||
ref="webFormRef"
|
|
||||||
:rules="rules"
|
|
||||||
:model="form"
|
|
||||||
label-position="top"
|
|
||||||
require-asterisk-position="right"
|
|
||||||
>
|
|
||||||
<el-form-item label="知识库类型" required>
|
|
||||||
<el-radio-group v-model="form.type" class="card__radio" @change="radioChange">
|
|
||||||
<el-row :gutter="20">
|
|
||||||
<el-col :span="12">
|
|
||||||
<el-card shadow="never" class="mb-16" :class="form.type === '0' ? 'active' : ''">
|
|
||||||
<el-radio value="0" size="large">
|
|
||||||
<div class="flex align-center">
|
|
||||||
<AppAvatar class="mr-8 avatar-blue" shape="square" :size="32">
|
|
||||||
<img src="@/assets/icon_document.svg" style="width: 58%" alt="" />
|
|
||||||
</AppAvatar>
|
|
||||||
<div>
|
|
||||||
<p class="mb-4">通用型</p>
|
|
||||||
<el-text type="info">可以通过上传文件或手动录入方式构建知识库</el-text>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</el-radio>
|
|
||||||
</el-card>
|
|
||||||
</el-col>
|
|
||||||
<el-col :span="12">
|
|
||||||
<el-card shadow="never" class="mb-16" :class="form.type === '1' ? 'active' : ''">
|
|
||||||
<el-radio value="1" size="large">
|
|
||||||
<div class="flex align-center">
|
|
||||||
<AppAvatar class="mr-8 avatar-purple" shape="square" :size="32">
|
|
||||||
<img src="@/assets/icon_web.svg" style="width: 58%" alt="" />
|
|
||||||
</AppAvatar>
|
|
||||||
<div>
|
|
||||||
<p class="mb-4">Web 站点</p>
|
|
||||||
<el-text type="info">通过网站链接同步方式构建知识库 </el-text>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</el-radio>
|
|
||||||
</el-card>
|
|
||||||
</el-col>
|
|
||||||
</el-row>
|
|
||||||
</el-radio-group>
|
|
||||||
</el-form-item>
|
|
||||||
<el-form-item label="Web 根地址" prop="source_url" v-if="form.type === '1'">
|
|
||||||
<el-input
|
|
||||||
v-model="form.source_url"
|
|
||||||
placeholder="请输入 Web 根地址"
|
|
||||||
@blur="form.source_url = form.source_url.trim()"
|
|
||||||
/>
|
|
||||||
</el-form-item>
|
|
||||||
<el-form-item label="选择器" v-if="form.type === '1'">
|
|
||||||
<el-input
|
|
||||||
v-model="form.selector"
|
|
||||||
placeholder="默认为 body,可输入 .classname/#idname/tagname"
|
|
||||||
@blur="form.selector = form.selector.trim()"
|
|
||||||
/>
|
|
||||||
</el-form-item>
|
|
||||||
</el-form>
|
|
||||||
|
|
||||||
<!-- 上传文档 -->
|
|
||||||
<UploadComponent ref="UploadComponentRef" v-if="form.type === '0'" />
|
|
||||||
</div>
|
|
||||||
</el-scrollbar>
|
|
||||||
</template>
|
|
||||||
<script setup lang="ts">
|
|
||||||
import { ref, onMounted, reactive, watch } from 'vue'
|
|
||||||
import { useRouter, useRoute } from 'vue-router'
|
|
||||||
import BaseForm from '@/views/dataset/component/BaseForm.vue'
|
|
||||||
import UploadComponent from '@/views/dataset/component/UploadComponent.vue'
|
|
||||||
import { isAllPropertiesEmpty } from '@/utils/utils'
|
|
||||||
import datasetApi from '@/api/dataset'
|
|
||||||
import { MsgError, MsgSuccess } from '@/utils/message'
|
|
||||||
import useStore from '@/stores'
|
|
||||||
const { dataset } = useStore()
|
|
||||||
|
|
||||||
const route = useRoute()
|
|
||||||
const router = useRouter()
|
|
||||||
const {
|
|
||||||
params: { type }
|
|
||||||
} = route
|
|
||||||
const isCreate = type === 'create'
|
|
||||||
const BaseFormRef = ref()
|
|
||||||
const UploadComponentRef = ref()
|
|
||||||
const webFormRef = ref()
|
|
||||||
const loading = ref(false)
|
|
||||||
|
|
||||||
const form = ref<any>({
|
|
||||||
type: '0',
|
|
||||||
source_url: '',
|
|
||||||
selector: ''
|
|
||||||
})
|
|
||||||
|
|
||||||
const rules = reactive({
|
|
||||||
source_url: [{ required: true, message: '请输入 Web 根地址', trigger: 'blur' }]
|
|
||||||
})
|
|
||||||
|
|
||||||
watch(form.value, (value) => {
|
|
||||||
if (isAllPropertiesEmpty(value)) {
|
|
||||||
dataset.saveWebInfo(null)
|
|
||||||
} else {
|
|
||||||
dataset.saveWebInfo(value)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
function radioChange() {
|
|
||||||
dataset.saveDocumentsFile([])
|
|
||||||
dataset.saveDocumentsType('')
|
|
||||||
form.value.source_url = ''
|
|
||||||
form.value.selector = ''
|
|
||||||
}
|
|
||||||
|
|
||||||
const onSubmit = async () => {
|
|
||||||
if (isCreate) {
|
|
||||||
if (form.value.type === '0') {
|
|
||||||
if ((await BaseFormRef.value?.validate()) && (await UploadComponentRef.value.validate())) {
|
|
||||||
if (UploadComponentRef.value.form.fileList.length > 50) {
|
|
||||||
MsgError('每次最多上传50个文件!')
|
|
||||||
return false
|
|
||||||
} else {
|
|
||||||
/*
|
|
||||||
stores保存数据
|
|
||||||
*/
|
|
||||||
dataset.saveBaseInfo(BaseFormRef.value.form)
|
|
||||||
dataset.saveDocumentsType(UploadComponentRef.value.form.fileType)
|
|
||||||
dataset.saveDocumentsFile(UploadComponentRef.value.form.fileList)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (await BaseFormRef.value?.validate()) {
|
|
||||||
await webFormRef.value.validate((valid: any) => {
|
|
||||||
if (valid) {
|
|
||||||
const obj = { ...BaseFormRef.value.form, ...form.value }
|
|
||||||
datasetApi.postWebDataset(obj, loading).then((res) => {
|
|
||||||
MsgSuccess('提交成功')
|
|
||||||
dataset.saveBaseInfo(null)
|
|
||||||
dataset.saveWebInfo(null)
|
|
||||||
router.push({ path: `/dataset/${res.data.id}/document` })
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (await UploadComponentRef.value.validate()) {
|
|
||||||
/*
|
|
||||||
stores保存数据
|
|
||||||
*/
|
|
||||||
dataset.saveDocumentsType(UploadComponentRef.value.form.fileType)
|
|
||||||
dataset.saveDocumentsFile(UploadComponentRef.value.form.fileList)
|
|
||||||
return true
|
|
||||||
} else {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
onMounted(() => {})
|
|
||||||
|
|
||||||
defineExpose({
|
|
||||||
onSubmit,
|
|
||||||
loading
|
|
||||||
})
|
|
||||||
</script>
|
|
||||||
<style scoped lang="scss">
|
|
||||||
.upload-document {
|
|
||||||
width: 70%;
|
|
||||||
margin: 0 auto;
|
|
||||||
margin-bottom: 20px;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
@ -141,4 +141,13 @@ const submit = async (formEl: FormInstance) => {
|
|||||||
|
|
||||||
defineExpose({ open })
|
defineExpose({ open })
|
||||||
</script>
|
</script>
|
||||||
<style lang="scss" scope></style>
|
<style lang="scss" scope>
|
||||||
|
.edit-mark-dialog {
|
||||||
|
.el-dialog__header.show-close {
|
||||||
|
padding-right: 15px;
|
||||||
|
}
|
||||||
|
.el-dialog__headerbtn {
|
||||||
|
top: 13px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
|||||||
@ -23,7 +23,7 @@
|
|||||||
v-if="isEdit"
|
v-if="isEdit"
|
||||||
v-model="form.content"
|
v-model="form.content"
|
||||||
placeholder="请输入分段内容"
|
placeholder="请输入分段内容"
|
||||||
:maxLength="4096"
|
:maxLength="100000"
|
||||||
:preview="false"
|
:preview="false"
|
||||||
:toolbars="toolbars"
|
:toolbars="toolbars"
|
||||||
style="height: 300px"
|
style="height: 300px"
|
||||||
@ -31,7 +31,7 @@
|
|||||||
:footers="footers"
|
:footers="footers"
|
||||||
>
|
>
|
||||||
<template #defFooters>
|
<template #defFooters>
|
||||||
<span style="margin-left: -6px">/ 4096</span>
|
<span style="margin-left: -6px">/ 100000</span>
|
||||||
</template>
|
</template>
|
||||||
</MdEditor>
|
</MdEditor>
|
||||||
<MdPreview
|
<MdPreview
|
||||||
|
|||||||
@ -55,6 +55,32 @@
|
|||||||
placeholder="请给基础模型设置一个名称"
|
placeholder="请给基础模型设置一个名称"
|
||||||
/>
|
/>
|
||||||
</el-form-item>
|
</el-form-item>
|
||||||
|
<el-form-item prop="model_type" :rules="base_form_data_rule.permission_type">
|
||||||
|
<template #label>
|
||||||
|
<span>权限</span>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<el-radio-group v-model="base_form_data.permission_type" class="card__radio">
|
||||||
|
<el-row :gutter="16">
|
||||||
|
<template v-for="(value, key) of PermissionType" :key="key">
|
||||||
|
<el-col :span="12">
|
||||||
|
<el-card
|
||||||
|
shadow="never"
|
||||||
|
class="mb-16"
|
||||||
|
:class="base_form_data.permission_type === key ? 'active' : ''"
|
||||||
|
>
|
||||||
|
<el-radio :value="key" size="large">
|
||||||
|
<p class="mb-4">{{ value }}</p>
|
||||||
|
<el-text type="info">
|
||||||
|
{{ PermissionDesc[key] }}
|
||||||
|
</el-text>
|
||||||
|
</el-radio>
|
||||||
|
</el-card>
|
||||||
|
</el-col>
|
||||||
|
</template>
|
||||||
|
</el-row>
|
||||||
|
</el-radio-group>
|
||||||
|
</el-form-item>
|
||||||
<el-form-item prop="model_type" :rules="base_form_data_rule.model_type">
|
<el-form-item prop="model_type" :rules="base_form_data_rule.model_type">
|
||||||
<template #label>
|
<template #label>
|
||||||
<span>模型类型</span>
|
<span>模型类型</span>
|
||||||
@ -74,6 +100,7 @@
|
|||||||
></el-option>
|
></el-option>
|
||||||
</el-select>
|
</el-select>
|
||||||
</el-form-item>
|
</el-form-item>
|
||||||
|
|
||||||
<el-form-item prop="model_name" :rules="base_form_data_rule.model_name">
|
<el-form-item prop="model_name" :rules="base_form_data_rule.model_name">
|
||||||
<template #label>
|
<template #label>
|
||||||
<div class="flex align-center" style="display: inline-flex">
|
<div class="flex align-center" style="display: inline-flex">
|
||||||
@ -135,6 +162,7 @@ import type { FormField } from '@/components/dynamics-form/type'
|
|||||||
import DynamicsForm from '@/components/dynamics-form/index.vue'
|
import DynamicsForm from '@/components/dynamics-form/index.vue'
|
||||||
import type { FormRules } from 'element-plus'
|
import type { FormRules } from 'element-plus'
|
||||||
import { MsgSuccess } from '@/utils/message'
|
import { MsgSuccess } from '@/utils/message'
|
||||||
|
import { PermissionType, PermissionDesc } from '@/enums/model'
|
||||||
|
|
||||||
const providerValue = ref<Provider>()
|
const providerValue = ref<Provider>()
|
||||||
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
|
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
|
||||||
@ -150,17 +178,18 @@ const dialogVisible = ref<boolean>(false)
|
|||||||
|
|
||||||
const base_form_data_rule = ref<FormRules>({
|
const base_form_data_rule = ref<FormRules>({
|
||||||
name: { required: true, trigger: 'blur', message: '模型名不能为空' },
|
name: { required: true, trigger: 'blur', message: '模型名不能为空' },
|
||||||
|
permission_type: { required: true, trigger: 'change', message: '权限不能为空' },
|
||||||
model_type: { required: true, trigger: 'change', message: '模型类型不能为空' },
|
model_type: { required: true, trigger: 'change', message: '模型类型不能为空' },
|
||||||
model_name: { required: true, trigger: 'change', message: '基础模型不能为空' }
|
model_name: { required: true, trigger: 'change', message: '基础模型不能为空' }
|
||||||
})
|
})
|
||||||
|
|
||||||
const base_form_data = ref<{
|
const base_form_data = ref<{
|
||||||
name: string
|
name: string
|
||||||
|
permission_type: string
|
||||||
model_type: string
|
model_type: string
|
||||||
|
|
||||||
model_name: string
|
model_name: string
|
||||||
}>({ name: '', model_type: '', model_name: '' })
|
}>({ name: '', model_type: '', model_name: '', permission_type: 'PRIVATE' })
|
||||||
|
|
||||||
const credential_form_data = ref<Dict<any>>({})
|
const credential_form_data = ref<Dict<any>>({})
|
||||||
|
|
||||||
@ -212,7 +241,7 @@ const list_base_model = (model_type: any) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const close = () => {
|
const close = () => {
|
||||||
base_form_data.value = { name: '', model_type: '', model_name: '' }
|
base_form_data.value = { name: '', model_type: '', model_name: '', permission_type: 'PRIVATE' }
|
||||||
credential_form_data.value = {}
|
credential_form_data.value = {}
|
||||||
model_form_field.value = []
|
model_form_field.value = []
|
||||||
base_model_list.value = []
|
base_model_list.value = []
|
||||||
|
|||||||
@ -48,6 +48,32 @@
|
|||||||
placeholder="请给基础模型设置一个名称"
|
placeholder="请给基础模型设置一个名称"
|
||||||
/>
|
/>
|
||||||
</el-form-item>
|
</el-form-item>
|
||||||
|
<el-form-item prop="model_type" :rules="base_form_data_rule.permission_type">
|
||||||
|
<template #label>
|
||||||
|
<span>权限</span>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<el-radio-group v-model="base_form_data.permission_type" class="card__radio">
|
||||||
|
<el-row :gutter="16">
|
||||||
|
<template v-for="(value, key) of PermissionType" :key="key">
|
||||||
|
<el-col :span="12">
|
||||||
|
<el-card
|
||||||
|
shadow="never"
|
||||||
|
class="mb-16"
|
||||||
|
:class="base_form_data.permission_type === key ? 'active' : ''"
|
||||||
|
>
|
||||||
|
<el-radio :value="key" size="large">
|
||||||
|
<p class="mb-4">{{ value }}</p>
|
||||||
|
<el-text type="info">
|
||||||
|
{{ PermissionDesc[key] }}
|
||||||
|
</el-text>
|
||||||
|
</el-radio>
|
||||||
|
</el-card>
|
||||||
|
</el-col>
|
||||||
|
</template>
|
||||||
|
</el-row>
|
||||||
|
</el-radio-group>
|
||||||
|
</el-form-item>
|
||||||
<el-form-item prop="model_type" :rules="base_form_data_rule.model_type">
|
<el-form-item prop="model_type" :rules="base_form_data_rule.model_type">
|
||||||
<template #label>
|
<template #label>
|
||||||
<span>模型类型</span>
|
<span>模型类型</span>
|
||||||
@ -128,7 +154,7 @@ import type { FormField } from '@/components/dynamics-form/type'
|
|||||||
import DynamicsForm from '@/components/dynamics-form/index.vue'
|
import DynamicsForm from '@/components/dynamics-form/index.vue'
|
||||||
import type { FormRules } from 'element-plus'
|
import type { FormRules } from 'element-plus'
|
||||||
import { MsgSuccess } from '@/utils/message'
|
import { MsgSuccess } from '@/utils/message'
|
||||||
import AppIcon from '@/components/icons/AppIcon.vue'
|
import { PermissionType, PermissionDesc } from '@/enums/model'
|
||||||
|
|
||||||
const providerValue = ref<Provider>()
|
const providerValue = ref<Provider>()
|
||||||
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
|
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
|
||||||
@ -151,11 +177,11 @@ const base_form_data_rule = ref<FormRules>({
|
|||||||
|
|
||||||
const base_form_data = ref<{
|
const base_form_data = ref<{
|
||||||
name: string
|
name: string
|
||||||
|
permission_type: string
|
||||||
model_type: string
|
model_type: string
|
||||||
|
|
||||||
model_name: string
|
model_name: string
|
||||||
}>({ name: '', model_type: '', model_name: '' })
|
}>({ name: '', model_type: '', model_name: '', permission_type: 'PRIVATE' })
|
||||||
|
|
||||||
const credential_form_data = ref<Dict<any>>({})
|
const credential_form_data = ref<Dict<any>>({})
|
||||||
|
|
||||||
@ -204,6 +230,7 @@ const open = (provider: Provider, model: Model) => {
|
|||||||
|
|
||||||
base_form_data.value = {
|
base_form_data.value = {
|
||||||
name: model.name,
|
name: model.name,
|
||||||
|
permission_type: model.permission_type,
|
||||||
model_type: model.model_type,
|
model_type: model.model_type,
|
||||||
model_name: model.model_name
|
model_name: model.model_name
|
||||||
}
|
}
|
||||||
@ -214,7 +241,7 @@ const open = (provider: Provider, model: Model) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const close = () => {
|
const close = () => {
|
||||||
base_form_data.value = { name: '', model_type: '', model_name: '' }
|
base_form_data.value = { name: '', model_type: '', model_name: '', permission_type: '' }
|
||||||
dynamicsFormRef.value?.ruleFormRef?.resetFields()
|
dynamicsFormRef.value?.ruleFormRef?.resetFields()
|
||||||
credential_form_data.value = {}
|
credential_form_data.value = {}
|
||||||
model_form_field.value = []
|
model_form_field.value = []
|
||||||
|
|||||||
@ -1,16 +1,30 @@
|
|||||||
<template>
|
<template>
|
||||||
<card-box :title="model.name" shadow="hover" class="model-card">
|
<card-box :title="model.name" shadow="hover" class="model-card">
|
||||||
<template #header>
|
<template #header>
|
||||||
<div class="flex align-center">
|
<div class="flex">
|
||||||
<span style="height: 32px; width: 32px" :innerHTML="icon" class="mr-12"></span>
|
<span style="height: 32px; width: 32px" :innerHTML="icon" class="mr-12"></span>
|
||||||
<auto-tooltip :content="model.name" style="max-width: 40%">
|
<div class="w-full">
|
||||||
{{ model.name }}
|
<div class="flex" style="height: 22px">
|
||||||
</auto-tooltip>
|
<auto-tooltip :content="model.name" style="max-width: 40%">
|
||||||
<div class="flex align-center" v-if="currentModel.status === 'ERROR'">
|
{{ model.name }}
|
||||||
<el-tag type="danger" class="ml-8">失败</el-tag>
|
</auto-tooltip>
|
||||||
<el-tooltip effect="dark" :content="errMessage" placement="top">
|
<span v-if="currentModel.status === 'ERROR'">
|
||||||
<el-icon class="danger ml-4" size="20"><Warning /></el-icon>
|
<el-tooltip effect="dark" :content="errMessage" placement="top">
|
||||||
</el-tooltip>
|
<el-icon class="danger ml-4" size="18"><Warning /></el-icon>
|
||||||
|
</el-tooltip>
|
||||||
|
</span>
|
||||||
|
<span v-if="currentModel.status === 'PAUSE_DOWNLOAD'">
|
||||||
|
<el-tooltip effect="dark" content="暂停下载" placement="top">
|
||||||
|
<el-icon class="danger ml-4" size="18"><Warning /></el-icon>
|
||||||
|
</el-tooltip>
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div class="mt-4">
|
||||||
|
<el-tag v-if="model.permission_type === 'PRIVATE'" type="danger" class="danger-tag"
|
||||||
|
>私有</el-tag
|
||||||
|
>
|
||||||
|
<el-tag v-else type="info" class="info-tag">公有</el-tag>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
@ -29,18 +43,14 @@
|
|||||||
</div>
|
</div>
|
||||||
<!-- progress -->
|
<!-- progress -->
|
||||||
<div class="progress-mask" v-if="currentModel.status === 'DOWNLOAD'">
|
<div class="progress-mask" v-if="currentModel.status === 'DOWNLOAD'">
|
||||||
<el-progress
|
<DownloadLoading class="percentage" />
|
||||||
type="circle"
|
|
||||||
:width="56"
|
<div class="percentage-label flex-center">
|
||||||
color="#3370FF"
|
正在下载中 <span class="dotting"></span>
|
||||||
:percentage="progress"
|
<el-button link type="primary" class="ml-16" @click.stop="cancelDownload"
|
||||||
class="percentage"
|
>取消下载</el-button
|
||||||
>
|
>
|
||||||
<template #default="{ percentage }">
|
</div>
|
||||||
<span class="percentage-value">{{ percentage }}%</span>
|
|
||||||
</template>
|
|
||||||
</el-progress>
|
|
||||||
<span class="percentage-label">正在下载 <span class="dotting"></span></span>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<template #mouseEnter>
|
<template #mouseEnter>
|
||||||
@ -48,7 +58,13 @@
|
|||||||
<el-tooltip effect="dark" content="修改" placement="top">
|
<el-tooltip effect="dark" content="修改" placement="top">
|
||||||
<el-button text @click.stop="openEditModel">
|
<el-button text @click.stop="openEditModel">
|
||||||
<el-icon>
|
<el-icon>
|
||||||
<component :is="currentModel.status === 'ERROR' ? 'RefreshRight' : 'EditPen'" />
|
<component
|
||||||
|
:is="
|
||||||
|
currentModel.status === 'ERROR' || currentModel.status === 'PAUSE_DOWNLOAD'
|
||||||
|
? 'RefreshRight'
|
||||||
|
: 'EditPen'
|
||||||
|
"
|
||||||
|
/>
|
||||||
</el-icon>
|
</el-icon>
|
||||||
</el-button>
|
</el-button>
|
||||||
</el-tooltip>
|
</el-tooltip>
|
||||||
@ -68,6 +84,7 @@ import type { Provider, Model } from '@/api/type/model'
|
|||||||
import ModelApi from '@/api/model'
|
import ModelApi from '@/api/model'
|
||||||
import { computed, ref, onMounted, onBeforeUnmount } from 'vue'
|
import { computed, ref, onMounted, onBeforeUnmount } from 'vue'
|
||||||
import EditModel from '@/views/template/component/EditModel.vue'
|
import EditModel from '@/views/template/component/EditModel.vue'
|
||||||
|
import DownloadLoading from '@/components/loading/DownloadLoading.vue'
|
||||||
import { MsgConfirm } from '@/utils/message'
|
import { MsgConfirm } from '@/utils/message'
|
||||||
|
|
||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
@ -94,27 +111,6 @@ const errMessage = computed(() => {
|
|||||||
}
|
}
|
||||||
return ''
|
return ''
|
||||||
})
|
})
|
||||||
const progress = computed(() => {
|
|
||||||
if (currentModel.value) {
|
|
||||||
const down_model_chunk = currentModel.value.meta['down_model_chunk']
|
|
||||||
if (down_model_chunk) {
|
|
||||||
const maxObj = down_model_chunk
|
|
||||||
.filter((chunk: any) => chunk.index > 1)
|
|
||||||
.reduce(
|
|
||||||
(prev: any, current: any) => {
|
|
||||||
return (prev.index || 0) > (current.index || 0) ? prev : current
|
|
||||||
},
|
|
||||||
{ progress: 0 }
|
|
||||||
)
|
|
||||||
if (maxObj) {
|
|
||||||
return parseFloat(maxObj.progress?.toFixed(1))
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
})
|
|
||||||
const emit = defineEmits(['change', 'update:model'])
|
const emit = defineEmits(['change', 'update:model'])
|
||||||
const eidtModelRef = ref<InstanceType<typeof EditModel>>()
|
const eidtModelRef = ref<InstanceType<typeof EditModel>>()
|
||||||
let interval: any
|
let interval: any
|
||||||
@ -130,6 +126,13 @@ const deleteModel = () => {
|
|||||||
})
|
})
|
||||||
.catch(() => {})
|
.catch(() => {})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const cancelDownload = () => {
|
||||||
|
ModelApi.pauseDownload(props.model.id).then(() => {
|
||||||
|
downModel.value = undefined
|
||||||
|
emit('change')
|
||||||
|
})
|
||||||
|
}
|
||||||
const openEditModel = () => {
|
const openEditModel = () => {
|
||||||
const provider = props.provider_list.find((p) => p.provider === props.model.provider)
|
const provider = props.provider_list.find((p) => p.provider === props.model.provider)
|
||||||
if (provider) {
|
if (provider) {
|
||||||
@ -197,21 +200,21 @@ onBeforeUnmount(() => {
|
|||||||
z-index: 99;
|
z-index: 99;
|
||||||
text-align: center;
|
text-align: center;
|
||||||
.percentage {
|
.percentage {
|
||||||
top: 50%;
|
margin-top: 55px;
|
||||||
transform: translateY(-65%);
|
margin-bottom: 16px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.percentage-value {
|
// .percentage-value {
|
||||||
display: block;
|
// display: flex;
|
||||||
font-size: 12px;
|
// font-size: 13px;
|
||||||
color: var(--el-color-primary);
|
// align-items: center;
|
||||||
}
|
// color: var(--app-text-color-secondary);
|
||||||
|
// }
|
||||||
.percentage-label {
|
.percentage-label {
|
||||||
display: block;
|
margin-top: 50px;
|
||||||
margin-top: 45px;
|
|
||||||
margin-left: 10px;
|
margin-left: 10px;
|
||||||
font-size: 12px;
|
font-size: 13px;
|
||||||
color: var(--el-color-primary);
|
color: var(--app-text-color-secondary);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -145,10 +145,7 @@ function createUser() {
|
|||||||
title.value = '创建用户'
|
title.value = '创建用户'
|
||||||
UserDialogRef.value.open()
|
UserDialogRef.value.open()
|
||||||
} else {
|
} else {
|
||||||
MsgAlert(
|
MsgAlert('提示', '社区版最多支持 2 个用户,如需拥有更多用户,请升级为专业版。')
|
||||||
'提示',
|
|
||||||
'社区版最多支持 2 个用户,如需拥有更多用户,请联系我们(https://fit2cloud.com/)。'
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user