feat: enhance Document API with workspace ID support for get, put, and delete operations

This commit is contained in:
CaptainB 2025-05-06 14:33:59 +08:00
parent 3e9069aac1
commit e702af8c2b
7 changed files with 263 additions and 16 deletions

View File

@ -8,7 +8,7 @@ from common.handle.base_parse_table_handle import BaseParseTableHandle
max_kb = logging.getLogger("max_kb") max_kb = logging.getLogger("max_kb")
class CsvSplitHandle(BaseParseTableHandle): class CsvParseTableHandle(BaseParseTableHandle):
def support(self, file, get_buffer): def support(self, file, get_buffer):
file_name: str = file.name.lower() file_name: str = file.name.lower()
if file_name.endswith(".csv"): if file_name.endswith(".csv"):

View File

@ -8,7 +8,7 @@ from common.handle.base_parse_table_handle import BaseParseTableHandle
max_kb = logging.getLogger("max_kb") max_kb = logging.getLogger("max_kb")
class XlsSplitHandle(BaseParseTableHandle): class XlsParseTableHandle(BaseParseTableHandle):
def support(self, file, get_buffer): def support(self, file, get_buffer):
file_name: str = file.name.lower() file_name: str = file.name.lower()
buffer = get_buffer(file) buffer = get_buffer(file)

View File

@ -10,7 +10,7 @@ from common.handle.impl.common_handle import xlsx_embed_cells_images
max_kb = logging.getLogger("max_kb") max_kb = logging.getLogger("max_kb")
class XlsxSplitHandle(BaseParseTableHandle): class XlsxParseTableHandle(BaseParseTableHandle):
def support(self, file, get_buffer): def support(self, file, get_buffer):
file_name: str = file.name.lower() file_name: str = file.name.lower()
if file_name.endswith('.xlsx'): if file_name.endswith('.xlsx'):

View File

@ -4,7 +4,7 @@ from drf_spectacular.utils import OpenApiParameter
from common.mixins.api_mixin import APIMixin from common.mixins.api_mixin import APIMixin
from common.result import DefaultResultSerializer from common.result import DefaultResultSerializer
from knowledge.serializers.common import BatchSerializer from knowledge.serializers.common import BatchSerializer
from knowledge.serializers.document import DocumentInstanceSerializer from knowledge.serializers.document import DocumentInstanceSerializer, DocumentWebInstanceSerializer
class DocumentSplitAPI(APIMixin): class DocumentSplitAPI(APIMixin):
@ -176,3 +176,45 @@ class DocumentEditAPI(DocumentReadAPI):
class DocumentDeleteAPI(DocumentReadAPI): class DocumentDeleteAPI(DocumentReadAPI):
pass pass
class TableDocumentCreateAPI(APIMixin):
@staticmethod
def get_parameters():
return [
OpenApiParameter(
name="workspace_id",
description="工作空间id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="knowledge_id",
description="知识库id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="file",
description="文件",
type=OpenApiTypes.BINARY,
location='query',
required=False,
),
]
@staticmethod
def get_response():
return DefaultResultSerializer
class QaDocumentCreateAPI(TableDocumentCreateAPI):
pass
class WebDocumentCreateAPI(APIMixin):
@staticmethod
def get_request():
return DocumentWebInstanceSerializer

View File

