fix: The knowledge base vector model can still be vectorized even after unauthorized use (#2080)

This commit is contained in:
shaohuzhang1 2025-01-23 10:53:12 +08:00 committed by GitHub
parent 40cfa33556
commit 5ebcad7cde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 4 deletions

View File

@ -47,7 +47,7 @@ from dataset.serializers.document_serializers import DocumentSerializers, Docume
from dataset.task import sync_web_dataset, sync_replace_web_dataset from dataset.task import sync_web_dataset, sync_replace_web_dataset
from embedding.models import SearchMode from embedding.models import SearchMode
from embedding.task import embedding_by_dataset, delete_embedding_by_dataset from embedding.task import embedding_by_dataset, delete_embedding_by_dataset
from setting.models import AuthOperate from setting.models import AuthOperate, Model
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -792,6 +792,15 @@ class DataSetSerializers(serializers.ModelSerializer):
def re_embedding(self, with_valid=True): def re_embedding(self, with_valid=True):
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
dataset_id = self.data.get('id')
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
embedding_model_id = dataset.embedding_mode_id
dataset_user_id = dataset.user_id
embedding_model = QuerySet(Model).filter(id=embedding_model_id).first()
if embedding_model is None:
raise AppApiException(500, _('Model does not exist'))
if embedding_model.permission_type == 'PRIVATE' and dataset_user_id != embedding_model.user_id:
raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}")
ListenerManagement.update_status(QuerySet(Document).filter(dataset_id=self.data.get('id')), ListenerManagement.update_status(QuerySet(Document).filter(dataset_id=self.data.get('id')),
TaskType.EMBEDDING, TaskType.EMBEDDING,
State.PENDING) State.PENDING)
@ -801,7 +810,7 @@ class DataSetSerializers(serializers.ModelSerializer):
ListenerManagement.get_aggregation_document_status_by_dataset_id(self.data.get('id'))() ListenerManagement.get_aggregation_document_status_by_dataset_id(self.data.get('id'))()
embedding_model_id = get_embedding_model_id_by_dataset_id(self.data.get('id')) embedding_model_id = get_embedding_model_id_by_dataset_id(self.data.get('id'))
try: try:
embedding_by_dataset.delay(self.data.get('id'), embedding_model_id) embedding_by_dataset.delay(dataset_id, embedding_model_id)
except AlreadyQueued as e: except AlreadyQueued as e:
raise AppApiException(500, _('Failed to send the vectorization task, please try again later!')) raise AppApiException(500, _('Failed to send the vectorization task, please try again later!'))

View File

@ -23,6 +23,7 @@ from django.db import transaction
from django.db.models import QuerySet, Count from django.db.models import QuerySet, Count
from django.db.models.functions import Substr, Reverse from django.db.models.functions import Substr, Reverse
from django.http import HttpResponse from django.http import HttpResponse
from django.utils.translation import gettext_lazy as _, gettext
from drf_yasg import openapi from drf_yasg import openapi
from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE
from rest_framework import serializers from rest_framework import serializers
@ -62,8 +63,8 @@ from dataset.task import sync_web_document, generate_related_by_document_id
from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \ from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \
delete_embedding_by_document, update_embedding_dataset_id, delete_embedding_by_paragraph_ids, \ delete_embedding_by_document, update_embedding_dataset_id, delete_embedding_by_paragraph_ids, \
embedding_by_document_list embedding_by_document_list
from setting.models import Model
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _, gettext
parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()] parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()]
parse_table_handle_list = [CsvSplitTableHandle(), XlsSplitTableHandle(), XlsxSplitTableHandle()] parse_table_handle_list = [CsvSplitTableHandle(), XlsSplitTableHandle(), XlsxSplitTableHandle()]
@ -716,6 +717,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
State.REVOKED.value, State.IGNORED.value] State.REVOKED.value, State.IGNORED.value]
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
dataset = QuerySet(DataSet).filter(id=self.data.get('dataset_id')).first()
embedding_model_id = dataset.embedding_mode_id
dataset_user_id = dataset.user_id
embedding_model = QuerySet(Model).filter(id=embedding_model_id).first()
if embedding_model is None:
raise AppApiException(500, _('Model does not exist'))
if embedding_model.permission_type == 'PRIVATE' and dataset_user_id != embedding_model.user_id:
raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}")
document_id = self.data.get("document_id") document_id = self.data.get("document_id")
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
State.PENDING) State.PENDING)
@ -728,7 +737,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
TaskType.EMBEDDING, TaskType.EMBEDDING,
State.PENDING) State.PENDING)
ListenerManagement.get_aggregation_document_status(document_id)() ListenerManagement.get_aggregation_document_status(document_id)()
embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id'))
try: try:
embedding_by_document.delay(document_id, embedding_model_id, state_list) embedding_by_document.delay(document_id, embedding_model_id, state_list)
except AlreadyQueued as e: except AlreadyQueued as e: