feat: add BatchGenerateRelated API for batch processing of related paragraphs

This commit is contained in:
CaptainB 2025-05-08 19:28:58 +08:00
parent ac698e0c4c
commit a75501737f
4 changed files with 69 additions and 3 deletions

View File

@ -4,7 +4,7 @@ from drf_spectacular.utils import OpenApiParameter
from common.mixins.api_mixin import APIMixin from common.mixins.api_mixin import APIMixin
from common.result import DefaultResultSerializer, ResultSerializer from common.result import DefaultResultSerializer, ResultSerializer
from knowledge.serializers.common import BatchSerializer from knowledge.serializers.common import BatchSerializer
from knowledge.serializers.paragraph import ParagraphSerializer from knowledge.serializers.paragraph import ParagraphSerializer, ParagraphBatchGenerateRelatedSerializer
from knowledge.serializers.problem import ProblemSerializer from knowledge.serializers.problem import ProblemSerializer
@ -106,6 +106,16 @@ class ParagraphBatchDeleteAPI(ParagraphCreateAPI):
return DefaultResultSerializer return DefaultResultSerializer
class ParagraphBatchGenerateRelatedAPI(ParagraphCreateAPI):
@staticmethod
def get_request():
return ParagraphBatchGenerateRelatedSerializer
@staticmethod
def get_response():
return DefaultResultSerializer
class ParagraphGetAPI(APIMixin): class ParagraphGetAPI(APIMixin):
@staticmethod @staticmethod
def get_parameters(): def get_parameters():

View File