@ -1,11 +1,13 @@
import logging import logging
import os import os
import re
import traceback import traceback
from functools import reduce from functools import reduce
from typing import Dict, List from typing import Dict, List
import uuid_utils.compat as uuid import uuid_utils.compat as uuid
from celery_once import AlreadyQueued from celery_once import AlreadyQueued
from django.core import validators
from django.db import transaction, models from django.db import transaction, models
from django.db.models import QuerySet, Model from django.db.models import QuerySet, Model
from django.db.models.functions import Substr, Reverse from django.db.models.functions import Substr, Reverse
@ -16,6 +18,13 @@ from common.db.search import native_search, get_dynamics_model, native_page_sear
from common.event import ListenerManagement from common.event import ListenerManagement
from common.event.common import work_thread_pool from common.event.common import work_thread_pool
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle
from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle
from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle
from common.handle.impl.qa.zip_parse_qa_handle import ZipParseQAHandle
from common.handle.impl.table.csv_parse_table_handle import CsvParseTableHandle
from common.handle.impl.table.xls_parse_table_handle import XlsParseTableHandle
from common.handle.impl.table.xlsx_parse_table_handle import XlsxParseTableHandle
from common.handle.impl.text.csv_split_handle import CsvSplitHandle from common.handle.impl.text.csv_split_handle import CsvSplitHandle
from common.handle.impl.text.doc_split_handle import DocSplitHandle from common.handle.impl.text.doc_split_handle import DocSplitHandle
from common.handle.impl.text.html_split_handle import HTMLSplitHandle from common.handle.impl.text.html_split_handle import HTMLSplitHandle
@ -26,14 +35,16 @@ from common.handle.impl.text.xlsx_split_handle import XlsxSplitHandle
from common.handle.impl.text.zip_split_handle import ZipSplitHandle from common.handle.impl.text.zip_split_handle import ZipSplitHandle
from common.utils.common import post, get_file_content, bulk_create_in_batches from common.utils.common import post, get_file_content, bulk_create_in_batches
from common.utils.fork import Fork from common.utils.fork import Fork
from common.utils.split_model import get_split_model from common.utils.split_model import get_split_model, flat_map
from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \ from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \
TaskType, File TaskType, File
from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer, get_embedding_model_id_by_knowledge_id from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer, \
get_embedding_model_id_by_knowledge_id, MetaSerializer
from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer, \ from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer, \
delete_problems_and_mappings delete_problems_and_mappings
from knowledge.task.embedding import embedding_by_document, delete_embedding_by_document_list, \ from knowledge.task.embedding import embedding_by_document, delete_embedding_by_document_list, \
delete_embedding_by_document delete_embedding_by_document
from knowledge.task.sync import sync_web_document
from maxkb.const import PROJECT_DIR from maxkb.const import PROJECT_DIR
default_split_handle = TextSplitHandle() default_split_handle = TextSplitHandle()
@ -48,6 +59,9 @@ split_handles = [
default_split_handle default_split_handle
] ]
parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()]
parse_table_handle_list = [CsvParseTableHandle(), XlsParseTableHandle(), XlsxParseTableHandle()]
class BatchCancelInstanceSerializer(serializers.Serializer): class BatchCancelInstanceSerializer(serializers.Serializer):
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=('id list')) id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=('id list'))
@ -67,6 +81,36 @@ class DocumentInstanceSerializer(serializers.Serializer):
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True) paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
class DocumentEditInstanceSerializer(serializers.Serializer):
meta = serializers.DictField(required=False)
name = serializers.CharField(required=False, max_length=128, min_length=1, label=_('document name'))
hit_handling_method = serializers.CharField(required=False, validators=[
validators.RegexValidator(regex=re.compile("^optimization|directly_return$"),
message=_('The type only supports optimization|directly_return'),
code=500)
], label=_('hit handling method'))
directly_return_similarity = serializers.FloatField(required=False, max_value=2, min_value=0,
label=_('directly return similarity'))
is_active = serializers.BooleanField(required=False, label=_('document is active'))
@staticmethod
def get_meta_valid_map():
dataset_meta_valid_map = {
KnowledgeType.BASE: MetaSerializer.BaseMeta,
KnowledgeType.WEB: MetaSerializer.WebMeta
}
return dataset_meta_valid_map
def is_valid(self, *, document: Document = None):
super().is_valid(raise_exception=True)
if 'meta' in self.data and self.data.get('meta') is not None:
dataset_meta_valid_map = self.get_meta_valid_map()
valid_class = dataset_meta_valid_map.get(document.type)
valid_class(data=self.data.get('meta')).is_valid(raise_exception=True)
class DocumentSplitRequest(serializers.Serializer): class DocumentSplitRequest(serializers.Serializer):
file = serializers.ListField(required=True, label=_('file list')) file = serializers.ListField(required=True, label=_('file list'))
limit = serializers.IntegerField(required=False, label=_('limit')) limit = serializers.IntegerField(required=False, label=_('limit'))
@ -78,6 +122,22 @@ class DocumentSplitRequest(serializers.Serializer):
with_filter = serializers.BooleanField(required=False, label=_('Auto Clean')) with_filter = serializers.BooleanField(required=False, label=_('Auto Clean'))
class DocumentWebInstanceSerializer(serializers.Serializer):
source_url_list = serializers.ListField(required=True, label=_('document url list'),
child=serializers.CharField(required=True, label=_('document url list')))
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('selector'))
class DocumentInstanceQASerializer(serializers.Serializer):
file_list = serializers.ListSerializer(required=True, label=_('file list'),
child=serializers.FileField(required=True, label=_('file')))
class DocumentInstanceTableSerializer(serializers.Serializer):
file_list = serializers.ListSerializer(required=True, label=_('file list'),
child=serializers.FileField(required=True, label=_('file')))
class DocumentSerializers(serializers.Serializer): class DocumentSerializers(serializers.Serializer):
class Query(serializers.Serializer): class Query(serializers.Serializer):
# 知识库id # 知识库id
@ -226,6 +286,7 @@ class DocumentSerializers(serializers.Serializer):
return True return True
class Operate(serializers.Serializer): class Operate(serializers.Serializer):
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
document_id = serializers.UUIDField(required=True, label=_('document id')) document_id = serializers.UUIDField(required=True, label=_('document id'))
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
@ -246,6 +307,31 @@ class DocumentSerializers(serializers.Serializer):
}, select_string=get_file_content( }, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')), with_search_one=True) os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')), with_search_one=True)
def edit(self, instance: Dict, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
_document = QuerySet(Document).get(id=self.data.get("document_id"))
if with_valid:
DocumentEditInstanceSerializer(data=instance).is_valid(document=_document)
update_keys = ['name', 'is_active', 'hit_handling_method', 'directly_return_similarity', 'meta']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
_document.__setattr__(update_key, instance.get(update_key))
_document.save()
return self.one()
@transaction.atomic
def delete(self):
document_id = self.data.get("document_id")
QuerySet(model=Document).filter(id=document_id).delete()
# 删除段落
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
# 删除问题
delete_problems_and_mappings([document_id])
# 删除向量库
delete_embedding_by_document(document_id)
return True
def refresh(self, state_list=None, with_valid=True): def refresh(self, state_list=None, with_valid=True):
if state_list is None: if state_list is None:
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value, state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
@ -369,6 +455,58 @@ class DocumentSerializers(serializers.Serializer):
instance.get('paragraphs') if 'paragraphs' in instance else [] instance.get('paragraphs') if 'paragraphs' in instance else []
) )
def save_web(self, instance: Dict, with_valid=True):
if with_valid:
DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True)
self.is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
source_url_list = instance.get('source_url_list')
selector = instance.get('selector')
sync_web_document.delay(dataset_id, source_url_list, selector)
def save_qa(self, instance: Dict, with_valid=True):
if with_valid:
DocumentInstanceQASerializer(data=instance).is_valid(raise_exception=True)
self.is_valid(raise_exception=True)
file_list = instance.get('file_list')
document_list = flat_map([self.parse_qa_file(file) for file in file_list])
return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list)
def save_table(self, instance: Dict, with_valid=True):
if with_valid:
DocumentInstanceTableSerializer(data=instance).is_valid(raise_exception=True)
self.is_valid(raise_exception=True)
file_list = instance.get('file_list')
document_list = flat_map([self.parse_table_file(file) for file in file_list])
return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list)
def parse_qa_file(self, file):
get_buffer = FileBufferHandle().get_buffer
for parse_qa_handle in parse_qa_handle_list:
if parse_qa_handle.support(file, get_buffer):
return parse_qa_handle.handle(file, get_buffer, self.save_image)
raise AppApiException(500, _('Unsupported file format'))
def parse_table_file(self, file):
get_buffer = FileBufferHandle().get_buffer
for parse_table_handle in parse_table_handle_list:
if parse_table_handle.support(file, get_buffer):
return parse_table_handle.handle(file, get_buffer, self.save_image)
raise AppApiException(500, _('Unsupported file format'))
def save_image(self, image_list):
if image_list is not None and len(image_list) > 0:
exist_image_list = [str(i.get('id')) for i in
QuerySet(File).filter(id__in=[i.id for i in image_list]).values('id')]
save_image_list = [image for image in image_list if not exist_image_list.__contains__(str(image.id))]
save_image_list = list({img.id: img for img in save_image_list}.values())
# save image
for file in save_image_list:
file_bytes = file.meta.pop('content')
file.workspace_id = self.data.get('workspace_id')
file.meta['knowledge_id'] = self.data.get('knowledge_id')
file.save(file_bytes)
class Split(serializers.Serializer): class Split(serializers.Serializer):
workspace_id = serializers.CharField(required=True, label=_('workspace id')) workspace_id = serializers.CharField(required=True, label=_('workspace id'))
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))

