feat: add SyncWeb and GenerateRelated APIs for knowledge base synchronization and related generation

This commit is contained in:
CaptainB 2025-05-08 16:19:22 +08:00
parent 0ae489a50b
commit 7dcd1a71e8
5 changed files with 335 additions and 5 deletions

View File

@ -2,7 +2,8 @@ from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter from drf_spectacular.utils import OpenApiParameter
from common.mixins.api_mixin import APIMixin from common.mixins.api_mixin import APIMixin
from common.result import ResultSerializer from common.result import ResultSerializer, DefaultResultSerializer
from knowledge.serializers.common import GenerateRelatedSerializer
from knowledge.serializers.knowledge import KnowledgeBaseCreateRequest, KnowledgeModelSerializer, KnowledgeEditRequest, \ from knowledge.serializers.knowledge import KnowledgeBaseCreateRequest, KnowledgeModelSerializer, KnowledgeEditRequest, \
KnowledgeWebCreateRequest KnowledgeWebCreateRequest
@ -206,3 +207,34 @@ class KnowledgePageAPI(KnowledgeReadAPI):
required=False, required=False,
), ),
] ]
class SyncWebAPI(APIMixin):
@staticmethod
def get_parameters():
return [
OpenApiParameter(
name="workspace_id",
description="工作空间id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="knowledge_id",
description="知识库id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
]
@staticmethod
def get_response():
return DefaultResultSerializer
class GenerateRelatedAPI(SyncWebAPI):
@staticmethod
def get_request():
return GenerateRelatedSerializer

View File

@ -1,23 +1,34 @@
import logging
import os import os
import re
import traceback
from functools import reduce from functools import reduce
from typing import Dict from typing import Dict
import uuid_utils.compat as uuid import uuid_utils.compat as uuid
from celery_once import AlreadyQueued
from django.core import validators
from django.db import transaction, models from django.db import transaction, models
from django.db.models import QuerySet from django.db.models import QuerySet
from django.db.models.functions import Reverse, Substr
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
from common.db.search import native_search, get_dynamics_model, native_page_search from common.db.search import native_search, get_dynamics_model, native_page_search
from common.db.sql_execute import select_list from common.db.sql_execute import select_list
from common.event import ListenerManagement
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.utils.common import valid_license, post, get_file_content from common.utils.common import valid_license, post, get_file_content
from common.utils.fork import Fork, ChildLink
from common.utils.split_model import get_split_model
from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType, Document, Paragraph, Problem, \ from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType, Document, Paragraph, Problem, \
ProblemParagraphMapping, ApplicationKnowledgeMapping ProblemParagraphMapping, ApplicationKnowledgeMapping, TaskType, State
from knowledge.serializers.common import ProblemParagraphManage, get_embedding_model_id_by_knowledge_id, MetaSerializer from knowledge.serializers.common import ProblemParagraphManage, get_embedding_model_id_by_knowledge_id, MetaSerializer, \
GenerateRelatedSerializer
from knowledge.serializers.document import DocumentSerializers from knowledge.serializers.document import DocumentSerializers
from knowledge.task.embedding import embedding_by_knowledge, delete_embedding_by_knowledge from knowledge.task.embedding import embedding_by_knowledge, delete_embedding_by_knowledge
from knowledge.task.sync import sync_web_knowledge from knowledge.task.generate import generate_related_by_knowledge_id
from knowledge.task.sync import sync_web_knowledge, sync_replace_web_knowledge
from maxkb.conf import PROJECT_DIR from maxkb.conf import PROJECT_DIR
@ -137,6 +148,35 @@ class KnowledgeSerializer(serializers.Serializer):
workspace_id = serializers.CharField(required=True, label=_('workspace id')) workspace_id = serializers.CharField(required=True, label=_('workspace id'))
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
def generate_related(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
GenerateRelatedSerializer(data=instance).is_valid(raise_exception=True)
knowledge_id = self.data.get('id')
model_id = instance.get("model_id")
prompt = instance.get("prompt")
state_list = instance.get('state_list')
ListenerManagement.update_status(
QuerySet(Document).filter(knowledge_id=knowledge_id),
TaskType.GENERATE_PROBLEM,
State.PENDING
)
ListenerManagement.update_status(
QuerySet(Paragraph).annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value, 1),
).filter(
task_type_status__in=state_list, knowledge_id=knowledge_id
).values('id'),
TaskType.GENERATE_PROBLEM,
State.PENDING
)
ListenerManagement.get_aggregation_document_status_by_knowledge_id(knowledge_id)()
try:
generate_related_by_knowledge_id.delay(knowledge_id, model_id, prompt, state_list)
except AlreadyQueued as e:
raise AppApiException(500, _('Failed to send the vectorization task, please try again later!'))
def list_application(self, with_valid=True): def list_application(self, with_valid=True):
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
@ -340,3 +380,80 @@ class KnowledgeSerializer(serializers.Serializer):
knowledge.save() knowledge.save()
sync_web_knowledge.delay(str(knowledge_id), instance.get('source_url'), instance.get('selector')) sync_web_knowledge.delay(str(knowledge_id), instance.get('source_url'), instance.get('selector'))
return {**KnowledgeModelSerializer(knowledge).data, 'document_list': []} return {**KnowledgeModelSerializer(knowledge).data, 'document_list': []}
class SyncWeb(serializers.Serializer):
id = serializers.CharField(required=True, label=_('knowledge id'))
user_id = serializers.UUIDField(required=False, label=_('user id'))
sync_type = serializers.CharField(required=True, label=_('sync type'), validators=[
validators.RegexValidator(regex=re.compile("^replace|complete$"),
message=_('The synchronization type only supports:replace|complete'), code=500)])
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
first = QuerySet(Knowledge).filter(id=self.data.get("id")).first()
if first is None:
raise AppApiException(300, _('id does not exist'))
if first.type != KnowledgeType.WEB:
raise AppApiException(500, _('Synchronization is only supported for web site types'))
def sync(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
sync_type = self.data.get('sync_type')
knowledge_id = self.data.get('id')
knowledge = QuerySet(Knowledge).get(id=knowledge_id)
self.__getattribute__(sync_type + '_sync')(knowledge)
return True
@staticmethod
def get_sync_handler(knowledge):
def handler(child_link: ChildLink, response: Fork.Response):
if response.status == 200:
try:
document_name = child_link.tag.text if child_link.tag is not None and len(
child_link.tag.text.strip()) > 0 else child_link.url
paragraphs = get_split_model('web.md').parse(response.content)
print(child_link.url.strip())
first = QuerySet(Document).filter(
meta__source_url=child_link.url.strip(),
knowledge=knowledge
).first()
if first is not None:
# 如果存在,使用文档同步
DocumentSerializers.Sync(data={'document_id': first.id}).sync()
else:
# 插入
DocumentSerializers.Create(data={'knowledge_id': knowledge.id}).save(
{'name': document_name, 'paragraphs': paragraphs,
'meta': {'source_url': child_link.url.strip(),
'selector': knowledge.meta.get('selector')},
'type': Knowledge.WEB}, with_valid=True)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
return handler
def replace_sync(self, knowledge):
"""
替换同步
:return:
"""
url = knowledge.meta.get('source_url')
selector = knowledge.meta.get('selector') if 'selector' in knowledge.meta else None
sync_replace_web_knowledge.delay(str(knowledge.id), url, selector)
def complete_sync(self, knowledge):
"""
完整同步 删掉当前数据集下所有的文档,再进行同步
:return:
"""
# 删除关联问题
QuerySet(ProblemParagraphMapping).filter(knowledge=knowledge).delete()
# 删除文档
QuerySet(Document).filter(knowledge=knowledge).delete()
# 删除段落
QuerySet(Paragraph).filter(knowledge=knowledge).delete()
# 删除向量
delete_embedding_by_knowledge(self.data.get('id'))
# 同步
self.replace_sync(knowledge)

View File

@ -0,0 +1,139 @@
import logging
import traceback
from celery_once import QueueOnce
from django.db.models import QuerySet
from django.db.models.functions import Reverse, Substr
from django.utils.translation import gettext_lazy as _
from langchain_core.messages import HumanMessage
from common.config.embedding_config import ModelManage
from common.event import ListenerManagement
from common.utils.page_utils import page, page_desc
from knowledge.models import Paragraph, Document, Status, TaskType, State
from knowledge.task.handler import save_problem
from models_provider.models import Model
from models_provider.tools import get_model
from ops import celery_app
max_kb_error = logging.getLogger("max_kb_error")
max_kb = logging.getLogger("max_kb")
def get_llm_model(model_id):
model = QuerySet(Model).filter(id=model_id).first()
return ModelManage.get_model(model_id, lambda _id: get_model(model))
def generate_problem_by_paragraph(paragraph, llm_model, prompt):
try:
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
State.STARTED)
res = llm_model.invoke(
[HumanMessage(content=prompt.replace('{data}', paragraph.content).replace('{title}', paragraph.title))])
if (res.content is None) or (len(res.content) == 0):
return
problems = res.content.split('\n')
for problem in problems:
save_problem(paragraph.knowledge_id, paragraph.document_id, paragraph.id, problem)
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
State.SUCCESS)
except Exception as e:
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
State.FAILURE)
def get_generate_problem(llm_model, prompt, post_apply=lambda: None, is_the_task_interrupted=lambda: False):
def generate_problem(paragraph_list):
for paragraph in paragraph_list:
if is_the_task_interrupted():
return
generate_problem_by_paragraph(paragraph, llm_model, prompt)
post_apply()
return generate_problem
def get_is_the_task_interrupted(document_id):
def is_the_task_interrupted():
document = QuerySet(Document).filter(id=document_id).first()
if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE:
return True
return False
return is_the_task_interrupted
@celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']},
name='celery:generate_related_by_knowledge')
def generate_related_by_knowledge_id(knowledge_id, model_id, prompt, state_list=None):
document_list = QuerySet(Document).filter(knowledge_id=knowledge_id)
for document in document_list:
try:
generate_related_by_document_id.delay(document.id, model_id, prompt, state_list)
except Exception as e:
pass
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']},
name='celery:generate_related_by_document')
def generate_related_by_document_id(document_id, model_id, prompt, state_list=None):
if state_list is None:
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
State.REVOKE.value,
State.REVOKED.value, State.IGNORED.value]
try:
is_the_task_interrupted = get_is_the_task_interrupted(document_id)
if is_the_task_interrupted():
return
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
TaskType.GENERATE_PROBLEM,
State.STARTED)
llm_model = get_llm_model(model_id)
# 生成问题函数
generate_problem = get_generate_problem(llm_model, prompt,
ListenerManagement.get_aggregation_document_status(
document_id), is_the_task_interrupted)
query_set = QuerySet(Paragraph).annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
1),
).filter(task_type_status__in=state_list, document_id=document_id)
page_desc(query_set, 10, generate_problem, is_the_task_interrupted)
except Exception as e:
max_kb_error.error(f'根据文档生成问题:{document_id}出现错误{str(e)}{traceback.format_exc()}')
max_kb_error.error(_('Generate issue based on document: {document_id} error {error}{traceback}').format(
document_id=document_id, error=str(e), traceback=traceback.format_exc()))
finally:
ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM)
max_kb.info(_('End--->Generate problem: {document_id}').format(document_id=document_id))
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']},
name='celery:generate_related_by_paragraph_list')
def generate_related_by_paragraph_id_list(document_id, paragraph_id_list, model_id, prompt):
try:
is_the_task_interrupted = get_is_the_task_interrupted(document_id)
if is_the_task_interrupted():
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
TaskType.GENERATE_PROBLEM,
State.REVOKED)
return
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
TaskType.GENERATE_PROBLEM,
State.STARTED)
llm_model = get_llm_model(model_id)
# 生成问题函数
generate_problem = get_generate_problem(llm_model, prompt, ListenerManagement.get_aggregation_document_status(
document_id))
def is_the_task_interrupted():
document = QuerySet(Document).filter(id=document_id).first()
if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE:
return True
return False
page(QuerySet(Paragraph).filter(id__in=paragraph_id_list), 10, generate_problem, is_the_task_interrupted)
finally:
ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM)

