feat: implement batch processing for document creation, synchronization, and deletion
This commit is contained in:
parent
43bef216d5
commit
0d3eb431f6
@ -263,3 +263,10 @@ def parse_md_image(content: str):
|
|||||||
image_list = [match.group() for match in matches]
|
image_list = [match.group() for match in matches]
|
||||||
return image_list
|
return image_list
|
||||||
|
|
||||||
|
def bulk_create_in_batches(model, data, batch_size=1000):
|
||||||
|
if len(data) == 0:
|
||||||
|
return
|
||||||
|
for i in range(0, len(data), batch_size):
|
||||||
|
batch = data[i:i + batch_size]
|
||||||
|
model.objects.bulk_create(batch)
|
||||||
|
|
||||||
|
|||||||
@ -2,16 +2,11 @@ from drf_spectacular.types import OpenApiTypes
|
|||||||
from drf_spectacular.utils import OpenApiParameter
|
from drf_spectacular.utils import OpenApiParameter
|
||||||
|
|
||||||
from common.mixins.api_mixin import APIMixin
|
from common.mixins.api_mixin import APIMixin
|
||||||
from common.result import DefaultResultSerializer, ResultSerializer
|
from common.result import DefaultResultSerializer
|
||||||
|
from knowledge.serializers.common import BatchSerializer
|
||||||
from knowledge.serializers.document import DocumentCreateRequest
|
from knowledge.serializers.document import DocumentCreateRequest
|
||||||
|
|
||||||
|
|
||||||
class DocumentCreateResponse(ResultSerializer):
|
|
||||||
@staticmethod
|
|
||||||
def get_data():
|
|
||||||
return DefaultResultSerializer()
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentCreateAPI(APIMixin):
|
class DocumentCreateAPI(APIMixin):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_parameters():
|
def get_parameters():
|
||||||
@ -31,7 +26,7 @@ class DocumentCreateAPI(APIMixin):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_response():
|
def get_response():
|
||||||
return DocumentCreateResponse
|
return DefaultResultSerializer
|
||||||
|
|
||||||
|
|
||||||
class DocumentSplitAPI(APIMixin):
|
class DocumentSplitAPI(APIMixin):
|
||||||
@ -75,3 +70,31 @@ class DocumentSplitAPI(APIMixin):
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentBatchAPI(APIMixin):
|
||||||
|
@staticmethod
|
||||||
|
def get_parameters():
|
||||||
|
return [
|
||||||
|
OpenApiParameter(
|
||||||
|
name="workspace_id",
|
||||||
|
description="工作空间id",
|
||||||
|
type=OpenApiTypes.STR,
|
||||||
|
location='path',
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
OpenApiParameter(
|
||||||
|
name="knowledge_id",
|
||||||
|
description="知识库id",
|
||||||
|
type=OpenApiTypes.STR,
|
||||||
|
location='path',
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_request():
|
||||||
|
return BatchSerializer
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_response():
|
||||||
|
return DefaultResultSerializer
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from rest_framework import serializers
|
|||||||
|
|
||||||
from common.db.search import native_search
|
from common.db.search import native_search
|
||||||
from common.event import ListenerManagement
|
from common.event import ListenerManagement
|
||||||
|
from common.event.common import work_thread_pool
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
from common.handle.impl.text.csv_split_handle import CsvSplitHandle
|
from common.handle.impl.text.csv_split_handle import CsvSplitHandle
|
||||||
from common.handle.impl.text.doc_split_handle import DocSplitHandle
|
from common.handle.impl.text.doc_split_handle import DocSplitHandle
|
||||||
@ -21,12 +22,13 @@ from common.handle.impl.text.text_split_handle import TextSplitHandle
|
|||||||
from common.handle.impl.text.xls_split_handle import XlsSplitHandle
|
from common.handle.impl.text.xls_split_handle import XlsSplitHandle
|
||||||
from common.handle.impl.text.xlsx_split_handle import XlsxSplitHandle
|
from common.handle.impl.text.xlsx_split_handle import XlsxSplitHandle
|
||||||
from common.handle.impl.text.zip_split_handle import ZipSplitHandle
|
from common.handle.impl.text.zip_split_handle import ZipSplitHandle
|
||||||
from common.utils.common import post, get_file_content
|
from common.utils.common import post, get_file_content, bulk_create_in_batches
|
||||||
from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \
|
from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \
|
||||||
TaskType, File
|
TaskType, File
|
||||||
from knowledge.serializers.common import ProblemParagraphManage
|
from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer
|
||||||
from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer
|
from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer, \
|
||||||
from knowledge.task import embedding_by_document
|
delete_problems_and_mappings
|
||||||
|
from knowledge.task import embedding_by_document, delete_embedding_by_document_list
|
||||||
from maxkb.const import PROJECT_DIR
|
from maxkb.const import PROJECT_DIR
|
||||||
|
|
||||||
default_split_handle = TextSplitHandle()
|
default_split_handle = TextSplitHandle()
|
||||||
@ -42,6 +44,19 @@ split_handles = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BatchCancelInstanceSerializer(serializers.Serializer):
|
||||||
|
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=('id list'))
|
||||||
|
type = serializers.IntegerField(required=True, label=_('task type'))
|
||||||
|
|
||||||
|
def is_valid(self, *, raise_exception=False):
|
||||||
|
super().is_valid(raise_exception=True)
|
||||||
|
_type = self.data.get('type')
|
||||||
|
try:
|
||||||
|
TaskType(_type)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppApiException(500, _('task type not support'))
|
||||||
|
|
||||||
|
|
||||||
class DocumentInstanceSerializer(serializers.Serializer):
|
class DocumentInstanceSerializer(serializers.Serializer):
|
||||||
name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1)
|
name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1)
|
||||||
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
|
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
|
||||||
@ -65,6 +80,17 @@ class DocumentSplitRequest(serializers.Serializer):
|
|||||||
with_filter = serializers.BooleanField(required=False, label=_('Auto Clean'))
|
with_filter = serializers.BooleanField(required=False, label=_('Auto Clean'))
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentBatchRequest(serializers.Serializer):
|
||||||
|
file = serializers.ListField(required=True, label=_('file list'))
|
||||||
|
limit = serializers.IntegerField(required=False, label=_('limit'))
|
||||||
|
patterns = serializers.ListField(
|
||||||
|
required=False,
|
||||||
|
child=serializers.CharField(required=True, label=_('patterns')),
|
||||||
|
label=_('patterns')
|
||||||
|
)
|
||||||
|
with_filter = serializers.BooleanField(required=False, label=_('Auto Clean'))
|
||||||
|
|
||||||
|
|
||||||
class DocumentSerializers(serializers.Serializer):
|
class DocumentSerializers(serializers.Serializer):
|
||||||
class Operate(serializers.Serializer):
|
class Operate(serializers.Serializer):
|
||||||
document_id = serializers.UUIDField(required=True, label=_('document id'))
|
document_id = serializers.UUIDField(required=True, label=_('document id'))
|
||||||
@ -264,6 +290,150 @@ class DocumentSerializers(serializers.Serializer):
|
|||||||
return result
|
return result
|
||||||
return [result]
|
return [result]
|
||||||
|
|
||||||
|
class Batch(serializers.Serializer):
|
||||||
|
workspace_id = serializers.UUIDField(required=True, label=_('workspace id'))
|
||||||
|
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def post_embedding(document_list, knowledge_id):
|
||||||
|
for document_dict in document_list:
|
||||||
|
DocumentSerializers.Operate(
|
||||||
|
data={'knowledge_id': knowledge_id, 'document_id': document_dict.get('id')}).refresh()
|
||||||
|
return document_list
|
||||||
|
|
||||||
|
@post(post_function=post_embedding)
|
||||||
|
@transaction.atomic
|
||||||
|
def batch_save(self, instance_list: List[Dict], with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
DocumentInstanceSerializer(many=True, data=instance_list).is_valid(raise_exception=True)
|
||||||
|
knowledge_id = self.data.get("knowledge_id")
|
||||||
|
document_model_list = []
|
||||||
|
paragraph_model_list = []
|
||||||
|
problem_paragraph_object_list = []
|
||||||
|
# 插入文档
|
||||||
|
for document in instance_list:
|
||||||
|
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(knowledge_id,
|
||||||
|
document)
|
||||||
|
document_model_list.append(document_paragraph_dict_model.get('document'))
|
||||||
|
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
|
||||||
|
paragraph_model_list.append(paragraph)
|
||||||
|
for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
|
||||||
|
problem_paragraph_object_list.append(problem_paragraph_object)
|
||||||
|
|
||||||
|
problem_model_list, problem_paragraph_mapping_list = (
|
||||||
|
ProblemParagraphManage(problem_paragraph_object_list, knowledge_id).to_problem_model_list()
|
||||||
|
)
|
||||||
|
# 插入文档
|
||||||
|
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
|
||||||
|
# 批量插入段落
|
||||||
|
bulk_create_in_batches(Paragraph, paragraph_model_list, batch_size=1000)
|
||||||
|
# 批量插入问题
|
||||||
|
bulk_create_in_batches(Problem, problem_model_list, batch_size=1000)
|
||||||
|
# 批量插入关联问题
|
||||||
|
bulk_create_in_batches(ProblemParagraphMapping, problem_paragraph_mapping_list, batch_size=1000)
|
||||||
|
# 查询文档
|
||||||
|
query_set = QuerySet(model=Document)
|
||||||
|
if len(document_model_list) == 0:
|
||||||
|
return [], knowledge_id
|
||||||
|
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
|
||||||
|
return native_search(
|
||||||
|
{
|
||||||
|
'document_custom_sql': query_set,
|
||||||
|
'order_by_query': QuerySet(Document).order_by('-create_time', 'id')
|
||||||
|
},
|
||||||
|
select_string=get_file_content(
|
||||||
|
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')
|
||||||
|
),
|
||||||
|
with_search_one=False
|
||||||
|
), knowledge_id
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _batch_sync(document_id_list: List[str]):
|
||||||
|
for document_id in document_id_list:
|
||||||
|
DocumentSerializers.Sync(data={'document_id': document_id}).sync()
|
||||||
|
|
||||||
|
def batch_sync(self, instance: Dict, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
# 异步同步
|
||||||
|
work_thread_pool.submit(self._batch_sync, instance.get('id_list'))
|
||||||
|
return True
|
||||||
|
|
||||||
|
@transaction.atomic
|
||||||
|
def batch_delete(self, instance: Dict, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
document_id_list = instance.get("id_list")
|
||||||
|
QuerySet(Document).filter(id__in=document_id_list).delete()
|
||||||
|
QuerySet(Paragraph).filter(document_id__in=document_id_list).delete()
|
||||||
|
delete_problems_and_mappings(document_id_list)
|
||||||
|
# 删除向量库
|
||||||
|
delete_embedding_by_document_list(document_id_list)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def batch_cancel(self, instance: Dict, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
BatchCancelInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
||||||
|
document_id_list = instance.get("id_list")
|
||||||
|
ListenerManagement.update_status(
|
||||||
|
QuerySet(Paragraph).annotate(
|
||||||
|
reversed_status=Reverse('status'),
|
||||||
|
task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value, 1),
|
||||||
|
).filter(
|
||||||
|
task_type_status__in=[State.PENDING.value, State.STARTED.value]
|
||||||
|
).filter(
|
||||||
|
document_id__in=document_id_list
|
||||||
|
).values('id'),
|
||||||
|
TaskType(instance.get('type')),
|
||||||
|
State.REVOKE
|
||||||
|
)
|
||||||
|
ListenerManagement.update_status(
|
||||||
|
QuerySet(Document).annotate(
|
||||||
|
reversed_status=Reverse('status'),
|
||||||
|
task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value, 1),
|
||||||
|
).filter(
|
||||||
|
task_type_status__in=[State.PENDING.value, State.STARTED.value]
|
||||||
|
).filter(
|
||||||
|
id__in=document_id_list
|
||||||
|
).values('id'),
|
||||||
|
TaskType(instance.get('type')),
|
||||||
|
State.REVOKE
|
||||||
|
)
|
||||||
|
|
||||||
|
def batch_edit_hit_handling(self, instance: Dict, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
|
||||||
|
hit_handling_method = instance.get('hit_handling_method')
|
||||||
|
if hit_handling_method is None:
|
||||||
|
raise AppApiException(500, _('Hit handling method is required'))
|
||||||
|
if hit_handling_method != 'optimization' and hit_handling_method != 'directly_return':
|
||||||
|
raise AppApiException(500, _('The hit processing method must be directly_return|optimization'))
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
document_id_list = instance.get("id_list")
|
||||||
|
hit_handling_method = instance.get('hit_handling_method')
|
||||||
|
directly_return_similarity = instance.get('directly_return_similarity')
|
||||||
|
update_dict = {'hit_handling_method': hit_handling_method}
|
||||||
|
if directly_return_similarity is not None:
|
||||||
|
update_dict['directly_return_similarity'] = directly_return_similarity
|
||||||
|
QuerySet(Document).filter(id__in=document_id_list).update(**update_dict)
|
||||||
|
|
||||||
|
def batch_refresh(self, instance: Dict, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
document_id_list = instance.get("id_list")
|
||||||
|
state_list = instance.get("state_list")
|
||||||
|
knowledge_id = self.data.get('knowledge_id')
|
||||||
|
for document_id in document_id_list:
|
||||||
|
try:
|
||||||
|
DocumentSerializers.Operate(
|
||||||
|
data={'knowledge_id': knowledge_id, 'document_id': document_id}).refresh(state_list)
|
||||||
|
except AlreadyQueued as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FileBufferHandle:
|
class FileBufferHandle:
|
||||||
buffer = None
|
buffer = None
|
||||||
|
|||||||
@ -9,5 +9,6 @@ urlpatterns = [
|
|||||||
path('workspace/<str:workspace_id>/knowledge/web', views.KnowledgeWebView.as_view()),
|
path('workspace/<str:workspace_id>/knowledge/web', views.KnowledgeWebView.as_view()),
|
||||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>', views.KnowledgeView.Operate.as_view()),
|
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>', views.KnowledgeView.Operate.as_view()),
|
||||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()),
|
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()),
|
||||||
|
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/batch', views.DocumentView.Batch.as_view()),
|
||||||
path('workspace/<str:workspace_id>/knowledge/<int:current_page>/<int:page_size>', views.KnowledgeView.Page.as_view()),
|
path('workspace/<str:workspace_id>/knowledge/<int:current_page>/<int:page_size>', views.KnowledgeView.Page.as_view()),
|
||||||
]
|
]
|
||||||
|
|||||||
@ -6,9 +6,9 @@ from rest_framework.views import APIView
|
|||||||
|
|
||||||
from common.auth import TokenAuth
|
from common.auth import TokenAuth
|
||||||
from common.auth.authentication import has_permissions
|
from common.auth.authentication import has_permissions
|
||||||
from common.constants.permission_constants import PermissionConstants, CompareConstants
|
from common.constants.permission_constants import PermissionConstants
|
||||||
from common.result import result
|
from common.result import result
|
||||||
from knowledge.api.document import DocumentSplitAPI
|
from knowledge.api.document import DocumentSplitAPI, DocumentBatchAPI
|
||||||
from knowledge.api.knowledge import KnowledgeTreeReadAPI
|
from knowledge.api.knowledge import KnowledgeTreeReadAPI
|
||||||
from knowledge.serializers.document import DocumentSerializers
|
from knowledge.serializers.document import DocumentSerializers
|
||||||
from knowledge.serializers.knowledge import KnowledgeSerializer
|
from knowledge.serializers.knowledge import KnowledgeSerializer
|
||||||
@ -68,3 +68,60 @@ class DocumentView(APIView):
|
|||||||
'workspace_id': workspace_id,
|
'workspace_id': workspace_id,
|
||||||
'knowledge_id': knowledge_id,
|
'knowledge_id': knowledge_id,
|
||||||
}).parse(split_data))
|
}).parse(split_data))
|
||||||
|
|
||||||
|
class Batch(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@extend_schema(
|
||||||
|
methods=['POST'],
|
||||||
|
description=_('Create documents in batches'),
|
||||||
|
operation_id=_('Create documents in batches'),
|
||||||
|
request=DocumentBatchAPI.get_request(),
|
||||||
|
parameters=DocumentBatchAPI.get_parameters(),
|
||||||
|
responses=DocumentBatchAPI.get_response(),
|
||||||
|
tags=[_('Knowledge Base/Documentation')]
|
||||||
|
)
|
||||||
|
@has_permissions([
|
||||||
|
PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(),
|
||||||
|
PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(),
|
||||||
|
])
|
||||||
|
def post(self, request: Request, workspace_id: str, knowledge_id: str):
|
||||||
|
return result.success(DocumentSerializers.Batch(
|
||||||
|
data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id}
|
||||||
|
).batch_save(request.data))
|
||||||
|
|
||||||
|
@extend_schema(
|
||||||
|
methods=['PUT'],
|
||||||
|
description=_('Batch sync documents'),
|
||||||
|
operation_id=_('Batch sync documents'),
|
||||||
|
request=DocumentBatchAPI.get_request(),
|
||||||
|
parameters=DocumentBatchAPI.get_parameters(),
|
||||||
|
responses=DocumentBatchAPI.get_response(),
|
||||||
|
tags=[_('Knowledge Base/Documentation')]
|
||||||
|
)
|
||||||
|
@has_permissions([
|
||||||
|
PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(),
|
||||||
|
PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(),
|
||||||
|
])
|
||||||
|
def put(self, request: Request, workspace_id: str, knowledge_id: str):
|
||||||
|
return result.success(DocumentSerializers.Batch(
|
||||||
|
data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id}
|
||||||
|
).batch_sync(request.data))
|
||||||
|
|
||||||
|
@extend_schema(
|
||||||
|
methods=['DELETE'],
|
||||||
|
description=_('Delete documents in batches'),
|
||||||
|
operation_id=_('Delete documents in batches'),
|
||||||
|
request=DocumentBatchAPI.get_request(),
|
||||||
|
parameters=DocumentBatchAPI.get_parameters(),
|
||||||
|
responses=DocumentBatchAPI.get_response(),
|
||||||
|
tags=[_('Knowledge Base/Documentation')]
|
||||||
|
)
|
||||||
|
@has_permissions([
|
||||||
|
PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(),
|
||||||
|
PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(),
|
||||||
|
])
|
||||||
|
def delete(self, request: Request, workspace_id: str, knowledge_id: str):
|
||||||
|
return result.success(DocumentSerializers.Batch(
|
||||||
|
data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id}
|
||||||
|
).batch_delete(request.data))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user