@ -3,15 +3,17 @@
from typing import Dict from typing import Dict
import uuid_utils.compat as uuid import uuid_utils.compat as uuid
from celery_once import AlreadyQueued
from django.db import transaction from django.db import transaction
from django.db.models import QuerySet, Count from django.db.models import QuerySet, Count
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
from common.db.search import page_search from common.db.search import page_search
from common.event import ListenerManagement
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.utils.common import post from common.utils.common import post
from knowledge.models import Paragraph, Problem, Document, ProblemParagraphMapping, SourceType from knowledge.models import Paragraph, Problem, Document, ProblemParagraphMapping, SourceType, TaskType, State
from knowledge.serializers.common import ProblemParagraphObject, ProblemParagraphManage, \ from knowledge.serializers.common import ProblemParagraphObject, ProblemParagraphManage, \
get_embedding_model_id_by_knowledge_id, update_document_char_length, BatchSerializer get_embedding_model_id_by_knowledge_id, update_document_char_length, BatchSerializer
from knowledge.serializers.problem import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers from knowledge.serializers.problem import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
@ -19,6 +21,7 @@ from knowledge.task.embedding import embedding_by_paragraph, enable_embedding_by
disable_embedding_by_paragraph, \ disable_embedding_by_paragraph, \
delete_embedding_by_paragraph, embedding_by_problem as embedding_by_problem_task, delete_embedding_by_paragraph_ids, \ delete_embedding_by_paragraph, embedding_by_problem as embedding_by_problem_task, delete_embedding_by_paragraph_ids, \
embedding_by_problem, delete_embedding_by_source embedding_by_problem, delete_embedding_by_source
from knowledge.task.generate import generate_related_by_paragraph_id_list
class ParagraphSerializer(serializers.ModelSerializer): class ParagraphSerializer(serializers.ModelSerializer):
@ -47,6 +50,15 @@ class EditParagraphSerializers(serializers.Serializer):
problem_list = ProblemInstanceSerializer(required=False, many=True) problem_list = ProblemInstanceSerializer(required=False, many=True)
class ParagraphBatchGenerateRelatedSerializer(serializers.Serializer):
paragraph_id_list = serializers.ListField(required=True, label=_('paragraph id list'),
child=serializers.UUIDField(required=True, label=_('paragraph id')))
model_id = serializers.UUIDField(required=True, label=_('model id'))
prompt = serializers.CharField(required=True, label=_('prompt'), max_length=102400, allow_null=True,
allow_blank=True)
document_id = serializers.UUIDField(required=True, label=_('document id'))
class ParagraphSerializers(serializers.Serializer): class ParagraphSerializers(serializers.Serializer):
title = serializers.CharField(required=False, max_length=256, label=_('section title'), allow_null=True, title = serializers.CharField(required=False, max_length=256, label=_('section title'), allow_null=True,
allow_blank=True) allow_blank=True)
@ -379,6 +391,29 @@ class ParagraphSerializers(serializers.Serializer):
delete_embedding_by_paragraph_ids(paragraph_id_list) delete_embedding_by_paragraph_ids(paragraph_id_list)
return True return True
def batch_generate_related(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
paragraph_id_list = instance.get("paragraph_id_list")
model_id = instance.get("model_id")
prompt = instance.get("prompt")
document_id = self.data.get('document_id')
ListenerManagement.update_status(
QuerySet(Document).filter(id=document_id),
TaskType.GENERATE_PROBLEM,
State.PENDING
)
ListenerManagement.update_status(
QuerySet(Paragraph).filter(id__in=paragraph_id_list),
TaskType.GENERATE_PROBLEM,
State.PENDING
)
ListenerManagement.get_aggregation_document_status(document_id)()
try:
generate_related_by_paragraph_id_list.delay(document_id, paragraph_id_list, model_id, prompt)
except AlreadyQueued as e:
raise AppApiException(500, _('The task is being executed, please do not send it again.'))
def delete_problems_and_mappings(paragraph_ids): def delete_problems_and_mappings(paragraph_ids):
problem_paragraph_mappings = ProblemParagraphMapping.objects.filter(paragraph_id__in=paragraph_ids) problem_paragraph_mappings = ProblemParagraphMapping.objects.filter(paragraph_id__in=paragraph_ids)

View File

@ -31,6 +31,7 @@ urlpatterns = [
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/batch_cancel_task', views.DocumentView.BatchCancelTask.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/batch_cancel_task', views.DocumentView.BatchCancelTask.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph', views.ParagraphView.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph', views.ParagraphView.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/batch_delete', views.ParagraphView.BatchDelete.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/batch_delete', views.ParagraphView.BatchDelete.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/batch_generate_related', views.ParagraphView.BatchGenerateRelated.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/<str:paragraph_id>', views.ParagraphView.Operate.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/<str:paragraph_id>', views.ParagraphView.Operate.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem', views.ParagraphView.Problem.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem', views.ParagraphView.Problem.as_view()),
path( 'workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/<int:current_page>/<int:page_size>', views.ParagraphView.Page.as_view()), path( 'workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/<int:current_page>/<int:page_size>', views.ParagraphView.Page.as_view()),

View File

@ -9,7 +9,8 @@ from common.constants.permission_constants import PermissionConstants
from common.result import result from common.result import result
from common.utils.common import query_params_to_single_dict from common.utils.common import query_params_to_single_dict
from knowledge.api.paragraph import ParagraphReadAPI, ParagraphCreateAPI, ParagraphBatchDeleteAPI, ParagraphEditAPI, \ from knowledge.api.paragraph import ParagraphReadAPI, ParagraphCreateAPI, ParagraphBatchDeleteAPI, ParagraphEditAPI, \
ParagraphGetAPI, ProblemCreateAPI, UnAssociationAPI, AssociationAPI, ParagraphPageAPI ParagraphGetAPI, ProblemCreateAPI, UnAssociationAPI, AssociationAPI, ParagraphPageAPI, \
ParagraphBatchGenerateRelatedAPI
from knowledge.serializers.paragraph import ParagraphSerializers from knowledge.serializers.paragraph import ParagraphSerializers
@ -70,6 +71,25 @@ class ParagraphView(APIView):
data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id, 'document_id': document_id} data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id, 'document_id': document_id}
).batch_delete(request.data)) ).batch_delete(request.data))
class BatchGenerateRelated(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
methods=['PUT'],
summary=_('Batch Generate Related'),
description=_('Batch Generate Related'),
operation_id=_('Batch Generate Related'),
parameters=ParagraphBatchGenerateRelatedAPI.get_parameters(),
request=ParagraphBatchGenerateRelatedAPI.get_request(),
responses=ParagraphBatchGenerateRelatedAPI.get_response(),
tags=[_('Knowledge Base/Documentation/Paragraph')]
)
@has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission())
def put(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str):
return result.success(ParagraphSerializers.Batch(
data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id, 'document_id': document_id}
).batch_generate_related(request.data))
class Operate(APIView): class Operate(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]