View File

@ -11,6 +11,9 @@ urlpatterns = [
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document', views.DocumentView.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document', views.DocumentView.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/batch', views.DocumentView.Batch.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/batch', views.DocumentView.Batch.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/web', views.WebDocumentView.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/qa', views.QaDocumentView.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/table', views.TableDocumentView.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>', views.DocumentView.Operate.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>', views.DocumentView.Operate.as_view()),
path('workspace/<str:workspace_id>/knowledge/<int:current_page>/<int:page_size>', views.KnowledgeView.Page.as_view()), path('workspace/<str:workspace_id>/knowledge/<int:current_page>/<int:page_size>', views.KnowledgeView.Page.as_view()),
] ]

View File

@ -9,7 +9,8 @@ from common.auth.authentication import has_permissions
from common.constants.permission_constants import PermissionConstants from common.constants.permission_constants import PermissionConstants
from common.result import result from common.result import result
from knowledge.api.document import DocumentSplitAPI, DocumentBatchAPI, DocumentBatchCreateAPI, DocumentCreateAPI, \ from knowledge.api.document import DocumentSplitAPI, DocumentBatchAPI, DocumentBatchCreateAPI, DocumentCreateAPI, \
DocumentReadAPI, DocumentEditAPI, DocumentDeleteAPI DocumentReadAPI, DocumentEditAPI, DocumentDeleteAPI, TableDocumentCreateAPI, QaDocumentCreateAPI, \
WebDocumentCreateAPI
from knowledge.api.knowledge import KnowledgeTreeReadAPI from knowledge.api.knowledge import KnowledgeTreeReadAPI
from knowledge.serializers.document import DocumentSerializers from knowledge.serializers.document import DocumentSerializers
@ -68,8 +69,10 @@ class DocumentView(APIView):
tags=[_('Knowledge Base/Documentation')] tags=[_('Knowledge Base/Documentation')]
) )
@has_permissions(PermissionConstants.DOCUMENT_READ.get_workspace_permission()) @has_permissions(PermissionConstants.DOCUMENT_READ.get_workspace_permission())
def get(self, request: Request, knowledge_id: str, document_id: str): def get(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str):
operate = DocumentSerializers.Operate(data={'document_id': document_id, 'knowledge_id': knowledge_id}) operate = DocumentSerializers.Operate(data={
'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id
})
operate.is_valid(raise_exception=True) operate.is_valid(raise_exception=True)
return result.success(operate.one()) return result.success(operate.one())
@ -83,11 +86,10 @@ class DocumentView(APIView):
tags=[_('Knowledge Base/Documentation')] tags=[_('Knowledge Base/Documentation')]
) )
@has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission()) @has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission())
def put(self, request: Request, knowledge_id: str, document_id: str): def put(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str):
return result.success( return result.success(DocumentSerializers.Operate(data={
DocumentSerializers.Operate(data={'document_id': document_id, 'knowledge_id': knowledge_id}).edit( 'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id
request.data, }).edit(request.data, with_valid=True))
with_valid=True))
@extend_schema( @extend_schema(
description=_('Delete document'), description=_('Delete document'),
@ -98,8 +100,10 @@ class DocumentView(APIView):
tags=[_('Knowledge Base/Documentation')] tags=[_('Knowledge Base/Documentation')]
) )
@has_permissions(PermissionConstants.DOCUMENT_DELETE.get_workspace_permission()) @has_permissions(PermissionConstants.DOCUMENT_DELETE.get_workspace_permission())
def delete(self, request: Request, knowledge_id: str, document_id: str): def delete(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str):
operate = DocumentSerializers.Operate(data={'document_id': document_id, 'knowledge_id': knowledge_id}) operate = DocumentSerializers.Operate(data={
'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id
})
operate.is_valid(raise_exception=True) operate.is_valid(raise_exception=True)
return result.success(operate.delete()) return result.success(operate.delete())
@ -195,3 +199,63 @@ class DocumentView(APIView):
return result.success(DocumentSerializers.Batch( return result.success(DocumentSerializers.Batch(
data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id} data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id}
).batch_delete(request.data)) ).batch_delete(request.data))
class WebDocumentView(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
methods=['POST'],
description=_('Create Web site documents'),
summary=_('Create Web site documents'),
operation_id=_('Create Web site documents'),
request=WebDocumentCreateAPI.get_request(),
parameters=WebDocumentCreateAPI.get_parameters(),
responses=WebDocumentCreateAPI.get_response(),
tags=[_('Knowledge Base/Documentation')]
)
@has_permissions(PermissionConstants.DOCUMENT_CREATE.get_workspace_permission())
def post(self, request: Request, workspace_id: str, knowledge_id: str):
return result.success(DocumentSerializers.Create(data={
'knowledge_id': knowledge_id, 'workspace_id': workspace_id
}).save_web(request.data, with_valid=True))
class QaDocumentView(APIView):
authentication_classes = [TokenAuth]
parser_classes = [MultiPartParser]
@extend_schema(
summary=_('Import QA and create documentation'),
description=_('Import QA and create documentation'),
operation_id=_('Import QA and create documentation'),
request=QaDocumentCreateAPI.get_request(),
parameters=QaDocumentCreateAPI.get_parameters(),
responses=QaDocumentCreateAPI.get_response(),
tags=[_('Knowledge Base/Documentation')]
)
@has_permissions(PermissionConstants.DOCUMENT_CREATE.get_workspace_permission())
def post(self, request: Request, workspace_id: str, knowledge_id: str):
return result.success(DocumentSerializers.Create(data={
'knowledge_id': knowledge_id, 'workspace_id': workspace_id
}).save_qa({'file_list': request.FILES.getlist('file')}, with_valid=True))
class TableDocumentView(APIView):
authentication_classes = [TokenAuth]
parser_classes = [MultiPartParser]
@extend_schema(
summary=_('Import tables and create documents'),
description=_('Import tables and create documents'),
operation_id=_('Import tables and create documents'),
request=TableDocumentCreateAPI.get_request(),
parameters=TableDocumentCreateAPI.get_parameters(),
responses=TableDocumentCreateAPI.get_response(),
tags=[_('Knowledge Base/Documentation')]
)
@has_permissions(PermissionConstants.DOCUMENT_CREATE.get_workspace_permission())
def post(self, request: Request, workspace_id: str, knowledge_id: str):
return result.success(DocumentSerializers.Create(
data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id}
).save_table({'file_list': request.FILES.getlist('file')}, with_valid=True))