feat: 模型管理支持向量模型,知识库可以关联向量模型

feat:  模型管理支持向量模型,知识库可以关联向量模型
This commit is contained in:
wangdan-fit2cloud 2024-07-19 02:21:24 -07:00 committed by GitHub
commit d3d09b10ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
74 changed files with 1562 additions and 544 deletions

View File

@ -43,6 +43,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message="类型只支持register|reset_password", code=500) message="类型只支持register|reset_password", code=500)
], error_messages=ErrMessage.char("检索模式")) ], error_messages=ErrMessage.char("检索模式"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]: def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
return self.InstanceSerializer return self.InstanceSerializer
@ -56,6 +57,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None, search_mode: str = None,
user_id=None,
**kwargs) -> List[ParagraphPipelineModel]: **kwargs) -> List[ParagraphPipelineModel]:
""" """
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询 关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
@ -67,6 +69,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
:param exclude_paragraph_id_list: 需要排除段落id :param exclude_paragraph_id_list: 需要排除段落id
:param padding_problem_text 补全问题 :param padding_problem_text 补全问题
:param search_mode 检索模式 :param search_mode 检索模式
:param user_id 用户id
:return: 段落列表 :return: 段落列表
""" """
pass pass

View File

@ -13,22 +13,48 @@ from django.db.models import QuerySet
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
from common.config.embedding_config import VectorStore, EmbeddingModel from common.config.embedding_config import VectorStore, ModelManage
from common.db.search import native_search from common.db.search import native_search
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from dataset.models import Paragraph from dataset.models import Paragraph, DataSet
from embedding.models import SearchMode from embedding.models import SearchMode
from setting.models import Model
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception("模型不存在")
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
raise Exception(f"无权限使用此模型:{model.name}")
return model
def get_embedding_id(dataset_id_list):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception("知识库未向量模型不一致")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return dataset_list[0].embedding_mode_id
class BaseSearchDatasetStep(ISearchDatasetStep): class BaseSearchDatasetStep(ISearchDatasetStep):
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None, search_mode: str = None,
user_id=None,
**kwargs) -> List[ParagraphPipelineModel]: **kwargs) -> List[ParagraphPipelineModel]:
if len(dataset_id_list) == 0:
return []
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
embedding_model = EmbeddingModel.get_embedding_model() model_id = get_embedding_id(dataset_id_list)
model = get_model_by_id(model_id, user_id)
self.context['model_name'] = model.name
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(exec_problem_text) embedding_value = embedding_model.embed_query(exec_problem_text)
vector = VectorStore.get_embedding_vector() vector = VectorStore.get_embedding_vector()
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list, embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
@ -101,7 +127,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
'run_time': self.context['run_time'], 'run_time': self.context['run_time'],
'problem_text': step_args.get( 'problem_text': step_args.get(
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'), 'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
'model_name': EmbeddingModel.get_embedding_model().model_name, 'model_name': self.context.get('model_name'),
'message_tokens': 0, 'message_tokens': 0,
'answer_tokens': 0, 'answer_tokens': 0,
'cost': 0 'cost': 0

View File

@ -111,6 +111,7 @@ class FlowParamsSerializer(serializers.Serializer):
client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型")) client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id"))
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案")) re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案"))

View File

@ -13,20 +13,50 @@ from django.db.models import QuerySet
from application.flow.i_step_node import NodeResult from application.flow.i_step_node import NodeResult
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
from common.config.embedding_config import EmbeddingModel, VectorStore from common.config.embedding_config import VectorStore, ModelManage
from common.db.search import native_search from common.db.search import native_search
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from dataset.models import Document, Paragraph from dataset.models import Document, Paragraph, DataSet
from embedding.models import SearchMode from embedding.models import SearchMode
from setting.models import Model
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception("模型不存在")
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
raise Exception(f"无权限使用此模型:{model.name}")
return model
def get_embedding_id(dataset_id_list):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception("知识库未向量模型不一致")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return dataset_list[0].embedding_mode_id
def get_none_result(question):
return NodeResult(
{'paragraph_list': [], 'is_hit_handling_method': [], 'question': question, 'data': '',
'directly_return': ''}, {})
class BaseSearchDatasetNode(ISearchDatasetStepNode): class BaseSearchDatasetNode(ISearchDatasetStepNode):
def execute(self, dataset_id_list, dataset_setting, question, def execute(self, dataset_id_list, dataset_setting, question,
exclude_paragraph_id_list=None, exclude_paragraph_id_list=None,
**kwargs) -> NodeResult: **kwargs) -> NodeResult:
self.context['question'] = question self.context['question'] = question
embedding_model = EmbeddingModel.get_embedding_model() if len(dataset_id_list) == 0:
return get_none_result(question)
model_id = get_embedding_id(dataset_id_list)
model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id'))
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(question) embedding_value = embedding_model.embed_query(question)
vector = VectorStore.get_embedding_vector() vector = VectorStore.get_embedding_vector()
exclude_document_id_list = [str(document.id) for document in exclude_document_id_list = [str(document.id) for document in
@ -37,7 +67,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
exclude_paragraph_id_list, True, dataset_setting.get('top_n'), exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode'))) dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
if embedding_list is None: if embedding_list is None:
return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {}) return get_none_result(question)
paragraph_list = self.list_paragraph(embedding_list, vector) paragraph_list = self.list_paragraph(embedding_list, vector)
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list] result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
result = sorted(result, key=lambda p: p.get('similarity'), reverse=True) result = sorted(result, key=lambda p: p.get('similarity'), reverse=True)

View File

@ -0,0 +1,19 @@
# Generated by Django 4.2.13 on 2024-07-15 15:52
import application.models.application
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0009_application_type_application_work_flow_and_more'),
]
operations = [
migrations.AlterField(
model_name='chatrecord',
name='details',
field=models.JSONField(default=dict, encoder=application.models.application.DateEncoder, verbose_name='对话详情'),
),
]

View File

@ -26,7 +26,7 @@ from rest_framework import serializers
from application.flow.workflow_manage import Flow from application.flow.workflow_manage import Flow
from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
from common.config.embedding_config import VectorStore, EmbeddingModel from common.config.embedding_config import VectorStore
from common.constants.authentication_type import AuthenticationType from common.constants.authentication_type import AuthenticationType
from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list from common.db.sql_execute import select_list
@ -36,7 +36,7 @@ from common.util.common import valid_license
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from dataset.models import DataSet, Document, Image from dataset.models import DataSet, Document, Image
from dataset.serializers.common_serializers import list_paragraph from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list
from embedding.models import SearchMode from embedding.models import SearchMode
from setting.models import AuthOperate from setting.models import AuthOperate
from setting.models.model_management import Model from setting.models.model_management import Model
@ -415,12 +415,13 @@ class ApplicationSerializer(serializers.Serializer):
QuerySet(Document).filter( QuerySet(Document).filter(
dataset_id__in=dataset_id_list, dataset_id__in=dataset_id_list,
is_active=False)] is_active=False)]
model = get_embedding_model_by_dataset_id_list(dataset_id_list)
# 向量库检索 # 向量库检索
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list, hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
self.data.get('top_number'), self.data.get('top_number'),
self.data.get('similarity'), self.data.get('similarity'),
SearchMode(self.data.get('search_mode')), SearchMode(self.data.get('search_mode')),
EmbeddingModel.get_embedding_model()) model)
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'), return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
@ -522,12 +523,14 @@ class ApplicationSerializer(serializers.Serializer):
if not QuerySet(Application).filter(id=self.data.get('application_id')).exists(): if not QuerySet(Application).filter(id=self.data.get('application_id')).exists():
raise AppApiException(500, '不存在的应用id') raise AppApiException(500, '不存在的应用id')
def list_model(self, with_valid=True): def list_model(self, model_type=None, with_valid=True):
if with_valid: if with_valid:
self.is_valid() self.is_valid()
if model_type is None:
model_type = "LLM"
application = QuerySet(Application).filter(id=self.data.get("application_id")).first() application = QuerySet(Application).filter(id=self.data.get("application_id")).first()
return ModelSerializer.Query( return ModelSerializer.Query(
data={'user_id': application.user_id}).list( data={'user_id': application.user_id, 'model_type': model_type}).list(
with_valid=True) with_valid=True)
def delete(self, with_valid=True): def delete(self, with_valid=True):

View File

@ -88,7 +88,9 @@ class ChatInfo:
'no_references_setting': self.application.dataset_setting.get( 'no_references_setting': self.application.dataset_setting.get(
'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else { 'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else {
'status': 'ai_questioning', 'status': 'ai_questioning',
'value': '{question}'} 'value': '{question}',
},
'user_id': self.application.user_id
} }
@ -221,11 +223,13 @@ class ChatMessageSerializer(serializers.Serializer):
stream = self.data.get('stream') stream = self.data.get('stream')
client_id = self.data.get('client_id') client_id = self.data.get('client_id')
client_type = self.data.get('client_type') client_type = self.data.get('client_type')
user_id = chat_info.application.user_id
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow), work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
{'history_chat_record': chat_info.chat_record_list, 'question': message, {'history_chat_record': chat_info.chat_record_list, 'question': message,
'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()), 'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()),
'stream': stream, 'stream': stream,
're_chat': re_chat}, WorkFlowPostHandler(chat_info, client_id, client_type)) 're_chat': re_chat,
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type))
r = work_flow_manage.run() r = work_flow_manage.run()
return r return r

View File

