feat: Document vectorization supports processing based on status (#1984)
This commit is contained in:
parent
9a310bfb98
commit
54381ffaf3
@ -6,26 +6,22 @@
|
|||||||
@date:2023/10/20 14:01
|
@date:2023/10/20 14:01
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
import datetime
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
import traceback
|
import traceback
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import django.db.models
|
import django.db.models
|
||||||
from django.db import models, transaction
|
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
from django.db.models.functions import Substr, Reverse
|
from django.db.models.functions import Substr, Reverse
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
from common.config.embedding_config import VectorStore
|
from common.config.embedding_config import VectorStore
|
||||||
from common.db.search import native_search, get_dynamics_model, native_update
|
from common.db.search import native_search, get_dynamics_model, native_update
|
||||||
from common.db.sql_execute import sql_execute, update_execute
|
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from common.util.lock import try_lock, un_lock
|
from common.util.lock import try_lock, un_lock
|
||||||
from common.util.page_utils import page
|
from common.util.page_utils import page_desc
|
||||||
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State
|
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State
|
||||||
from embedding.models import SourceType, SearchMode
|
from embedding.models import SourceType, SearchMode
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
@ -241,13 +237,16 @@ class ListenerManagement:
|
|||||||
lock.release()
|
lock.release()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def embedding_by_document(document_id, embedding_model: Embeddings):
|
def embedding_by_document(document_id, embedding_model: Embeddings, state_list=None):
|
||||||
"""
|
"""
|
||||||
向量化文档
|
向量化文档
|
||||||
|
@param state_list:
|
||||||
@param document_id: 文档id
|
@param document_id: 文档id
|
||||||
@param embedding_model 向量模型
|
@param embedding_model 向量模型
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
|
if state_list is None:
|
||||||
|
state_list = [State.PENDING, State.SUCCESS, State.FAILURE, State.REVOKE, State.REVOKED]
|
||||||
if not try_lock('embedding' + str(document_id)):
|
if not try_lock('embedding' + str(document_id)):
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
@ -268,7 +267,13 @@ class ListenerManagement:
|
|||||||
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
|
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
|
||||||
|
|
||||||
# 根据段落进行向量化处理
|
# 根据段落进行向量化处理
|
||||||
page(QuerySet(Paragraph).filter(document_id=document_id).values('id'), 5,
|
page_desc(QuerySet(Paragraph)
|
||||||
|
.annotate(
|
||||||
|
reversed_status=Reverse('status'),
|
||||||
|
task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value,
|
||||||
|
1),
|
||||||
|
).filter(task_type_status__in=state_list, document_id=document_id)
|
||||||
|
.values('id'), 5,
|
||||||
ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted,
|
ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted,
|
||||||
ListenerManagement.get_aggregation_document_status(
|
ListenerManagement.get_aggregation_document_status(
|
||||||
document_id)),
|
document_id)),
|
||||||
|
|||||||
@ -26,3 +26,22 @@ def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
|
|||||||
offset = i * page_size
|
offset = i * page_size
|
||||||
paragraph_list = query.all()[offset: offset + page_size]
|
paragraph_list = query.all()[offset: offset + page_size]
|
||||||
handler(paragraph_list)
|
handler(paragraph_list)
|
||||||
|
|
||||||
|
|
||||||
|
def page_desc(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
|
||||||
|
"""
|
||||||
|
|
||||||
|
@param query_set: 查询query_set
|
||||||
|
@param page_size: 每次查询大小
|
||||||
|
@param handler: 数据处理器
|
||||||
|
@param is_the_task_interrupted: 任务是否被中断
|
||||||
|
@return:
|
||||||
|
"""
|
||||||
|
query = query_set.order_by("id")
|
||||||
|
count = query_set.count()
|
||||||
|
for i in sorted(range(0, ceil(count / page_size)), reverse=True):
|
||||||
|
if is_the_task_interrupted():
|
||||||
|
return
|
||||||
|
offset = i * page_size
|
||||||
|
paragraph_list = query.all()[offset: offset + page_size]
|
||||||
|
handler(paragraph_list)
|
||||||
|
|||||||
@ -700,20 +700,24 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
_document.save()
|
_document.save()
|
||||||
return self.one()
|
return self.one()
|
||||||
|
|
||||||
@transaction.atomic
|
def refresh(self, state_list, with_valid=True):
|
||||||
def refresh(self, with_valid=True):
|
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
document_id = self.data.get("document_id")
|
document_id = self.data.get("document_id")
|
||||||
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
|
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
|
||||||
State.PENDING)
|
State.PENDING)
|
||||||
ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id),
|
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
|
||||||
|
reversed_status=Reverse('status'),
|
||||||
|
task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value,
|
||||||
|
1),
|
||||||
|
).filter(task_type_status__in=state_list, document_id=document_id)
|
||||||
|
.values('id'),
|
||||||
TaskType.EMBEDDING,
|
TaskType.EMBEDDING,
|
||||||
State.PENDING)
|
State.PENDING)
|
||||||
ListenerManagement.get_aggregation_document_status(document_id)()
|
ListenerManagement.get_aggregation_document_status(document_id)()
|
||||||
embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id'))
|
embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id'))
|
||||||
try:
|
try:
|
||||||
embedding_by_document.delay(document_id, embedding_model_id)
|
embedding_by_document.delay(document_id, embedding_model_id, state_list)
|
||||||
except AlreadyQueued as e:
|
except AlreadyQueued as e:
|
||||||
raise AppApiException(500, "任务正在执行中,请勿重复下发")
|
raise AppApiException(500, "任务正在执行中,请勿重复下发")
|
||||||
|
|
||||||
@ -1122,12 +1126,12 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
document_id_list = instance.get("id_list")
|
document_id_list = instance.get("id_list")
|
||||||
with transaction.atomic():
|
state_list = instance.get("state_list")
|
||||||
dataset_id = self.data.get('dataset_id')
|
dataset_id = self.data.get('dataset_id')
|
||||||
for document_id in document_id_list:
|
for document_id in document_id_list:
|
||||||
try:
|
try:
|
||||||
DocumentSerializers.Operate(
|
DocumentSerializers.Operate(
|
||||||
data={'dataset_id': dataset_id, 'document_id': document_id}).refresh()
|
data={'dataset_id': dataset_id, 'document_id': document_id}).refresh(state_list)
|
||||||
except AlreadyQueued as e:
|
except AlreadyQueued as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -51,3 +51,16 @@ class DocumentApi(ApiMixin):
|
|||||||
description="1|2|3 1:向量化|2:生成问题|3:同步文档", default=1)
|
description="1|2|3 1:向量化|2:生成问题|3:同步文档", default=1)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class EmbeddingState(ApiMixin):
|
||||||
|
@staticmethod
|
||||||
|
def get_request_body_api():
|
||||||
|
return openapi.Schema(
|
||||||
|
type=openapi.TYPE_OBJECT,
|
||||||
|
properties={
|
||||||
|
'state_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||||
|
items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||||
|
title="状态列表",
|
||||||
|
description="状态列表")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|||||||
@ -262,6 +262,7 @@ class Document(APIView):
|
|||||||
@action(methods=['PUT'], detail=False)
|
@action(methods=['PUT'], detail=False)
|
||||||
@swagger_auto_schema(operation_summary="刷新文档向量库",
|
@swagger_auto_schema(operation_summary="刷新文档向量库",
|
||||||
operation_id="刷新文档向量库",
|
operation_id="刷新文档向量库",
|
||||||
|
request_body=DocumentApi.EmbeddingState.get_request_body_api(),
|
||||||
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
|
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
|
||||||
responses=result.get_default_response(),
|
responses=result.get_default_response(),
|
||||||
tags=["知识库/文档"]
|
tags=["知识库/文档"]
|
||||||
@ -272,6 +273,7 @@ class Document(APIView):
|
|||||||
def put(self, request: Request, dataset_id: str, document_id: str):
|
def put(self, request: Request, dataset_id: str, document_id: str):
|
||||||
return result.success(
|
return result.success(
|
||||||
DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).refresh(
|
DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).refresh(
|
||||||
|
request.data.get('state_list')
|
||||||
))
|
))
|
||||||
|
|
||||||
class BatchRefresh(APIView):
|
class BatchRefresh(APIView):
|
||||||
|
|||||||
@ -56,14 +56,20 @@ def embedding_by_paragraph_list(paragraph_id_list, model_id):
|
|||||||
|
|
||||||
|
|
||||||
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:embedding_by_document')
|
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:embedding_by_document')
|
||||||
def embedding_by_document(document_id, model_id):
|
def embedding_by_document(document_id, model_id, state_list=None):
|
||||||
"""
|
"""
|
||||||
向量化文档
|
向量化文档
|
||||||
|
@param state_list:
|
||||||
@param document_id: 文档id
|
@param document_id: 文档id
|
||||||
@param model_id 向量模型
|
@param model_id 向量模型
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if state_list is None:
|
||||||
|
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
|
||||||
|
State.REVOKE.value,
|
||||||
|
State.REVOKED.value, State.IGNORED.value]
|
||||||
|
|
||||||
def exception_handler(e):
|
def exception_handler(e):
|
||||||
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
|
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
|
||||||
State.FAILURE)
|
State.FAILURE)
|
||||||
@ -71,7 +77,7 @@ def embedding_by_document(document_id, model_id):
|
|||||||
f'获取向量模型失败:{str(e)}{traceback.format_exc()}')
|
f'获取向量模型失败:{str(e)}{traceback.format_exc()}')
|
||||||
|
|
||||||
embedding_model = get_embedding_model(model_id, exception_handler)
|
embedding_model = get_embedding_model(model_id, exception_handler)
|
||||||
ListenerManagement.embedding_by_document(document_id, embedding_model)
|
ListenerManagement.embedding_by_document(document_id, embedding_model, state_list)
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name='celery:embedding_by_document_list')
|
@celery_app.task(name='celery:embedding_by_document_list')
|
||||||
|
|||||||
@ -129,11 +129,12 @@ const delMulDocument: (
|
|||||||
const batchRefresh: (
|
const batchRefresh: (
|
||||||
dataset_id: string,
|
dataset_id: string,
|
||||||
data: any,
|
data: any,
|
||||||
|
stateList: Array<string>,
|
||||||
loading?: Ref<boolean>
|
loading?: Ref<boolean>
|
||||||
) => Promise<Result<boolean>> = (dataset_id, data, loading) => {
|
) => Promise<Result<boolean>> = (dataset_id, data, stateList, loading) => {
|
||||||
return put(
|
return put(
|
||||||
`${prefix}/${dataset_id}/document/batch_refresh`,
|
`${prefix}/${dataset_id}/document/batch_refresh`,
|
||||||
{ id_list: data },
|
{ id_list: data, state_list: stateList },
|
||||||
undefined,
|
undefined,
|
||||||
loading
|
loading
|
||||||
)
|
)
|
||||||
@ -157,11 +158,12 @@ const getDocumentDetail: (dataset_id: string, document_id: string) => Promise<Re
|
|||||||
const putDocumentRefresh: (
|
const putDocumentRefresh: (
|
||||||
dataset_id: string,
|
dataset_id: string,
|
||||||
document_id: string,
|
document_id: string,
|
||||||
|
state_list: Array<string>,
|
||||||
loading?: Ref<boolean>
|
loading?: Ref<boolean>
|
||||||
) => Promise<Result<any>> = (dataset_id, document_id, loading) => {
|
) => Promise<Result<any>> = (dataset_id, document_id, state_list, loading) => {
|
||||||
return put(
|
return put(
|
||||||
`${prefix}/${dataset_id}/document/${document_id}/refresh`,
|
`${prefix}/${dataset_id}/document/${document_id}/refresh`,
|
||||||
undefined,
|
{ state_list },
|
||||||
undefined,
|
undefined,
|
||||||
loading
|
loading
|
||||||
)
|
)
|
||||||
|
|||||||
41
ui/src/views/document/component/EmbeddingContentDialog.vue
Normal file
41
ui/src/views/document/component/EmbeddingContentDialog.vue
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
<template>
|
||||||
|
<el-dialog v-model="dialogVisible" title="选择向量化内容" width="500" :before-close="close">
|
||||||
|
<el-radio-group v-model="state">
|
||||||
|
<el-radio value="error" size="large">向量化未成功的分段</el-radio>
|
||||||
|
<el-radio value="all" size="large">全部分段</el-radio>
|
||||||
|
</el-radio-group>
|
||||||
|
<template #footer>
|
||||||
|
<div class="dialog-footer">
|
||||||
|
<el-button @click="close">取消</el-button>
|
||||||
|
<el-button type="primary" @click="submit"> 提交 </el-button>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
</el-dialog>
|
||||||
|
</template>
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { ref } from 'vue'
|
||||||
|
const dialogVisible = ref<boolean>(false)
|
||||||
|
const state = ref<'all' | 'error'>('error')
|
||||||
|
const stateMap = {
|
||||||
|
all: ['0', '1', '2', '3', '4', '5', 'n'],
|
||||||
|
error: ['0', '1', '3', '4', '5', 'n']
|
||||||
|
}
|
||||||
|
const submit_handle = ref<(stateList: Array<string>) => void>()
|
||||||
|
const submit = () => {
|
||||||
|
if (submit_handle.value) {
|
||||||
|
submit_handle.value(stateMap[state.value])
|
||||||
|
}
|
||||||
|
close()
|
||||||
|
}
|
||||||
|
|
||||||
|
const open = (handle: (stateList: Array<string>) => void) => {
|
||||||
|
submit_handle.value = handle
|
||||||
|
dialogVisible.value = true
|
||||||
|
}
|
||||||
|
const close = () => {
|
||||||
|
submit_handle.value = undefined
|
||||||
|
dialogVisible.value = false
|
||||||
|
}
|
||||||
|
defineExpose({ open, close })
|
||||||
|
</script>
|
||||||
|
<style lang="scss" scoped></style>
|
||||||
@ -422,6 +422,7 @@
|
|||||||
</el-text>
|
</el-text>
|
||||||
<el-button class="ml-16" type="primary" link @click="clearSelection"> 清空 </el-button>
|
<el-button class="ml-16" type="primary" link @click="clearSelection"> 清空 </el-button>
|
||||||
</div>
|
</div>
|
||||||
|
<EmbeddingContentDialog ref="embeddingContentDialogRef"></EmbeddingContentDialog>
|
||||||
</LayoutContainer>
|
</LayoutContainer>
|
||||||
</template>
|
</template>
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
@ -439,6 +440,7 @@ import { MsgSuccess, MsgConfirm, MsgError } from '@/utils/message'
|
|||||||
import useStore from '@/stores'
|
import useStore from '@/stores'
|
||||||
import StatusVlue from '@/views/document/component/Status.vue'
|
import StatusVlue from '@/views/document/component/Status.vue'
|
||||||
import GenerateRelatedDialog from '@/components/generate-related-dialog/index.vue'
|
import GenerateRelatedDialog from '@/components/generate-related-dialog/index.vue'
|
||||||
|
import EmbeddingContentDialog from '@/views/document/component/EmbeddingContentDialog.vue'
|
||||||
import { TaskType, State } from '@/utils/status'
|
import { TaskType, State } from '@/utils/status'
|
||||||
const router = useRouter()
|
const router = useRouter()
|
||||||
const route = useRoute()
|
const route = useRoute()
|
||||||
@ -469,7 +471,7 @@ onBeforeRouteLeave((to: any) => {
|
|||||||
})
|
})
|
||||||
const beforePagination = computed(() => common.paginationConfig[storeKey])
|
const beforePagination = computed(() => common.paginationConfig[storeKey])
|
||||||
const beforeSearch = computed(() => common.search[storeKey])
|
const beforeSearch = computed(() => common.search[storeKey])
|
||||||
|
const embeddingContentDialogRef = ref<InstanceType<typeof EmbeddingContentDialog>>()
|
||||||
const SyncWebDialogRef = ref()
|
const SyncWebDialogRef = ref()
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
let interval: any
|
let interval: any
|
||||||
@ -621,11 +623,15 @@ function syncDocument(row: any) {
|
|||||||
.catch(() => {})
|
.catch(() => {})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function refreshDocument(row: any) {
|
function refreshDocument(row: any) {
|
||||||
documentApi.putDocumentRefresh(row.dataset_id, row.id).then(() => {
|
const embeddingDocument = (stateList: Array<string>) => {
|
||||||
|
return documentApi.putDocumentRefresh(row.dataset_id, row.id, stateList).then(() => {
|
||||||
getList()
|
getList()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
embeddingContentDialogRef.value?.open(embeddingDocument)
|
||||||
|
}
|
||||||
|
|
||||||
function rowClickHandle(row: any, column: any) {
|
function rowClickHandle(row: any, column: any) {
|
||||||
if (column && column.type === 'selection') {
|
if (column && column.type === 'selection') {
|
||||||
@ -691,18 +697,15 @@ function deleteMulDocument() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function batchRefresh() {
|
function batchRefresh() {
|
||||||
const arr: string[] = []
|
const arr: string[] = multipleSelection.value.map((v) => v.id)
|
||||||
multipleSelection.value.map((v) => {
|
const embeddingBatchDocument = (stateList: Array<string>) => {
|
||||||
if (v) {
|
documentApi.batchRefresh(id, arr, stateList, loading).then(() => {
|
||||||
arr.push(v.id)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
documentApi.batchRefresh(id, arr, loading).then(() => {
|
|
||||||
MsgSuccess('批量向量化成功')
|
MsgSuccess('批量向量化成功')
|
||||||
multipleTableRef.value?.clearSelection()
|
multipleTableRef.value?.clearSelection()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
embeddingContentDialogRef.value?.open(embeddingBatchDocument)
|
||||||
|
}
|
||||||
|
|
||||||
function deleteDocument(row: any) {
|
function deleteDocument(row: any) {
|
||||||
MsgConfirm(
|
MsgConfirm(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user