This commit is contained in:
liqiang-fit2cloud 2024-11-26 15:15:13 +08:00
commit c0a04eee6e
28 changed files with 879 additions and 236 deletions

View File

@ -12,7 +12,7 @@ from django.db import DEFAULT_DB_ALIAS, models, connections
from django.db.models import QuerySet from django.db.models import QuerySet
from common.db.compiler import AppSQLCompiler from common.db.compiler import AppSQLCompiler
from common.db.sql_execute import select_one, select_list from common.db.sql_execute import select_one, select_list, update_execute
from common.response.result import Page from common.response.result import Page
@ -109,6 +109,24 @@ def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
return select_list(exec_sql, exec_params) return select_list(exec_sql, exec_params)
def native_update(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None,
with_table_name=False):
"""
复杂查询
:param with_table_name: 生成sql是否包含表名
:param queryset: 查询条件构造器
:param select_string: 查询前缀 不包括 where limit 等信息
:param field_replace_dict: 需要替换的字段
:return: 查询结果
"""
if isinstance(queryset, Dict):
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name)
else:
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name)
return update_execute(exec_sql, exec_params)
def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler): def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler):
""" """
分页查询 分页查询

View File

@ -9,24 +9,29 @@
import datetime import datetime
import logging import logging
import os import os
import threading
import traceback import traceback
from typing import List from typing import List
import django.db.models import django.db.models
from django.db import models
from django.db.models import QuerySet from django.db.models import QuerySet
from django.db.models.functions import Substr, Reverse
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from common.config.embedding_config import VectorStore from common.config.embedding_config import VectorStore
from common.db.search import native_search, get_dynamics_model from common.db.search import native_search, get_dynamics_model, native_update
from common.event.common import embedding_poxy from common.db.sql_execute import sql_execute, update_execute
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock from common.util.lock import try_lock, un_lock
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping from common.util.page_utils import page
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State
from embedding.models import SourceType, SearchMode from embedding.models import SourceType, SearchMode
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
max_kb_error = logging.getLogger(__file__) max_kb_error = logging.getLogger(__file__)
max_kb = logging.getLogger(__file__) max_kb = logging.getLogger(__file__)
lock = threading.Lock()
class SyncWebDatasetArgs: class SyncWebDatasetArgs:
@ -114,7 +119,8 @@ class ListenerManagement:
@param embedding_model: 向量模型 @param embedding_model: 向量模型
""" """
max_kb.info(f"开始--->向量化段落:{paragraph_id}") max_kb.info(f"开始--->向量化段落:{paragraph_id}")
status = Status.success # 更新到开始状态
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, State.STARTED)
try: try:
data_list = native_search( data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter( {'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
@ -125,16 +131,22 @@ class ListenerManagement:
# 删除段落 # 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id) VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
def is_save_function(): def is_the_task_interrupted():
return QuerySet(Paragraph).filter(id=paragraph_id).exists() _paragraph = QuerySet(Paragraph).filter(id=paragraph_id).first()
if _paragraph is None or Status(_paragraph.status)[TaskType.EMBEDDING] == State.REVOKE:
return True
return False
# 批量向量化 # 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_the_task_interrupted)
# 更新到开始状态
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING,
State.SUCCESS)
except Exception as e: except Exception as e:
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}') max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING,
State.FAILURE)
finally: finally:
QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': status})
max_kb.info(f'结束--->向量化段落:{paragraph_id}') max_kb.info(f'结束--->向量化段落:{paragraph_id}')
@staticmethod @staticmethod
@ -142,6 +154,89 @@ class ListenerManagement:
# 批量向量化 # 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True) VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True)
@staticmethod
def get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, post_apply=lambda: None):
def embedding_paragraph_apply(paragraph_list):
for paragraph in paragraph_list:
if is_the_task_interrupted():
break
ListenerManagement.embedding_by_paragraph(str(paragraph.get('id')), embedding_model)
post_apply()
return embedding_paragraph_apply
@staticmethod
def get_aggregation_document_status(document_id):
def aggregation_document_status():
sql = get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql'))
native_update({'document_custom_sql': QuerySet(Document).filter(id=document_id)}, sql, with_table_name=True)
return aggregation_document_status
@staticmethod
def get_aggregation_document_status_by_dataset_id(dataset_id):
def aggregation_document_status():
sql = get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql'))
native_update({'document_custom_sql': QuerySet(Document).filter(dataset_id=dataset_id)}, sql)
return aggregation_document_status
@staticmethod
def get_aggregation_document_status_by_query_set(queryset):
def aggregation_document_status():
sql = get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql'))
native_update({'document_custom_sql': queryset}, sql)
return aggregation_document_status
@staticmethod
def post_update_document_status(document_id, task_type: TaskType):
_document = QuerySet(Document).filter(id=document_id).first()
status = Status(_document.status)
if status[task_type] == State.REVOKE:
status[task_type] = State.REVOKED
else:
status[task_type] = State.SUCCESS
for item in _document.status_meta.get('aggs', []):
agg_status = item.get('status')
agg_count = item.get('count')
if Status(agg_status)[task_type] == State.FAILURE and agg_count > 0:
status[task_type] = State.FAILURE
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), task_type, status[task_type])
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', task_type.value,
task_type.value),
).filter(task_type_status=State.REVOKE.value).filter(document_id=document_id).values('id'),
task_type,
State.REVOKED)
@staticmethod
def update_status(query_set: QuerySet, taskType: TaskType, state: State):
exec_sql = get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_paragraph_status.sql'))
bit_number = len(TaskType)
up_index = taskType.value - 1
next_index = taskType.value + 1
current_index = taskType.value
status_number = state.value
params_dict = {'${bit_number}': bit_number, '${up_index}': up_index,
'${status_number}': status_number, '${next_index}': next_index,
'${table_name}': query_set.model._meta.db_table, '${current_index}': current_index}
for key in params_dict:
_value_ = params_dict[key]
exec_sql = exec_sql.replace(key, str(_value_))
lock.acquire()
try:
native_update(query_set, exec_sql)
finally:
lock.release()
@staticmethod @staticmethod
def embedding_by_document(document_id, embedding_model: Embeddings): def embedding_by_document(document_id, embedding_model: Embeddings):
""" """
@ -153,33 +248,29 @@ class ListenerManagement:
if not try_lock('embedding' + str(document_id)): if not try_lock('embedding' + str(document_id)):
return return
max_kb.info(f"开始--->向量化文档:{document_id}") max_kb.info(f"开始--->向量化文档:{document_id}")
QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding}) # 批量修改状态为PADDING
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.embedding}) ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.STARTED)
status = Status.success
try: try:
data_list = native_search(
{'problem': QuerySet(
get_dynamics_model({'paragraph.document_id': django.db.models.CharField()})).filter(
**{'paragraph.document_id': document_id}),
'paragraph': QuerySet(Paragraph).filter(document_id=document_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 删除文档向量数据 # 删除文档向量数据
VectorStore.get_embedding_vector().delete_by_document_id(document_id) VectorStore.get_embedding_vector().delete_by_document_id(document_id)
def is_save_function(): def is_the_task_interrupted():
return QuerySet(Document).filter(id=document_id).exists() document = QuerySet(Document).filter(id=document_id).first()
if document is None or Status(document.status)[TaskType.EMBEDDING] == State.REVOKE:
return True
return False
# 批量向量化 # 根据段落进行向量化处理
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) page(QuerySet(Paragraph).filter(document_id=document_id).values('id'), 5,
ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted,
ListenerManagement.get_aggregation_document_status(
document_id)),
is_the_task_interrupted)
except Exception as e: except Exception as e:
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}') max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
finally: finally:
# 修改状态 ListenerManagement.post_update_document_status(document_id, TaskType.EMBEDDING)
QuerySet(Document).filter(id=document_id).update( ListenerManagement.get_aggregation_document_status(document_id)()
**{'status': status, 'update_time': datetime.datetime.now()})
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status})
max_kb.info(f"结束--->向量化文档:{document_id}") max_kb.info(f"结束--->向量化文档:{document_id}")
un_lock('embedding' + str(document_id)) un_lock('embedding' + str(document_id))