@ -29,6 +29,7 @@ from application.models.api_key_model import ApplicationAccessToken
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \ from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
ModelSettingSerializer ModelSettingSerializer
from application.serializers.chat_message_serializers import ChatInfo from application.serializers.chat_message_serializers import ChatInfo
from common.config.embedding_config import ModelManage
from common.constants.permission_constants import RoleConstants from common.constants.permission_constants import RoleConstants
from common.db.search import native_search, native_page_search, page_search, get_dynamics_model from common.db.search import native_search, native_page_search, page_search, get_dynamics_model
from common.event import ListenerManagement from common.event import ListenerManagement
@ -39,8 +40,10 @@ from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock from common.util.lock import try_lock, un_lock
from common.util.rsa_util import rsa_long_decrypt from common.util.rsa_util import rsa_long_decrypt
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers from dataset.serializers.paragraph_serializers import ParagraphSerializers
from setting.models import Model from setting.models import Model
from setting.models_provider import get_model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
@ -241,12 +244,7 @@ class ChatSerializers(serializers.Serializer):
application_id=application_id)] application_id=application_id)]
chat_model = None chat_model = None
if model is not None: if model is not None:
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, chat_model = ModelManage.get_model(str(model.id), lambda _id: get_model(model))
json.loads(
rsa_long_decrypt(
model.credential)),
streaming=True)
chat_id = str(uuid.uuid1()) chat_id = str(uuid.uuid1())
chat_cache.set(chat_id, chat_cache.set(chat_id,
ChatInfo(chat_id, chat_model, dataset_id_list, ChatInfo(chat_id, chat_model, dataset_id_list,
@ -259,6 +257,7 @@ class ChatSerializers(serializers.Serializer):
class OpenWorkFlowChat(serializers.Serializer): class OpenWorkFlowChat(serializers.Serializer):
work_flow = WorkFlowSerializers(error_messages=ErrMessage.uuid("工作流")) work_flow = WorkFlowSerializers(error_messages=ErrMessage.uuid("工作流"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
def open(self): def open(self):
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
@ -269,7 +268,8 @@ class ChatSerializers(serializers.Serializer):
dataset_setting={}, dataset_setting={},
model_setting={}, model_setting={},
problem_optimization=None, problem_optimization=None,
type=ApplicationTypeChoices.WORK_FLOW type=ApplicationTypeChoices.WORK_FLOW,
user_id=self.data.get('user_id')
) )
work_flow_version = WorkFlowVersion(work_flow=work_flow) work_flow_version = WorkFlowVersion(work_flow=work_flow)
chat_cache.set(chat_id, chat_cache.set(chat_id,
@ -332,7 +332,8 @@ class ChatSerializers(serializers.Serializer):
application = Application(id=None, dialogue_number=3, model=model, application = Application(id=None, dialogue_number=3, model=model,
dataset_setting=self.data.get('dataset_setting'), dataset_setting=self.data.get('dataset_setting'),
model_setting=self.data.get('model_setting'), model_setting=self.data.get('model_setting'),
problem_optimization=self.data.get('problem_optimization')) problem_optimization=self.data.get('problem_optimization'),
user_id=user_id)
chat_cache.set(chat_id, chat_cache.set(chat_id,
ChatInfo(chat_id, chat_model, dataset_id_list, ChatInfo(chat_id, chat_model, dataset_id_list,
[str(document.id) for document in [str(document.id) for document in
@ -533,9 +534,10 @@ class ChatRecordSerializer(serializers.Serializer):
raise AppApiException(500, "文档id不正确") raise AppApiException(500, "文档id不正确")
@staticmethod @staticmethod
def post_embedding_paragraph(chat_record, paragraph_id): def post_embedding_paragraph(chat_record, paragraph_id, dataset_id):
model = get_embedding_model_by_dataset_id(dataset_id)
# 发送向量化事件 # 发送向量化事件
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id) ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id, embedding_model=model)
return chat_record return chat_record
@post(post_function=post_embedding_paragraph) @post(post_function=post_embedding_paragraph)
@ -573,7 +575,7 @@ class ChatRecordSerializer(serializers.Serializer):
chat_record.improve_paragraph_id_list.append(paragraph.id) chat_record.improve_paragraph_id_list.append(paragraph.id)
# 添加标注 # 添加标注
chat_record.save() chat_record.save()
return ChatRecordSerializerModel(chat_record).data, paragraph.id return ChatRecordSerializerModel(chat_record).data, paragraph.id, dataset_id
class Operate(serializers.Serializer): class Operate(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))

View File

@ -67,6 +67,20 @@ class ApplicationApi(ApiMixin):
} }
) )
class Model(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),
openapi.Parameter(name='model_type', in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='模型类型'),
]
class ApiKey(ApiMixin): class ApiKey(ApiMixin):
@staticmethod @staticmethod
def get_request_params_api(): def get_request_params_api():

View File

@ -175,7 +175,7 @@ class Application(APIView):
@swagger_auto_schema(operation_summary="获取模型列表", @swagger_auto_schema(operation_summary="获取模型列表",
operation_id="获取模型列表", operation_id="获取模型列表",
tags=["应用"], tags=["应用"],
manual_parameters=ApplicationApi.ApiKey.get_request_params_api()) manual_parameters=ApplicationApi.Model.get_request_params_api())
@has_permissions(ViewPermission( @has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER], [RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
@ -185,7 +185,7 @@ class Application(APIView):
return result.success( return result.success(
ApplicationSerializer.Operate( ApplicationSerializer.Operate(
data={'application_id': application_id, data={'application_id': application_id,
'user_id': request.user.id}).list_model()) 'user_id': request.user.id}).list_model(request.query_params.get('model_type')))
class Profile(APIView): class Profile(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]

View File

@ -41,3 +41,7 @@ class MemCache(LocMemCache):
delete_keys.append(key) delete_keys.append(key)
for key in delete_keys: for key in delete_keys:
self._delete(key) self._delete(key)
def clear_timeout_data(self):
for key in self._cache.keys():
self.get(key)

View File

@ -6,33 +6,36 @@
@date2023/10/23 16:03 @date2023/10/23 16:03
@desc: @desc:
""" """
from langchain_huggingface.embeddings import HuggingFaceEmbeddings import time
from smartdoc.const import CONFIG from common.cache.mem_cache import MemCache
class EmbeddingModel: class ModelManage:
instance = None cache = MemCache('model', {})
up_clear_time = time.time()
@staticmethod @staticmethod
def get_embedding_model(): def get_model(_id, get_model):
""" model_instance = ModelManage.cache.get(_id)
获取向量化模型 if model_instance is None:
:return: model_instance = get_model(_id)
""" ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
if EmbeddingModel.instance is None: return model_instance
model_name = CONFIG.get('EMBEDDING_MODEL_NAME') # 续期
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH') ModelManage.cache.touch(_id, timeout=60 * 30)
device = CONFIG.get('EMBEDDING_DEVICE') ModelManage.clear_timeout_cache()
encode_kwargs = {'normalize_embeddings': True} return model_instance
e = HuggingFaceEmbeddings(
model_name=model_name, @staticmethod
cache_folder=cache_folder, def clear_timeout_cache():
model_kwargs={'device': device}, if time.time() - ModelManage.up_clear_time > 60:
encode_kwargs=encode_kwargs, ModelManage.cache.clear_timeout_data()
)
EmbeddingModel.instance = e @staticmethod
return EmbeddingModel.instance def delete_key(_id):
if ModelManage.cache.has_key(_id):
ModelManage.cache.delete(_id)
class VectorStore: class VectorStore:

View File

@ -14,14 +14,14 @@ embedding_thread_pool = ThreadPoolExecutor(3)
def poxy(poxy_function): def poxy(poxy_function):
def inner(args): def inner(args, **keywords):
work_thread_pool.submit(poxy_function, args) work_thread_pool.submit(poxy_function, args, **keywords)
return inner return inner
def embedding_poxy(poxy_function): def embedding_poxy(poxy_function):
def inner(args): def inner(args, **keywords):
embedding_thread_pool.submit(poxy_function, args) embedding_thread_pool.submit(poxy_function, args, **keywords)
return inner return inner

View File

@ -15,8 +15,9 @@ from typing import List
import django.db.models import django.db.models
from blinker import signal from blinker import signal
from django.db.models import QuerySet from django.db.models import QuerySet
from langchain_core.embeddings import Embeddings
from common.config.embedding_config import VectorStore, EmbeddingModel from common.config.embedding_config import VectorStore
from common.db.search import native_search, get_dynamics_model from common.db.search import native_search, get_dynamics_model
from common.event.common import poxy, embedding_poxy from common.event.common import poxy, embedding_poxy
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
@ -46,22 +47,26 @@ class SyncWebDocumentArgs:
class UpdateProblemArgs: class UpdateProblemArgs:
def __init__(self, problem_id: str, problem_content: str): def __init__(self, problem_id: str, problem_content: str, embedding_model: Embeddings):
self.problem_id = problem_id self.problem_id = problem_id
self.problem_content = problem_content self.problem_content = problem_content
self.embedding_model = embedding_model
class UpdateEmbeddingDatasetIdArgs: class UpdateEmbeddingDatasetIdArgs:
def __init__(self, paragraph_id_list: List[str], target_dataset_id: str): def __init__(self, paragraph_id_list: List[str], target_dataset_id: str, target_embedding_model: Embeddings):
self.paragraph_id_list = paragraph_id_list self.paragraph_id_list = paragraph_id_list
self.target_dataset_id = target_dataset_id self.target_dataset_id = target_dataset_id
self.target_embedding_model = target_embedding_model
class UpdateEmbeddingDocumentIdArgs: class UpdateEmbeddingDocumentIdArgs:
def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str): def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str,
target_embedding_model: Embeddings = None):
self.paragraph_id_list = paragraph_id_list self.paragraph_id_list = paragraph_id_list
self.target_document_id = target_document_id self.target_document_id = target_document_id
self.target_dataset_id = target_dataset_id self.target_dataset_id = target_dataset_id
self.target_embedding_model = target_embedding_model
class ListenerManagement: class ListenerManagement:
@ -84,16 +89,46 @@ class ListenerManagement:
delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list") delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list")
@staticmethod @staticmethod
def embedding_by_problem(args): def embedding_by_problem(args, embedding_model: Embeddings):
VectorStore.get_embedding_vector().save(**args) VectorStore.get_embedding_vector().save(**args, embedding=embedding_model)
@staticmethod
def embedding_by_paragraph_list(paragraph_id_list, embedding_model: Embeddings):
try:
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
**{'paragraph.id__in': paragraph_id_list}),
'paragraph': QuerySet(Paragraph).filter(id__in=paragraph_id_list)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list=paragraph_id_list,
embedding_model=embedding_model)
except Exception as e:
max_kb_error.error(f'查询向量数据:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}')
@staticmethod @staticmethod
@embedding_poxy @embedding_poxy
def embedding_by_paragraph(paragraph_id): def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings):
max_kb.info(f'开始--->向量化段落:{paragraph_id_list}')
try:
# 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_id_list)
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
except Exception as e:
max_kb_error.error(f'向量化段落:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
finally:
QuerySet(Paragraph).filter(id__in=paragraph_id_list).update(**{'status': status})
max_kb.info(f'结束--->向量化段落:{paragraph_id_list}')
@staticmethod
@embedding_poxy
def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
""" """
向量化段落 根据段落id 向量化段落 根据段落id
:param paragraph_id: 段落id @param paragraph_id: 段落id
:return: None @param embedding_model: 向量模型
""" """
max_kb.info(f"开始--->向量化段落:{paragraph_id}") max_kb.info(f"开始--->向量化段落:{paragraph_id}")
status = Status.success status = Status.success
@ -107,7 +142,7 @@ class ListenerManagement:
# 删除段落 # 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id) VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
# 批量向量化 # 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list) VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
except Exception as e: except Exception as e:
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}') max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error status = Status.error
@ -117,12 +152,15 @@ class ListenerManagement:
@staticmethod @staticmethod
@embedding_poxy @embedding_poxy
def embedding_by_document(document_id): def embedding_by_document(document_id, embedding_model: Embeddings):
""" """
向量化文档 向量化文档
:param document_id: 文档id @param document_id: 文档id
@param embedding_model 向量模型
:return: None :return: None
""" """
if not try_lock('embedding' + str(document_id)):
return
max_kb.info(f"开始--->向量化文档:{document_id}") max_kb.info(f"开始--->向量化文档:{document_id}")
QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding}) QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding})
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.embedding}) QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.embedding})
@ -138,7 +176,7 @@ class ListenerManagement:
# 删除文档向量数据 # 删除文档向量数据
VectorStore.get_embedding_vector().delete_by_document_id(document_id) VectorStore.get_embedding_vector().delete_by_document_id(document_id)
# 批量向量化 # 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list) VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
except Exception as e: except Exception as e:
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}') max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error status = Status.error
@ -148,21 +186,24 @@ class ListenerManagement:
**{'status': status, 'update_time': datetime.datetime.now()}) **{'status': status, 'update_time': datetime.datetime.now()})
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status}) QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status})
max_kb.info(f"结束--->向量化文档:{document_id}") max_kb.info(f"结束--->向量化文档:{document_id}")
un_lock('embedding' + str(document_id))
@staticmethod @staticmethod
@embedding_poxy @embedding_poxy
def embedding_by_dataset(dataset_id): def embedding_by_dataset(dataset_id, embedding_model: Embeddings):
""" """
向量化知识库 向量化知识库
:param dataset_id: 知识库id @param dataset_id: 知识库id
@param embedding_model 向量模型
:return: None :return: None
""" """
max_kb.info(f"开始--->向量化数据集:{dataset_id}") max_kb.info(f"开始--->向量化数据集:{dataset_id}")
try: try:
ListenerManagement.delete_embedding_by_dataset(dataset_id)
document_list = QuerySet(Document).filter(dataset_id=dataset_id) document_list = QuerySet(Document).filter(dataset_id=dataset_id)
max_kb.info(f"数据集文档:{[d.name for d in document_list]}") max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
for document in document_list: for document in document_list:
ListenerManagement.embedding_by_document(document.id) ListenerManagement.embedding_by_document(document.id, embedding_model=embedding_model)
except Exception as e: except Exception as e:
max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}') max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
finally: finally:
@ -224,14 +265,22 @@ class ListenerManagement:
@staticmethod @staticmethod
def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs): def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs):
if args.target_embedding_model is None:
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'dataset_id': args.target_dataset_id}) {'dataset_id': args.target_dataset_id})
else:
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
embedding_model=args.target_embedding_model)
@staticmethod @staticmethod
def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs): def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs):
if args.target_embedding_model is None:
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'document_id': args.target_document_id, {'document_id': args.target_document_id,
'dataset_id': args.target_dataset_id}) 'dataset_id': args.target_dataset_id})
else:
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
embedding_model=args.target_embedding_model)
@staticmethod @staticmethod
def delete_embedding_by_source_ids(source_ids: List[str]): def delete_embedding_by_source_ids(source_ids: List[str]):
@ -245,11 +294,6 @@ class ListenerManagement:
def delete_embedding_by_dataset_id_list(source_ids: List[str]): def delete_embedding_by_dataset_id_list(source_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids) VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids)
@staticmethod
@poxy
def init_embedding_model(ages):
EmbeddingModel.get_embedding_model()
def run(self): def run(self):
# 添加向量 根据问题id # 添加向量 根据问题id
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem) ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
@ -276,8 +320,7 @@ class ListenerManagement:
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph) ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
# 启动段落向量 # 启动段落向量
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph) ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
# 初始化向量化模型
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)
# 同步web站点知识库 # 同步web站点知识库
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset) ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
# 同步web站点 文档 # 同步web站点 文档

View File

@ -0,0 +1,21 @@
# Generated by Django 4.2.13 on 2024-07-17 13:56
import dataset.models.data_set
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('setting', '0005_model_permission_type'),
('dataset', '0005_file'),
]
operations = [
migrations.AddField(
model_name='dataset',
name='embedding_mode',
field=models.ForeignKey(default=dataset.models.data_set.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model', verbose_name='向量模型'),
),
]

View File

@ -9,9 +9,11 @@
import uuid import uuid
from django.db import models from django.db import models
from django.db.models import QuerySet
from common.db.sql_execute import select_one from common.db.sql_execute import select_one
from common.mixins.app_model_mixin import AppModelMixin from common.mixins.app_model_mixin import AppModelMixin
from setting.models import Model
from users.models import User from users.models import User
@ -33,6 +35,10 @@ class HitHandlingMethod(models.TextChoices):
directly_return = 'directly_return', '直接返回' directly_return = 'directly_return', '直接返回'
def default_model():
return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab')
class DataSet(AppModelMixin): class DataSet(AppModelMixin):
""" """
数据集表 数据集表
@ -43,7 +49,8 @@ class DataSet(AppModelMixin):
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户") user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices, type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
default=Type.base) default=Type.base)
embedding_mode = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
default=default_model)
meta = models.JSONField(verbose_name="元数据", default=dict) meta = models.JSONField(verbose_name="元数据", default=dict)
class Meta: class Meta:

View File

@ -14,6 +14,7 @@ from django.db.models import QuerySet
from drf_yasg import openapi from drf_yasg import openapi
from rest_framework import serializers from rest_framework import serializers
from common.config.embedding_config import ModelManage
from common.db.search import native_search from common.db.search import native_search
from common.db.sql_execute import update_execute from common.db.sql_execute import update_execute
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
@ -21,7 +22,8 @@ from common.mixins.api_mixin import ApiMixin
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from common.util.fork import Fork from common.util.fork import Fork
from dataset.models import Paragraph, Problem, ProblemParagraphMapping from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
@ -130,3 +132,22 @@ class ProblemParagraphManage:
result = [problem_model for problem_model, is_create in problem_content_dict.values() if result = [problem_model for problem_model, is_create in problem_content_dict.values() if
is_create], problem_paragraph_mapping_list is_create], problem_paragraph_mapping_list
return result return result
def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception("知识库未向量模型不一致")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return ModelManage.get_model(str(dataset_list[0].id),
lambda _id: get_model(dataset_list[0].embedding_mode))
def get_embedding_model_by_dataset_id(dataset_id: str):
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
return ModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode))
def get_embedding_model_by_dataset(dataset):
return ModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))

View File

@ -15,7 +15,6 @@ from functools import reduce
from typing import Dict, List from typing import Dict, List
from urllib.parse import urlparse from urllib.parse import urlparse
from django.conf import settings
from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields import ArrayField
from django.core import validators from django.core import validators
from django.db import transaction, models from django.db import transaction, models
@ -25,7 +24,7 @@ from drf_yasg import openapi
from rest_framework import serializers from rest_framework import serializers
from application.models import ApplicationDatasetMapping from application.models import ApplicationDatasetMapping
from common.config.embedding_config import VectorStore, EmbeddingModel from common.config.embedding_config import VectorStore
from common.db.search import get_dynamics_model, native_page_search, native_search from common.db.search import get_dynamics_model, native_page_search, native_search
from common.db.sql_execute import select_list from common.db.sql_execute import select_list
from common.event import ListenerManagement, SyncWebDatasetArgs from common.event import ListenerManagement, SyncWebDatasetArgs
@ -37,7 +36,8 @@ from common.util.file_util import get_file_content
from common.util.fork import ChildLink, Fork from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_by_dataset_id
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from embedding.models import SearchMode from embedding.models import SearchMode
from setting.models import AuthOperate from setting.models import AuthOperate
@ -206,6 +206,8 @@ class DataSetSerializers(serializers.ModelSerializer):
max_length=256, max_length=256,
min_length=1) min_length=1)
embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型"))
documents = DocumentInstanceSerializer(required=False, many=True) documents = DocumentInstanceSerializer(required=False, many=True)
def is_valid(self, *, raise_exception=False): def is_valid(self, *, raise_exception=False):
@ -226,6 +228,8 @@ class DataSetSerializers(serializers.ModelSerializer):
max_length=256, max_length=256,
min_length=1) min_length=1)
embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型"))
file_list = serializers.ListSerializer(required=True, file_list = serializers.ListSerializer(required=True,
error_messages=ErrMessage.list("文件列表"), error_messages=ErrMessage.list("文件列表"),
child=serializers.FileField(required=True, child=serializers.FileField(required=True,
@ -296,6 +300,8 @@ class DataSetSerializers(serializers.ModelSerializer):
min_length=1) min_length=1)
source_url = serializers.CharField(required=True, error_messages=ErrMessage.char("Web 根地址"), ) source_url = serializers.CharField(required=True, error_messages=ErrMessage.char("Web 根地址"), )
embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型"))
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, selector = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("选择器")) error_messages=ErrMessage.char("选择器"))
@ -347,6 +353,8 @@ class DataSetSerializers(serializers.ModelSerializer):
properties={ properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"), 'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
'embedding_mode_id': openapi.Schema(type=openapi.TYPE_STRING, title="向量模型id",
description="向量模型id"),
'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url", 'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url",
description="web站点url"), description="web站点url"),
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器") 'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器")
@ -355,8 +363,9 @@ class DataSetSerializers(serializers.ModelSerializer):
@staticmethod @staticmethod
def post_embedding_dataset(document_list, dataset_id): def post_embedding_dataset(document_list, dataset_id):
model = get_embedding_model_by_dataset_id(dataset_id)
# 发送向量化事件 # 发送向量化事件
ListenerManagement.embedding_by_dataset_signal.send(dataset_id) ListenerManagement.embedding_by_dataset_signal.send(dataset_id, embedding_model=model)
return document_list return document_list
def save_qa(self, instance: Dict, with_valid=True): def save_qa(self, instance: Dict, with_valid=True):
@ -365,7 +374,8 @@ class DataSetSerializers(serializers.ModelSerializer):
self.CreateQASerializers(data=instance).is_valid() self.CreateQASerializers(data=instance).is_valid()
file_list = instance.get('file_list') file_list = instance.get('file_list')
document_list = flat_map([DocumentSerializers.Create.parse_qa_file(file) for file in file_list]) document_list = flat_map([DocumentSerializers.Create.parse_qa_file(file) for file in file_list])
dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list} dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list,
'embedding_mode_id': instance.get('embedding_mode_id')}
return self.save(dataset_instance, with_valid=True) return self.save(dataset_instance, with_valid=True)
@valid_license(model=DataSet, count=50, @valid_license(model=DataSet, count=50,
@ -381,7 +391,8 @@ class DataSetSerializers(serializers.ModelSerializer):
if QuerySet(DataSet).filter(user_id=user_id, name=instance.get('name')).exists(): if QuerySet(DataSet).filter(user_id=user_id, name=instance.get('name')).exists():
raise AppApiException(500, "知识库名称重复!") raise AppApiException(500, "知识库名称重复!")
dataset = DataSet( dataset = DataSet(
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id}) **{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
'embedding_mode_id': instance.get('embedding_mode_id')})
document_model_list = [] document_model_list = []
paragraph_model_list = [] paragraph_model_list = []
@ -452,7 +463,8 @@ class DataSetSerializers(serializers.ModelSerializer):
dataset = DataSet( dataset = DataSet(
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id, **{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
'type': Type.web, 'type': Type.web,
'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}}) 'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector'),
'embedding_mode_id': instance.get('embedding_mode_id')}})
dataset.save() dataset.save()
ListenerManagement.sync_web_dataset_signal.send( ListenerManagement.sync_web_dataset_signal.send(
SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'), SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'),
@ -500,6 +512,8 @@ class DataSetSerializers(serializers.ModelSerializer):
properties={ properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"), 'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
'embedding_mode_id': openapi.Schema(type=openapi.TYPE_STRING, title='向量模型',
description='向量模型'),
'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据", 'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据",
items=DocumentSerializers().Create.get_request_body_api() items=DocumentSerializers().Create.get_request_body_api()
) )
@ -557,12 +571,13 @@ class DataSetSerializers(serializers.ModelSerializer):
QuerySet(Document).filter( QuerySet(Document).filter(
dataset_id=self.data.get('id'), dataset_id=self.data.get('id'),
is_active=False)] is_active=False)]
model = get_embedding_model_by_dataset_id(self.data.get('id'))
# 向量库检索 # 向量库检索
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list, hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list,
self.data.get('top_number'), self.data.get('top_number'),
self.data.get('similarity'), self.data.get('similarity'),
SearchMode(self.data.get('search_mode')), SearchMode(self.data.get('search_mode')),
EmbeddingModel.get_embedding_model()) model)
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'), return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
@ -730,7 +745,8 @@ class DataSetSerializers(serializers.ModelSerializer):
def re_embedding(self, with_valid=True): def re_embedding(self, with_valid=True):
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id')) model = get_embedding_model_by_dataset_id(self.data.get('id'))
ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'), embedding_model=model)
def list_application(self, with_valid=True): def list_application(self, with_valid=True):
if with_valid: if with_valid:
@ -769,6 +785,7 @@ class DataSetSerializers(serializers.ModelSerializer):
QuerySet(ApplicationDatasetMapping).filter( QuerySet(ApplicationDatasetMapping).filter(
dataset_id=self.data.get('id'))]))} dataset_id=self.data.get('id'))]))}
@transaction.atomic
def edit(self, dataset: Dict, user_id: str): def edit(self, dataset: Dict, user_id: str):
""" """
修改知识库 修改知识库
@ -782,6 +799,8 @@ class DataSetSerializers(serializers.ModelSerializer):
raise AppApiException(500, "知识库名称重复!") raise AppApiException(500, "知识库名称重复!")
_dataset = QuerySet(DataSet).get(id=self.data.get("id")) _dataset = QuerySet(DataSet).get(id=self.data.get("id"))
DataSetSerializers.Edit(data=dataset).is_valid(dataset=_dataset) DataSetSerializers.Edit(data=dataset).is_valid(dataset=_dataset)
if 'embedding_mode_id' in dataset:
_dataset.embedding_mode_id = dataset.get('embedding_mode_id')
if "name" in dataset: if "name" in dataset:
_dataset.name = dataset.get("name") _dataset.name = dataset.get("name")
if 'desc' in dataset: if 'desc' in dataset:

