diff --git a/apps/common/db/search.py b/apps/common/db/search.py index 76366715..bef42a14 100644 --- a/apps/common/db/search.py +++ b/apps/common/db/search.py @@ -12,7 +12,7 @@ from django.db import DEFAULT_DB_ALIAS, models, connections from django.db.models import QuerySet from common.db.compiler import AppSQLCompiler -from common.db.sql_execute import select_one, select_list +from common.db.sql_execute import select_one, select_list, update_execute from common.response.result import Page @@ -109,6 +109,24 @@ def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str, return select_list(exec_sql, exec_params) +def native_update(queryset: QuerySet | Dict[str, QuerySet], select_string: str, + field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None, + with_table_name=False): + """ + 复杂查询 + :param with_table_name: 生成sql是否包含表名 + :param queryset: 查询条件构造器 + :param select_string: 查询前缀 不包括 where limit 等信息 + :param field_replace_dict: 需要替换的字段 + :return: 查询结果 + """ + if isinstance(queryset, Dict): + exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name) + else: + exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name) + return update_execute(exec_sql, exec_params) + + def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler): """ 分页查询 diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 40ac4884..a98b29bf 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -9,24 +9,29 @@ import datetime import logging import os +import threading import traceback from typing import List import django.db.models +from django.db import models from django.db.models import QuerySet +from django.db.models.functions import Substr, Reverse from langchain_core.embeddings import Embeddings from common.config.embedding_config import VectorStore -from common.db.search import native_search, get_dynamics_model -from common.event.common import embedding_poxy +from common.db.search import native_search, get_dynamics_model, native_update +from common.db.sql_execute import sql_execute, update_execute from common.util.file_util import get_file_content from common.util.lock import try_lock, un_lock -from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping +from common.util.page_utils import page +from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State from embedding.models import SourceType, SearchMode from smartdoc.conf import PROJECT_DIR max_kb_error = logging.getLogger(__file__) max_kb = logging.getLogger(__file__) +lock = threading.Lock() class SyncWebDatasetArgs: @@ -114,7 +119,8 @@ class ListenerManagement: @param embedding_model: 向量模型 """ max_kb.info(f"开始--->向量化段落:{paragraph_id}") - status = Status.success + # 更新到开始状态 + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, State.STARTED) try: data_list = native_search( {'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter( @@ -125,16 +131,22 @@ class ListenerManagement: # 删除段落 VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id) - def is_save_function(): - return QuerySet(Paragraph).filter(id=paragraph_id).exists() + def is_the_task_interrupted(): + _paragraph = QuerySet(Paragraph).filter(id=paragraph_id).first() + if _paragraph is None or Status(_paragraph.status)[TaskType.EMBEDDING] == State.REVOKE: + return True + return False # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_the_task_interrupted) + # 更新到开始状态 + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, + State.SUCCESS) except Exception as e: max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}') - status = Status.error + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, + State.FAILURE) finally: - QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': status}) max_kb.info(f'结束--->向量化段落:{paragraph_id}') @staticmethod @@ -142,6 +154,89 @@ class ListenerManagement: # 批量向量化 VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True) + @staticmethod + def get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, post_apply=lambda: None): + def embedding_paragraph_apply(paragraph_list): + for paragraph in paragraph_list: + if is_the_task_interrupted(): + break + ListenerManagement.embedding_by_paragraph(str(paragraph.get('id')), embedding_model) + post_apply() + + return embedding_paragraph_apply + + @staticmethod + def get_aggregation_document_status(document_id): + def aggregation_document_status(): + sql = get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql')) + native_update({'document_custom_sql': QuerySet(Document).filter(id=document_id)}, sql, with_table_name=True) + + return aggregation_document_status + + @staticmethod + def get_aggregation_document_status_by_dataset_id(dataset_id): + def aggregation_document_status(): + sql = get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql')) + native_update({'document_custom_sql': QuerySet(Document).filter(dataset_id=dataset_id)}, sql) + + return aggregation_document_status + + @staticmethod + def get_aggregation_document_status_by_query_set(queryset): + def aggregation_document_status(): + sql = get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql')) + native_update({'document_custom_sql': queryset}, sql) + + return aggregation_document_status + + @staticmethod + def post_update_document_status(document_id, task_type: TaskType): + _document = QuerySet(Document).filter(id=document_id).first() + + status = Status(_document.status) + if status[task_type] == State.REVOKE: + status[task_type] = State.REVOKED + else: + status[task_type] = State.SUCCESS + for item in _document.status_meta.get('aggs', []): + agg_status = item.get('status') + agg_count = item.get('count') + if Status(agg_status)[task_type] == State.FAILURE and agg_count > 0: + status[task_type] = State.FAILURE + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), task_type, status[task_type]) + + ListenerManagement.update_status(QuerySet(Paragraph).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', task_type.value, + task_type.value), + ).filter(task_type_status=State.REVOKE.value).filter(document_id=document_id).values('id'), + task_type, + State.REVOKED) + + @staticmethod + def update_status(query_set: QuerySet, taskType: TaskType, state: State): + exec_sql = get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_paragraph_status.sql')) + bit_number = len(TaskType) + up_index = taskType.value - 1 + next_index = taskType.value + 1 + current_index = taskType.value + status_number = state.value + params_dict = {'${bit_number}': bit_number, '${up_index}': up_index, + '${status_number}': status_number, '${next_index}': next_index, + '${table_name}': query_set.model._meta.db_table, '${current_index}': current_index} + for key in params_dict: + _value_ = params_dict[key] + exec_sql = exec_sql.replace(key, str(_value_)) + lock.acquire() + try: + native_update(query_set, exec_sql) + finally: + lock.release() + @staticmethod def embedding_by_document(document_id, embedding_model: Embeddings): """ @@ -153,33 +248,29 @@ class ListenerManagement: if not try_lock('embedding' + str(document_id)): return max_kb.info(f"开始--->向量化文档:{document_id}") - QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding}) - QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.embedding}) - status = Status.success + # 批量修改状态为PADDING + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.STARTED) try: - data_list = native_search( - {'problem': QuerySet( - get_dynamics_model({'paragraph.document_id': django.db.models.CharField()})).filter( - **{'paragraph.document_id': document_id}), - 'paragraph': QuerySet(Paragraph).filter(document_id=document_id)}, - select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) # 删除文档向量数据 VectorStore.get_embedding_vector().delete_by_document_id(document_id) - def is_save_function(): - return QuerySet(Document).filter(id=document_id).exists() + def is_the_task_interrupted(): + document = QuerySet(Document).filter(id=document_id).first() + if document is None or Status(document.status)[TaskType.EMBEDDING] == State.REVOKE: + return True + return False - # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) + # 根据段落进行向量化处理 + page(QuerySet(Paragraph).filter(document_id=document_id).values('id'), 5, + ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, + ListenerManagement.get_aggregation_document_status( + document_id)), + is_the_task_interrupted) except Exception as e: max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}') - status = Status.error finally: - # 修改状态 - QuerySet(Document).filter(id=document_id).update( - **{'status': status, 'update_time': datetime.datetime.now()}) - QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status}) + ListenerManagement.post_update_document_status(document_id, TaskType.EMBEDDING) + ListenerManagement.get_aggregation_document_status(document_id)() max_kb.info(f"结束--->向量化文档:{document_id}") un_lock('embedding' + str(document_id)) diff --git a/apps/common/util/page_utils.py b/apps/common/util/page_utils.py new file mode 100644 index 00000000..7fc176b6 --- /dev/null +++ b/apps/common/util/page_utils.py @@ -0,0 +1,27 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: page_utils.py + @date:2024/11/21 10:32 + @desc: +""" +from math import ceil + + +def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False): + """ + + @param query_set: 查询query_set + @param page_size: 每次查询大小 + @param handler: 数据处理器 + @param is_the_task_interrupted: 任务是否被中断 + @return: + """ + count = query_set.count() + for i in range(0, ceil(count / page_size)): + if is_the_task_interrupted(): + return + offset = i * page_size + paragraph_list = query_set[offset: offset + page_size] + handler(paragraph_list) diff --git a/apps/dataset/migrations/0011_document_status_meta_paragraph_status_meta_and_more.py b/apps/dataset/migrations/0011_document_status_meta_paragraph_status_meta_and_more.py new file mode 100644 index 00000000..c64a4db2 --- /dev/null +++ b/apps/dataset/migrations/0011_document_status_meta_paragraph_status_meta_and_more.py @@ -0,0 +1,47 @@ +# Generated by Django 4.2.15 on 2024-11-22 14:44 +from django.db.models import QuerySet + +from django.db import migrations, models + +import dataset +from common.event import ListenerManagement +from dataset.models import State, TaskType + + +def updateDocumentStatus(apps, schema_editor): + ParagraphModel = apps.get_model('dataset', 'Paragraph') + DocumentModel = apps.get_model('dataset', 'Document') + success_list = QuerySet(DocumentModel).filter(status='2') + ListenerManagement.update_status(QuerySet(ParagraphModel).filter(document_id__in=[d.id for d in success_list]), + TaskType.EMBEDDING, State.SUCCESS) + ListenerManagement.get_aggregation_document_status_by_query_set(QuerySet(DocumentModel))() + + +class Migration(migrations.Migration): + dependencies = [ + ('dataset', '0010_file_meta'), + ] + + operations = [ + migrations.AddField( + model_name='document', + name='status_meta', + field=models.JSONField(default=dataset.models.data_set.default_status_meta, verbose_name='状态统计数据'), + ), + migrations.AddField( + model_name='paragraph', + name='status_meta', + field=models.JSONField(default=dataset.models.data_set.default_status_meta, verbose_name='状态数据'), + ), + migrations.AlterField( + model_name='document', + name='status', + field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'), + ), + migrations.AlterField( + model_name='paragraph', + name='status', + field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'), + ), + migrations.RunPython(updateDocumentStatus) + ] diff --git a/apps/dataset/models/data_set.py b/apps/dataset/models/data_set.py index cd91b6d1..4f46eda2 100644 --- a/apps/dataset/models/data_set.py +++ b/apps/dataset/models/data_set.py @@ -7,6 +7,7 @@ @desc: 数据集 """ import uuid +from enum import Enum from django.db import models from django.db.models.signals import pre_delete @@ -18,13 +19,62 @@ from setting.models import Model from users.models import User -class Status(models.TextChoices): - """订单类型""" - embedding = 0, '导入中' - success = 1, '已完成' - error = 2, '导入失败' - queue_up = 3, '排队中' - generating = 4, '生成问题中' +class TaskType(Enum): + # 向量 + EMBEDDING = 1 + # 生成问题 + GENERATE_PROBLEM = 2 + # 同步 + SYNC = 3 + + +class State(Enum): + # 等待 + PENDING = '0' + # 执行中 + STARTED = '1' + # 成功 + SUCCESS = '2' + # 失败 + FAILURE = '3' + # 取消任务 + REVOKE = '4' + # 取消成功 + REVOKED = '5' + # 忽略 + IGNORED = 'n' + + +class Status: + type_cls = TaskType + state_cls = State + + def __init__(self, status: str = None): + self.task_status = {} + status_list = list(status[::-1] if status is not None else '') + for _type in self.type_cls: + index = _type.value - 1 + _state = self.state_cls(status_list[index] if len(status_list) > index else 'n') + self.task_status[_type] = _state + + @staticmethod + def of(status: str): + return Status(status) + + def __str__(self): + result = [] + for _type in sorted(self.type_cls, key=lambda item: item.value, reverse=True): + result.insert(len(self.type_cls) - _type.value, self.task_status[_type].value) + return ''.join(result) + + def __setitem__(self, key, value): + self.task_status[key] = value + + def __getitem__(self, item): + return self.task_status[item] + + def update_status(self, task_type: TaskType, state: State): + self.task_status[task_type] = state class Type(models.TextChoices): @@ -42,6 +92,10 @@ def default_model(): return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab') +def default_status_meta(): + return {"state_time": {}} + + class DataSet(AppModelMixin): """ 数据集表 @@ -68,8 +122,8 @@ class Document(AppModelMixin): dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING) name = models.CharField(max_length=150, verbose_name="文档名称") char_length = models.IntegerField(verbose_name="文档字符数 冗余字段") - status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices, - default=Status.queue_up) + status = models.CharField(verbose_name='状态', max_length=20, default=Status('').__str__) + status_meta = models.JSONField(verbose_name="状态统计数据", default=default_status_meta) is_active = models.BooleanField(default=True) type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices, @@ -94,8 +148,8 @@ class Paragraph(AppModelMixin): dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING) content = models.CharField(max_length=102400, verbose_name="段落内容") title = models.CharField(max_length=256, verbose_name="标题", default="") - status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices, - default=Status.embedding) + status = models.CharField(verbose_name='状态', max_length=20, default=Status('').__str__) + status_meta = models.JSONField(verbose_name="状态数据", default=default_status_meta) hit_num = models.IntegerField(verbose_name="命中次数", default=0) is_active = models.BooleanField(default=True) @@ -145,7 +199,6 @@ class File(AppModelMixin): meta = models.JSONField(verbose_name="文件关联数据", default=dict) - class Meta: db_table = "file" @@ -161,7 +214,6 @@ class File(AppModelMixin): return result['data'] - @receiver(pre_delete, sender=File) def on_delete_file(sender, instance, **kwargs): select_one(f'SELECT lo_unlink({instance.loid})', []) diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 7598432c..85e73ee3 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -27,6 +27,7 @@ from application.models import ApplicationDatasetMapping from common.config.embedding_config import VectorStore from common.db.search import get_dynamics_model, native_page_search, native_search from common.db.sql_execute import select_list +from common.event import ListenerManagement from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.common import post, flat_map, valid_license @@ -34,7 +35,8 @@ from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import ChildLink, Fork from common.util.split_model import get_split_model -from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Status +from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Status, \ + TaskType, State from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \ get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer @@ -733,9 +735,13 @@ class DataSetSerializers(serializers.ModelSerializer): def re_embedding(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - - QuerySet(Document).filter(dataset_id=self.data.get('id')).update(**{'status': Status.queue_up}) - QuerySet(Paragraph).filter(dataset_id=self.data.get('id')).update(**{'status': Status.queue_up}) + ListenerManagement.update_status(QuerySet(Document).filter(dataset_id=self.data.get('id')), + TaskType.EMBEDDING, + State.PENDING) + ListenerManagement.update_status(QuerySet(Paragraph).filter(dataset_id=self.data.get('id')), + TaskType.EMBEDDING, + State.PENDING) + ListenerManagement.get_aggregation_document_status_by_dataset_id(self.data.get('id'))() embedding_model_id = get_embedding_model_id_by_dataset_id(self.data.get('id')) embedding_by_dataset.delay(self.data.get('id'), embedding_model_id) diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 61a6b02c..1ab74ead 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -19,6 +19,7 @@ from celery_once import AlreadyQueued from django.core import validators from django.db import transaction from django.db.models import QuerySet +from django.db.models.functions import Substr, Reverse from django.http import HttpResponse from drf_yasg import openapi from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE @@ -26,6 +27,7 @@ from rest_framework import serializers from xlwt import Utils from common.db.search import native_search, native_page_search +from common.event import ListenerManagement from common.event.common import work_thread_pool from common.exception.app_exception import AppApiException from common.handle.impl.doc_split_handle import DocSplitHandle @@ -44,7 +46,8 @@ from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import Fork from common.util.split_model import get_split_model -from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image +from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Image, \ + TaskType, State from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \ get_embedding_model_id_by_dataset_id from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer @@ -67,6 +70,19 @@ class FileBufferHandle: return self.buffer +class CancelInstanceSerializer(serializers.Serializer): + type = serializers.IntegerField(required=True, error_messages=ErrMessage.boolean( + "任务类型")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + _type = self.data.get('type') + try: + TaskType(_type) + except Exception as e: + raise AppApiException(500, '任务类型不支持') + + class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer): meta = serializers.DictField(required=False) name = serializers.CharField(required=False, max_length=128, min_length=1, @@ -278,7 +294,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): # 修改向量信息 if model_id: delete_embedding_by_paragraph_ids(pid_list) - QuerySet(Document).filter(id__in=document_id_list).update(status=Status.queue_up) + ListenerManagement.update_status(QuerySet(Document).filter(id__in=document_id_list), + TaskType.EMBEDDING, + State.PENDING) embedding_by_document_list.delay(document_id_list, model_id) else: update_embedding_dataset_id(pid_list, target_dataset_id) @@ -404,11 +422,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): self.is_valid(raise_exception=True) document_id = self.data.get('document_id') document = QuerySet(Document).filter(id=document_id).first() + state = State.SUCCESS if document.type != Type.web: return True try: - document.status = Status.queue_up - document.save() + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.SYNC, + State.PENDING) source_url = document.meta.get('source_url') selector_list = document.meta.get('selector').split( " ") if 'selector' in document.meta and document.meta.get('selector') is not None else [] @@ -442,13 +462,18 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): if with_embedding: embedding_model_id = get_embedding_model_id_by_dataset_id(document.dataset_id) embedding_by_document.delay(document_id, embedding_model_id) + else: - document.status = Status.error - document.save() + state = State.FAILURE except Exception as e: logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') - document.status = Status.error - document.save() + state = State.FAILURE + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.SYNC, + state) + ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id), + TaskType.SYNC, + state) return True class Operate(ApiMixin, serializers.Serializer): @@ -586,14 +611,35 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): if with_valid: self.is_valid(raise_exception=True) document_id = self.data.get("document_id") - QuerySet(Document).filter(id=document_id).update(**{'status': Status.queue_up}) - QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.queue_up}) + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, + State.PENDING) + ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id), TaskType.EMBEDDING, + State.PENDING) + ListenerManagement.get_aggregation_document_status(document_id)() embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id')) try: embedding_by_document.delay(document_id, embedding_model_id) except AlreadyQueued as e: raise AppApiException(500, "任务正在执行中,请勿重复下发") + def cancel(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + CancelInstanceSerializer(data=instance).is_valid() + document_id = self.data.get("document_id") + ListenerManagement.update_status(QuerySet(Paragraph).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value, + TaskType(instance.get('type')).value), + ).filter(task_type_status__in=[State.PENDING.value, State.STARTED.value]).filter( + document_id=document_id).values('id'), + TaskType(instance.get('type')), + State.REVOKE) + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType(instance.get('type')), + State.REVOKE) + + return True + @transaction.atomic def delete(self): document_id = self.data.get("document_id") @@ -955,15 +1001,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): self.is_valid(raise_exception=True) document_id_list = instance.get("id_list") with transaction.atomic(): - Document.objects.filter(id__in=document_id_list).update(status=Status.queue_up) - Paragraph.objects.filter(document_id__in=document_id_list).update(status=Status.queue_up) dataset_id = self.data.get('dataset_id') - embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=dataset_id) for document_id in document_id_list: try: - embedding_by_document.delay(document_id, embedding_model_id) + DocumentSerializers.Operate( + data={'dataset_id': dataset_id, 'document_id': document_id}).refresh() except AlreadyQueued as e: - raise AppApiException(500, "任务正在执行中,请勿重复下发") + pass class GenerateRelated(ApiMixin, serializers.Serializer): document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) @@ -978,7 +1022,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): if with_valid: self.is_valid(raise_exception=True) document_id = self.data.get('document_id') - QuerySet(Document).filter(id=document_id).update(status=Status.queue_up) + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.GENERATE_PROBLEM, + State.PENDING) + ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id), + TaskType.GENERATE_PROBLEM, + State.PENDING) + ListenerManagement.get_aggregation_document_status(document_id)() try: generate_related_by_document_id.delay(document_id, model_id, prompt) except AlreadyQueued as e: diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 6614d712..82aacc79 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -16,11 +16,12 @@ from drf_yasg import openapi from rest_framework import serializers from common.db.search import page_search +from common.event import ListenerManagement from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.common import post from common.util.field_message import ErrMessage -from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet +from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet, TaskType, State from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \ ProblemParagraphManage, get_embedding_model_id_by_dataset_id from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers @@ -722,7 +723,6 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): } ) - class BatchGenerateRelated(ApiMixin, serializers.Serializer): dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) @@ -734,10 +734,16 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): paragraph_id_list = instance.get("paragraph_id_list") model_id = instance.get("model_id") prompt = instance.get("prompt") + document_id = self.data.get('document_id') + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.GENERATE_PROBLEM, + State.PENDING) + ListenerManagement.update_status(QuerySet(Paragraph).filter(id__in=paragraph_id_list), + TaskType.GENERATE_PROBLEM, + State.PENDING) + ListenerManagement.get_aggregation_document_status(document_id)() try: - generate_related_by_paragraph_id_list.delay(paragraph_id_list, model_id, prompt) + generate_related_by_paragraph_id_list.delay(document_id, paragraph_id_list, model_id, + prompt) except AlreadyQueued as e: raise AppApiException(500, "任务正在执行中,请勿重复下发") - - - diff --git a/apps/dataset/sql/list_document.sql b/apps/dataset/sql/list_document.sql index 818d783c..c1e3a903 100644 --- a/apps/dataset/sql/list_document.sql +++ b/apps/dataset/sql/list_document.sql @@ -1,6 +1,7 @@ SELECT "document".* , to_json("document"."meta") as meta, + to_json("document"."status_meta") as status_meta, (SELECT "count"("id") FROM "paragraph" WHERE document_id="document"."id") as "paragraph_count" FROM "document" "document" diff --git a/apps/dataset/sql/update_document_status_meta.sql b/apps/dataset/sql/update_document_status_meta.sql new file mode 100644 index 00000000..6065931f --- /dev/null +++ b/apps/dataset/sql/update_document_status_meta.sql @@ -0,0 +1,25 @@ +UPDATE "document" "document" +SET status_meta = jsonb_set ( "document".status_meta, '{aggs}', tmp.status_meta ) +FROM + ( + SELECT COALESCE + ( jsonb_agg ( jsonb_delete ( ( row_to_json ( record ) :: JSONB ), 'document_id' ) ), '[]' :: JSONB ) AS status_meta, + document_id AS document_id + FROM + ( + SELECT + "paragraph".status, + "count" ( "paragraph"."id" ), + "document"."id" AS document_id + FROM + "document" "document" + LEFT JOIN "paragraph" "paragraph" ON "document"."id" = paragraph.document_id + ${document_custom_sql} + GROUP BY + "paragraph".status, + "document"."id" + ) record + GROUP BY + document_id + ) tmp +WHERE "document".id="tmp".document_id \ No newline at end of file diff --git a/apps/dataset/sql/update_paragraph_status.sql b/apps/dataset/sql/update_paragraph_status.sql new file mode 100644 index 00000000..45f9c674 --- /dev/null +++ b/apps/dataset/sql/update_paragraph_status.sql @@ -0,0 +1,13 @@ +UPDATE "${table_name}" +SET status = reverse ( + SUBSTRING ( reverse ( LPAD( status, ${bit_number}, 'n' ) ) :: TEXT FROM 1 FOR ${up_index} ) || ${status_number} || SUBSTRING ( reverse ( LPAD( status, ${bit_number}, 'n' ) ) :: TEXT FROM ${next_index} ) +), +status_meta = jsonb_set ( + "${table_name}".status_meta, + '{state_time,${current_index}}', + jsonb_set ( + COALESCE ( "${table_name}".status_meta #> '{state_time,${current_index}}', jsonb_build_object ( '${status_number}', now( ) ) ), + '{${status_number}}', + CONCAT ( '"', now( ), '"' ) :: JSONB + ) + ) \ No newline at end of file diff --git a/apps/dataset/swagger_api/document_api.py b/apps/dataset/swagger_api/document_api.py index 637a7e50..8fe588b7 100644 --- a/apps/dataset/swagger_api/document_api.py +++ b/apps/dataset/swagger_api/document_api.py @@ -26,3 +26,14 @@ class DocumentApi(ApiMixin): 'directly_return_similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title="直接返回相似度") } ) + + class Cancel(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'type': openapi.Schema(type=openapi.TYPE_INTEGER, title="任务类型", + description="1|2|3 1:向量化|2:生成问题|3:同步文档") + } + ) diff --git a/apps/dataset/task/generate.py b/apps/dataset/task/generate.py index 86042597..e8103974 100644 --- a/apps/dataset/task/generate.py +++ b/apps/dataset/task/generate.py @@ -1,12 +1,14 @@ import logging -from math import ceil +import traceback from celery_once import QueueOnce from django.db.models import QuerySet from langchain_core.messages import HumanMessage from common.config.embedding_config import ModelManage -from dataset.models import Paragraph, Document, Status +from common.event import ListenerManagement +from common.util.page_utils import page +from dataset.models import Paragraph, Document, Status, TaskType, State from dataset.task.tools import save_problem from ops import celery_app from setting.models import Model @@ -21,44 +23,79 @@ def get_llm_model(model_id): return ModelManage.get_model(model_id, lambda _id: get_model(model)) +def generate_problem_by_paragraph(paragraph, llm_model, prompt): + try: + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM, + State.STARTED) + res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))]) + if (res.content is None) or (len(res.content) == 0): + return + problems = res.content.split('\n') + for problem in problems: + save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem) + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM, + State.SUCCESS) + except Exception as e: + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM, + State.FAILURE) + + +def get_generate_problem(llm_model, prompt, post_apply=lambda: None, is_the_task_interrupted=lambda: False): + def generate_problem(paragraph_list): + for paragraph in paragraph_list: + if is_the_task_interrupted(): + return + generate_problem_by_paragraph(paragraph, llm_model, prompt) + post_apply() + + return generate_problem + + @celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:generate_related_by_document') def generate_related_by_document_id(document_id, model_id, prompt): - llm_model = get_llm_model(model_id) - offset = 0 - page_size = 10 - QuerySet(Document).filter(id=document_id).update(status=Status.generating) + try: + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.GENERATE_PROBLEM, + State.STARTED) + llm_model = get_llm_model(model_id) - count = QuerySet(Paragraph).filter(document_id=document_id).count() - for i in range(0, ceil(count / page_size)): - paragraph_list = QuerySet(Paragraph).filter(document_id=document_id).all()[offset:offset + page_size] - offset += page_size - for paragraph in paragraph_list: - res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))]) - if (res.content is None) or (len(res.content) == 0): - continue - problems = res.content.split('\n') - for problem in problems: - save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem) - - QuerySet(Document).filter(id=document_id).update(status=Status.success) + def is_the_task_interrupted(): + document = QuerySet(Document).filter(id=document_id).first() + if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE: + return True + return False + # 生成问题函数 + generate_problem = get_generate_problem(llm_model, prompt, + ListenerManagement.get_aggregation_document_status( + document_id), is_the_task_interrupted) + page(QuerySet(Paragraph).filter(document_id=document_id), 10, generate_problem, is_the_task_interrupted) + except Exception as e: + max_kb_error.error(f'根据文档生成问题:{document_id}出现错误{str(e)}{traceback.format_exc()}') + finally: + ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM) + max_kb.info(f"结束--->生成问题:{document_id}") @celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:generate_related_by_paragraph_list') -def generate_related_by_paragraph_id_list(paragraph_id_list, model_id, prompt): - llm_model = get_llm_model(model_id) - offset = 0 - page_size = 10 - count = QuerySet(Paragraph).filter(id__in=paragraph_id_list).count() - for i in range(0, ceil(count / page_size)): - paragraph_list = QuerySet(Paragraph).filter(id__in=paragraph_id_list).all()[offset:offset + page_size] - offset += page_size - for paragraph in paragraph_list: - res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))]) - if (res.content is None) or (len(res.content) == 0): - continue - problems = res.content.split('\n') - for problem in problems: - save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem) +def generate_related_by_paragraph_id_list(document_id, paragraph_id_list, model_id, prompt): + try: + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.GENERATE_PROBLEM, + State.STARTED) + llm_model = get_llm_model(model_id) + # 生成问题函数 + generate_problem = get_generate_problem(llm_model, prompt, ListenerManagement.get_aggregation_document_status( + document_id)) + + def is_the_task_interrupted(): + document = QuerySet(Document).filter(id=document_id).first() + if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE: + return True + return False + + page(QuerySet(Paragraph).filter(id__in=paragraph_id_list), 10, generate_problem, is_the_task_interrupted) + finally: + ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM) diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index b2246355..9e583531 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -37,6 +37,7 @@ urlpatterns = [ name="document_export"), path('dataset//document//sync', views.Document.SyncWeb.as_view()), path('dataset//document//refresh', views.Document.Refresh.as_view()), + path('dataset//document//cancel_task', views.Document.CancelTask.as_view()), path('dataset//document//paragraph', views.Paragraph.as_view()), path('dataset//document/batch_generate_related', views.Document.BatchGenerateRelated.as_view()), path( @@ -45,7 +46,8 @@ urlpatterns = [ path('dataset//document//paragraph/_batch', views.Paragraph.Batch.as_view()), path('dataset//document//paragraph//', views.Paragraph.Page.as_view(), name='paragraph_page'), - path('dataset//document//paragraph/batch_generate_related', views.Paragraph.BatchGenerateRelated.as_view()), + path('dataset//document//paragraph/batch_generate_related', + views.Paragraph.BatchGenerateRelated.as_view()), path('dataset//document//paragraph/', views.Paragraph.Operate.as_view()), path('dataset//document//paragraph//problem', diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index d911d0de..4a98fb08 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -218,6 +218,26 @@ class Document(APIView): DocumentSerializers.Sync(data={'document_id': document_id, 'dataset_id': dataset_id}).sync( )) + class CancelTask(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="取消任务", + operation_id="取消任务", + manual_parameters=DocumentSerializers.Operate.get_request_params_api(), + request_body=DocumentApi.Cancel.get_request_body_api(), + responses=result.get_default_response(), + tags=["知识库/文档"] + ) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, document_id: str): + return result.success( + DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).cancel( + request.data + )) + class Refresh(APIView): authentication_classes = [TokenAuth] diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index ab5ab410..9a6eaff6 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -86,20 +86,20 @@ class BaseVectorStore(ABC): for child_array in result: self._batch_save(child_array, embedding, lambda: True) - def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_save_function): + def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_the_task_interrupted): """ 批量插入 @param data_list: 数据列表 @param embedding: 向量化处理器 - @param is_save_function: + @param is_the_task_interrupted: 判断是否中断任务 :return: bool """ self.save_pre_handler() chunk_list = chunk_data_list(data_list) result = sub_array(chunk_list) for child_array in result: - if is_save_function(): - self._batch_save(child_array, embedding, is_save_function) + if not is_the_task_interrupted(): + self._batch_save(child_array, embedding, is_the_task_interrupted) else: break @@ -110,7 +110,7 @@ class BaseVectorStore(ABC): pass @abstractmethod - def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function): + def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted): pass def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index 8cd2146a..906da0cb 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -57,7 +57,7 @@ class PGVector(BaseVectorStore): embedding.save() return True - def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function): + def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted): texts = [row.get('text') for row in text_list] embeddings = embedding.embed_documents(texts) embedding_list = [Embedding(id=uuid.uuid1(), @@ -70,7 +70,7 @@ class PGVector(BaseVectorStore): embedding=embeddings[index], search_vector=to_ts_vector(text_list[index]['text'])) for index in range(0, len(texts))] - if is_save_function(): + if not is_the_task_interrupted(): QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None return True diff --git a/apps/ops/celery/logger.py b/apps/ops/celery/logger.py index bdadc568..1b2843c2 100644 --- a/apps/ops/celery/logger.py +++ b/apps/ops/celery/logger.py @@ -208,6 +208,7 @@ class CeleryThreadTaskFileHandler(CeleryThreadingLoggerHandler): f.flush() def handle_task_start(self, task_id): + print('handle_task_start') log_path = get_celery_task_log_path(task_id) thread_id = self.get_current_thread_id() self.task_id_thread_id_mapper[task_id] = thread_id @@ -215,6 +216,7 @@ class CeleryThreadTaskFileHandler(CeleryThreadingLoggerHandler): self.thread_id_fd_mapper[thread_id] = f def handle_task_end(self, task_id): + print('handle_task_end') ident_id = self.task_id_thread_id_mapper.get(task_id, '') f = self.thread_id_fd_mapper.pop(ident_id, None) if f and not f.closed: diff --git a/apps/ops/celery/signal_handler.py b/apps/ops/celery/signal_handler.py index 90ed6240..46671a0d 100644 --- a/apps/ops/celery/signal_handler.py +++ b/apps/ops/celery/signal_handler.py @@ -5,7 +5,7 @@ import os from celery import subtask from celery.signals import ( - worker_ready, worker_shutdown, after_setup_logger + worker_ready, worker_shutdown, after_setup_logger, task_revoked, task_prerun ) from django.core.cache import cache from django_celery_beat.models import PeriodicTask @@ -61,3 +61,15 @@ def add_celery_logger_handler(sender=None, logger=None, loglevel=None, format=No formatter = logging.Formatter(format) task_handler.setFormatter(formatter) logger.addHandler(task_handler) + + +@task_revoked.connect +def on_task_revoked(request, terminated, signum, expired, **kwargs): + print('task_revoked', terminated) + + +@task_prerun.connect +def on_taskaa_start(sender, task_id, **kwargs): + pass + # sender.update_state(state='REVOKED', +# meta={'exc_type': 'Exception', 'exc': 'Exception', 'message': '暂停任务', 'exc_message': ''}) diff --git a/ui/src/api/document.ts b/ui/src/api/document.ts index 28954d0c..7bd42546 100644 --- a/ui/src/api/document.ts +++ b/ui/src/api/document.ts @@ -322,8 +322,17 @@ const batchGenerateRelated: ( data: any, loading?: Ref ) => Promise> = (dataset_id, data, loading) => { + return put(`${prefix}/${dataset_id}/document/batch_generate_related`, data, undefined, loading) +} + +const cancelTask: ( + dataset_id: string, + document_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, document_id, data, loading) => { return put( - `${prefix}/${dataset_id}/document/batch_generate_related`, + `${prefix}/${dataset_id}/document/${document_id}/cancel_task`, data, undefined, loading @@ -352,5 +361,6 @@ export default { postTableDocument, exportDocument, batchRefresh, - batchGenerateRelated + batchGenerateRelated, + cancelTask } diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue index ef7ce1b7..46a8f394 100644 --- a/ui/src/components/ai-chat/ExecutionDetailDialog.vue +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -28,7 +28,12 @@
{{ item?.message_tokens + item?.answer_tokens }} tokens {{ item?.run_time?.toFixed(2) || 0.0 }} s @@ -166,7 +171,10 @@
-
+
本次对话
{{ item.question || '-' }} diff --git a/ui/src/components/ai-chat/KnowledgeSource.vue b/ui/src/components/ai-chat/KnowledgeSource.vue index 4de9ce1c..cb8bf66d 100644 --- a/ui/src/components/ai-chat/KnowledgeSource.vue +++ b/ui/src/components/ai-chat/KnowledgeSource.vue @@ -8,30 +8,31 @@ >
- -
- - - - - {{ paragraph?.document_name }} - - - - {{ paragraph?.document_name }} - - -
-
+ + +
@@ -59,7 +60,7 @@ import { computed, ref } from 'vue' import ParagraphSourceDialog from './ParagraphSourceDialog.vue' import ExecutionDetailDialog from './ExecutionDetailDialog.vue' import { isWorkFlow } from '@/utils/application' - +import { getImgUrl } from '@/utils/utils' const props = defineProps({ data: { type: Object, @@ -70,15 +71,6 @@ const props = defineProps({ default: '' } }) -const iconMap: { [key: string]: string } = { - doc: '../../assets/doc-icon.svg', - docx: '../../assets/docx-icon.svg', - pdf: '../../assets/pdf-icon.svg', - md: '../../assets/md-icon.svg', - txt: '../../assets/txt-icon.svg', - xls: '../../assets/xls-icon.svg', - xlsx: '../../assets/xlsx-icon.svg' -} const ParagraphSourceDialogRef = ref() const ExecutionDetailDialogRef = ref() @@ -107,14 +99,6 @@ const uniqueParagraphList = computed(() => { ) }) -function getIconPath(documentName: string) { - const extension = documentName.split('.').pop()?.toLowerCase() - if (!documentName || !extension) return new URL(`${iconMap['doc']}`, import.meta.url).href - if (iconMap && extension && iconMap[extension]) { - return new URL(`${iconMap[extension]}`, import.meta.url).href - } - return new URL(`${iconMap['doc']}`, import.meta.url).href -} function openLink(url: string) { // 如果url不是以/结尾,加上/ if (url && !url.endsWith('/')) { diff --git a/ui/src/components/ai-chat/component/ParagraphCard.vue b/ui/src/components/ai-chat/component/ParagraphCard.vue index 1a566c14..4fb785e6 100644 --- a/ui/src/components/ai-chat/component/ParagraphCard.vue +++ b/ui/src/components/ai-chat/component/ParagraphCard.vue @@ -18,45 +18,8 @@ diff --git a/ui/src/views/document/index.vue b/ui/src/views/document/index.vue index cd301600..4777dd10 100644 --- a/ui/src/views/document/index.vue +++ b/ui/src/views/document/index.vue @@ -134,21 +134,7 @@
@@ -249,7 +235,7 @@