View File

@ -0,0 +1,27 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file page_utils.py
@date2024/11/21 10:32
@desc:
"""
from math import ceil
def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
"""
@param query_set: 查询query_set
@param page_size: 每次查询大小
@param handler: 数据处理器
@param is_the_task_interrupted: 任务是否被中断
@return:
"""
count = query_set.count()
for i in range(0, ceil(count / page_size)):
if is_the_task_interrupted():
return
offset = i * page_size
paragraph_list = query_set[offset: offset + page_size]
handler(paragraph_list)

View File

@ -0,0 +1,47 @@
# Generated by Django 4.2.15 on 2024-11-22 14:44
from django.db.models import QuerySet
from django.db import migrations, models
import dataset
from common.event import ListenerManagement
from dataset.models import State, TaskType
def updateDocumentStatus(apps, schema_editor):
ParagraphModel = apps.get_model('dataset', 'Paragraph')
DocumentModel = apps.get_model('dataset', 'Document')
success_list = QuerySet(DocumentModel).filter(status='2')
ListenerManagement.update_status(QuerySet(ParagraphModel).filter(document_id__in=[d.id for d in success_list]),
TaskType.EMBEDDING, State.SUCCESS)
ListenerManagement.get_aggregation_document_status_by_query_set(QuerySet(DocumentModel))()
class Migration(migrations.Migration):
dependencies = [
('dataset', '0010_file_meta'),
]
operations = [
migrations.AddField(
model_name='document',
name='status_meta',
field=models.JSONField(default=dataset.models.data_set.default_status_meta, verbose_name='状态统计数据'),
),
migrations.AddField(
model_name='paragraph',
name='status_meta',
field=models.JSONField(default=dataset.models.data_set.default_status_meta, verbose_name='状态数据'),
),
migrations.AlterField(
model_name='document',
name='status',
field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'),
),
migrations.AlterField(
model_name='paragraph',
name='status',
field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'),
),
migrations.RunPython(updateDocumentStatus)
]

View File

@ -7,6 +7,7 @@
@desc: 数据集 @desc: 数据集
""" """
import uuid import uuid
from enum import Enum
from django.db import models from django.db import models
from django.db.models.signals import pre_delete from django.db.models.signals import pre_delete
@ -18,13 +19,62 @@ from setting.models import Model
from users.models import User from users.models import User
class Status(models.TextChoices): class TaskType(Enum):
"""订单类型""" # 向量
embedding = 0, '导入中' EMBEDDING = 1
success = 1, '已完成' # 生成问题
error = 2, '导入失败' GENERATE_PROBLEM = 2
queue_up = 3, '排队中' # 同步
generating = 4, '生成问题中' SYNC = 3
class State(Enum):
# 等待
PENDING = '0'
# 执行中
STARTED = '1'
# 成功
SUCCESS = '2'
# 失败
FAILURE = '3'
# 取消任务
REVOKE = '4'
# 取消成功
REVOKED = '5'
# 忽略
IGNORED = 'n'
class Status:
type_cls = TaskType
state_cls = State
def __init__(self, status: str = None):
self.task_status = {}
status_list = list(status[::-1] if status is not None else '')
for _type in self.type_cls:
index = _type.value - 1
_state = self.state_cls(status_list[index] if len(status_list) > index else 'n')
self.task_status[_type] = _state
@staticmethod
def of(status: str):
return Status(status)
def __str__(self):
result = []
for _type in sorted(self.type_cls, key=lambda item: item.value, reverse=True):
result.insert(len(self.type_cls) - _type.value, self.task_status[_type].value)
return ''.join(result)
def __setitem__(self, key, value):
self.task_status[key] = value
def __getitem__(self, item):
return self.task_status[item]
def update_status(self, task_type: TaskType, state: State):
self.task_status[task_type] = state
class Type(models.TextChoices): class Type(models.TextChoices):
@ -42,6 +92,10 @@ def default_model():
return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab') return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab')
def default_status_meta():
return {"state_time": {}}
class DataSet(AppModelMixin): class DataSet(AppModelMixin):
""" """
数据集表 数据集表
@ -68,8 +122,8 @@ class Document(AppModelMixin):
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING) dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
name = models.CharField(max_length=150, verbose_name="文档名称") name = models.CharField(max_length=150, verbose_name="文档名称")
char_length = models.IntegerField(verbose_name="文档字符数 冗余字段") char_length = models.IntegerField(verbose_name="文档字符数 冗余字段")
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices, status = models.CharField(verbose_name='状态', max_length=20, default=Status('').__str__)
default=Status.queue_up) status_meta = models.JSONField(verbose_name="状态统计数据", default=default_status_meta)
is_active = models.BooleanField(default=True) is_active = models.BooleanField(default=True)
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices, type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
@ -94,8 +148,8 @@ class Paragraph(AppModelMixin):
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING) dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
content = models.CharField(max_length=102400, verbose_name="段落内容") content = models.CharField(max_length=102400, verbose_name="段落内容")
title = models.CharField(max_length=256, verbose_name="标题", default="") title = models.CharField(max_length=256, verbose_name="标题", default="")
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices, status = models.CharField(verbose_name='状态', max_length=20, default=Status('').__str__)
default=Status.embedding) status_meta = models.JSONField(verbose_name="状态数据", default=default_status_meta)
hit_num = models.IntegerField(verbose_name="命中次数", default=0) hit_num = models.IntegerField(verbose_name="命中次数", default=0)
is_active = models.BooleanField(default=True) is_active = models.BooleanField(default=True)
@ -145,7 +199,6 @@ class File(AppModelMixin):
meta = models.JSONField(verbose_name="文件关联数据", default=dict) meta = models.JSONField(verbose_name="文件关联数据", default=dict)
class Meta: class Meta:
db_table = "file" db_table = "file"
@ -161,7 +214,6 @@ class File(AppModelMixin):
return result['data'] return result['data']
@receiver(pre_delete, sender=File) @receiver(pre_delete, sender=File)
def on_delete_file(sender, instance, **kwargs): def on_delete_file(sender, instance, **kwargs):
select_one(f'SELECT lo_unlink({instance.loid})', []) select_one(f'SELECT lo_unlink({instance.loid})', [])

View File

@ -27,6 +27,7 @@ from application.models import ApplicationDatasetMapping
from common.config.embedding_config import VectorStore from common.config.embedding_config import VectorStore
from common.db.search import get_dynamics_model, native_page_search, native_search from common.db.search import get_dynamics_model, native_page_search, native_search
from common.db.sql_execute import select_list from common.db.sql_execute import select_list
from common.event import ListenerManagement
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin from common.mixins.api_mixin import ApiMixin
from common.util.common import post, flat_map, valid_license from common.util.common import post, flat_map, valid_license
@ -34,7 +35,8 @@ from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from common.util.fork import ChildLink, Fork from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Status from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Status, \
TaskType, State
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \ from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
@ -733,9 +735,13 @@ class DataSetSerializers(serializers.ModelSerializer):
def re_embedding(self, with_valid=True): def re_embedding(self, with_valid=True):
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
ListenerManagement.update_status(QuerySet(Document).filter(dataset_id=self.data.get('id')),
QuerySet(Document).filter(dataset_id=self.data.get('id')).update(**{'status': Status.queue_up}) TaskType.EMBEDDING,
QuerySet(Paragraph).filter(dataset_id=self.data.get('id')).update(**{'status': Status.queue_up}) State.PENDING)
ListenerManagement.update_status(QuerySet(Paragraph).filter(dataset_id=self.data.get('id')),
TaskType.EMBEDDING,
State.PENDING)
ListenerManagement.get_aggregation_document_status_by_dataset_id(self.data.get('id'))()
embedding_model_id = get_embedding_model_id_by_dataset_id(self.data.get('id')) embedding_model_id = get_embedding_model_id_by_dataset_id(self.data.get('id'))
embedding_by_dataset.delay(self.data.get('id'), embedding_model_id) embedding_by_dataset.delay(self.data.get('id'), embedding_model_id)

View File

@ -19,6 +19,7 @@ from celery_once import AlreadyQueued
from django.core import validators from django.core import validators
from django.db import transaction from django.db import transaction
from django.db.models import QuerySet from django.db.models import QuerySet
from django.db.models.functions import Substr, Reverse
from django.http import HttpResponse from django.http import HttpResponse
from drf_yasg import openapi from drf_yasg import openapi
from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE
@ -26,6 +27,7 @@ from rest_framework import serializers
from xlwt import Utils from xlwt import Utils
from common.db.search import native_search, native_page_search from common.db.search import native_search, native_page_search
from common.event import ListenerManagement
from common.event.common import work_thread_pool from common.event.common import work_thread_pool
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.handle.impl.doc_split_handle import DocSplitHandle from common.handle.impl.doc_split_handle import DocSplitHandle
@ -44,7 +46,8 @@ from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from common.util.fork import Fork from common.util.fork import Fork
from common.util.split_model import get_split_model from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Image, \
TaskType, State
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \ from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_id_by_dataset_id get_embedding_model_id_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
@ -67,6 +70,19 @@ class FileBufferHandle:
return self.buffer return self.buffer
class CancelInstanceSerializer(serializers.Serializer):
type = serializers.IntegerField(required=True, error_messages=ErrMessage.boolean(
"任务类型"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
_type = self.data.get('type')
try:
TaskType(_type)
except Exception as e:
raise AppApiException(500, '任务类型不支持')
class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer): class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer):
meta = serializers.DictField(required=False) meta = serializers.DictField(required=False)
name = serializers.CharField(required=False, max_length=128, min_length=1, name = serializers.CharField(required=False, max_length=128, min_length=1,
@ -278,7 +294,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
# 修改向量信息 # 修改向量信息
if model_id: if model_id:
delete_embedding_by_paragraph_ids(pid_list) delete_embedding_by_paragraph_ids(pid_list)
QuerySet(Document).filter(id__in=document_id_list).update(status=Status.queue_up) ListenerManagement.update_status(QuerySet(Document).filter(id__in=document_id_list),
TaskType.EMBEDDING,
State.PENDING)
embedding_by_document_list.delay(document_id_list, model_id) embedding_by_document_list.delay(document_id_list, model_id)
else: else:
update_embedding_dataset_id(pid_list, target_dataset_id) update_embedding_dataset_id(pid_list, target_dataset_id)
@ -404,11 +422,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
document_id = self.data.get('document_id') document_id = self.data.get('document_id')
document = QuerySet(Document).filter(id=document_id).first() document = QuerySet(Document).filter(id=document_id).first()
state = State.SUCCESS
if document.type != Type.web: if document.type != Type.web:
return True return True
try: try:
document.status = Status.queue_up ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
document.save() TaskType.SYNC,
State.PENDING)
source_url = document.meta.get('source_url') source_url = document.meta.get('source_url')
selector_list = document.meta.get('selector').split( selector_list = document.meta.get('selector').split(
" ") if 'selector' in document.meta and document.meta.get('selector') is not None else [] " ") if 'selector' in document.meta and document.meta.get('selector') is not None else []
@ -442,13 +462,18 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
if with_embedding: if with_embedding:
embedding_model_id = get_embedding_model_id_by_dataset_id(document.dataset_id) embedding_model_id = get_embedding_model_id_by_dataset_id(document.dataset_id)
embedding_by_document.delay(document_id, embedding_model_id) embedding_by_document.delay(document_id, embedding_model_id)
else: else:
document.status = Status.error state = State.FAILURE
document.save()
except Exception as e: except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
document.status = Status.error state = State.FAILURE
document.save() ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
TaskType.SYNC,
state)
ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id),
TaskType.SYNC,
state)
return True return True
class Operate(ApiMixin, serializers.Serializer): class Operate(ApiMixin, serializers.Serializer):
@ -586,14 +611,35 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
document_id = self.data.get("document_id") document_id = self.data.get("document_id")
QuerySet(Document).filter(id=document_id).update(**{'status': Status.queue_up}) ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.queue_up}) State.PENDING)
ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id), TaskType.EMBEDDING,
State.PENDING)
ListenerManagement.get_aggregation_document_status(document_id)()
embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id')) embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id'))
try: try:
embedding_by_document.delay(document_id, embedding_model_id) embedding_by_document.delay(document_id, embedding_model_id)
except AlreadyQueued as e: except AlreadyQueued as e:
raise AppApiException(500, "任务正在执行中,请勿重复下发") raise AppApiException(500, "任务正在执行中,请勿重复下发")
def cancel(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
CancelInstanceSerializer(data=instance).is_valid()
document_id = self.data.get("document_id")
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value,
TaskType(instance.get('type')).value),
).filter(task_type_status__in=[State.PENDING.value, State.STARTED.value]).filter(
document_id=document_id).values('id'),
TaskType(instance.get('type')),
State.REVOKE)
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType(instance.get('type')),
State.REVOKE)
return True
@transaction.atomic @transaction.atomic
def delete(self): def delete(self):
document_id = self.data.get("document_id") document_id = self.data.get("document_id")
@ -955,15 +1001,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
document_id_list = instance.get("id_list") document_id_list = instance.get("id_list")
with transaction.atomic(): with transaction.atomic():
Document.objects.filter(id__in=document_id_list).update(status=Status.queue_up)
Paragraph.objects.filter(document_id__in=document_id_list).update(status=Status.queue_up)
dataset_id = self.data.get('dataset_id') dataset_id = self.data.get('dataset_id')
embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=dataset_id)
for document_id in document_id_list: for document_id in document_id_list:
try: try:
embedding_by_document.delay(document_id, embedding_model_id) DocumentSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': document_id}).refresh()
except AlreadyQueued as e: except AlreadyQueued as e:
raise AppApiException(500, "任务正在执行中,请勿重复下发") pass
class GenerateRelated(ApiMixin, serializers.Serializer): class GenerateRelated(ApiMixin, serializers.Serializer):
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
@ -978,7 +1022,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
document_id = self.data.get('document_id') document_id = self.data.get('document_id')
QuerySet(Document).filter(id=document_id).update(status=Status.queue_up) ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
TaskType.GENERATE_PROBLEM,
State.PENDING)
ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id),
TaskType.GENERATE_PROBLEM,
State.PENDING)
ListenerManagement.get_aggregation_document_status(document_id)()
try: try:
generate_related_by_document_id.delay(document_id, model_id, prompt) generate_related_by_document_id.delay(document_id, model_id, prompt)
except AlreadyQueued as e: except AlreadyQueued as e:

View File

@ -16,11 +16,12 @@ from drf_yasg import openapi
from rest_framework import serializers from rest_framework import serializers
from common.db.search import page_search from common.db.search import page_search
from common.event import ListenerManagement
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin from common.mixins.api_mixin import ApiMixin
from common.util.common import post from common.util.common import post
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet, TaskType, State
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \ from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
ProblemParagraphManage, get_embedding_model_id_by_dataset_id ProblemParagraphManage, get_embedding_model_id_by_dataset_id
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
@ -722,7 +723,6 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
} }
) )
class BatchGenerateRelated(ApiMixin, serializers.Serializer): class BatchGenerateRelated(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
@ -734,10 +734,16 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
paragraph_id_list = instance.get("paragraph_id_list") paragraph_id_list = instance.get("paragraph_id_list")
model_id = instance.get("model_id") model_id = instance.get("model_id")
prompt = instance.get("prompt") prompt = instance.get("prompt")
document_id = self.data.get('document_id')
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
TaskType.GENERATE_PROBLEM,
State.PENDING)
ListenerManagement.update_status(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
TaskType.GENERATE_PROBLEM,
State.PENDING)
ListenerManagement.get_aggregation_document_status(document_id)()
try: try:
generate_related_by_paragraph_id_list.delay(paragraph_id_list, model_id, prompt) generate_related_by_paragraph_id_list.delay(document_id, paragraph_id_list, model_id,
prompt)
except AlreadyQueued as e: except AlreadyQueued as e:
raise AppApiException(500, "任务正在执行中,请勿重复下发") raise AppApiException(500, "任务正在执行中,请勿重复下发")

View File

@ -1,6 +1,7 @@
SELECT SELECT
"document".* , "document".* ,
to_json("document"."meta") as meta, to_json("document"."meta") as meta,
to_json("document"."status_meta") as status_meta,
(SELECT "count"("id") FROM "paragraph" WHERE document_id="document"."id") as "paragraph_count" (SELECT "count"("id") FROM "paragraph" WHERE document_id="document"."id") as "paragraph_count"
FROM FROM
"document" "document" "document" "document"

View File

@ -0,0 +1,25 @@
UPDATE "document" "document"
SET status_meta = jsonb_set ( "document".status_meta, '{aggs}', tmp.status_meta )
FROM
(
SELECT COALESCE
( jsonb_agg ( jsonb_delete ( ( row_to_json ( record ) :: JSONB ), 'document_id' ) ), '[]' :: JSONB ) AS status_meta,
document_id AS document_id
FROM
(
SELECT
"paragraph".status,
"count" ( "paragraph"."id" ),
"document"."id" AS document_id
FROM
"document" "document"
LEFT JOIN "paragraph" "paragraph" ON "document"."id" = paragraph.document_id
${document_custom_sql}
GROUP BY
"paragraph".status,
"document"."id"
) record
GROUP BY
document_id
) tmp
WHERE "document".id="tmp".document_id

View File

@ -0,0 +1,13 @@
UPDATE "${table_name}"
SET status = reverse (
SUBSTRING ( reverse ( LPAD( status, ${bit_number}, 'n' ) ) :: TEXT FROM 1 FOR ${up_index} ) || ${status_number} || SUBSTRING ( reverse ( LPAD( status, ${bit_number}, 'n' ) ) :: TEXT FROM ${next_index} )
),
status_meta = jsonb_set (
"${table_name}".status_meta,
'{state_time,${current_index}}',
jsonb_set (
COALESCE ( "${table_name}".status_meta #> '{state_time,${current_index}}', jsonb_build_object ( '${status_number}', now( ) ) ),
'{${status_number}}',
CONCAT ( '"', now( ), '"' ) :: JSONB
)
)

View File

@ -26,3 +26,14 @@ class DocumentApi(ApiMixin):
'directly_return_similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title="直接返回相似度") 'directly_return_similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title="直接返回相似度")
} }
) )
class Cancel(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
properties={
'type': openapi.Schema(type=openapi.TYPE_INTEGER, title="任务类型",
description="1|2|3 1:向量化|2:生成问题|3:同步文档")
}
)

View File

@ -1,12 +1,14 @@
import logging import logging
from math import ceil import traceback
from celery_once import QueueOnce from celery_once import QueueOnce
from django.db.models import QuerySet from django.db.models import QuerySet
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from common.config.embedding_config import ModelManage from common.config.embedding_config import ModelManage
from dataset.models import Paragraph, Document, Status from common.event import ListenerManagement
from common.util.page_utils import page
from dataset.models import Paragraph, Document, Status, TaskType, State
from dataset.task.tools import save_problem from dataset.task.tools import save_problem
from ops import celery_app from ops import celery_app
from setting.models import Model from setting.models import Model
@ -21,44 +23,79 @@ def get_llm_model(model_id):
return ModelManage.get_model(model_id, lambda _id: get_model(model)) return ModelManage.get_model(model_id, lambda _id: get_model(model))
def generate_problem_by_paragraph(paragraph, llm_model, prompt):
try:
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
State.STARTED)
res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))])
if (res.content is None) or (len(res.content) == 0):
return
problems = res.content.split('\n')
for problem in problems:
save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem)
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
State.SUCCESS)
except Exception as e:
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
State.FAILURE)
def get_generate_problem(llm_model, prompt, post_apply=lambda: None, is_the_task_interrupted=lambda: False):
def generate_problem(paragraph_list):
for paragraph in paragraph_list:
if is_the_task_interrupted():
return
generate_problem_by_paragraph(paragraph, llm_model, prompt)
post_apply()
return generate_problem
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, @celery_app.task(base=QueueOnce, once={'keys': ['document_id']},
name='celery:generate_related_by_document') name='celery:generate_related_by_document')
def generate_related_by_document_id(document_id, model_id, prompt): def generate_related_by_document_id(document_id, model_id, prompt):
llm_model = get_llm_model(model_id) try:
offset = 0 ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
page_size = 10 TaskType.GENERATE_PROBLEM,
QuerySet(Document).filter(id=document_id).update(status=Status.generating) State.STARTED)
llm_model = get_llm_model(model_id)
count = QuerySet(Paragraph).filter(document_id=document_id).count() def is_the_task_interrupted():
for i in range(0, ceil(count / page_size)): document = QuerySet(Document).filter(id=document_id).first()
paragraph_list = QuerySet(Paragraph).filter(document_id=document_id).all()[offset:offset + page_size] if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE:
offset += page_size return True
for paragraph in paragraph_list: return False
res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))])
if (res.content is None) or (len(res.content) == 0):
continue
problems = res.content.split('\n')
for problem in problems:
save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem)
QuerySet(Document).filter(id=document_id).update(status=Status.success)
# 生成问题函数
generate_problem = get_generate_problem(llm_model, prompt,
ListenerManagement.get_aggregation_document_status(
document_id), is_the_task_interrupted)
page(QuerySet(Paragraph).filter(document_id=document_id), 10, generate_problem, is_the_task_interrupted)
except Exception as e:
max_kb_error.error(f'根据文档生成问题:{document_id}出现错误{str(e)}{traceback.format_exc()}')
finally:
ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM)
max_kb.info(f"结束--->生成问题:{document_id}")
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, @celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']},
name='celery:generate_related_by_paragraph_list') name='celery:generate_related_by_paragraph_list')
def generate_related_by_paragraph_id_list(paragraph_id_list, model_id, prompt): def generate_related_by_paragraph_id_list(document_id, paragraph_id_list, model_id, prompt):
llm_model = get_llm_model(model_id) try:
offset = 0 ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
page_size = 10 TaskType.GENERATE_PROBLEM,
count = QuerySet(Paragraph).filter(id__in=paragraph_id_list).count() State.STARTED)
for i in range(0, ceil(count / page_size)): llm_model = get_llm_model(model_id)
paragraph_list = QuerySet(Paragraph).filter(id__in=paragraph_id_list).all()[offset:offset + page_size] # 生成问题函数
offset += page_size generate_problem = get_generate_problem(llm_model, prompt, ListenerManagement.get_aggregation_document_status(
for paragraph in paragraph_list: document_id))
res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))])
if (res.content is None) or (len(res.content) == 0): def is_the_task_interrupted():
continue document = QuerySet(Document).filter(id=document_id).first()
problems = res.content.split('\n') if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE:
for problem in problems: return True
save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem) return False
page(QuerySet(Paragraph).filter(id__in=paragraph_id_list), 10, generate_problem, is_the_task_interrupted)
finally:
ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM)

View File

@ -37,6 +37,7 @@ urlpatterns = [
name="document_export"), name="document_export"),
path('dataset/<str:dataset_id>/document/<str:document_id>/sync', views.Document.SyncWeb.as_view()), path('dataset/<str:dataset_id>/document/<str:document_id>/sync', views.Document.SyncWeb.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/refresh', views.Document.Refresh.as_view()), path('dataset/<str:dataset_id>/document/<str:document_id>/refresh', views.Document.Refresh.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/cancel_task', views.Document.CancelTask.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph', views.Paragraph.as_view()), path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph', views.Paragraph.as_view()),
path('dataset/<str:dataset_id>/document/batch_generate_related', views.Document.BatchGenerateRelated.as_view()), path('dataset/<str:dataset_id>/document/batch_generate_related', views.Document.BatchGenerateRelated.as_view()),
path( path(
@ -45,7 +46,8 @@ urlpatterns = [
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/_batch', views.Paragraph.Batch.as_view()), path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/_batch', views.Paragraph.Batch.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<int:current_page>/<int:page_size>', path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<int:current_page>/<int:page_size>',
views.Paragraph.Page.as_view(), name='paragraph_page'), views.Paragraph.Page.as_view(), name='paragraph_page'),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/batch_generate_related', views.Paragraph.BatchGenerateRelated.as_view()), path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/batch_generate_related',
views.Paragraph.BatchGenerateRelated.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<paragraph_id>', path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<paragraph_id>',
views.Paragraph.Operate.as_view()), views.Paragraph.Operate.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem', path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem',

View File

@ -218,6 +218,26 @@ class Document(APIView):
DocumentSerializers.Sync(data={'document_id': document_id, 'dataset_id': dataset_id}).sync( DocumentSerializers.Sync(data={'document_id': document_id, 'dataset_id': dataset_id}).sync(
)) ))
class CancelTask(APIView):
authentication_classes = [TokenAuth]
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="取消任务",
operation_id="取消任务",
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
request_body=DocumentApi.Cancel.get_request_body_api(),
responses=result.get_default_response(),
tags=["知识库/文档"]
)
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def put(self, request: Request, dataset_id: str, document_id: str):
return result.success(
DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).cancel(
request.data
))
class Refresh(APIView): class Refresh(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]

View File

@ -86,20 +86,20 @@ class BaseVectorStore(ABC):
for child_array in result: for child_array in result:
self._batch_save(child_array, embedding, lambda: True) self._batch_save(child_array, embedding, lambda: True)
def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_save_function): def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
""" """
批量插入 批量插入
@param data_list: 数据列表 @param data_list: 数据列表
@param embedding: 向量化处理器 @param embedding: 向量化处理器
@param is_save_function: @param is_the_task_interrupted: 判断是否中断任务
:return: bool :return: bool
""" """
self.save_pre_handler() self.save_pre_handler()
chunk_list = chunk_data_list(data_list) chunk_list = chunk_data_list(data_list)
result = sub_array(chunk_list) result = sub_array(chunk_list)
for child_array in result: for child_array in result:
if is_save_function(): if not is_the_task_interrupted():
self._batch_save(child_array, embedding, is_save_function) self._batch_save(child_array, embedding, is_the_task_interrupted)
else: else:
break break
@ -110,7 +110,7 @@ class BaseVectorStore(ABC):
pass pass
@abstractmethod @abstractmethod
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function): def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
pass pass
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],

View File

@ -57,7 +57,7 @@ class PGVector(BaseVectorStore):
embedding.save() embedding.save()
return True return True
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function): def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
texts = [row.get('text') for row in text_list] texts = [row.get('text') for row in text_list]
embeddings = embedding.embed_documents(texts) embeddings = embedding.embed_documents(texts)
embedding_list = [Embedding(id=uuid.uuid1(), embedding_list = [Embedding(id=uuid.uuid1(),
@ -70,7 +70,7 @@ class PGVector(BaseVectorStore):
embedding=embeddings[index], embedding=embeddings[index],
search_vector=to_ts_vector(text_list[index]['text'])) for index in search_vector=to_ts_vector(text_list[index]['text'])) for index in
range(0, len(texts))] range(0, len(texts))]
if is_save_function(): if not is_the_task_interrupted():
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
return True return True

View File

@ -208,6 +208,7 @@ class CeleryThreadTaskFileHandler(CeleryThreadingLoggerHandler):
f.flush() f.flush()
def handle_task_start(self, task_id): def handle_task_start(self, task_id):
print('handle_task_start')
log_path = get_celery_task_log_path(task_id) log_path = get_celery_task_log_path(task_id)
thread_id = self.get_current_thread_id() thread_id = self.get_current_thread_id()
self.task_id_thread_id_mapper[task_id] = thread_id self.task_id_thread_id_mapper[task_id] = thread_id
@ -215,6 +216,7 @@ class CeleryThreadTaskFileHandler(CeleryThreadingLoggerHandler):
self.thread_id_fd_mapper[thread_id] = f self.thread_id_fd_mapper[thread_id] = f
def handle_task_end(self, task_id): def handle_task_end(self, task_id):
print('handle_task_end')
ident_id = self.task_id_thread_id_mapper.get(task_id, '') ident_id = self.task_id_thread_id_mapper.get(task_id, '')
f = self.thread_id_fd_mapper.pop(ident_id, None) f = self.thread_id_fd_mapper.pop(ident_id, None)
if f and not f.closed: if f and not f.closed:

View File

@ -5,7 +5,7 @@ import os
from celery import subtask from celery import subtask
from celery.signals import ( from celery.signals import (
worker_ready, worker_shutdown, after_setup_logger worker_ready, worker_shutdown, after_setup_logger, task_revoked, task_prerun
) )
from django.core.cache import cache from django.core.cache import cache
from django_celery_beat.models import PeriodicTask from django_celery_beat.models import PeriodicTask
@ -61,3 +61,15 @@ def add_celery_logger_handler(sender=None, logger=None, loglevel=None, format=No
formatter = logging.Formatter(format) formatter = logging.Formatter(format)
task_handler.setFormatter(formatter) task_handler.setFormatter(formatter)
logger.addHandler(task_handler) logger.addHandler(task_handler)
@task_revoked.connect
def on_task_revoked(request, terminated, signum, expired, **kwargs):
print('task_revoked', terminated)
@task_prerun.connect
def on_taskaa_start(sender, task_id, **kwargs):
pass
# sender.update_state(state='REVOKED',
# meta={'exc_type': 'Exception', 'exc': 'Exception', 'message': '暂停任务', 'exc_message': ''})

View File

@ -322,8 +322,17 @@ const batchGenerateRelated: (
data: any, data: any,
loading?: Ref<boolean> loading?: Ref<boolean>
) => Promise<Result<boolean>> = (dataset_id, data, loading) => { ) => Promise<Result<boolean>> = (dataset_id, data, loading) => {
return put(`${prefix}/${dataset_id}/document/batch_generate_related`, data, undefined, loading)
}
const cancelTask: (
dataset_id: string,
document_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<boolean>> = (dataset_id, document_id, data, loading) => {
return put( return put(
`${prefix}/${dataset_id}/document/batch_generate_related`, `${prefix}/${dataset_id}/document/${document_id}/cancel_task`,
data, data,
undefined, undefined,
loading loading
@ -352,5 +361,6 @@ export default {
postTableDocument, postTableDocument,
exportDocument, exportDocument,
batchRefresh, batchRefresh,
batchGenerateRelated batchGenerateRelated,
cancelTask
} }

View File

@ -28,7 +28,12 @@
<div class="flex align-center"> <div class="flex align-center">
<span <span
class="mr-16 color-secondary" class="mr-16 color-secondary"
v-if="item.type === WorkflowType.Question || item.type === WorkflowType.AiChat || item.type === WorkflowType.ImageUnderstandNode" v-if="
item.type === WorkflowType.Question ||
item.type === WorkflowType.AiChat ||
item.type === WorkflowType.ImageUnderstandNode ||
item.type === WorkflowType.Application
"
>{{ item?.message_tokens + item?.answer_tokens }} tokens</span >{{ item?.message_tokens + item?.answer_tokens }} tokens</span
> >
<span class="mr-16 color-secondary">{{ item?.run_time?.toFixed(2) || 0.0 }} s</span> <span class="mr-16 color-secondary">{{ item?.run_time?.toFixed(2) || 0.0 }} s</span>
@ -166,7 +171,10 @@
<template v-else> - </template> <template v-else> - </template>
</div> </div>
</div> </div>
<div class="card-never border-r-4 mt-8"> <div
class="card-never border-r-4 mt-8"
v-if="item.type !== WorkflowType.Application"
>
<h5 class="p-8-12">本次对话</h5> <h5 class="p-8-12">本次对话</h5>
<div class="p-8-12 border-t-dashed lighter pre-wrap"> <div class="p-8-12 border-t-dashed lighter pre-wrap">
{{ item.question || '-' }} {{ item.question || '-' }}

View File

@ -8,30 +8,31 @@
> >
</div> </div>
<div class="mt-8" v-if="!isWorkFlow(props.type)"> <div class="mt-8" v-if="!isWorkFlow(props.type)">
<el-space wrap> <el-row :gutter="8" v-if="uniqueParagraphList?.length">
<div v-for="(paragraph, index) in uniqueParagraphList" :key="index"> <template v-for="(item, index) in uniqueParagraphList" :key="index">
<el-icon class="mr-4" :size="25"> <el-col :span="12" class="mb-8">
<img :src="getIconPath(paragraph.document_name)" style="width: 90%" alt="" /> <el-card shadow="never" class="file-List-card" data-width="40">
</el-icon> <div class="flex-between">
<span <div class="flex">
v-if="!paragraph.source_url" <img :src="getImgUrl(item && item?.document_name)" alt="" width="20" />
class="ellipsis" <div class="ml-4" v-if="!item.source_url">
:title="paragraph?.document_name?.trim()" <p>{{ item && item?.document_name }}</p>
> </div>
{{ paragraph?.document_name }} <div class="ml-8" v-else>
</span> <a
<a @click="openLink(item.source_url)"
v-else class="ellipsis"
@click="openLink(paragraph.source_url)" :title="item?.document_name?.trim()"
class="ellipsis" >
:title="paragraph?.document_name?.trim()" <span :title="item?.document_name?.trim()">{{ item?.document_name }}</span>
> </a>
<span :title="paragraph?.document_name?.trim()"> </div>
{{ paragraph?.document_name }} </div>
</span> </div>
</a> </el-card>
</div> </el-col>
</el-space> </template>
</el-row>
</div> </div>
<div class="border-t color-secondary flex-between mt-12" style="padding-top: 12px"> <div class="border-t color-secondary flex-between mt-12" style="padding-top: 12px">
@ -59,7 +60,7 @@ import { computed, ref } from 'vue'
import ParagraphSourceDialog from './ParagraphSourceDialog.vue' import ParagraphSourceDialog from './ParagraphSourceDialog.vue'
import ExecutionDetailDialog from './ExecutionDetailDialog.vue' import ExecutionDetailDialog from './ExecutionDetailDialog.vue'
import { isWorkFlow } from '@/utils/application' import { isWorkFlow } from '@/utils/application'
import { getImgUrl } from '@/utils/utils'
const props = defineProps({ const props = defineProps({
data: { data: {
type: Object, type: Object,
@ -70,15 +71,6 @@ const props = defineProps({
default: '' default: ''
} }
}) })
const iconMap: { [key: string]: string } = {
doc: '../../assets/doc-icon.svg',
docx: '../../assets/docx-icon.svg',
pdf: '../../assets/pdf-icon.svg',
md: '../../assets/md-icon.svg',
txt: '../../assets/txt-icon.svg',
xls: '../../assets/xls-icon.svg',
xlsx: '../../assets/xlsx-icon.svg'
}
const ParagraphSourceDialogRef = ref() const ParagraphSourceDialogRef = ref()
const ExecutionDetailDialogRef = ref() const ExecutionDetailDialogRef = ref()
@ -107,14 +99,6 @@ const uniqueParagraphList = computed(() => {
) )
}) })
function getIconPath(documentName: string) {
const extension = documentName.split('.').pop()?.toLowerCase()
if (!documentName || !extension) return new URL(`${iconMap['doc']}`, import.meta.url).href
if (iconMap && extension && iconMap[extension]) {
return new URL(`${iconMap[extension]}`, import.meta.url).href
}
return new URL(`${iconMap['doc']}`, import.meta.url).href
}
function openLink(url: string) { function openLink(url: string) {
// url// // url//
if (url && !url.endsWith('/')) { if (url && !url.endsWith('/')) {

View File

@ -18,45 +18,8 @@
<template #footer> <template #footer>
<div class="footer-content flex-between"> <div class="footer-content flex-between">
<el-text class="flex align-center" style="width: 70%"> <el-text class="flex align-center" style="width: 70%">
<el-icon class="mr-4" :size="25"> <img :src="getImgUrl(data?.document_name?.trim())" alt="" width="20" class="mr-4" />
<img
src="@/assets/doc-icon.svg"
style="width: 90%"
alt=""
v-if="data?.document_name?.includes('doc')"
/>
<img
src="@/assets/docx-icon.svg"
style="width: 90%"
alt=""
v-else-if="data?.document_name?.includes('docx')"
/>
<img
src="@/assets/pdf-icon.svg"
style="width: 90%"
alt=""
v-else-if="data?.document_name?.includes('pdf')"
/>
<img
src="@/assets/md-icon.svg"
style="width: 90%"
alt=""
v-else-if="data?.document_name?.includes('md')"
/>
<img
src="@/assets/xls-icon.svg"
style="width: 90%"
alt=""
v-else-if="data?.document_name?.includes('xls')"
/>
<img
src="@/assets/txt-icon.svg"
style="width: 90%"
alt=""
v-else-if="data?.document_name?.includes('txt')"
/>
<img src="@/assets/doc-icon.svg" style="width: 90%" alt="" v-else />
</el-icon>
<span class="ellipsis" :title="data?.document_name?.trim()"> <span class="ellipsis" :title="data?.document_name?.trim()">
{{ data?.document_name.trim() }}</span {{ data?.document_name.trim() }}</span
> >
@ -73,6 +36,7 @@
</CardBox> </CardBox>
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { getImgUrl } from '@/utils/utils'
const props = defineProps({ const props = defineProps({
data: { data: {
type: Object, type: Object,
@ -83,23 +47,6 @@ const props = defineProps({
default: 0 default: 0
} }
}) })
const iconMap: { [key: string]: string } = {
doc: '../../assets/doc-icon.svg',
docx: '../../assets/docx-icon.svg',
pdf: '../../assets/pdf-icon.svg',
md: '../../assets/md-icon.svg',
txt: '../../assets/txt-icon.svg',
xls: '../../assets/xls-icon.svg',
xlsx: '../../assets/xlsx-icon.svg'
}
function getIconPath(documentName: string) {
const extension = documentName.split('.').pop()?.toLowerCase()
if (!documentName || !extension) return new URL(`${iconMap['doc']}`, import.meta.url).href
if (iconMap && extension && iconMap[extension]) {
return new URL(`${iconMap[extension]}`, import.meta.url).href
}
return new URL(`${iconMap['doc']}`, import.meta.url).href
}
</script> </script>
<style lang="scss" scoped> <style lang="scss" scoped>
.paragraph-source-card-height { .paragraph-source-card-height {

View File

@ -89,7 +89,7 @@
props.applicationDetails.file_upload_setting.maxFiles props.applicationDetails.file_upload_setting.maxFiles
}}每个文件限制 }}每个文件限制
{{ props.applicationDetails.file_upload_setting.fileLimit }}MB<br />文件类型{{ {{ props.applicationDetails.file_upload_setting.fileLimit }}MB<br />文件类型{{
getAcceptList() getAcceptList().replace(/\./g, '').replace(/,/g, '、').toUpperCase()
}}</template }}</template
> >
<el-button text> <el-button text>

View File

@ -1307,5 +1307,26 @@ export const iconMap: any = {
) )
]) ])
} }
},
'app-close': {
iconReader: () => {
return h('i', [
h(
'svg',
{
style: { height: '100%', width: '100%' },
viewBox: '0 0 16 16',
version: '1.1',
xmlns: 'http://www.w3.org/2000/svg'
},
[
h('path', {
d: 'M7.96141 6.98572L12.4398 2.50738C12.5699 2.3772 12.781 2.3772 12.9112 2.50738L13.3826 2.97878C13.5127 3.10895 13.5127 3.32001 13.3826 3.45018L8.90422 7.92853L13.3826 12.4069C13.5127 12.537 13.5127 12.7481 13.3826 12.8783L12.9112 13.3497C12.781 13.4799 12.5699 13.4799 12.4398 13.3497L7.96141 8.87134L3.48307 13.3497C3.35289 13.4799 3.14184 13.4799 3.01166 13.3497L2.54026 12.8783C2.41008 12.7481 2.41008 12.537 2.54026 12.4069L7.0186 7.92853L2.54026 3.45018C2.41008 3.32001 2.41008 3.10895 2.54026 2.97878L3.01166 2.50738C3.14184 2.3772 3.35289 2.3772 3.48307 2.50738L7.96141 6.98572Z',
fill: 'currentColor'
})
]
)
])
}
} }
} }

68
ui/src/utils/status.ts Normal file
View File

@ -0,0 +1,68 @@
import { type Dict } from '@/api/type/common'
interface TaskTypeInterface {
// 向量化
EMBEDDING: number
// 生成问题
GENERATE_PROBLEM: number
// 同步
SYNC: number
}
interface StateInterface {
// 等待
PENDING: '0'
// 执行中
STARTED: '1'
// 成功
SUCCESS: '2'
// 失败
FAILURE: '3'
// 取消任务
REVOKE: '4'
// 取消成功
REVOKED: '5'
IGNORED: 'n'
}
const TaskType: TaskTypeInterface = {
EMBEDDING: 1,
GENERATE_PROBLEM: 2,
SYNC: 3
}
const State: StateInterface = {
// 等待
PENDING: '0',
// 执行中
STARTED: '1',
// 成功
SUCCESS: '2',
// 失败
FAILURE: '3',
// 取消任务
REVOKE: '4',
// 取消成功
REVOKED: '5',
IGNORED: 'n'
}
class Status {
task_status: Dict<any>
constructor(status?: string) {
if (!status) {
status = ''
}
status = status.split('').reverse().join('')
this.task_status = {}
for (let key in TaskType) {
const value = TaskType[key as keyof TaskTypeInterface]
const index = value - 1
this.task_status[value] = status[index] ? status[index] : 'n'
}
}
toString() {
const r = []
for (let key in TaskType) {
const value = TaskType[key as keyof TaskTypeInterface]
r.push(this.task_status[value])
}
return r.reverse().join('')
}
}
export { Status, State, TaskType, type TaskTypeInterface, type StateInterface }

View File

@ -0,0 +1,167 @@
<template>
<el-popover placement="top" :width="450" trigger="hover">
<template #default>
<el-row :gutter="3" v-for="status in statusTable" :key="status.type">
<el-col :span="4">{{ taskTypeMap[status.type] }} </el-col>
<el-col :span="4">
<el-text v-if="status.state === State.SUCCESS || status.state === State.REVOKED">
<el-icon class="success"><SuccessFilled /></el-icon>
{{ stateMap[status.state](status.type) }}
</el-text>
<el-text v-else-if="status.state === State.FAILURE">
<el-icon class="danger"><CircleCloseFilled /></el-icon>
{{ stateMap[status.state](status.type) }}
</el-text>
<el-text v-else-if="status.state === State.STARTED">
<el-icon class="is-loading primary"><Loading /></el-icon>
{{ stateMap[status.state](status.type) }}
</el-text>
<el-text v-else-if="status.state === State.PENDING">
<el-icon class="is-loading primary"><Loading /></el-icon>
{{ stateMap[status.state](status.type) }}
</el-text>
<el-text v-else-if="aggStatus?.value === State.REVOKE">
<el-icon class="is-loading primary"><Loading /></el-icon>
{{ stateMap[aggStatus.value](aggStatus.key) }}
</el-text>
</el-col>
<el-col :span="5">
完成
{{
Object.keys(status.aggs ? status.aggs : {})
.filter((k) => k == State.SUCCESS)
.map((k) => status.aggs[k])
.reduce((x: any, y: any) => x + y, 0)
}}/{{
Object.values(status.aggs ? status.aggs : {}).reduce((x: any, y: any) => x + y, 0)
}}
</el-col>
<el-col :span="9">
{{
status.time
? status.time[
status.state == State.REVOKED ? State.REVOKED : State.PENDING
]?.substring(0, 19)
: undefined
}}
</el-col>
</el-row>
</template>
<template #reference>
<el-text v-if="aggStatus?.value === State.SUCCESS || aggStatus?.value === State.REVOKED">
<el-icon class="success"><SuccessFilled /></el-icon>
{{ stateMap[aggStatus.value](aggStatus.key) }}
</el-text>
<el-text v-else-if="aggStatus?.value === State.FAILURE">
<el-icon class="danger"><CircleCloseFilled /></el-icon>
{{ stateMap[aggStatus.value](aggStatus.key) }}
</el-text>
<el-text v-else-if="aggStatus?.value === State.STARTED">
<el-icon class="is-loading primary"><Loading /></el-icon>
{{ stateMap[aggStatus.value](aggStatus.key) }}
</el-text>
<el-text v-else-if="aggStatus?.value === State.PENDING">
<el-icon class="is-loading primary"><Loading /></el-icon>
{{ stateMap[aggStatus.value](aggStatus.key) }}
</el-text>
<el-text v-else-if="aggStatus?.value === State.REVOKE">
<el-icon class="is-loading primary"><Loading /></el-icon>
{{ stateMap[aggStatus.value](aggStatus.key) }}
</el-text>
</template>
</el-popover>
</template>
<script setup lang="ts">
import { computed } from 'vue'
import { Status, TaskType, State, type TaskTypeInterface } from '@/utils/status'
import { mergeWith } from 'lodash'
const props = defineProps<{ status: string; statusMeta: any }>()
const checkList: Array<string> = [
State.REVOKE,
State.STARTED,
State.PENDING,
State.REVOKED,
State.FAILURE,
State.SUCCESS
]
const aggStatus = computed(() => {
for (const i in checkList) {
const state = checkList[i]
const index = props.status.indexOf(state)
if (index > -1) {
return { key: props.status.length - index, value: state }
}
}
})
const startedMap = {
[TaskType.EMBEDDING]: '索引中',
[TaskType.GENERATE_PROBLEM]: '生成中',
[TaskType.SYNC]: '同步中'
}
const taskTypeMap = {
[TaskType.EMBEDDING]: '向量化',
[TaskType.GENERATE_PROBLEM]: '生成问题',
[TaskType.SYNC]: '同步'
}
const stateMap: any = {
[State.PENDING]: (type: number) => '排队中',
[State.STARTED]: (type: number) => startedMap[type],
[State.REVOKE]: (type: number) => '取消中',
[State.REVOKED]: (type: number) => '成功',
[State.FAILURE]: (type: number) => '失败',
[State.SUCCESS]: (type: number) => '成功'
}
const parseAgg = (agg: { count: number; status: string }) => {
const status = new Status(agg.status)
return Object.keys(TaskType)
.map((key) => {
const value = TaskType[key as keyof TaskTypeInterface]
return { [value]: { [status.task_status[value]]: agg.count } }
})
.reduce((x, y) => ({ ...x, ...y }), {})
}
const customizer: (x: any, y: any) => any = (objValue: any, srcValue: any) => {
if (objValue == undefined && srcValue) {
return srcValue
}
if (srcValue == undefined && objValue) {
return objValue
}
//
if (typeof objValue === 'object' && typeof srcValue === 'object') {
// object
return mergeWith(objValue, srcValue, customizer)
} else {
//
return objValue + srcValue
}
}
const aggs = computed(() => {
return (props.statusMeta.aggs ? props.statusMeta.aggs : [])
.map((agg: any) => {
return parseAgg(agg)
})
.reduce((x: any, y: any) => {
return mergeWith(x, y, customizer)
}, {})
})
const statusTable = computed(() => {
return Object.keys(TaskType)
.map((key) => {
const value = TaskType[key as keyof TaskTypeInterface]
const parseStatus = new Status(props.status)
return {
type: value,
state: parseStatus.task_status[value],
aggs: aggs.value[value],
time: props.statusMeta.state_time[value]
}
})
.filter((item) => item.state !== State.IGNORED)
})
</script>
<style lang="scss" scoped></style>

View File

@ -134,21 +134,7 @@
</div> </div>
</template> </template>
<template #default="{ row }"> <template #default="{ row }">
<el-text v-if="row.status === '1'"> <StatusVlue :status="row.status" :status-meta="row.status_meta"></StatusVlue>
<el-icon class="success"><SuccessFilled /></el-icon>
</el-text>
<el-text v-else-if="row.status === '2'">
<el-icon class="danger"><CircleCloseFilled /></el-icon>
</el-text>
<el-text v-else-if="row.status === '0'">
<el-icon class="is-loading primary"><Loading /></el-icon>
</el-text>
<el-text v-else-if="row.status === '3'">
<el-icon class="is-loading primary"><Loading /></el-icon>
</el-text>
<el-text v-else-if="row.status === '4'">
<el-icon class="is-loading primary"><Loading /></el-icon>
</el-text>
</template> </template>
</el-table-column> </el-table-column>
<el-table-column width="130"> <el-table-column width="130">
@ -249,7 +235,7 @@
<template #default="{ row }"> <template #default="{ row }">
<div v-if="datasetDetail.type === '0'"> <div v-if="datasetDetail.type === '0'">
<span class="mr-4"> <span class="mr-4">
<el-tooltip effect="dark" content="重新向量化" placement="top"> <el-tooltip effect="dark" content="向量化" placement="top">
<el-button type="primary" text @click.stop="refreshDocument(row)"> <el-button type="primary" text @click.stop="refreshDocument(row)">
<AppIcon iconName="app-document-refresh" style="font-size: 16px"></AppIcon> <AppIcon iconName="app-document-refresh" style="font-size: 16px"></AppIcon>
</el-button> </el-button>
@ -298,7 +284,22 @@
</el-tooltip> </el-tooltip>
</span> </span>
<span class="mr-4"> <span class="mr-4">
<el-tooltip effect="dark" content="重新向量化" placement="top"> <el-tooltip
effect="dark"
v-if="getTaskState(row.status, TaskType.EMBEDDING) == State.STARTED"
content="取消向量化"
placement="top"
>
<el-button
type="primary"
text
@click.stop="cancelTask(row, TaskType.EMBEDDING)"
>
<AppIcon iconName="app-close" style="font-size: 16px"></AppIcon>
</el-button>
</el-tooltip>
<el-tooltip effect="dark" v-else content="向量化" placement="top">
<el-button type="primary" text @click.stop="refreshDocument(row)"> <el-button type="primary" text @click.stop="refreshDocument(row)">
<AppIcon iconName="app-document-refresh" style="font-size: 16px"></AppIcon> <AppIcon iconName="app-document-refresh" style="font-size: 16px"></AppIcon>
</el-button> </el-button>
@ -315,9 +316,18 @@
<el-dropdown-item icon="Setting" @click="settingDoc(row)" <el-dropdown-item icon="Setting" @click="settingDoc(row)"
>设置</el-dropdown-item >设置</el-dropdown-item
> >
<el-dropdown-item @click="openGenerateDialog(row)"> <el-dropdown-item
v-if="
getTaskState(row.status, TaskType.GENERATE_PROBLEM) == State.STARTED
"
@click="cancelTask(row, TaskType.GENERATE_PROBLEM)"
>
<el-icon><Connection /></el-icon> <el-icon><Connection /></el-icon>
生成关联问题 取消生成问题
</el-dropdown-item>
<el-dropdown-item v-else @click="openGenerateDialog(row)">
<el-icon><Connection /></el-icon>
生成问题
</el-dropdown-item> </el-dropdown-item>
<el-dropdown-item @click="openDatasetDialog(row)"> <el-dropdown-item @click="openDatasetDialog(row)">
<AppIcon iconName="app-migrate"></AppIcon> <AppIcon iconName="app-migrate"></AppIcon>
@ -360,7 +370,9 @@ import { datetimeFormat } from '@/utils/time'
import { hitHandlingMethod } from '@/enums/document' import { hitHandlingMethod } from '@/enums/document'
import { MsgSuccess, MsgConfirm, MsgError } from '@/utils/message' import { MsgSuccess, MsgConfirm, MsgError } from '@/utils/message'
import useStore from '@/stores' import useStore from '@/stores'
import StatusVlue from '@/views/document/component/Status.vue'
import GenerateRelatedDialog from '@/views/document/component/GenerateRelatedDialog.vue' import GenerateRelatedDialog from '@/views/document/component/GenerateRelatedDialog.vue'
import { TaskType, State } from '@/utils/status'
const router = useRouter() const router = useRouter()
const route = useRoute() const route = useRoute()
const { const {
@ -368,9 +380,11 @@ const {
} = route as any } = route as any
const { common, dataset, document } = useStore() const { common, dataset, document } = useStore()
const storeKey = 'documents' const storeKey = 'documents'
const getTaskState = (status: string, taskType: number) => {
const statusList = status.split('').reverse()
return taskType - 1 > statusList.length + 1 ? 'n' : statusList[taskType - 1]
}
onBeforeRouteUpdate(() => { onBeforeRouteUpdate(() => {
common.savePage(storeKey, null) common.savePage(storeKey, null)
common.saveCondition(storeKey, null) common.saveCondition(storeKey, null)
@ -441,7 +455,11 @@ function beforeCommand(attr: string, val: any) {
command: val command: val
} }
} }
const cancelTask = (row: any, task_type: number) => {
documentApi.cancelTask(row.dataset_id, row.id, { type: task_type }).then(() => {
MsgSuccess('发送成功')
})
}
function syncDataset() { function syncDataset() {
SyncWebDialogRef.value.open(id) SyncWebDialogRef.value.open(id)
} }