View File

@ -41,7 +41,8 @@ from common.util.file_util import get_file_content
from common.util.fork import Fork from common.util.fork import Fork
from common.util.split_model import get_split_model from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
@ -234,12 +235,17 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
meta={}) meta={})
else: else:
document_list.update(dataset_id=target_dataset_id) document_list.update(dataset_id=target_dataset_id)
# 修改向量信息 model = None
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs( if dataset.embedding_mode_id != target_dataset.embedding_mode_id:
[paragraph.id for paragraph in paragraph_list], model = get_embedding_model_by_dataset_id(target_dataset_id)
target_dataset_id))
pid_list = [paragraph.id for paragraph in paragraph_list]
# 修改段落信息 # 修改段落信息
paragraph_list.update(dataset_id=target_dataset_id) paragraph_list.update(dataset_id=target_dataset_id)
# 修改向量信息
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
pid_list,
target_dataset_id, model))
@staticmethod @staticmethod
def get_target_dataset_problem(target_dataset_id: str, def get_target_dataset_problem(target_dataset_id: str,
@ -392,7 +398,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
problem_paragraph_mapping_list) > 0 else None problem_paragraph_mapping_list) > 0 else None
# 向量化 # 向量化
if with_embedding: if with_embedding:
ListenerManagement.embedding_by_document_signal.send(document_id) model = get_embedding_model_by_dataset_id(dataset_id=document.dataset_id)
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
else: else:
document.status = Status.error document.status = Status.error
document.save() document.save()
@ -405,6 +412,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
class Operate(ApiMixin, serializers.Serializer): class Operate(ApiMixin, serializers.Serializer):
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"文档id")) "文档id"))
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("数据集id"))
@staticmethod @staticmethod
def get_request_params_api(): def get_request_params_api():
@ -530,7 +538,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
document_id = self.data.get("document_id") document_id = self.data.get("document_id")
ListenerManagement.embedding_by_document_signal.send(document_id) model = get_embedding_model_by_dataset_id(dataset_id=self.data.get('dataset_id'))
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
@transaction.atomic @transaction.atomic
def delete(self): def delete(self):
@ -599,8 +608,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
return True return True
@staticmethod @staticmethod
def post_embedding(result, document_id): def post_embedding(result, document_id, dataset_id):
ListenerManagement.embedding_by_document_signal.send(document_id) model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
return result return result
@staticmethod @staticmethod
@ -646,7 +656,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
document_id = str(document_model.id) document_id = str(document_model.id)
return DocumentSerializers.Operate( return DocumentSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': document_id}).one( data={'dataset_id': dataset_id, 'document_id': document_id}).one(
with_valid=True), document_id with_valid=True), document_id, dataset_id
@staticmethod @staticmethod
def get_sync_handler(dataset_id): def get_sync_handler(dataset_id):
@ -803,9 +813,10 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api()) return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api())
@staticmethod @staticmethod
def post_embedding(document_list): def post_embedding(document_list, dataset_id):
for document_dict in document_list: for document_dict in document_list:
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id')) model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'), embedding_model=model)
return document_list return document_list
@post(post_function=post_embedding) @post(post_function=post_embedding)
@ -846,7 +857,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
return [], return [],
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]}) query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
return native_search(query_set, select_string=get_file_content( return native_search(query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False), os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')),
with_search_one=False), dataset_id
@staticmethod @staticmethod
def _batch_sync(document_id_list: List[str]): def _batch_sync(document_id_list: List[str]):

View File

@ -20,9 +20,9 @@ from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin from common.mixins.api_mixin import ApiMixin
from common.util.common import post from common.util.common import post
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \ from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
ProblemParagraphManage ProblemParagraphManage, get_embedding_model_by_dataset_id, get_embedding_model_by_dataset
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
from embedding.models import SourceType from embedding.models import SourceType
@ -132,6 +132,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
paragraph_id=self.data.get('paragraph_id'), paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id')) dataset_id=self.data.get('dataset_id'))
problem_paragraph_mapping.save() problem_paragraph_mapping.save()
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
if with_embedding: if with_embedding:
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content, ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
'is_active': True, 'is_active': True,
@ -140,7 +141,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
'document_id': self.data.get('document_id'), 'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'), 'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'), 'dataset_id': self.data.get('dataset_id'),
}) }, embedding_model=model)
return ProblemSerializers.Operate( return ProblemSerializers.Operate(
data={'dataset_id': self.data.get('dataset_id'), data={'dataset_id': self.data.get('dataset_id'),
@ -227,6 +228,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
problem_id=problem.id) problem_id=problem.id)
problem_paragraph_mapping.save() problem_paragraph_mapping.save()
if with_embedding: if with_embedding:
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content, ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
'is_active': True, 'is_active': True,
'source_type': SourceType.PROBLEM, 'source_type': SourceType.PROBLEM,
@ -234,7 +236,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
'document_id': self.data.get('document_id'), 'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'), 'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'), 'dataset_id': self.data.get('dataset_id'),
}) }, embedding_model=model)
def un_association(self, with_valid=True): def un_association(self, with_valid=True):
if with_valid: if with_valid:
@ -336,10 +338,11 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
# 修改mapping # 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list, QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
['document_id']) ['document_id'])
# 修改向量段落信息 # 修改向量段落信息
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs( ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
[paragraph.id for paragraph in paragraph_list], [paragraph.id for paragraph in paragraph_list],
target_document_id, target_dataset_id)) target_document_id, target_dataset_id, target_embedding_model=None))
# 修改段落信息 # 修改段落信息
paragraph_list.update(document_id=target_document_id) paragraph_list.update(document_id=target_document_id)
# 不同数据集迁移 # 不同数据集迁移
@ -366,12 +369,19 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
# 修改mapping # 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list, QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
['problem_id', 'dataset_id', 'document_id']) ['problem_id', 'dataset_id', 'document_id'])
# 修改向量段落信息 target_dataset = QuerySet(DataSet).filter(id=target_dataset_id).first()
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs( dataset = QuerySet(DataSet).filter(id=dataset_id).first()
[paragraph.id for paragraph in paragraph_list], embedding_model = None
target_document_id, target_dataset_id)) if target_dataset.embedding_mode_id != dataset.embedding_mode_id:
embedding_model = get_embedding_model_by_dataset(target_dataset)
pid_list = [paragraph.id for paragraph in paragraph_list]
# 修改段落信息 # 修改段落信息
paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id) paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id)
# 修改向量段落信息
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
pid_list,
target_document_id, target_dataset_id, target_embedding_model=embedding_model))
update_document_char_length(document_id) update_document_char_length(document_id)
update_document_char_length(target_document_id) update_document_char_length(target_document_id)
@ -454,13 +464,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
raise AppApiException(500, "段落id不存在") raise AppApiException(500, "段落id不存在")
@staticmethod @staticmethod
def post_embedding(paragraph, instance): def post_embedding(paragraph, instance, dataset_id):
if 'is_active' in instance and instance.get('is_active') is not None: if 'is_active' in instance and instance.get('is_active') is not None:
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get( s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal) 'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
s.send(paragraph.get('id')) s.send(paragraph.get('id'))
else: else:
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id')) model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'), embedding_model=model)
return paragraph return paragraph
@post(post_embedding) @post(post_embedding)
@ -508,7 +519,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
_paragraph.save() _paragraph.save()
update_document_char_length(self.data.get('document_id')) update_document_char_length(self.data.get('document_id'))
return self.one(), instance return self.one(), instance, self.data.get('dataset_id')
def get_problem_list(self): def get_problem_list(self):
ProblemParagraphMapping(ProblemParagraphMapping) ProblemParagraphMapping(ProblemParagraphMapping)
@ -582,7 +593,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
# 修改长度 # 修改长度
update_document_char_length(document_id) update_document_char_length(document_id)
if with_embedding: if with_embedding:
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id)) model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id), embedding_model=model)
return ParagraphSerializers.Operate( return ParagraphSerializers.Operate(
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one( data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
with_valid=True) with_valid=True)

View File

@ -20,7 +20,8 @@ from common.event import ListenerManagement, UpdateProblemArgs
from common.mixins.api_mixin import ApiMixin from common.mixins.api_mixin import ApiMixin
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from dataset.models import Problem, Paragraph, ProblemParagraphMapping from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
@ -157,6 +158,8 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
content = instance.get('content') content = instance.get('content')
problem = QuerySet(Problem).filter(id=problem_id, problem = QuerySet(Problem).filter(id=problem_id,
dataset_id=dataset_id).first() dataset_id=dataset_id).first()
QuerySet(DataSet).filter(id=dataset_id)
problem.content = content problem.content = content
problem.save() problem.save()
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content)) model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content, model))

View File

@ -52,7 +52,6 @@ class Dataset(APIView):
@action(methods=['POST'], detail=False) @action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建QA知识库", @swagger_auto_schema(operation_summary="创建QA知识库",
operation_id="创建QA知识库", operation_id="创建QA知识库",
manual_parameters=DataSetSerializers.Create.CreateQASerializers.get_request_params_api(), manual_parameters=DataSetSerializers.Create.CreateQASerializers.get_request_params_api(),
responses=get_api_response( responses=get_api_response(
DataSetSerializers.Create.CreateQASerializers.get_response_body_api()), DataSetSerializers.Create.CreateQASerializers.get_response_body_api()),

View File

@ -10,9 +10,8 @@ import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Dict from typing import List, Dict
from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_core.embeddings import Embeddings
from common.config.embedding_config import EmbeddingModel
from common.util.common import sub_array from common.util.common import sub_array
from embedding.models import SourceType, SearchMode from embedding.models import SourceType, SearchMode
@ -51,7 +50,7 @@ class BaseVectorStore(ABC):
def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool, is_active: bool,
embedding=None): embedding: Embeddings):
""" """
插入向量数据 插入向量数据
:param source_id: 资源id :param source_id: 资源id
@ -64,13 +63,10 @@ class BaseVectorStore(ABC):
:param paragraph_id 段落id :param paragraph_id 段落id
:return: bool :return: bool
""" """
if embedding is None:
embedding = EmbeddingModel.get_embedding_model()
self.save_pre_handler() self.save_pre_handler()
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding) self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
def batch_save(self, data_list: List[Dict], embedding=None): def batch_save(self, data_list: List[Dict], embedding: Embeddings):
# 获取锁 # 获取锁
lock.acquire() lock.acquire()
try: try:
@ -80,8 +76,6 @@ class BaseVectorStore(ABC):
:param embedding: 向量化处理器 :param embedding: 向量化处理器
:return: bool :return: bool
""" """
if embedding is None:
embedding = EmbeddingModel.get_embedding_model()
self.save_pre_handler() self.save_pre_handler()
result = sub_array(data_list) result = sub_array(data_list)
for child_array in result: for child_array in result:
@ -94,17 +88,17 @@ class BaseVectorStore(ABC):
@abstractmethod @abstractmethod
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool, is_active: bool,
embedding: HuggingFaceEmbeddings): embedding: Embeddings):
pass pass
@abstractmethod @abstractmethod
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings): def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
pass pass
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], exclude_paragraph_list: list[str],
is_active: bool, is_active: bool,
embedding: HuggingFaceEmbeddings): embedding: Embeddings):
if dataset_id_list is None or len(dataset_id_list) == 0: if dataset_id_list is None or len(dataset_id_list) == 0:
return [] return []
embedding_query = embedding.embed_query(query_text) embedding_query = embedding.embed_query(query_text)
@ -123,7 +117,7 @@ class BaseVectorStore(ABC):
def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int, def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float, similarity: float,
search_mode: SearchMode, search_mode: SearchMode,
embedding: HuggingFaceEmbeddings): embedding: Embeddings):
pass pass
@abstractmethod @abstractmethod
@ -142,14 +136,6 @@ class BaseVectorStore(ABC):
def update_by_source_ids(self, source_ids: List[str], instance: Dict): def update_by_source_ids(self, source_ids: List[str], instance: Dict):
pass pass
@abstractmethod
def embed_documents(self, text_list: List[str]):
pass
@abstractmethod
def embed_query(self, text: str):
pass
@abstractmethod @abstractmethod
def delete_by_dataset_id(self, dataset_id: str): def delete_by_dataset_id(self, dataset_id: str):
pass pass

View File