View File

@ -8,6 +8,8 @@ urlpatterns = [
path('workspace/<str:workspace_id>/knowledge/base', views.KnowledgeBaseView.as_view()), path('workspace/<str:workspace_id>/knowledge/base', views.KnowledgeBaseView.as_view()),
path('workspace/<str:workspace_id>/knowledge/web', views.KnowledgeWebView.as_view()), path('workspace/<str:workspace_id>/knowledge/web', views.KnowledgeWebView.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>', views.KnowledgeView.Operate.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>', views.KnowledgeView.Operate.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/sync', views.KnowledgeView.SyncWeb.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/generate_related', views.KnowledgeView.GenerateRelated.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document', views.DocumentView.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document', views.DocumentView.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split_pattern', views.DocumentView.SplitPattern.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split_pattern', views.DocumentView.SplitPattern.as_view()),

View File

@ -8,7 +8,7 @@ from common.auth.authentication import has_permissions
from common.constants.permission_constants import PermissionConstants from common.constants.permission_constants import PermissionConstants
from common.result import result from common.result import result
from knowledge.api.knowledge import KnowledgeBaseCreateAPI, KnowledgeWebCreateAPI, KnowledgeTreeReadAPI, \ from knowledge.api.knowledge import KnowledgeBaseCreateAPI, KnowledgeWebCreateAPI, KnowledgeTreeReadAPI, \
KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI, SyncWebAPI, GenerateRelatedAPI
from knowledge.serializers.knowledge import KnowledgeSerializer from knowledge.serializers.knowledge import KnowledgeSerializer
@ -110,6 +110,46 @@ class KnowledgeView(APIView):
} }
).page(current_page, page_size)) ).page(current_page, page_size))
class SyncWeb(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
methods=['PUT'],
summary=_("Synchronize the knowledge base of the website"),
description=_("Synchronize the knowledge base of the website"),
operation_id=_("Synchronize the knowledge base of the website"),
parameters=SyncWebAPI.get_parameters(),
responses=SyncWebAPI.get_response(),
tags=[_('Knowledge Base')]
)
@has_permissions(PermissionConstants.KNOWLEDGE_EDIT.get_workspace_permission())
def put(self, request: Request, workspace_id: str, knowledge_id: str):
return result.success(KnowledgeSerializer.SyncWeb(
data={
'workspace_id': workspace_id,
'sync_type': request.query_params.get('sync_type'),
'id': knowledge_id,
'user_id': str(request.user.id)
}
).sync())
class GenerateRelated(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
methods=['PUT'],
summary=_('Generate related'),
description=_('Generate related'),
operation_id=_('Generate related'),
parameters=GenerateRelatedAPI.get_parameters(),
request=GenerateRelatedAPI.get_request(),
tags=[_('Knowledge Base')]
)
def put(self, request: Request, workspace_id: str, knowledge_id: str):
return result.success(KnowledgeSerializer.Operate(
data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id, 'user_id': request.user.id}
).generate_related(request.data))
class KnowledgeBaseView(APIView): class KnowledgeBaseView(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]