diff --git a/apps/knowledge/api/knowledge.py b/apps/knowledge/api/knowledge.py index 554485c2..077595a4 100644 --- a/apps/knowledge/api/knowledge.py +++ b/apps/knowledge/api/knowledge.py @@ -2,7 +2,8 @@ from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter from common.mixins.api_mixin import APIMixin -from common.result import ResultSerializer +from common.result import ResultSerializer, DefaultResultSerializer +from knowledge.serializers.common import GenerateRelatedSerializer from knowledge.serializers.knowledge import KnowledgeBaseCreateRequest, KnowledgeModelSerializer, KnowledgeEditRequest, \ KnowledgeWebCreateRequest @@ -206,3 +207,34 @@ class KnowledgePageAPI(KnowledgeReadAPI): required=False, ), ] + + +class SyncWebAPI(APIMixin): + @staticmethod + def get_parameters(): + return [ + OpenApiParameter( + name="workspace_id", + description="工作空间id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + OpenApiParameter( + name="knowledge_id", + description="知识库id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + ] + + @staticmethod + def get_response(): + return DefaultResultSerializer + + +class GenerateRelatedAPI(SyncWebAPI): + @staticmethod + def get_request(): + return GenerateRelatedSerializer diff --git a/apps/knowledge/serializers/knowledge.py b/apps/knowledge/serializers/knowledge.py index a0c659ea..630ba1eb 100644 --- a/apps/knowledge/serializers/knowledge.py +++ b/apps/knowledge/serializers/knowledge.py @@ -1,23 +1,34 @@ +import logging import os +import re +import traceback from functools import reduce from typing import Dict import uuid_utils.compat as uuid +from celery_once import AlreadyQueued +from django.core import validators from django.db import transaction, models from django.db.models import QuerySet +from django.db.models.functions import Reverse, Substr from django.utils.translation import gettext_lazy as _ from rest_framework import serializers from common.db.search import native_search, get_dynamics_model, native_page_search from common.db.sql_execute import select_list +from common.event import ListenerManagement from common.exception.app_exception import AppApiException from common.utils.common import valid_license, post, get_file_content +from common.utils.fork import Fork, ChildLink +from common.utils.split_model import get_split_model from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType, Document, Paragraph, Problem, \ - ProblemParagraphMapping, ApplicationKnowledgeMapping -from knowledge.serializers.common import ProblemParagraphManage, get_embedding_model_id_by_knowledge_id, MetaSerializer + ProblemParagraphMapping, ApplicationKnowledgeMapping, TaskType, State +from knowledge.serializers.common import ProblemParagraphManage, get_embedding_model_id_by_knowledge_id, MetaSerializer, \ + GenerateRelatedSerializer from knowledge.serializers.document import DocumentSerializers from knowledge.task.embedding import embedding_by_knowledge, delete_embedding_by_knowledge -from knowledge.task.sync import sync_web_knowledge +from knowledge.task.generate import generate_related_by_knowledge_id +from knowledge.task.sync import sync_web_knowledge, sync_replace_web_knowledge from maxkb.conf import PROJECT_DIR @@ -137,6 +148,35 @@ class KnowledgeSerializer(serializers.Serializer): workspace_id = serializers.CharField(required=True, label=_('workspace id')) knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) + def generate_related(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + GenerateRelatedSerializer(data=instance).is_valid(raise_exception=True) + knowledge_id = self.data.get('id') + model_id = instance.get("model_id") + prompt = instance.get("prompt") + state_list = instance.get('state_list') + ListenerManagement.update_status( + QuerySet(Document).filter(knowledge_id=knowledge_id), + TaskType.GENERATE_PROBLEM, + State.PENDING + ) + ListenerManagement.update_status( + QuerySet(Paragraph).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value, 1), + ).filter( + task_type_status__in=state_list, knowledge_id=knowledge_id + ).values('id'), + TaskType.GENERATE_PROBLEM, + State.PENDING + ) + ListenerManagement.get_aggregation_document_status_by_knowledge_id(knowledge_id)() + try: + generate_related_by_knowledge_id.delay(knowledge_id, model_id, prompt, state_list) + except AlreadyQueued as e: + raise AppApiException(500, _('Failed to send the vectorization task, please try again later!')) + def list_application(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) @@ -340,3 +380,80 @@ class KnowledgeSerializer(serializers.Serializer): knowledge.save() sync_web_knowledge.delay(str(knowledge_id), instance.get('source_url'), instance.get('selector')) return {**KnowledgeModelSerializer(knowledge).data, 'document_list': []} + + class SyncWeb(serializers.Serializer): + id = serializers.CharField(required=True, label=_('knowledge id')) + user_id = serializers.UUIDField(required=False, label=_('user id')) + sync_type = serializers.CharField(required=True, label=_('sync type'), validators=[ + validators.RegexValidator(regex=re.compile("^replace|complete$"), + message=_('The synchronization type only supports:replace|complete'), code=500)]) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + first = QuerySet(Knowledge).filter(id=self.data.get("id")).first() + if first is None: + raise AppApiException(300, _('id does not exist')) + if first.type != KnowledgeType.WEB: + raise AppApiException(500, _('Synchronization is only supported for web site types')) + + def sync(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + sync_type = self.data.get('sync_type') + knowledge_id = self.data.get('id') + knowledge = QuerySet(Knowledge).get(id=knowledge_id) + self.__getattribute__(sync_type + '_sync')(knowledge) + return True + + @staticmethod + def get_sync_handler(knowledge): + def handler(child_link: ChildLink, response: Fork.Response): + if response.status == 200: + try: + document_name = child_link.tag.text if child_link.tag is not None and len( + child_link.tag.text.strip()) > 0 else child_link.url + paragraphs = get_split_model('web.md').parse(response.content) + print(child_link.url.strip()) + first = QuerySet(Document).filter( + meta__source_url=child_link.url.strip(), + knowledge=knowledge + ).first() + if first is not None: + # 如果存在,使用文档同步 + DocumentSerializers.Sync(data={'document_id': first.id}).sync() + else: + # 插入 + DocumentSerializers.Create(data={'knowledge_id': knowledge.id}).save( + {'name': document_name, 'paragraphs': paragraphs, + 'meta': {'source_url': child_link.url.strip(), + 'selector': knowledge.meta.get('selector')}, + 'type': Knowledge.WEB}, with_valid=True) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + + return handler + + def replace_sync(self, knowledge): + """ + 替换同步 + :return: + """ + url = knowledge.meta.get('source_url') + selector = knowledge.meta.get('selector') if 'selector' in knowledge.meta else None + sync_replace_web_knowledge.delay(str(knowledge.id), url, selector) + + def complete_sync(self, knowledge): + """ + 完整同步 删掉当前数据集下所有的文档,再进行同步 + :return: + """ + # 删除关联问题 + QuerySet(ProblemParagraphMapping).filter(knowledge=knowledge).delete() + # 删除文档 + QuerySet(Document).filter(knowledge=knowledge).delete() + # 删除段落 + QuerySet(Paragraph).filter(knowledge=knowledge).delete() + # 删除向量 + delete_embedding_by_knowledge(self.data.get('id')) + # 同步 + self.replace_sync(knowledge) diff --git a/apps/knowledge/task/generate.py b/apps/knowledge/task/generate.py new file mode 100644 index 00000000..b2b9bdd2 --- /dev/null +++ b/apps/knowledge/task/generate.py @@ -0,0 +1,139 @@ +import logging +import traceback + +from celery_once import QueueOnce +from django.db.models import QuerySet +from django.db.models.functions import Reverse, Substr +from django.utils.translation import gettext_lazy as _ +from langchain_core.messages import HumanMessage + +from common.config.embedding_config import ModelManage +from common.event import ListenerManagement +from common.utils.page_utils import page, page_desc +from knowledge.models import Paragraph, Document, Status, TaskType, State +from knowledge.task.handler import save_problem +from models_provider.models import Model +from models_provider.tools import get_model +from ops import celery_app + +max_kb_error = logging.getLogger("max_kb_error") +max_kb = logging.getLogger("max_kb") + + +def get_llm_model(model_id): + model = QuerySet(Model).filter(id=model_id).first() + return ModelManage.get_model(model_id, lambda _id: get_model(model)) + + +def generate_problem_by_paragraph(paragraph, llm_model, prompt): + try: + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM, + State.STARTED) + res = llm_model.invoke( + [HumanMessage(content=prompt.replace('{data}', paragraph.content).replace('{title}', paragraph.title))]) + if (res.content is None) or (len(res.content) == 0): + return + problems = res.content.split('\n') + for problem in problems: + save_problem(paragraph.knowledge_id, paragraph.document_id, paragraph.id, problem) + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM, + State.SUCCESS) + except Exception as e: + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM, + State.FAILURE) + + +def get_generate_problem(llm_model, prompt, post_apply=lambda: None, is_the_task_interrupted=lambda: False): + def generate_problem(paragraph_list): + for paragraph in paragraph_list: + if is_the_task_interrupted(): + return + generate_problem_by_paragraph(paragraph, llm_model, prompt) + post_apply() + + return generate_problem + + +def get_is_the_task_interrupted(document_id): + def is_the_task_interrupted(): + document = QuerySet(Document).filter(id=document_id).first() + if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE: + return True + return False + + return is_the_task_interrupted + + +@celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, + name='celery:generate_related_by_knowledge') +def generate_related_by_knowledge_id(knowledge_id, model_id, prompt, state_list=None): + document_list = QuerySet(Document).filter(knowledge_id=knowledge_id) + for document in document_list: + try: + generate_related_by_document_id.delay(document.id, model_id, prompt, state_list) + except Exception as e: + pass + + +@celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, + name='celery:generate_related_by_document') +def generate_related_by_document_id(document_id, model_id, prompt, state_list=None): + if state_list is None: + state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value, + State.REVOKE.value, + State.REVOKED.value, State.IGNORED.value] + try: + is_the_task_interrupted = get_is_the_task_interrupted(document_id) + if is_the_task_interrupted(): + return + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.GENERATE_PROBLEM, + State.STARTED) + llm_model = get_llm_model(model_id) + + # 生成问题函数 + generate_problem = get_generate_problem(llm_model, prompt, + ListenerManagement.get_aggregation_document_status( + document_id), is_the_task_interrupted) + query_set = QuerySet(Paragraph).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value, + 1), + ).filter(task_type_status__in=state_list, document_id=document_id) + page_desc(query_set, 10, generate_problem, is_the_task_interrupted) + except Exception as e: + max_kb_error.error(f'根据文档生成问题:{document_id}出现错误{str(e)}{traceback.format_exc()}') + max_kb_error.error(_('Generate issue based on document: {document_id} error {error}{traceback}').format( + document_id=document_id, error=str(e), traceback=traceback.format_exc())) + finally: + ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM) + max_kb.info(_('End--->Generate problem: {document_id}').format(document_id=document_id)) + + +@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, + name='celery:generate_related_by_paragraph_list') +def generate_related_by_paragraph_id_list(document_id, paragraph_id_list, model_id, prompt): + try: + is_the_task_interrupted = get_is_the_task_interrupted(document_id) + if is_the_task_interrupted(): + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.GENERATE_PROBLEM, + State.REVOKED) + return + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.GENERATE_PROBLEM, + State.STARTED) + llm_model = get_llm_model(model_id) + # 生成问题函数 + generate_problem = get_generate_problem(llm_model, prompt, ListenerManagement.get_aggregation_document_status( + document_id)) + + def is_the_task_interrupted(): + document = QuerySet(Document).filter(id=document_id).first() + if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE: + return True + return False + + page(QuerySet(Paragraph).filter(id__in=paragraph_id_list), 10, generate_problem, is_the_task_interrupted) + finally: + ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM) diff --git a/apps/knowledge/urls.py b/apps/knowledge/urls.py index 04d66067..09e87177 100644 --- a/apps/knowledge/urls.py +++ b/apps/knowledge/urls.py @@ -8,6 +8,8 @@ urlpatterns = [ path('workspace//knowledge/base', views.KnowledgeBaseView.as_view()), path('workspace//knowledge/web', views.KnowledgeWebView.as_view()), path('workspace//knowledge/', views.KnowledgeView.Operate.as_view()), + path('workspace//knowledge//sync', views.KnowledgeView.SyncWeb.as_view()), + path('workspace//knowledge//generate_related', views.KnowledgeView.GenerateRelated.as_view()), path('workspace//knowledge//document', views.DocumentView.as_view()), path('workspace//knowledge//document/split', views.DocumentView.Split.as_view()), path('workspace//knowledge//document/split_pattern', views.DocumentView.SplitPattern.as_view()), diff --git a/apps/knowledge/views/knowledge.py b/apps/knowledge/views/knowledge.py index ec421b34..f5347351 100644 --- a/apps/knowledge/views/knowledge.py +++ b/apps/knowledge/views/knowledge.py @@ -8,7 +8,7 @@ from common.auth.authentication import has_permissions from common.constants.permission_constants import PermissionConstants from common.result import result from knowledge.api.knowledge import KnowledgeBaseCreateAPI, KnowledgeWebCreateAPI, KnowledgeTreeReadAPI, \ - KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI + KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI, SyncWebAPI, GenerateRelatedAPI from knowledge.serializers.knowledge import KnowledgeSerializer @@ -110,6 +110,46 @@ class KnowledgeView(APIView): } ).page(current_page, page_size)) + class SyncWeb(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['PUT'], + summary=_("Synchronize the knowledge base of the website"), + description=_("Synchronize the knowledge base of the website"), + operation_id=_("Synchronize the knowledge base of the website"), + parameters=SyncWebAPI.get_parameters(), + responses=SyncWebAPI.get_response(), + tags=[_('Knowledge Base')] + ) + @has_permissions(PermissionConstants.KNOWLEDGE_EDIT.get_workspace_permission()) + def put(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success(KnowledgeSerializer.SyncWeb( + data={ + 'workspace_id': workspace_id, + 'sync_type': request.query_params.get('sync_type'), + 'id': knowledge_id, + 'user_id': str(request.user.id) + } + ).sync()) + + class GenerateRelated(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['PUT'], + summary=_('Generate related'), + description=_('Generate related'), + operation_id=_('Generate related'), + parameters=GenerateRelatedAPI.get_parameters(), + request=GenerateRelatedAPI.get_request(), + tags=[_('Knowledge Base')] + ) + def put(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success(KnowledgeSerializer.Operate( + data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id, 'user_id': request.user.id} + ).generate_related(request.data)) + class KnowledgeBaseView(APIView): authentication_classes = [TokenAuth]