@ -13,9 +13,8 @@ from abc import ABC, abstractmethod
from typing import Dict, List from typing import Dict, List
from django.db.models import QuerySet from django.db.models import QuerySet
from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_core.embeddings import Embeddings
from common.config.embedding_config import EmbeddingModel
from common.db.search import generate_sql_by_query_dict from common.db.search import generate_sql_by_query_dict
from common.db.sql_execute import select_list from common.db.sql_execute import select_list
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
@ -33,14 +32,6 @@ class PGVector(BaseVectorStore):
def update_by_source_ids(self, source_ids: List[str], instance: Dict): def update_by_source_ids(self, source_ids: List[str], instance: Dict):
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance) QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
def embed_documents(self, text_list: List[str]):
embedding = EmbeddingModel.get_embedding_model()
return embedding.embed_documents(text_list)
def embed_query(self, text: str):
embedding = EmbeddingModel.get_embedding_model()
return embedding.embed_query(text)
def vector_is_create(self) -> bool: def vector_is_create(self) -> bool:
# 项目启动默认是创建好的 不需要再创建 # 项目启动默认是创建好的 不需要再创建
return True return True
@ -50,7 +41,7 @@ class PGVector(BaseVectorStore):
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool, is_active: bool,
embedding: HuggingFaceEmbeddings): embedding: Embeddings):
text_embedding = embedding.embed_query(text) text_embedding = embedding.embed_query(text)
embedding = Embedding(id=uuid.uuid1(), embedding = Embedding(id=uuid.uuid1(),
dataset_id=dataset_id, dataset_id=dataset_id,
@ -64,7 +55,7 @@ class PGVector(BaseVectorStore):
embedding.save() embedding.save()
return True return True
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings): def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
texts = [row.get('text') for row in text_list] texts = [row.get('text') for row in text_list]
embeddings = embedding.embed_documents(texts) embeddings = embedding.embed_documents(texts)
embedding_list = [Embedding(id=uuid.uuid1(), embedding_list = [Embedding(id=uuid.uuid1(),
@ -83,7 +74,7 @@ class PGVector(BaseVectorStore):
def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int, def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float, similarity: float,
search_mode: SearchMode, search_mode: SearchMode,
embedding: HuggingFaceEmbeddings): embedding: Embeddings):
if dataset_id_list is None or len(dataset_id_list) == 0: if dataset_id_list is None or len(dataset_id_list) == 0:
return [] return []
exclude_dict = {} exclude_dict = {}

View File

@ -0,0 +1,46 @@
# Generated by Django 4.2.13 on 2024-07-15 15:23
import json
from django.db import migrations, models
from django.db.models import QuerySet
from common.util.rsa_util import rsa_long_encrypt
from setting.models import Status, PermissionType
from smartdoc.const import CONFIG
default_embedding_model_id = '42f63a3d-427e-11ef-b3ec-a8a1595801ab'
def save_default_embedding_model(apps, schema_editor):
ModelModel = apps.get_model('setting', 'Model')
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH')
model_name = CONFIG.get('EMBEDDING_MODEL_NAME')
credential = {'cache_folder': cache_folder}
model_credential_str = json.dumps(credential)
model = ModelModel(id=default_embedding_model_id, name='maxkb-embedding', status=Status.SUCCESS,
model_type="EMBEDDING", model_name=model_name, user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab',
provider='model_local_provider',
credential=rsa_long_encrypt(model_credential_str), meta={},
permission_type=PermissionType.PUBLIC)
model.save()
def reverse_code_embedding_model(apps, schema_editor):
ModelModel = apps.get_model('setting', 'Model')
QuerySet(ModelModel).filter(id=default_embedding_model_id).delete()
class Migration(migrations.Migration):
dependencies = [
('setting', '0004_alter_model_credential'),
]
operations = [
migrations.AddField(
model_name='model',
name='permission_type',
field=models.CharField(choices=[('PUBLIC', '公开'), ('PRIVATE', '私有')], default='PRIVATE', max_length=20,
verbose_name='权限类型'),
),
migrations.RunPython(save_default_embedding_model, reverse_code_embedding_model)
]

View File

@ -22,6 +22,13 @@ class Status(models.TextChoices):
DOWNLOAD = "DOWNLOAD", '下载中' DOWNLOAD = "DOWNLOAD", '下载中'
PAUSE_DOWNLOAD = "PAUSE_DOWNLOAD", '暂停下载'
class PermissionType(models.TextChoices):
PUBLIC = "PUBLIC", '公开'
PRIVATE = "PRIVATE", "私有"
class Model(AppModelMixin): class Model(AppModelMixin):
""" """
@ -46,6 +53,9 @@ class Model(AppModelMixin):
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict) meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)
permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices,
default=PermissionType.PRIVATE)
class Meta: class Meta:
db_table = "model" db_table = "model"
unique_together = ['name', 'user_id'] unique_together = ['name', 'user_id']

View File

@ -6,3 +6,85 @@
@date2023/10/31 17:16 @date2023/10/31 17:16
@desc: @desc:
""" """
import json
from typing import Dict
from common.util.rsa_util import rsa_long_decrypt
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
def get_model_(provider, model_type, model_name, credential):
"""
获取模型实例
@param provider: 供应商
@param model_type: 模型类型
@param model_name: 模型名称
@param credential: 认证信息
@return: 模型实例
"""
model = get_provider(provider).get_model(model_type, model_name,
json.loads(
rsa_long_decrypt(credential)),
streaming=True)
return model
def get_model(model):
"""
获取模型实例
@param model: model 数据库Model实例对象
@return: 模型实例
"""
return get_model_(model.provider, model.model_type, model.model_name, model.credential)
def get_provider(provider):
"""
获取供应商实例
@param provider: 供应商字符串
@return: 供应商实例
"""
return ModelProvideConstants[provider].value
def get_model_list(provider, model_type):
"""
获取模型列表
@param provider: 供应商字符串
@param model_type: 模型类型
@return: 模型列表
"""
return get_provider(provider).get_model_list(model_type)
def get_model_credential(provider, model_type, model_name):
"""
获取模型认证实例
@param provider: 供应商字符串
@param model_type: 模型类型
@param model_name: 模型名称
@return: 认证实例对象
"""
return get_provider(provider).get_model_credential(model_type, model_name)
def get_model_type_list(provider):
"""
获取模型类型列表
@param provider: 供应商字符串
@return: 模型类型列表
"""
return get_provider(provider).get_model_type_list()
def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
"""
校验模型认证参数
@param provider: 供应商字符串
@param model_type: 模型类型
@param model_name: 模型名称
@param model_credential: 模型认证数据
@param raise_exception: 是否抛出错误
@return: True|False
"""
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception)

View File

@ -61,7 +61,7 @@ class IModelProvider(ABC):
def get_model_list(self, model_type): def get_model_list(self, model_type):
if model_type is None: if model_type is None:
raise AppApiException(500, '模型类型不能为空') raise AppApiException(500, '模型类型不能为空')
return self.get_model_info_manage().get_model_list() return self.get_model_info_manage().get_model_list_by_model_type(model_type)
def get_model_credential(self, model_type, model_name): def get_model_credential(self, model_type, model_name):
model_info = self.get_model_info_manage().get_model_info(model_type, model_name) model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
@ -191,6 +191,9 @@ class ModelInfoManage:
def get_model_list(self): def get_model_list(self):
return [model.to_dict() for model in self.model_list] return [model.to_dict() for model in self.model_list]
def get_model_list_by_model_type(self, model_type):
return [model.to_dict() for model in self.model_list if model.model_type == model_type]
def get_model_type_list(self): def get_model_type_list(self):
return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if
len([model for model in self.model_list if model.model_type == _type.name]) > 0] len([model for model in self.model_list if model.model_type == _type.name]) > 0]

View File

@ -18,6 +18,7 @@ from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
from setting.models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
class ModelProvideConstants(Enum): class ModelProvideConstants(Enum):
@ -31,3 +32,4 @@ class ModelProvideConstants(Enum):
model_xf_provider = XunFeiModelProvider() model_xf_provider = XunFeiModelProvider()
model_deepseek_provider = DeepSeekModelProvider() model_deepseek_provider = DeepSeekModelProvider()
model_gemini_provider = GeminiModelProvider() model_gemini_provider = GeminiModelProvider()
model_local_provider = LocalModelProvider()

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py
@date2024/7/10 17:48
@desc:
"""

View File

@ -0,0 +1,45 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/11 11:06
@desc:
"""
from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
if not model_type == 'EMBEDDING':
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['cache_folder']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
model.embed_query('你好')
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return model
cache_folder = forms.TextInputField('模型目录', required=True)

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" t="1720668342208" class="icon" viewBox="0 0 1024 1024" version="1.1" p-id="9052" width="100%" height="100%"><path d="M512.2 475.7c-8.2-0.3-16.1-2.4-23.4-6.1L192.6 330.2c-24.9-11.1-36.1-40.3-25-65.2 5-11.2 13.9-20.1 25-25l281.5-133.2a89.43 89.43 0 0 1 76 0L831.7 240c24.9 11.1 36.1 40.3 25 65.2-5 11.2-13.9 20.1-25 25L535.5 469.5c-7.2 3.8-15.2 5.9-23.3 6.2z m-76.5 452.5c-7.6 0-15.1-1.9-21.8-5.5L146.3 797.2c-17-8.9-27.5-26.5-27.3-45.6v-320c0.1-18 9.7-34.5 25.1-43.7 14.3-8.1 31.8-8.1 46.1 0l267.1 125.4c16.1 8.7 26.4 25.4 27.1 43.7v320.4c-0.2 17.9-9.6 34.4-24.9 43.7-7.2 4.3-15.4 6.8-23.8 7.1z m152.9 0c-8.3 0-16.5-2.2-23.8-6.3-15.3-9.3-24.7-25.8-24.9-43.7V556.9c0.4-18.2 10.4-34.8 26.2-43.7L835 387c14.2-7.5 31.4-7.1 45.2 1.1 15.5 9.1 25 25.7 25.1 43.7v319.8c0.4 18.9-9.7 36.5-26.2 45.6L610.5 922.8c-6.8 3.6-14.3 5.5-21.9 5.4z" p-id="9053"/></svg>

After

Width:  |  Height:  |  Size: 931 B

View File

@ -0,0 +1,38 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file zhipu_model_provider.py
@date2024/04/19 13:5
@desc:
"""
import os
from typing import Dict
from pydantic import BaseModel
from common.exception.app_exception import AppApiException
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfoManage
from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
from smartdoc.conf import PROJECT_DIR
embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
LocalEmbeddingCredential(), LocalEmbedding)
model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese)
.append_default_model_info(embedding_text2vec_base_chinese)
.build())
class LocalModelProvider(IModelProvider):
def get_model_info_manage(self):
return model_info_manage
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_local_provider', name='本地模型', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'local_model_provider', 'icon',
'local_icon_svg')))

View File

@ -0,0 +1,22 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/11 14:06
@desc:
"""
from typing import Dict
from langchain_huggingface import HuggingFaceEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
model_kwargs={'device': model_credential.get('device')},
encode_kwargs={'normalize_embeddings': True},
)

View File

@ -0,0 +1,45 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 15:10
@desc:
"""
from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
class OllamaEmbeddingModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
try:
model_list = provider.get_base_model_list(model_credential.get('api_base'))
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
if len(exist) == 0:
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
model.embed_query('你好')
return True
def encryption_dict(self, model_info: Dict[str, object]):
return model_info
def build_model(self, model_info: Dict[str, object]):
for key in ['model']:
if key not in model_info:
raise AppApiException(500, f'{key} 字段为必填字段')
return self
api_base = forms.TextInputField('API 域名', required=True)

View File

@ -0,0 +1,48 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 15:02
@desc:
"""
from typing import Dict, List
from langchain_community.embeddings import OllamaEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
class OllamaEmbedding(MaxKBBaseModel, OllamaEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return OllamaEmbedding(
model=model_name,
base_url=model_credential.get('api_base'),
)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed documents using an Ollama deployed embedding model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
instruction_pairs = [f"{text}" for text in texts]
embeddings = self._embed(instruction_pairs)
return embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed a query using a Ollama deployed embedding model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
instruction_pair = f"{text}"
embedding = self._embed([instruction_pair])[0]
return embedding

View File

@ -20,7 +20,9 @@ from common.forms import BaseForm
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode, ModelInfoManage BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode, ModelInfoManage
from setting.models_provider.impl.ollama_model_provider.credential.embedding import OllamaEmbeddingModelCredential
from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential
from setting.models_provider.impl.ollama_model_provider.model.embedding import OllamaEmbedding
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
@ -88,14 +90,25 @@ model_info_list = [
ModelInfo( ModelInfo(
'phi3', 'phi3',
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。', 'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel) ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
]
ollama_embedding_model_credential = OllamaEmbeddingModelCredential()
embedding_model_info = [
ModelInfo(
'nomic-embed-text',
'一个具有大令牌上下文窗口的高性能开放嵌入模型。',
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding),
] ]
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_model_info_list(
embedding_model_info).append_default_model_info(
ModelInfo( ModelInfo(
'phi3', 'phi3',
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。', 'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).build() ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).append_default_model_info(ModelInfo(
'nomic-embed-text',
'一个具有大令牌上下文窗口的高性能开放嵌入模型。',
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), ).build()
def get_base_url(url: str): def get_base_url(url: str):
@ -139,7 +152,6 @@ def convert(response_stream) -> Iterator[DownModelChunk]:
temp = "" temp = ""
if len(temp) > 0: if len(temp) > 0:
print(temp)
rows = [t for t in temp.split("\n") if len(t) > 0] rows = [t for t in temp.split("\n") if len(t) > 0]
for row in rows: for row in rows:
yield convert_to_down_model_chunk(row, index) yield convert_to_down_model_chunk(row, index)
@ -154,9 +166,6 @@ class OllamaModelProvider(IModelProvider):
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon', os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon',
'ollama_icon_svg'))) 'ollama_icon_svg')))
def get_dialogue_number(self):
return 2
@staticmethod @staticmethod
def get_base_model_list(api_base): def get_base_model_list(api_base):
base_url = get_base_url(api_base) base_url = get_base_url(api_base)
@ -165,7 +174,7 @@ class OllamaModelProvider(IModelProvider):
return r.json() return r.json()
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]: def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
api_base = model_credential.get('api_base') api_base = model_credential.get('api_base', '')
base_url = get_base_url(api_base) base_url = get_base_url(api_base)
r = requests.request( r = requests.request(
method="POST", method="POST",

View File

@ -0,0 +1,46 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 16:45
@desc:
"""
from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=True):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.embed_query('你好')
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -0,0 +1,23 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 17:44
@desc:
"""
from typing import Dict
from langchain_community.embeddings import OpenAIEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
class OpenAIEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return OpenAIEmbeddingModel(
api_key=model_credential.get('api_key'),
model=model_name,
openai_api_base=model_credential.get('api_base'),
)

View File

@ -11,7 +11,9 @@ import os
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \ from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
ModelTypeConst, ModelInfoManage ModelTypeConst, ModelInfoManage
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
@ -58,11 +60,17 @@ model_info_list = [
ModelTypeConst.LLM, openai_llm_model_credential, ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel) OpenAIChatModel)
] ]
open_ai_embedding_credential = OpenAIEmbeddingCredential()
model_info_embedding_list = [
ModelInfo('text-embedding-ada-002', '',
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
OpenAIEmbeddingModel)]
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo随OpenAI调整而更新', ModelTypeConst.LLM, ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential, OpenAIChatModel openai_llm_model_credential, OpenAIChatModel
)).build() )).append_model_info_list(model_info_embedding_list).append_default_model_info(
model_info_embedding_list[0]).build()
class OpenAIModelProvider(IModelProvider): class OpenAIModelProvider(IModelProvider):

View File

