feat: implement batch processing for document creation, synchronization, and deletion

This commit is contained in:
CaptainB 2025-04-30 17:48:31 +08:00
parent 43bef216d5
commit 0d3eb431f6
5 changed files with 272 additions and 14 deletions

View File

@ -263,3 +263,10 @@ def parse_md_image(content: str):
image_list = [match.group() for match in matches] image_list = [match.group() for match in matches]
return image_list return image_list
def bulk_create_in_batches(model, data, batch_size=1000):
if len(data) == 0:
return
for i in range(0, len(data), batch_size):
batch = data[i:i + batch_size]
model.objects.bulk_create(batch)

View File

@ -2,16 +2,11 @@ from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter from drf_spectacular.utils import OpenApiParameter
from common.mixins.api_mixin import APIMixin from common.mixins.api_mixin import APIMixin
from common.result import DefaultResultSerializer, ResultSerializer from common.result import DefaultResultSerializer
from knowledge.serializers.common import BatchSerializer
from knowledge.serializers.document import DocumentCreateRequest from knowledge.serializers.document import DocumentCreateRequest
class DocumentCreateResponse(ResultSerializer):
@staticmethod
def get_data():
return DefaultResultSerializer()
class DocumentCreateAPI(APIMixin): class DocumentCreateAPI(APIMixin):
@staticmethod @staticmethod
def get_parameters(): def get_parameters():
@ -31,7 +26,7 @@ class DocumentCreateAPI(APIMixin):
@staticmethod @staticmethod
def get_response(): def get_response():
return DocumentCreateResponse return DefaultResultSerializer
class DocumentSplitAPI(APIMixin): class DocumentSplitAPI(APIMixin):
@ -75,3 +70,31 @@ class DocumentSplitAPI(APIMixin):
), ),
] ]
class DocumentBatchAPI(APIMixin):
@staticmethod
def get_parameters():
return [
OpenApiParameter(
name="workspace_id",
description="工作空间id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="knowledge_id",
description="知识库id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
]
@staticmethod
def get_request():
return BatchSerializer
@staticmethod
def get_response():
return DefaultResultSerializer

View File

@ -12,6 +12,7 @@ from rest_framework import serializers
from common.db.search import native_search from common.db.search import native_search
from common.event import ListenerManagement from common.event import ListenerManagement
from common.event.common import work_thread_pool
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.handle.impl.text.csv_split_handle import CsvSplitHandle from common.handle.impl.text.csv_split_handle import CsvSplitHandle
from common.handle.impl.text.doc_split_handle import DocSplitHandle from common.handle.impl.text.doc_split_handle import DocSplitHandle
@ -21,12 +22,13 @@ from common.handle.impl.text.text_split_handle import TextSplitHandle
from common.handle.impl.text.xls_split_handle import XlsSplitHandle from common.handle.impl.text.xls_split_handle import XlsSplitHandle
from common.handle.impl.text.xlsx_split_handle import XlsxSplitHandle from common.handle.impl.text.xlsx_split_handle import XlsxSplitHandle
from common.handle.impl.text.zip_split_handle import ZipSplitHandle from common.handle.impl.text.zip_split_handle import ZipSplitHandle
from common.utils.common import post, get_file_content from common.utils.common import post, get_file_content, bulk_create_in_batches
from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \ from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \
TaskType, File TaskType, File
from knowledge.serializers.common import ProblemParagraphManage from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer
from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer, \
from knowledge.task import embedding_by_document delete_problems_and_mappings
from knowledge.task import embedding_by_document, delete_embedding_by_document_list
from maxkb.const import PROJECT_DIR from maxkb.const import PROJECT_DIR
default_split_handle = TextSplitHandle() default_split_handle = TextSplitHandle()
@ -42,6 +44,19 @@ split_handles = [
] ]
class BatchCancelInstanceSerializer(serializers.Serializer):
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=('id list'))
type = serializers.IntegerField(required=True, label=_('task type'))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
_type = self.data.get('type')
try:
TaskType(_type)
except Exception as e:
raise AppApiException(500, _('task type not support'))
class DocumentInstanceSerializer(serializers.Serializer): class DocumentInstanceSerializer(serializers.Serializer):
name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1) name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1)
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True) paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
@ -65,6 +80,17 @@ class DocumentSplitRequest(serializers.Serializer):
with_filter = serializers.BooleanField(required=False, label=_('Auto Clean')) with_filter = serializers.BooleanField(required=False, label=_('Auto Clean'))
class DocumentBatchRequest(serializers.Serializer):
file = serializers.ListField(required=True, label=_('file list'))
limit = serializers.IntegerField(required=False, label=_('limit'))
patterns = serializers.ListField(
required=False,
child=serializers.CharField(required=True, label=_('patterns')),
label=_('patterns')
)
with_filter = serializers.BooleanField(required=False, label=_('Auto Clean'))
class DocumentSerializers(serializers.Serializer): class DocumentSerializers(serializers.Serializer):
class Operate(serializers.Serializer): class Operate(serializers.Serializer):
document_id = serializers.UUIDField(required=True, label=_('document id')) document_id = serializers.UUIDField(required=True, label=_('document id'))
@ -264,6 +290,150 @@ class DocumentSerializers(serializers.Serializer):
return result return result
return [result] return [result]
class Batch(serializers.Serializer):
workspace_id = serializers.UUIDField(required=True, label=_('workspace id'))
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
@staticmethod
def post_embedding(document_list, knowledge_id):
for document_dict in document_list:
DocumentSerializers.Operate(
data={'knowledge_id': knowledge_id, 'document_id': document_dict.get('id')}).refresh()
return document_list
@post(post_function=post_embedding)
@transaction.atomic
def batch_save(self, instance_list: List[Dict], with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
DocumentInstanceSerializer(many=True, data=instance_list).is_valid(raise_exception=True)
knowledge_id = self.data.get("knowledge_id")
document_model_list = []
paragraph_model_list = []
problem_paragraph_object_list = []
# 插入文档
for document in instance_list:
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(knowledge_id,
document)
document_model_list.append(document_paragraph_dict_model.get('document'))
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
paragraph_model_list.append(paragraph)
for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
problem_paragraph_object_list.append(problem_paragraph_object)
problem_model_list, problem_paragraph_mapping_list = (
ProblemParagraphManage(problem_paragraph_object_list, knowledge_id).to_problem_model_list()
)
# 插入文档
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
# 批量插入段落
bulk_create_in_batches(Paragraph, paragraph_model_list, batch_size=1000)
# 批量插入问题
bulk_create_in_batches(Problem, problem_model_list, batch_size=1000)
# 批量插入关联问题
bulk_create_in_batches(ProblemParagraphMapping, problem_paragraph_mapping_list, batch_size=1000)
# 查询文档
query_set = QuerySet(model=Document)
if len(document_model_list) == 0:
return [], knowledge_id
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
return native_search(
{
'document_custom_sql': query_set,
'order_by_query': QuerySet(Document).order_by('-create_time', 'id')
},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')
),
with_search_one=False
), knowledge_id
@staticmethod
def _batch_sync(document_id_list: List[str]):
for document_id in document_id_list:
DocumentSerializers.Sync(data={'document_id': document_id}).sync()
def batch_sync(self, instance: Dict, with_valid=True):
if with_valid:
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
self.is_valid(raise_exception=True)
# 异步同步
work_thread_pool.submit(self._batch_sync, instance.get('id_list'))
return True
@transaction.atomic
def batch_delete(self, instance: Dict, with_valid=True):
if with_valid:
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
self.is_valid(raise_exception=True)
document_id_list = instance.get("id_list")
QuerySet(Document).filter(id__in=document_id_list).delete()
QuerySet(Paragraph).filter(document_id__in=document_id_list).delete()
delete_problems_and_mappings(document_id_list)
# 删除向量库
delete_embedding_by_document_list(document_id_list)
return True
def batch_cancel(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
BatchCancelInstanceSerializer(data=instance).is_valid(raise_exception=True)
document_id_list = instance.get("id_list")
ListenerManagement.update_status(
QuerySet(Paragraph).annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value, 1),
).filter(
task_type_status__in=[State.PENDING.value, State.STARTED.value]
).filter(
document_id__in=document_id_list
).values('id'),
TaskType(instance.get('type')),
State.REVOKE
)
ListenerManagement.update_status(
QuerySet(Document).annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value, 1),
).filter(
task_type_status__in=[State.PENDING.value, State.STARTED.value]
).filter(
id__in=document_id_list
).values('id'),
TaskType(instance.get('type')),
State.REVOKE
)
def batch_edit_hit_handling(self, instance: Dict, with_valid=True):
if with_valid:
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
hit_handling_method = instance.get('hit_handling_method')
if hit_handling_method is None:
raise AppApiException(500, _('Hit handling method is required'))
if hit_handling_method != 'optimization' and hit_handling_method != 'directly_return':
raise AppApiException(500, _('The hit processing method must be directly_return|optimization'))
self.is_valid(raise_exception=True)
document_id_list = instance.get("id_list")
hit_handling_method = instance.get('hit_handling_method')
directly_return_similarity = instance.get('directly_return_similarity')
update_dict = {'hit_handling_method': hit_handling_method}
if directly_return_similarity is not None:
update_dict['directly_return_similarity'] = directly_return_similarity
QuerySet(Document).filter(id__in=document_id_list).update(**update_dict)
def batch_refresh(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
document_id_list = instance.get("id_list")
state_list = instance.get("state_list")
knowledge_id = self.data.get('knowledge_id')
for document_id in document_id_list:
try:
DocumentSerializers.Operate(
data={'knowledge_id': knowledge_id, 'document_id': document_id}).refresh(state_list)
except AlreadyQueued as e:
pass
class FileBufferHandle: class FileBufferHandle:
buffer = None buffer = None

View File

@ -9,5 +9,6 @@ urlpatterns = [
path('workspace/<str:workspace_id>/knowledge/web', views.KnowledgeWebView.as_view()), path('workspace/<str:workspace_id>/knowledge/web', views.KnowledgeWebView.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>', views.KnowledgeView.Operate.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>', views.KnowledgeView.Operate.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/batch', views.DocumentView.Batch.as_view()),
path('workspace/<str:workspace_id>/knowledge/<int:current_page>/<int:page_size>', views.KnowledgeView.Page.as_view()), path('workspace/<str:workspace_id>/knowledge/<int:current_page>/<int:page_size>', views.KnowledgeView.Page.as_view()),
] ]

View File

@ -6,9 +6,9 @@ from rest_framework.views import APIView
from common.auth import TokenAuth from common.auth import TokenAuth
from common.auth.authentication import has_permissions from common.auth.authentication import has_permissions
from common.constants.permission_constants import PermissionConstants, CompareConstants from common.constants.permission_constants import PermissionConstants
from common.result import result from common.result import result
from knowledge.api.document import DocumentSplitAPI from knowledge.api.document import DocumentSplitAPI, DocumentBatchAPI
from knowledge.api.knowledge import KnowledgeTreeReadAPI from knowledge.api.knowledge import KnowledgeTreeReadAPI
from knowledge.serializers.document import DocumentSerializers from knowledge.serializers.document import DocumentSerializers
from knowledge.serializers.knowledge import KnowledgeSerializer from knowledge.serializers.knowledge import KnowledgeSerializer
@ -68,3 +68,60 @@ class DocumentView(APIView):
'workspace_id': workspace_id, 'workspace_id': workspace_id,
'knowledge_id': knowledge_id, 'knowledge_id': knowledge_id,
}).parse(split_data)) }).parse(split_data))
class Batch(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
methods=['POST'],
description=_('Create documents in batches'),
operation_id=_('Create documents in batches'),
request=DocumentBatchAPI.get_request(),
parameters=DocumentBatchAPI.get_parameters(),
responses=DocumentBatchAPI.get_response(),
tags=[_('Knowledge Base/Documentation')]
)
@has_permissions([
PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(),
PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(),
])
def post(self, request: Request, workspace_id: str, knowledge_id: str):
return result.success(DocumentSerializers.Batch(
data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id}
).batch_save(request.data))
@extend_schema(
methods=['PUT'],
description=_('Batch sync documents'),
operation_id=_('Batch sync documents'),
request=DocumentBatchAPI.get_request(),
parameters=DocumentBatchAPI.get_parameters(),
responses=DocumentBatchAPI.get_response(),
tags=[_('Knowledge Base/Documentation')]
)
@has_permissions([
PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(),
PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(),
])
def put(self, request: Request, workspace_id: str, knowledge_id: str):
return result.success(DocumentSerializers.Batch(
data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id}
).batch_sync(request.data))
@extend_schema(
methods=['DELETE'],
description=_('Delete documents in batches'),
operation_id=_('Delete documents in batches'),
request=DocumentBatchAPI.get_request(),
parameters=DocumentBatchAPI.get_parameters(),
responses=DocumentBatchAPI.get_response(),
tags=[_('Knowledge Base/Documentation')]
)
@has_permissions([
PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(),
PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(),
])
def delete(self, request: Request, workspace_id: str, knowledge_id: str):
return result.success(DocumentSerializers.Batch(
data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id}
).batch_delete(request.data))