@ -7,12 +7,14 @@
@desc: @desc:
""" """
import json import json
import re
import threading import threading
import time import time
import uuid import uuid
from typing import Dict from typing import Dict
from django.db.models import QuerySet from django.core import validators
from django.db.models import QuerySet, Q
from rest_framework import serializers from rest_framework import serializers
from application.models import Application from application.models import Application
@ -36,6 +38,9 @@ class ModelPullManage:
for chunk in response: for chunk in response:
down_model_chunk[chunk.digest] = chunk.to_dict() down_model_chunk[chunk.digest] = chunk.to_dict()
if time.time() - timestamp > 5: if time.time() - timestamp > 5:
model_new = QuerySet(Model).filter(id=model.id).first()
if model_new.status == Status.PAUSE_DOWNLOAD:
return
QuerySet(Model).filter(id=model.id).update( QuerySet(Model).filter(id=model.id).update(
meta={"down_model_chunk": list(down_model_chunk.values())}) meta={"down_model_chunk": list(down_model_chunk.values())})
timestamp = time.time() timestamp = time.time()
@ -72,7 +77,7 @@ class ModelSerializer(serializers.Serializer):
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
user_id = self.data.get('user_id') user_id = self.data.get('user_id')
name = self.data.get('name') name = self.data.get('name')
model_query_set = QuerySet(Model).filter(user_id=user_id) model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC')))
query_params = {} query_params = {}
if name is not None: if name is not None:
query_params['name__contains'] = name query_params['name__contains'] = name
@ -85,7 +90,8 @@ class ModelSerializer(serializers.Serializer):
return [ return [
{'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name, 'status': model.status, 'meta': model.meta} for model in 'model_name': model.model_name, 'status': model.status, 'meta': model.meta,
'permission_type': model.permission_type} for model in
model_query_set.filter(**query_params).order_by("-create_time")] model_query_set.filter(**query_params).order_by("-create_time")]
class Edit(serializers.Serializer): class Edit(serializers.Serializer):
@ -96,6 +102,11 @@ class ModelSerializer(serializers.Serializer):
model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型")) model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
permission_type = serializers.CharField(required=False, error_messages=ErrMessage.char("权限"), validators=[
validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"),
message="权限只支持PUBLIC|PRIVATE", code=500)
])
model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型")) model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
credential = serializers.DictField(required=False, error_messages=ErrMessage.dict("认证信息")) credential = serializers.DictField(required=False, error_messages=ErrMessage.dict("认证信息"))
@ -135,6 +146,11 @@ class ModelSerializer(serializers.Serializer):
model_type = serializers.CharField(required=True, error_messages=ErrMessage.char("模型类型")) model_type = serializers.CharField(required=True, error_messages=ErrMessage.char("模型类型"))
permission_type = serializers.CharField(required=True, error_messages=ErrMessage.char("权限"), validators=[
validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"),
message="权限只支持PUBLIC|PRIVATE", code=500)
])
model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型")) model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型"))
credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息")) credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息"))
@ -165,10 +181,12 @@ class ModelSerializer(serializers.Serializer):
provider = self.data.get('provider') provider = self.data.get('provider')
model_type = self.data.get('model_type') model_type = self.data.get('model_type')
model_name = self.data.get('model_name') model_name = self.data.get('model_name')
permission_type = self.data.get('permission_type')
model_credential_str = json.dumps(credential) model_credential_str = json.dumps(credential)
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name, model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
credential=rsa_long_encrypt(model_credential_str), credential=rsa_long_encrypt(model_credential_str),
provider=provider, model_type=model_type, model_name=model_name) provider=provider, model_type=model_type, model_name=model_name,
permission_type=permission_type)
model.save() model.save()
if status == Status.DOWNLOAD: if status == Status.DOWNLOAD:
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential)) thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
@ -184,7 +202,8 @@ class ModelSerializer(serializers.Serializer):
'meta': model.meta, 'meta': model.meta,
'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type, 'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
model.model_name).encryption_dict( model.model_name).encryption_dict(
credential)} credential),
'permission_type': model.permission_type}
class Operate(serializers.Serializer): class Operate(serializers.Serializer):
id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
@ -210,7 +229,8 @@ class ModelSerializer(serializers.Serializer):
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name, 'model_name': model.model_name,
'status': model.status, 'status': model.status,
'meta': model.meta, } 'meta': model.meta
}
def delete(self, with_valid=True): def delete(self, with_valid=True):
if with_valid: if with_valid:
@ -221,6 +241,12 @@ class ModelSerializer(serializers.Serializer):
QuerySet(Model).filter(id=self.data.get('id')).delete() QuerySet(Model).filter(id=self.data.get('id')).delete()
return True return True
def pause_download(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
QuerySet(Model).filter(id=self.data.get('id')).update(status=Status.PAUSE_DOWNLOAD)
return True
def edit(self, instance: Dict, user_id: str, with_valid=True): def edit(self, instance: Dict, user_id: str, with_valid=True):
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
@ -245,7 +271,7 @@ class ModelSerializer(serializers.Serializer):
model.status = Status.DOWNLOAD model.status = Status.DOWNLOAD
else: else:
raise e raise e
update_keys = ['credential', 'name', 'model_type', 'model_name'] update_keys = ['credential', 'name', 'model_type', 'model_name', 'permission_type']
for update_key in update_keys: for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None: if update_key in instance and instance.get(update_key) is not None:
if update_key == 'credential': if update_key == 'credential':

View File

@ -74,6 +74,8 @@ class ModelCreateApi(ApiMixin):
'provider': openapi.Schema(type=openapi.TYPE_STRING, 'provider': openapi.Schema(type=openapi.TYPE_STRING,
title="供应商", title="供应商",
description="供应商"), description="供应商"),
'permission_type': openapi.Schema(type=openapi.TYPE_STRING, title="权限",
description="PUBLIC|PRIVATE"),
'model_type': openapi.Schema(type=openapi.TYPE_STRING, 'model_type': openapi.Schema(type=openapi.TYPE_STRING,
title="供应商", title="供应商",
description="供应商"), description="供应商"),
@ -82,7 +84,8 @@ class ModelCreateApi(ApiMixin):
description="供应商"), description="供应商"),
'credential': openapi.Schema(type=openapi.TYPE_OBJECT, 'credential': openapi.Schema(type=openapi.TYPE_OBJECT,
title="模型证书信息", title="模型证书信息",
description="模型证书信息") description="模型证书信息"),
} }
) )

View File

@ -16,6 +16,7 @@ urlpatterns = [
name="provider/model_form"), name="provider/model_form"),
path('model', views.Model.as_view(), name='model'), path('model', views.Model.as_view(), name='model'),
path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'), path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
path('model/<str:model_id>/pause_download', views.Model.PauseDownload.as_view(), name='model/operate'),
path('model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'), path('model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),
path('email_setting', views.SystemSetting.Email.as_view(), name='email_setting'), path('email_setting', views.SystemSetting.Email.as_view(), name='email_setting'),
path('valid/<str:valid_type>/<int:valid_count>', views.Valid.as_view()) path('valid/<str:valid_type>/<int:valid_count>', views.Valid.as_view())

View File

@ -69,6 +69,18 @@ class Model(APIView):
return result.success( return result.success(
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one_meta(with_valid=True)) ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one_meta(with_valid=True))
class PauseDownload(APIView):
authentication_classes = [TokenAuth]
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="暂停模型下载",
operation_id="暂停模型下载",
tags=["模型"])
@has_permissions(PermissionConstants.MODEL_CREATE)
def put(self, request: Request, model_id: str):
return result.success(
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).pause_download())
class Operate(APIView): class Operate(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]

View File

@ -104,9 +104,6 @@ CACHES = {
"token_cache": { "token_cache": {
'BACKEND': 'common.cache.file_cache.FileCache', 'BACKEND': 'common.cache.file_cache.FileCache',
'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径 'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径
},
"chat_cache": {
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
} }
} }

View File

@ -129,11 +129,12 @@ const getDatasetDetail: (dataset_id: string, loading?: Ref<boolean>) => Promise<
"desc": true "desc": true
} }
*/ */
const putDataset: (dataset_id: string, data: any) => Promise<Result<any>> = ( const putDataset: (
dataset_id, dataset_id: string,
data: any data: any,
) => { loading?: Ref<boolean>
return put(`${prefix}/${dataset_id}`, data) ) => Promise<Result<any>> = (dataset_id, data, loading) => {
return put(`${prefix}/${dataset_id}`, data, undefined, loading)
} }
/** /**
* *

View File

@ -130,7 +130,18 @@ const getModelMetaById: (model_id: string, loading?: Ref<boolean>) => Promise<Re
) => { ) => {
return get(`${prefix}/${model_id}/meta`, {}, loading) return get(`${prefix}/${model_id}/meta`, {}, loading)
} }
/**
*
* @param model_id id
* @param loading
* @returns
*/
const pauseDownload: (model_id: string, loading?: Ref<boolean>) => Promise<Result<boolean>> = (
model_id,
loading
) => {
return put(`${prefix}/${model_id}/pause_download`, undefined, {}, loading)
}
const deleteModel: (model_id: string, loading?: Ref<boolean>) => Promise<Result<boolean>> = ( const deleteModel: (model_id: string, loading?: Ref<boolean>) => Promise<Result<boolean>> = (
model_id, model_id,
loading loading
@ -147,5 +158,6 @@ export default {
updateModel, updateModel,
deleteModel, deleteModel,
getModelById, getModelById,
getModelMetaById getModelMetaById,
pauseDownload
} }

View File

@ -3,6 +3,7 @@ interface datasetData {
desc: String desc: String
documents?: Array<any> documents?: Array<any>
type?: String type?: String
embedding_mode_id?: String
} }
export type { datasetData } export type { datasetData }

View File

@ -53,6 +53,7 @@ interface Model {
* *
*/ */
model_type: string model_type: string
permission_type: 'PUBLIC' | 'PRIVATE'
/** /**
* *
*/ */
@ -68,7 +69,7 @@ interface Model {
/** /**
* *
*/ */
status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR' status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR' | 'PAUSE_DOWNLOAD'
/** /**
* *
*/ */

View File

@ -18,7 +18,8 @@
</slot> </slot>
<slot></slot> <slot></slot>
</div> </div>
<el-checkbox v-bind:modelValue="modelValue.includes(toModelValue)"> </el-checkbox> <el-checkbox v-bind:modelValue="modelValue.includes(toModelValue)" @change="checkboxChange">
</el-checkbox>
</div> </div>
</el-card> </el-card>
</template> </template>
@ -40,7 +41,7 @@ const toModelValue = computed(() => (props.valueField ? props.data[props.valueFi
// set: (val) => val // set: (val) => val
// }) // })
const emit = defineEmits(['update:modelValue']) const emit = defineEmits(['update:modelValue', 'change'])
const checked = () => { const checked = () => {
const value = props.modelValue ? props.modelValue : [] const value = props.modelValue ? props.modelValue : []
@ -53,6 +54,10 @@ const checked = () => {
emit('update:modelValue', [...value, toModelValue.value]) emit('update:modelValue', [...value, toModelValue.value])
} }
} }
function checkboxChange() {
emit('change')
}
</script> </script>
<style lang="scss" scoped> <style lang="scss" scoped>
.card-checkbox { .card-checkbox {

View File

@ -0,0 +1,93 @@
<template>
<div class="loading-container loader">
<div class="download-loading">
<div></div>
<div></div>
<div></div>
<div></div>
<div></div>
<div></div>
<div></div>
<div></div>
</div>
</div>
</template>
<script setup lang="ts"></script>
<style lang="scss" scoped>
.loading-container {
display: -webkit-flex; /*safari弹性布局*/
justify-content: center;
display: flex;
align-items: center;
}
@-webkit-keyframes loader {
0% {
opacity: 0.3;
}
80% {
opacity: 1;
}
100% {
opacity: 1;
}
}
.download-loading {
position: relative;
}
.download-loading div {
width: 5px;
height: 12px;
background: var(--el-color-info);
position: absolute;
border-radius: 2px;
margin: 0 auto;
}
.download-loading div:nth-child(1) {
top: -20px;
left: 0;
-webkit-animation: loader 1s -0.8s infinite ease-in-out;
}
.download-loading div:nth-child(2) {
top: -13px;
left: 13px;
-webkit-transform: rotate(45deg);
-webkit-animation: loader 1s -0.6s infinite ease-in-out;
}
.download-loading div:nth-child(3) {
top: 0px;
left: 20px;
-webkit-transform: rotate(90deg);
-webkit-animation: loader 1s -0.5s infinite ease-in-out;
}
.download-loading div:nth-child(4) {
top: 13px;
left: 13px;
-webkit-transform: rotate(-45deg);
-webkit-animation: loader 1s -0.4s infinite ease-in-out;
}
.download-loading div:nth-child(5) {
top: 20px;
left: 0px;
-webkit-transform: rotate(0deg);
-webkit-animation: loader 1s -0.3s infinite ease-in-out;
}
.download-loading div:nth-child(6) {
top: 13px;
left: -13px;
-webkit-transform: rotate(45deg);
-webkit-animation: loader 1s -0.2s infinite ease-in-out;
}
.download-loading div:nth-child(7) {
top: 0px;
left: -20px;
-webkit-transform: rotate(90deg);
-webkit-animation: loader 1s -0.1s infinite ease-in-out;
}
.download-loading div:nth-child(8) {
top: -13px;
left: -13px;
-webkit-transform: rotate(-45deg);
-webkit-animation: loader 1s 0s infinite ease-in-out;
}
</style>

8
ui/src/enums/model.ts Normal file
View File

@ -0,0 +1,8 @@
export enum PermissionType {
PRIVATE = '私有',
PUBLIC = '公用'
}
export enum PermissionDesc {
PRIVATE = '仅自己使用',
PUBLIC = '所有用户都可使用,不能编辑'
}

View File

@ -99,7 +99,7 @@
</div> </div>
</template> </template>
<template v-else-if="isDataset"> <template v-else-if="isDataset">
<div class="w-full text-left cursor" @click="router.push({ path: '/dataset/create' })"> <div class="w-full text-left cursor" @click="openCreateDialog">
<el-button link> <el-button link>
<el-icon class="mr-4"><Plus /></el-icon> <el-icon class="mr-4"><Plus /></el-icon>
</el-button> </el-button>
@ -110,12 +110,14 @@
</el-dropdown> </el-dropdown>
</div> </div>
<CreateApplicationDialog ref="CreateApplicationDialogRef" @refresh="refresh" /> <CreateApplicationDialog ref="CreateApplicationDialogRef" @refresh="refresh" />
<CreateDatasetDialog ref="CreateDatasetDialogRef" @refresh="refresh" />
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref, onMounted, computed } from 'vue' import { ref, onMounted, computed } from 'vue'
import { onBeforeRouteLeave, useRouter, useRoute } from 'vue-router' import { onBeforeRouteLeave, useRouter, useRoute } from 'vue-router'
import CreateApplicationDialog from '@/views/application/component/CreateApplicationDialog.vue' import CreateApplicationDialog from '@/views/application/component/CreateApplicationDialog.vue'
import CreateDatasetDialog from '@/views/dataset/component/CreateDatasetDialog.vue'
import { isAppIcon, isWorkFlow } from '@/utils/application' import { isAppIcon, isWorkFlow } from '@/utils/application'
import useStore from '@/stores' import useStore from '@/stores'
const { common, dataset, application } = useStore() const { common, dataset, application } = useStore()
@ -130,6 +132,7 @@ onBeforeRouteLeave((to, from) => {
common.saveBreadcrumb(null) common.saveBreadcrumb(null)
}) })
const CreateDatasetDialogRef = ref()
const CreateApplicationDialogRef = ref() const CreateApplicationDialogRef = ref()
const list = ref<any[]>([]) const list = ref<any[]>([])
const loading = ref(false) const loading = ref(false)
@ -148,7 +151,11 @@ const isDataset = computed(() => {
}) })
function openCreateDialog() { function openCreateDialog() {
if (isDataset.value) {
CreateDatasetDialogRef.value.open()
} else if (isApplication.value) {
CreateApplicationDialogRef.value.open() CreateApplicationDialogRef.value.open()
}
} }
function changeMenu(id: string) { function changeMenu(id: string) {

View File

@ -13,9 +13,9 @@ const datasetRouter = {
}, },
{ {
path: '/dataset/:type', // create 或者 upload path: '/dataset/:type', // create 或者 upload
name: 'CreateDataset', name: 'UploadDocumentDataset',
meta: { activeMenu: '/dataset' }, meta: { activeMenu: '/dataset' },
component: () => import('@/views/dataset/CreateDataset.vue'), component: () => import('@/views/dataset/UploadDocumentDataset.vue'),
hidden: true hidden: true
}, },
{ {

View File

@ -1,11 +1,11 @@
import { defineStore } from 'pinia' import { defineStore } from 'pinia'
import modelApi from '@/api/model' import modelApi from '@/api/model'
import type { modelRequest, Provider } from '@/api/type/model' import type { ListModelRequest, Provider } from '@/api/type/model'
const useModelStore = defineStore({ const useModelStore = defineStore({
id: 'model', id: 'model',
state: () => ({}), state: () => ({}),
actions: { actions: {
async asyncGetModel(data?: modelRequest) { async asyncGetModel(data?: ListModelRequest) {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
modelApi modelApi
.getModel(data) .getModel(data)

View File

@ -377,6 +377,11 @@ h5 {
color: var(--el-color-primary); color: var(--el-color-primary);
border: none; border: none;
} }
.danger-tag {
background: var(--tag-danger-bg);
color: #d03f3b;
border: none;
}
.success-tag { .success-tag {
background: var(--tag-success-bg); background: var(--tag-success-bg);
color: var(--el-color-success); color: var(--el-color-success);
@ -388,6 +393,12 @@ h5 {
border: none; border: none;
} }
.info-tag {
background: var(--app-text-color-light-1);
color: var(--app-text-color-secondary);
border: none;
}
.purple-tag { .purple-tag {
background: #f2ebfe; background: #f2ebfe;
color: #7f3bf5; color: #7f3bf5;

View File

@ -30,6 +30,7 @@
--tag-success-color: #2ca91f; --tag-success-color: #2ca91f;
--tag-warning-bg: rgba(255, 136, 0, 0.2); --tag-warning-bg: rgba(255, 136, 0, 0.2);
--tag-warning-color: #d97400; --tag-warning-color: #d97400;
--tag-danger-bg: rgba(245, 74, 69, 0.2);
/** card */ /** card */
--card-width: 330px; --card-width: 330px;

View File

@ -4,21 +4,28 @@
v-model="dialogVisible" v-model="dialogVisible"
width="600" width="600"
append-to-body append-to-body
class="addDataset-dialog"
> >
<template #header="{ titleId, titleClass }"> <template #header="{ titleId, titleClass }">
<div class="my-header flex"> <div class="flex-between mb-8">
<h4 :id="titleId" :class="titleClass"> <h4 :id="titleId" :class="titleClass">
{{ $t('views.application.applicationForm.dialogues.addDataset') }} {{ $t('views.application.applicationForm.dialogues.addDataset') }}
</h4> </h4>
<div class="flex align-center">
<el-button link class="ml-16" @click="refresh"> <el-button link class="ml-16" @click="refresh">
<el-icon class="mr-4"><Refresh /></el-icon <el-icon class="mr-4"><Refresh /></el-icon
>{{ $t('views.application.applicationForm.dialogues.refresh') }} >{{ $t('views.application.applicationForm.dialogues.refresh') }}
</el-button> </el-button>
<el-divider direction="vertical" />
</div> </div>
</div>
<el-text type="info" class="color-secondary">
所选知识库必须使用相同的 Embedding 模型
</el-text>
</template> </template>
<el-row :gutter="12" v-loading="loading"> <el-row :gutter="12" v-loading="loading">
<el-col :span="12" v-for="(item, index) in data" :key="index" class="mb-16"> <el-col :span="12" v-for="(item, index) in filterData" :key="index" class="mb-16">
<CardCheckbox value-field="id" :data="item" v-model="checkList"> <CardCheckbox value-field="id" :data="item" v-model="checkList" @change="changeHandle">
<span class="ellipsis"> <span class="ellipsis">
{{ item.name }} {{ item.name }}
</span> </span>
@ -26,7 +33,16 @@
</el-col> </el-col>
</el-row> </el-row>
<template #footer> <template #footer>
<span class="dialog-footer"> <div class="flex-between">
<div>
<el-text type="info" class="color-secondary" v-if="checkList.length > 0">
已选 {{ checkList.length }} 个知识库
</el-text>
<el-button link type="primary" v-if="checkList.length > 0" @click="clearCheck">
清空
</el-button>
</div>
<span>
<el-button @click.prevent="dialogVisible = false"> <el-button @click.prevent="dialogVisible = false">
{{ $t('views.application.applicationForm.buttons.cancel') }} {{ $t('views.application.applicationForm.buttons.cancel') }}
</el-button> </el-button>
@ -34,11 +50,12 @@
{{ $t('views.application.applicationForm.buttons.confirm') }} {{ $t('views.application.applicationForm.buttons.confirm') }}
</el-button> </el-button>
</span> </span>
</div>
</template> </template>
</el-dialog> </el-dialog>
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref, watch } from 'vue' import { computed, ref, watch } from 'vue'
const props = defineProps({ const props = defineProps({
data: { data: {
type: Array<any>, type: Array<any>,
@ -51,6 +68,13 @@ const emit = defineEmits(['addData', 'refresh'])
const dialogVisible = ref<boolean>(false) const dialogVisible = ref<boolean>(false)
const checkList = ref([]) const checkList = ref([])
const currentEmbedding = ref('')
const filterData = computed(() => {
return currentEmbedding.value
? props.data.filter((v) => v.embedding_mode_id === currentEmbedding.value)
: props.data
})
watch(dialogVisible, (bool) => { watch(dialogVisible, (bool) => {
if (!bool) { if (!bool) {
@ -58,6 +82,18 @@ watch(dialogVisible, (bool) => {
} }
}) })
function changeHandle() {
if (checkList.value.length === 1) {
currentEmbedding.value = props.data.filter(
(v) => v.id === checkList.value[0]
)[0].embedding_mode_id
}
}
function clearCheck() {
checkList.value = []
currentEmbedding.value = ''
}
const open = (checked: any) => { const open = (checked: any) => {
checkList.value = checked checkList.value = checked
dialogVisible.value = true dialogVisible.value = true
@ -73,4 +109,13 @@ const refresh = () => {
defineExpose({ open }) defineExpose({ open })
</script> </script>
<style lang="scss" scope></style> <style lang="scss" scope>
.addDataset-dialog {
.el-dialog__header.show-close {
padding-right: 15px;
}
.el-dialog__headerbtn {
top: 13px;
}
}
</style>

View File

@ -63,10 +63,10 @@
</el-form> </el-form>
<template #footer> <template #footer>
<span class="dialog-footer"> <span class="dialog-footer">
<el-button @click.prevent="dialogVisible = false"> <el-button @click.prevent="dialogVisible = false" :loading="loading">
{{ $t('views.application.applicationForm.buttons.cancel') }} {{ $t('views.application.applicationForm.buttons.cancel') }}
</el-button> </el-button>
<el-button type="primary" @click="submitValid(applicationFormRef)"> <el-button type="primary" @click="submitValid(applicationFormRef)" :loading="loading">
{{ $t('views.application.applicationForm.buttons.create') }} {{ $t('views.application.applicationForm.buttons.create') }}
</el-button> </el-button>
</span> </span>
@ -183,10 +183,7 @@ const submitValid = (formEl: FormInstance | undefined) => {
if (res?.data) { if (res?.data) {
submitHandle(formEl) submitHandle(formEl)
} else { } else {
MsgAlert( MsgAlert('提示', '社区版最多支持 5 个应用,如需拥有更多应用,请升级为专业版。')
'提示',
'社区版最多支持 5 个应用如需拥有更多应用请联系我们https://fit2cloud.com/)。'
)
} }
}) })
} }

View File

@ -1,5 +1,5 @@
<template> <template>
<div class="authentication-setting p-24"> <div class="authentication-setting p-16-24">
<h4>{{ $t('login.authentication') }}</h4> <h4>{{ $t('login.authentication') }}</h4>
<el-tabs v-model="activeName" class="demo-tabs" @tab-click="handleClick"> <el-tabs v-model="activeName" class="demo-tabs" @tab-click="handleClick">
<template v-for="(item, index) in tabList" :key="index"> <template v-for="(item, index) in tabList" :key="index">
@ -38,7 +38,7 @@ onMounted(() => {})
background-color: var(--app-view-bg-color); background-color: var(--app-view-bg-color);
box-sizing: border-box; box-sizing: border-box;
min-width: 700px; min-width: 700px;
height: calc(100vh - var(--app-header-height) - var(--app-view-padding) * 2 - 80px); height: calc(100vh - var(--app-header-height) - var(--app-view-padding) * 2 - 70px);
box-sizing: border-box; box-sizing: border-box;
.form-container { .form-container {
width: 70%; width: 70%;

View File

@ -3,6 +3,7 @@
<div class="dataset-setting main-calc-height"> <div class="dataset-setting main-calc-height">
<el-scrollbar> <el-scrollbar>
<div class="p-24" v-loading="loading"> <div class="p-24" v-loading="loading">
<h4 class="title-decoration-1 mb-16">基本信息</h4>
<BaseForm ref="BaseFormRef" :data="detail" /> <BaseForm ref="BaseFormRef" :data="detail" />
<el-form <el-form
@ -104,7 +105,7 @@ import { useRoute } from 'vue-router'
import BaseForm from '@/views/dataset/component/BaseForm.vue' import BaseForm from '@/views/dataset/component/BaseForm.vue'
import datasetApi from '@/api/dataset' import datasetApi from '@/api/dataset'
import type { ApplicationFormType } from '@/api/type/application' import type { ApplicationFormType } from '@/api/type/application'
import { MsgSuccess } from '@/utils/message' import { MsgSuccess, MsgConfirm } from '@/utils/message'
import { isAppIcon } from '@/utils/application' import { isAppIcon } from '@/utils/application'
import useStore from '@/stores' import useStore from '@/stores'
const route = useRoute() const route = useRoute()
@ -119,6 +120,8 @@ const loading = ref(false)
const detail = ref<any>({}) const detail = ref<any>({})
const application_list = ref<Array<ApplicationFormType>>([]) const application_list = ref<Array<ApplicationFormType>>([])
const application_id_list = ref([]) const application_id_list = ref([])
const cloneModelId = ref('')
const form = ref<any>({ const form = ref<any>({
source_url: '', source_url: '',
selector: '' selector: ''
@ -132,7 +135,6 @@ async function submit() {
if (await BaseFormRef.value?.validate()) { if (await BaseFormRef.value?.validate()) {
await webFormRef.value.validate((valid: any) => { await webFormRef.value.validate((valid: any) => {
if (valid) { if (valid) {
loading.value = true
const obj = const obj =
detail.value.type === '1' detail.value.type === '1'
? { ? {
@ -144,15 +146,25 @@ async function submit() {
application_id_list: application_id_list.value, application_id_list: application_id_list.value,
...BaseFormRef.value.form ...BaseFormRef.value.form
} }
datasetApi
.putDataset(id, obj) if (cloneModelId.value !== BaseFormRef.value.form.embedding_mode_id) {
.then((res) => { MsgConfirm(`提示`, `修改知识库向量模型后,需要对知识库重新向量化,是否继续保存?`, {
confirmButtonText: '重新向量化',
confirmButtonClass: 'primary'
})
.then(() => {
datasetApi.putDataset(id, obj, loading).then((res) => {
datasetApi.putReEmbeddingDataset(id).then(() => {
MsgSuccess('保存成功') MsgSuccess('保存成功')
loading.value = false
}) })
.catch(() => {
loading.value = false
}) })
})
.catch(() => {})
} else {
datasetApi.putDataset(id, obj, loading).then((res) => {
MsgSuccess('保存成功')
})
}
} }
}) })
} }
@ -161,10 +173,10 @@ async function submit() {
function getDetail() { function getDetail() {
dataset.asyncGetDatasetDetail(id, loading).then((res: any) => { dataset.asyncGetDatasetDetail(id, loading).then((res: any) => {
detail.value = res.data detail.value = res.data
cloneModelId.value = res.data?.embedding_mode_id
if (detail.value.type === '1') { if (detail.value.type === '1') {
form.value = res.data.meta form.value = res.data.meta
} }
application_id_list.value = res.data?.application_id_list application_id_list.value = res.data?.application_id_list
datasetApi.listUsableApplication(id, loading).then((ok) => { datasetApi.listUsableApplication(id, loading).then((ok) => {
application_list.value = ok.data application_list.value = ok.data

View File

@ -1,15 +1,18 @@
<template> <template>
<LayoutContainer :header="isCreate ? '创建知识库' : '上传文档'" class="create-dataset"> <LayoutContainer header="上传文档" class="create-dataset">
<template #backButton> <template #backButton>
<back-button @click="back"></back-button> <back-button @click="back"></back-button>
</template> </template>
<div class="create-dataset__main flex" v-loading="loading"> <div class="create-dataset__main flex" v-loading="loading">
<div class="create-dataset__component main-calc-height"> <div class="create-dataset__component main-calc-height">
<template v-if="active === 0"> <template v-if="active === 0">
<StepFirst ref="StepFirstRef" /> <div class="upload-document p-24">
<!-- 上传文档 -->
<UploadComponent ref="UploadComponentRef" />
</div>
</template> </template>
<template v-else-if="active === 1"> <template v-else-if="active === 1">
<StepSecond ref="StepSecondRef" /> <SetRules ref="SetRulesRef" />
</template> </template>
<template v-else-if="active === 2"> <template v-else-if="active === 2">
<ResultSuccess :data="successInfo" /> <ResultSuccess :data="successInfo" />
@ -19,12 +22,7 @@
<div class="create-dataset__footer text-right border-t" v-if="active !== 2"> <div class="create-dataset__footer text-right border-t" v-if="active !== 2">
<el-button @click="router.go(-1)" :disabled="loading">取消</el-button> <el-button @click="router.go(-1)" :disabled="loading">取消</el-button>
<el-button @click="prev" v-if="active === 1" :disabled="loading">上一步</el-button> <el-button @click="prev" v-if="active === 1" :disabled="loading">上一步</el-button>
<el-button <el-button @click="next" type="primary" v-if="active === 0" :disabled="loading">
@click="next"
type="primary"
v-if="active === 0"
:disabled="loading || StepFirstRef?.loading"
>
创建并导入 创建并导入
</el-button> </el-button>
<el-button @click="submit" type="primary" v-if="active === 1" :disabled="loading"> <el-button @click="submit" type="primary" v-if="active === 1" :disabled="loading">
@ -36,9 +34,9 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed, onUnmounted } from 'vue' import { ref, computed, onUnmounted } from 'vue'
import { useRouter, useRoute } from 'vue-router' import { useRouter, useRoute } from 'vue-router'
import StepFirst from './step/StepFirst.vue' import SetRules from './component/SetRules.vue'
import StepSecond from './step/StepSecond.vue' import ResultSuccess from './component/ResultSuccess.vue'
import ResultSuccess from './step/ResultSuccess.vue' import UploadComponent from './component/UploadComponent.vue'
import datasetApi from '@/api/dataset' import datasetApi from '@/api/dataset'
import documentApi from '@/api/document' import documentApi from '@/api/document'
import type { datasetData } from '@/api/type/dataset' import type { datasetData } from '@/api/type/dataset'
@ -46,33 +44,17 @@ import { MsgConfirm, MsgSuccess } from '@/utils/message'
import useStore from '@/stores' import useStore from '@/stores'
const { dataset, document } = useStore() const { dataset, document } = useStore()
const baseInfo = computed(() => dataset.baseInfo)
const webInfo = computed(() => dataset.webInfo)
const documentsFiles = computed(() => dataset.documentsFiles) const documentsFiles = computed(() => dataset.documentsFiles)
const documentsType = computed(() => dataset.documentsType) const documentsType = computed(() => dataset.documentsType)
const router = useRouter() const router = useRouter()
const route = useRoute() const route = useRoute()
const { const {
params: { type },
query: { id } // iddatasetIDid query: { id } // iddatasetIDid
} = route } = route
const isCreate = type === 'create'
// const steps = [
// {
// ref: 'StepFirstRef',
// name: '',
// component: StepFirst
// },
// {
// ref: 'StepSecondRef',
// name: '',
// component: StepSecond
// }
// ]
const StepFirstRef = ref() const SetRulesRef = ref()
const StepSecondRef = ref() const UploadComponentRef = ref()
const loading = ref(false) const loading = ref(false)
const disabled = ref(false) const disabled = ref(false)
@ -81,7 +63,7 @@ const successInfo = ref<any>(null)
async function next() { async function next() {
disabled.value = true disabled.value = true
if (await StepFirstRef.value?.onSubmit()) { if (await UploadComponentRef.value.validate()) {
if (documentsType.value === 'QA') { if (documentsType.value === 'QA') {
let fd = new FormData() let fd = new FormData()
documentsFiles.value.forEach((item: any) => { documentsFiles.value.forEach((item: any) => {
@ -118,16 +100,14 @@ const prev = () => {
} }
function clearStore() { function clearStore() {
dataset.saveBaseInfo(null)
dataset.saveWebInfo(null)
dataset.saveDocumentsFile([]) dataset.saveDocumentsFile([])
dataset.saveDocumentsType('') dataset.saveDocumentsType('')
} }
function submit() { function submit() {
loading.value = true loading.value = true
const documents = [] as any const documents = [] as any
StepSecondRef.value?.paragraphList.map((item: any) => { SetRulesRef.value?.paragraphList.map((item: any) => {
if (!StepSecondRef.value?.checkedConnect) { if (!SetRulesRef.value?.checkedConnect) {
item.content.map((v: any) => { item.content.map((v: any) => {
delete v['problem_list'] delete v['problem_list']
}) })
@ -159,7 +139,7 @@ function submit() {
} }
} }
function back() { function back() {
if (baseInfo.value || webInfo.value || documentsFiles.value?.length > 0) { if (documentsFiles.value?.length > 0) {
MsgConfirm(`提示`, `当前的更改尚未保存,确认退出吗?`, { MsgConfirm(`提示`, `当前的更改尚未保存,确认退出吗?`, {
confirmButtonText: '确认', confirmButtonText: '确认',
type: 'warning' type: 'warning'
@ -206,5 +186,10 @@ onUnmounted(() => {
width: 100%; width: 100%;
box-sizing: border-box; box-sizing: border-box;
} }
.upload-document {
width: 70%;
margin: 0 auto;
margin-bottom: 20px;
}
} }
</style> </style>

View File

@ -1,11 +1,11 @@
<template> <template>
<h4 class="title-decoration-1 mb-16">基本信息</h4>
<el-form <el-form
ref="FormRef" ref="FormRef"
:model="form" :model="form"
:rules="rules" :rules="rules"
label-position="top" label-position="top"
require-asterisk-position="right" require-asterisk-position="right"
v-loading="loading"
> >
<el-form-item label="知识库名称" prop="name"> <el-form-item label="知识库名称" prop="name">
<el-input <el-input
@ -27,14 +27,72 @@
@blur="form.desc = form.desc.trim()" @blur="form.desc = form.desc.trim()"
/> />
</el-form-item> </el-form-item>
<el-form-item label="Embedding模型" prop="embedding_mode_id">
<el-select
v-model="form.embedding_mode_id"
placeholder="请选择Embedding模型"
class="w-full"
popper-class="select-model"
:clearable="true"
>
<el-option-group
v-for="(value, label) in modelOptions"
:key="value"
:label="relatedObject(providerOptions, label, 'provider')?.name"
>
<el-option
v-for="item in value.filter((v: any) => v.status === 'SUCCESS')"
:key="item.id"
:label="item.name"
:value="item.id"
class="flex-between"
>
<div class="flex">
<span
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
class="model-icon mr-8"
></span>
<span>{{ item.name }}</span>
</div>
<el-icon class="check-icon" v-if="item.id === form.embedding_mode_id"
><Check
/></el-icon>
</el-option>
<!-- 不可用 -->
<el-option
v-for="item in value.filter((v: any) => v.status !== 'SUCCESS')"
:key="item.id"
:label="item.name"
:value="item.id"
class="flex-between"
disabled
>
<div class="flex">
<span
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
class="model-icon mr-8"
></span>
<span>{{ item.name }}</span>
<span class="danger">{{
$t('views.application.applicationForm.form.aiModel.unavailable')
}}</span>
</div>
<el-icon class="check-icon" v-if="item.id === form.embedding_mode_id"
><Check
/></el-icon>
</el-option>
</el-option-group>
</el-select>
</el-form-item>
</el-form> </el-form>
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref, reactive, onMounted, onUnmounted, computed, watch } from 'vue' import { ref, reactive, onMounted, onUnmounted, computed, watch } from 'vue'
import { useRoute } from 'vue-router' import { groupBy } from 'lodash'
import useStore from '@/stores' import useStore from '@/stores'
import type { datasetData } from '@/api/type/dataset' import type { datasetData } from '@/api/type/dataset'
import { isAllPropertiesEmpty } from '@/utils/utils' import { relatedObject } from '@/utils/utils'
import type { Provider } from '@/api/type/model'
const props = defineProps({ const props = defineProps({
data: { data: {
@ -42,23 +100,23 @@ const props = defineProps({
default: () => {} default: () => {}
} }
}) })
const route = useRoute() const { model } = useStore()
const {
params: { type }
} = route
const isCreate = type === 'create'
const { dataset } = useStore()
const baseInfo = computed(() => dataset.baseInfo)
const form = ref<datasetData>({ const form = ref<datasetData>({
name: '', name: '',
desc: '' desc: '',
embedding_mode_id: ''
}) })
const rules = reactive({ const rules = reactive({
name: [{ required: true, message: '请输入知识库名称', trigger: 'blur' }], name: [{ required: true, message: '请输入知识库名称', trigger: 'blur' }],
desc: [{ required: true, message: '请输入知识库描述', trigger: 'blur' }] desc: [{ required: true, message: '请输入知识库描述', trigger: 'blur' }],
embedding_mode_id: [{ required: true, message: '请输入Embedding模型', trigger: 'change' }]
}) })
const FormRef = ref() const FormRef = ref()
const loading = ref(false)
const modelOptions = ref<any>([])
const providerOptions = ref<Array<Provider>>([])
watch( watch(
() => props.data, () => props.data,
@ -66,23 +124,13 @@ watch(
if (value && JSON.stringify(value) !== '{}') { if (value && JSON.stringify(value) !== '{}') {
form.value.name = value.name form.value.name = value.name
form.value.desc = value.desc form.value.desc = value.desc
form.value.embedding_mode_id = value.embedding_mode_id
} }
}, },
{ {
immediate: true immediate: true
} }
) )
watch(form.value, (value) => {
if (isAllPropertiesEmpty(value)) {
dataset.saveBaseInfo(null)
} else {
if (isCreate) {
dataset.saveBaseInfo(value)
}
}
})
/* /*
表单校验 表单校验
*/ */
@ -93,16 +141,43 @@ function validate() {
}) })
} }
function getModel() {
loading.value = true
model
.asyncGetModel({ model_type: 'EMBEDDING' })
.then((res: any) => {
modelOptions.value = groupBy(res?.data, 'provider')
loading.value = false
})
.catch(() => {
loading.value = false
})
}
function getProvider() {
loading.value = true
model
.asyncGetProvider()
.then((res: any) => {
providerOptions.value = res?.data
loading.value = false
})
.catch(() => {
loading.value = false
})
}
onMounted(() => { onMounted(() => {
if (baseInfo.value) { getProvider()
form.value = baseInfo.value getModel()
}
}) })
onUnmounted(() => { onUnmounted(() => {
form.value = { form.value = {
name: '', name: '',
desc: '' desc: '',
embedding_mode_id: ''
} }
FormRef.value?.clearValidate()
}) })
defineExpose({ defineExpose({

View File

@ -0,0 +1,177 @@
<template>
<el-dialog title="创建知识库" v-model="dialogVisible" width="650" append-to-body>
<!-- 基本信息 -->
<BaseForm ref="BaseFormRef" v-if="dialogVisible" />
<el-form
ref="DatasetFormRef"
:rules="rules"
:model="datasetForm"
label-position="top"
require-asterisk-position="right"
>
<el-form-item label="知识库类型" required>
<el-radio-group v-model="datasetForm.type" class="card__radio" @change="radioChange">
<el-row :gutter="20">
<el-col :span="12">
<el-card
shadow="never"
class="mb-16"
:class="datasetForm.type === '0' ? 'active' : ''"
>
<el-radio value="0" size="large">
<div class="flex align-center">
<AppAvatar class="mr-8 avatar-blue" shape="square" :size="32">
<img src="@/assets/icon_document.svg" style="width: 58%" alt="" />
</AppAvatar>
<div>
<p class="mb-4">通用型</p>
<el-text type="info">可以通过上传文件或手动录入方式构建知识库</el-text>
</div>
</div>
</el-radio>
</el-card>
</el-col>
<el-col :span="12">
<el-card
shadow="never"
class="mb-16"
:class="datasetForm.type === '1' ? 'active' : ''"
>
<el-radio value="1" size="large">
<div class="flex align-center">
<AppAvatar class="mr-8 avatar-purple" shape="square" :size="32">
<img src="@/assets/icon_web.svg" style="width: 58%" alt="" />
</AppAvatar>
<div>
<p class="mb-4">Web 站点</p>
<el-text type="info">通过网站链接同步方式构建知识库 </el-text>
</div>
</div>
</el-radio>
</el-card>
</el-col>
</el-row>
</el-radio-group>
</el-form-item>
<el-form-item label="Web 根地址" prop="source_url" v-if="datasetForm.type === '1'">
<el-input
v-model="datasetForm.source_url"
placeholder="请输入 Web 根地址"
@blur="datasetForm.source_url = datasetForm.source_url.trim()"
/>
</el-form-item>
<el-form-item label="选择器" v-if="datasetForm.type === '1'">
<el-input
v-model="datasetForm.selector"
placeholder="默认为 body可输入 .classname/#idname/tagname"
@blur="datasetForm.selector = datasetForm.selector.trim()"
/>
</el-form-item>
</el-form>
<template #footer>
<span class="dialog-footer">
<el-button @click.prevent="dialogVisible = false" :loading="loading">
{{ $t('views.application.applicationForm.buttons.cancel') }}
</el-button>
<el-button type="primary" @click="submitValid" :loading="loading">
{{ $t('views.application.applicationForm.buttons.create') }}
</el-button>
</span>
</template>
</el-dialog>
</template>
<script setup lang="ts">
import { ref, watch, reactive } from 'vue'
import { useRouter, useRoute } from 'vue-router'
import BaseForm from './BaseForm.vue'
import datasetApi from '@/api/dataset'
import { MsgSuccess, MsgAlert } from '@/utils/message'
import useStore from '@/stores'
import { ValidType, ValidCount } from '@/enums/common'
const emit = defineEmits(['refresh'])
const { common, user } = useStore()
const router = useRouter()
const BaseFormRef = ref()
const DatasetFormRef = ref()
const loading = ref(false)
const dialogVisible = ref<boolean>(false)
const datasetForm = ref<any>({
type: '0',
source_url: '',
selector: ''
})
const rules = reactive({
source_url: [{ required: true, message: '请输入 Web 根地址', trigger: 'blur' }]
})
watch(dialogVisible, (bool) => {
if (!bool) {
datasetForm.value = {
type: '0',
source_url: '',
selector: ''
}
DatasetFormRef.value?.clearValidate()
}
})
const open = () => {
dialogVisible.value = true
}
const submitValid = () => {
if (user.isEnterprise()) {
submitHandle()
} else {
common.asyncGetValid(ValidType.Dataset, ValidCount.Dataset, loading).then(async (res: any) => {
if (res?.data) {
submitHandle()
} else {
MsgAlert('提示', '社区版最多支持 50 个知识库,如需拥有更多知识库,请升级为专业版。')
}
})
}
}
const submitHandle = async () => {
if (await BaseFormRef.value?.validate()) {
await DatasetFormRef.value.validate((valid: any) => {
if (valid) {
if (datasetForm.value.type === '0') {
const obj = {
...BaseFormRef.value.form,
type: datasetForm.value.type
}
datasetApi.postDataset(obj, loading).then((res) => {
MsgSuccess('创建成功')
router.push({ path: `/dataset/${res.data.id}/document` })
emit('refresh')
})
} else {
const obj = { ...BaseFormRef.value.form, ...datasetForm.value }
datasetApi.postWebDataset(obj, loading).then((res) => {
MsgSuccess('创建成功')
router.push({ path: `/dataset/${res.data.id}/document` })
emit('refresh')
})
}
} else {
return false
}
})
} else {
return false
}
}
function radioChange() {
datasetForm.value.source_url = ''
datasetForm.value.selector = ''
}
defineExpose({ open })
</script>
<style lang="scss" scope></style>

View File

@ -65,7 +65,7 @@
show-input show-input
:show-input-controls="false" :show-input-controls="false"
:min="50" :min="50"
:max="4096" :max="100000"
/> />
</div> </div>
<div class="form-item mb-16"> <div class="form-item mb-16">

View File

@ -22,7 +22,7 @@
> >
<el-row :gutter="15"> <el-row :gutter="15">
<el-col :xs="24" :sm="12" :md="8" :lg="6" :xl="4" class="mb-16"> <el-col :xs="24" :sm="12" :md="8" :lg="6" :xl="4" class="mb-16">
<CardAdd title="创建知识库" @click="router.push({ path: '/dataset/create' })" /> <CardAdd title="创建知识库" @click="openCreateDialog" />
</el-col> </el-col>
<template v-for="(item, index) in datasetList" :key="index"> <template v-for="(item, index) in datasetList" :key="index">
<el-col :xs="24" :sm="12" :md="8" :lg="6" :xl="4" class="mb-16"> <el-col :xs="24" :sm="12" :md="8" :lg="6" :xl="4" class="mb-16">
@ -107,17 +107,20 @@
</InfiniteScroll> </InfiniteScroll>
</div> </div>
<SyncWebDialog ref="SyncWebDialogRef" @refresh="refresh" /> <SyncWebDialog ref="SyncWebDialogRef" @refresh="refresh" />
<CreateDatasetDialog ref="CreateDatasetDialogRef"/>
</div> </div>
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref, onMounted, reactive, computed } from 'vue' import { ref, onMounted, reactive, computed } from 'vue'
import SyncWebDialog from '@/views/dataset/component/SyncWebDialog.vue' import SyncWebDialog from '@/views/dataset/component/SyncWebDialog.vue'
import CreateDatasetDialog from './component/CreateDatasetDialog.vue'
import datasetApi from '@/api/dataset' import datasetApi from '@/api/dataset'
import { MsgSuccess, MsgConfirm } from '@/utils/message' import { MsgSuccess, MsgConfirm } from '@/utils/message'
import { useRouter } from 'vue-router' import { useRouter } from 'vue-router'
import { numberFormat } from '@/utils/utils' import { numberFormat } from '@/utils/utils'
const router = useRouter() const router = useRouter()
const CreateDatasetDialogRef = ref()
const SyncWebDialogRef = ref() const SyncWebDialogRef = ref()
const loading = ref(false) const loading = ref(false)
const datasetList = ref<any[]>([]) const datasetList = ref<any[]>([])
@ -129,6 +132,10 @@ const paginationConfig = reactive({
const searchValue = ref('') const searchValue = ref('')
function openCreateDialog() {
CreateDatasetDialogRef.value.open()
}
function refresh() { function refresh() {
MsgSuccess('同步任务发送成功') MsgSuccess('同步任务发送成功')
} }

View File

@ -1,183 +0,0 @@
<template>
<el-scrollbar>
<div class="upload-document p-24">
<!-- 基本信息 -->
<BaseForm ref="BaseFormRef" v-if="isCreate" />
<el-form
v-if="isCreate"
ref="webFormRef"
:rules="rules"
:model="form"
label-position="top"
require-asterisk-position="right"
>
<el-form-item label="知识库类型" required>
<el-radio-group v-model="form.type" class="card__radio" @change="radioChange">
<el-row :gutter="20">
<el-col :span="12">
<el-card shadow="never" class="mb-16" :class="form.type === '0' ? 'active' : ''">
<el-radio value="0" size="large">
<div class="flex align-center">
<AppAvatar class="mr-8 avatar-blue" shape="square" :size="32">
<img src="@/assets/icon_document.svg" style="width: 58%" alt="" />
</AppAvatar>
<div>
<p class="mb-4">通用型</p>
<el-text type="info">可以通过上传文件或手动录入方式构建知识库</el-text>
</div>
</div>
</el-radio>
</el-card>
</el-col>
<el-col :span="12">
<el-card shadow="never" class="mb-16" :class="form.type === '1' ? 'active' : ''">
<el-radio value="1" size="large">
<div class="flex align-center">
<AppAvatar class="mr-8 avatar-purple" shape="square" :size="32">
<img src="@/assets/icon_web.svg" style="width: 58%" alt="" />
</AppAvatar>
<div>
<p class="mb-4">Web 站点</p>
<el-text type="info">通过网站链接同步方式构建知识库 </el-text>
</div>
</div>
</el-radio>
</el-card>
</el-col>
</el-row>
</el-radio-group>
</el-form-item>
<el-form-item label="Web 根地址" prop="source_url" v-if="form.type === '1'">
<el-input
v-model="form.source_url"
placeholder="请输入 Web 根地址"
@blur="form.source_url = form.source_url.trim()"
/>
</el-form-item>
<el-form-item label="选择器" v-if="form.type === '1'">
<el-input
v-model="form.selector"
placeholder="默认为 body可输入 .classname/#idname/tagname"
@blur="form.selector = form.selector.trim()"
/>
</el-form-item>
</el-form>
<!-- 上传文档 -->
<UploadComponent ref="UploadComponentRef" v-if="form.type === '0'" />
</div>
</el-scrollbar>
</template>
<script setup lang="ts">
import { ref, onMounted, reactive, watch } from 'vue'
import { useRouter, useRoute } from 'vue-router'
import BaseForm from '@/views/dataset/component/BaseForm.vue'
import UploadComponent from '@/views/dataset/component/UploadComponent.vue'
import { isAllPropertiesEmpty } from '@/utils/utils'
import datasetApi from '@/api/dataset'
import { MsgError, MsgSuccess } from '@/utils/message'
import useStore from '@/stores'
const { dataset } = useStore()
const route = useRoute()
const router = useRouter()
const {
params: { type }
} = route
const isCreate = type === 'create'
const BaseFormRef = ref()
const UploadComponentRef = ref()
const webFormRef = ref()
const loading = ref(false)
const form = ref<any>({
type: '0',
source_url: '',
selector: ''
})
const rules = reactive({
source_url: [{ required: true, message: '请输入 Web 根地址', trigger: 'blur' }]
})
watch(form.value, (value) => {
if (isAllPropertiesEmpty(value)) {
dataset.saveWebInfo(null)
} else {
dataset.saveWebInfo(value)
}
})
function radioChange() {
dataset.saveDocumentsFile([])
dataset.saveDocumentsType('')
form.value.source_url = ''
form.value.selector = ''
}
const onSubmit = async () => {
if (isCreate) {
if (form.value.type === '0') {
if ((await BaseFormRef.value?.validate()) && (await UploadComponentRef.value.validate())) {
if (UploadComponentRef.value.form.fileList.length > 50) {
MsgError('每次最多上传50个文件')
return false
} else {
/*
stores保存数据
*/
dataset.saveBaseInfo(BaseFormRef.value.form)
dataset.saveDocumentsType(UploadComponentRef.value.form.fileType)
dataset.saveDocumentsFile(UploadComponentRef.value.form.fileList)
return true
}
} else {
return false
}
} else {
if (await BaseFormRef.value?.validate()) {
await webFormRef.value.validate((valid: any) => {
if (valid) {
const obj = { ...BaseFormRef.value.form, ...form.value }
datasetApi.postWebDataset(obj, loading).then((res) => {
MsgSuccess('提交成功')
dataset.saveBaseInfo(null)
dataset.saveWebInfo(null)
router.push({ path: `/dataset/${res.data.id}/document` })
})
} else {
return false
}
})
} else {
return false
}
}
} else {
if (await UploadComponentRef.value.validate()) {
/*
stores保存数据
*/
dataset.saveDocumentsType(UploadComponentRef.value.form.fileType)
dataset.saveDocumentsFile(UploadComponentRef.value.form.fileList)
return true
} else {
return false
}
}
}
onMounted(() => {})
defineExpose({
onSubmit,
loading
})
</script>
<style scoped lang="scss">
.upload-document {
width: 70%;
margin: 0 auto;
margin-bottom: 20px;
}
</style>

View File

@ -141,4 +141,13 @@ const submit = async (formEl: FormInstance) => {
defineExpose({ open }) defineExpose({ open })
</script> </script>
<style lang="scss" scope></style> <style lang="scss" scope>
.edit-mark-dialog {
.el-dialog__header.show-close {
padding-right: 15px;
}
.el-dialog__headerbtn {
top: 13px;
}
}
</style>

View File

@ -23,7 +23,7 @@
v-if="isEdit" v-if="isEdit"
v-model="form.content" v-model="form.content"
placeholder="请输入分段内容" placeholder="请输入分段内容"
:maxLength="4096" :maxLength="100000"
:preview="false" :preview="false"
:toolbars="toolbars" :toolbars="toolbars"
style="height: 300px" style="height: 300px"
@ -31,7 +31,7 @@
:footers="footers" :footers="footers"
> >
<template #defFooters> <template #defFooters>
<span style="margin-left: -6px">/ 4096</span> <span style="margin-left: -6px">/ 100000</span>
</template> </template>
</MdEditor> </MdEditor>
<MdPreview <MdPreview

View File

@ -55,6 +55,32 @@
placeholder="请给基础模型设置一个名称" placeholder="请给基础模型设置一个名称"
/> />
</el-form-item> </el-form-item>
<el-form-item prop="model_type" :rules="base_form_data_rule.permission_type">
<template #label>
<span>权限</span>
</template>
<el-radio-group v-model="base_form_data.permission_type" class="card__radio">
<el-row :gutter="16">
<template v-for="(value, key) of PermissionType" :key="key">
<el-col :span="12">
<el-card
shadow="never"
class="mb-16"
:class="base_form_data.permission_type === key ? 'active' : ''"
>
<el-radio :value="key" size="large">
<p class="mb-4">{{ value }}</p>
<el-text type="info">
{{ PermissionDesc[key] }}
</el-text>
</el-radio>
</el-card>
</el-col>
</template>
</el-row>
</el-radio-group>
</el-form-item>
<el-form-item prop="model_type" :rules="base_form_data_rule.model_type"> <el-form-item prop="model_type" :rules="base_form_data_rule.model_type">
<template #label> <template #label>
<span>模型类型</span> <span>模型类型</span>
@ -74,6 +100,7 @@
></el-option> ></el-option>
</el-select> </el-select>
</el-form-item> </el-form-item>
<el-form-item prop="model_name" :rules="base_form_data_rule.model_name"> <el-form-item prop="model_name" :rules="base_form_data_rule.model_name">
<template #label> <template #label>
<div class="flex align-center" style="display: inline-flex"> <div class="flex align-center" style="display: inline-flex">
@ -135,6 +162,7 @@ import type { FormField } from '@/components/dynamics-form/type'
import DynamicsForm from '@/components/dynamics-form/index.vue' import DynamicsForm from '@/components/dynamics-form/index.vue'
import type { FormRules } from 'element-plus' import type { FormRules } from 'element-plus'
import { MsgSuccess } from '@/utils/message' import { MsgSuccess } from '@/utils/message'
import { PermissionType, PermissionDesc } from '@/enums/model'
const providerValue = ref<Provider>() const providerValue = ref<Provider>()
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>() const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
@ -150,17 +178,18 @@ const dialogVisible = ref<boolean>(false)
const base_form_data_rule = ref<FormRules>({ const base_form_data_rule = ref<FormRules>({
name: { required: true, trigger: 'blur', message: '模型名不能为空' }, name: { required: true, trigger: 'blur', message: '模型名不能为空' },
permission_type: { required: true, trigger: 'change', message: '权限不能为空' },
model_type: { required: true, trigger: 'change', message: '模型类型不能为空' }, model_type: { required: true, trigger: 'change', message: '模型类型不能为空' },
model_name: { required: true, trigger: 'change', message: '基础模型不能为空' } model_name: { required: true, trigger: 'change', message: '基础模型不能为空' }
}) })
const base_form_data = ref<{ const base_form_data = ref<{
name: string name: string
permission_type: string
model_type: string model_type: string
model_name: string model_name: string
}>({ name: '', model_type: '', model_name: '' }) }>({ name: '', model_type: '', model_name: '', permission_type: 'PRIVATE' })
const credential_form_data = ref<Dict<any>>({}) const credential_form_data = ref<Dict<any>>({})
@ -212,7 +241,7 @@ const list_base_model = (model_type: any) => {
} }
const close = () => { const close = () => {
base_form_data.value = { name: '', model_type: '', model_name: '' } base_form_data.value = { name: '', model_type: '', model_name: '', permission_type: 'PRIVATE' }
credential_form_data.value = {} credential_form_data.value = {}
model_form_field.value = [] model_form_field.value = []
base_model_list.value = [] base_model_list.value = []

View File

@ -48,6 +48,32 @@
placeholder="请给基础模型设置一个名称" placeholder="请给基础模型设置一个名称"
/> />
</el-form-item> </el-form-item>
<el-form-item prop="model_type" :rules="base_form_data_rule.permission_type">
<template #label>
<span>权限</span>
</template>
<el-radio-group v-model="base_form_data.permission_type" class="card__radio">
<el-row :gutter="16">
<template v-for="(value, key) of PermissionType" :key="key">
<el-col :span="12">
<el-card
shadow="never"
class="mb-16"
:class="base_form_data.permission_type === key ? 'active' : ''"
>
<el-radio :value="key" size="large">
<p class="mb-4">{{ value }}</p>
<el-text type="info">
{{ PermissionDesc[key] }}
</el-text>
</el-radio>
</el-card>
</el-col>
</template>
</el-row>
</el-radio-group>
</el-form-item>
<el-form-item prop="model_type" :rules="base_form_data_rule.model_type"> <el-form-item prop="model_type" :rules="base_form_data_rule.model_type">
<template #label> <template #label>
<span>模型类型</span> <span>模型类型</span>
@ -128,7 +154,7 @@ import type { FormField } from '@/components/dynamics-form/type'
import DynamicsForm from '@/components/dynamics-form/index.vue' import DynamicsForm from '@/components/dynamics-form/index.vue'
import type { FormRules } from 'element-plus' import type { FormRules } from 'element-plus'
import { MsgSuccess } from '@/utils/message' import { MsgSuccess } from '@/utils/message'
import AppIcon from '@/components/icons/AppIcon.vue' import { PermissionType, PermissionDesc } from '@/enums/model'
const providerValue = ref<Provider>() const providerValue = ref<Provider>()
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>() const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
@ -151,11 +177,11 @@ const base_form_data_rule = ref<FormRules>({
const base_form_data = ref<{ const base_form_data = ref<{
name: string name: string
permission_type: string
model_type: string model_type: string
model_name: string model_name: string
}>({ name: '', model_type: '', model_name: '' }) }>({ name: '', model_type: '', model_name: '', permission_type: 'PRIVATE' })
const credential_form_data = ref<Dict<any>>({}) const credential_form_data = ref<Dict<any>>({})
@ -204,6 +230,7 @@ const open = (provider: Provider, model: Model) => {
base_form_data.value = { base_form_data.value = {
name: model.name, name: model.name,
permission_type: model.permission_type,
model_type: model.model_type, model_type: model.model_type,
model_name: model.model_name model_name: model.model_name
} }
@ -214,7 +241,7 @@ const open = (provider: Provider, model: Model) => {
} }
const close = () => { const close = () => {
base_form_data.value = { name: '', model_type: '', model_name: '' } base_form_data.value = { name: '', model_type: '', model_name: '', permission_type: '' }
dynamicsFormRef.value?.ruleFormRef?.resetFields() dynamicsFormRef.value?.ruleFormRef?.resetFields()
credential_form_data.value = {} credential_form_data.value = {}
model_form_field.value = [] model_form_field.value = []

View File

@ -1,16 +1,30 @@
<template> <template>
<card-box :title="model.name" shadow="hover" class="model-card"> <card-box :title="model.name" shadow="hover" class="model-card">
<template #header> <template #header>
<div class="flex align-center"> <div class="flex">
<span style="height: 32px; width: 32px" :innerHTML="icon" class="mr-12"></span> <span style="height: 32px; width: 32px" :innerHTML="icon" class="mr-12"></span>
<div class="w-full">
<div class="flex" style="height: 22px">
<auto-tooltip :content="model.name" style="max-width: 40%"> <auto-tooltip :content="model.name" style="max-width: 40%">
{{ model.name }} {{ model.name }}
</auto-tooltip> </auto-tooltip>
<div class="flex align-center" v-if="currentModel.status === 'ERROR'"> <span v-if="currentModel.status === 'ERROR'">
<el-tag type="danger" class="ml-8">失败</el-tag>
<el-tooltip effect="dark" :content="errMessage" placement="top"> <el-tooltip effect="dark" :content="errMessage" placement="top">
<el-icon class="danger ml-4" size="20"><Warning /></el-icon> <el-icon class="danger ml-4" size="18"><Warning /></el-icon>
</el-tooltip> </el-tooltip>
</span>
<span v-if="currentModel.status === 'PAUSE_DOWNLOAD'">
<el-tooltip effect="dark" content="暂停下载" placement="top">
<el-icon class="danger ml-4" size="18"><Warning /></el-icon>
</el-tooltip>
</span>
</div>
<div class="mt-4">
<el-tag v-if="model.permission_type === 'PRIVATE'" type="danger" class="danger-tag"
>私有</el-tag
>
<el-tag v-else type="info" class="info-tag">公有</el-tag>
</div>
</div> </div>
</div> </div>
</template> </template>
@ -29,18 +43,14 @@
</div> </div>
<!-- progress --> <!-- progress -->
<div class="progress-mask" v-if="currentModel.status === 'DOWNLOAD'"> <div class="progress-mask" v-if="currentModel.status === 'DOWNLOAD'">
<el-progress <DownloadLoading class="percentage" />
type="circle"
:width="56" <div class="percentage-label flex-center">
color="#3370FF" 正在下载中 <span class="dotting"></span>
:percentage="progress" <el-button link type="primary" class="ml-16" @click.stop="cancelDownload"
class="percentage" >取消下载</el-button
> >
<template #default="{ percentage }"> </div>
<span class="percentage-value">{{ percentage }}%</span>
</template>
</el-progress>
<span class="percentage-label">正在下载 <span class="dotting"></span></span>
</div> </div>
<template #mouseEnter> <template #mouseEnter>
@ -48,7 +58,13 @@
<el-tooltip effect="dark" content="修改" placement="top"> <el-tooltip effect="dark" content="修改" placement="top">
<el-button text @click.stop="openEditModel"> <el-button text @click.stop="openEditModel">
<el-icon> <el-icon>
<component :is="currentModel.status === 'ERROR' ? 'RefreshRight' : 'EditPen'" /> <component
:is="
currentModel.status === 'ERROR' || currentModel.status === 'PAUSE_DOWNLOAD'
? 'RefreshRight'
: 'EditPen'
"
/>
</el-icon> </el-icon>
</el-button> </el-button>
</el-tooltip> </el-tooltip>
@ -68,6 +84,7 @@ import type { Provider, Model } from '@/api/type/model'
import ModelApi from '@/api/model' import ModelApi from '@/api/model'
import { computed, ref, onMounted, onBeforeUnmount } from 'vue' import { computed, ref, onMounted, onBeforeUnmount } from 'vue'
import EditModel from '@/views/template/component/EditModel.vue' import EditModel from '@/views/template/component/EditModel.vue'
import DownloadLoading from '@/components/loading/DownloadLoading.vue'
import { MsgConfirm } from '@/utils/message' import { MsgConfirm } from '@/utils/message'
const props = defineProps<{ const props = defineProps<{
@ -94,27 +111,6 @@ const errMessage = computed(() => {
} }
return '' return ''
}) })
const progress = computed(() => {
if (currentModel.value) {
const down_model_chunk = currentModel.value.meta['down_model_chunk']
if (down_model_chunk) {
const maxObj = down_model_chunk
.filter((chunk: any) => chunk.index > 1)
.reduce(
(prev: any, current: any) => {
return (prev.index || 0) > (current.index || 0) ? prev : current
},
{ progress: 0 }
)
if (maxObj) {
return parseFloat(maxObj.progress?.toFixed(1))
}
return 0
}
return 0
}
return 0
})
const emit = defineEmits(['change', 'update:model']) const emit = defineEmits(['change', 'update:model'])
const eidtModelRef = ref<InstanceType<typeof EditModel>>() const eidtModelRef = ref<InstanceType<typeof EditModel>>()
let interval: any let interval: any
@ -130,6 +126,13 @@ const deleteModel = () => {
}) })
.catch(() => {}) .catch(() => {})
} }
const cancelDownload = () => {
ModelApi.pauseDownload(props.model.id).then(() => {
downModel.value = undefined
emit('change')
})
}
const openEditModel = () => { const openEditModel = () => {
const provider = props.provider_list.find((p) => p.provider === props.model.provider) const provider = props.provider_list.find((p) => p.provider === props.model.provider)
if (provider) { if (provider) {
@ -197,21 +200,21 @@ onBeforeUnmount(() => {
z-index: 99; z-index: 99;
text-align: center; text-align: center;
.percentage { .percentage {
top: 50%; margin-top: 55px;
transform: translateY(-65%); margin-bottom: 16px;
} }
.percentage-value { // .percentage-value {
display: block; // display: flex;
font-size: 12px; // font-size: 13px;
color: var(--el-color-primary); // align-items: center;
} // color: var(--app-text-color-secondary);
// }
.percentage-label { .percentage-label {
display: block; margin-top: 50px;
margin-top: 45px;
margin-left: 10px; margin-left: 10px;
font-size: 12px; font-size: 13px;
color: var(--el-color-primary); color: var(--app-text-color-secondary);
} }
} }
} }

View File

@ -145,10 +145,7 @@ function createUser() {
title.value = '创建用户' title.value = '创建用户'
UserDialogRef.value.open() UserDialogRef.value.open()
} else { } else {
MsgAlert( MsgAlert('提示', '社区版最多支持 2 个用户,如需拥有更多用户,请升级为专业版。')
'提示',
'社区版最多支持 2 个用户如需拥有更多用户请联系我们https://fit2cloud.com/)。'
)
} }
}) })
} }