fix: 【知识库】知识库上传 有关联问题的会阻塞 (#676)

This commit is contained in:
shaohuzhang1 2024-07-01 19:39:07 +08:00 committed by GitHub
parent 60181d6f83
commit fc6da6a484
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 89 additions and 71 deletions

View File

@ -7,6 +7,7 @@
@desc: @desc:
""" """
import os import os
import uuid
from typing import List from typing import List
from django.db.models import QuerySet from django.db.models import QuerySet
@ -20,7 +21,7 @@ from common.mixins.api_mixin import ApiMixin
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from common.util.fork import Fork from common.util.fork import Fork
from dataset.models import Paragraph from dataset.models import Paragraph, Problem, ProblemParagraphMapping
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
@ -79,3 +80,53 @@ class BatchSerializer(ApiMixin, serializers.Serializer):
description="主键id列表") description="主键id列表")
} }
) )
class ProblemParagraphObject:
def __init__(self, dataset_id: str, document_id: str, paragraph_id: str, problem_content: str):
self.dataset_id = dataset_id
self.document_id = document_id
self.paragraph_id = paragraph_id
self.problem_content = problem_content
def or_get(exists_problem_list, content, dataset_id, document_id, paragraph_id, problem_content_dict):
if content in problem_content_dict:
return problem_content_dict.get(content)[0], document_id, paragraph_id
exists = [row for row in exists_problem_list if row.content == content]
if len(exists) > 0:
problem_content_dict[content] = exists[0], False
return exists[0], document_id, paragraph_id
else:
problem = Problem(id=uuid.uuid1(), content=content, dataset_id=dataset_id)
problem_content_dict[content] = problem, True
return problem, document_id, paragraph_id
class ProblemParagraphManage:
def __init__(self, problemParagraphObjectList: [ProblemParagraphObject], dataset_id):
self.dataset_id = dataset_id
self.problemParagraphObjectList = problemParagraphObjectList
def to_problem_model_list(self):
problem_list = [item.problem_content for item in self.problemParagraphObjectList]
exists_problem_list = []
if len(self.problemParagraphObjectList) > 0:
# 查询到已存在的问题列表
exists_problem_list = QuerySet(Problem).filter(dataset_id=self.dataset_id,
content__in=problem_list).all()
problem_content_dict = {}
problem_model_list = [
or_get(exists_problem_list, problemParagraphObject.problem_content, problemParagraphObject.dataset_id,
problemParagraphObject.document_id, problemParagraphObject.paragraph_id, problem_content_dict) for
problemParagraphObject in self.problemParagraphObjectList]
problem_paragraph_mapping_list = [
ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id,
paragraph_id=paragraph_id,
dataset_id=self.dataset_id) for
problem_model, document_id, paragraph_id in problem_model_list]
result = [problem_model for problem_model, is_create in problem_content_dict.values() if
is_create], problem_paragraph_mapping_list
return result

View File

@ -37,7 +37,7 @@ from common.util.file_util import get_file_content
from common.util.fork import ChildLink, Fork from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from embedding.models import SearchMode from embedding.models import SearchMode
from setting.models import AuthOperate from setting.models import AuthOperate
@ -383,8 +383,7 @@ class DataSetSerializers(serializers.ModelSerializer):
document_model_list = [] document_model_list = []
paragraph_model_list = [] paragraph_model_list = []
problem_model_list = [] problem_paragraph_object_list = []
problem_paragraph_mapping_list = []
# 插入文档 # 插入文档
for document in instance.get('documents') if 'documents' in instance else []: for document in instance.get('documents') if 'documents' in instance else []:
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id, document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
@ -392,12 +391,12 @@ class DataSetSerializers(serializers.ModelSerializer):
document_model_list.append(document_paragraph_dict_model.get('document')) document_model_list.append(document_paragraph_dict_model.get('document'))
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'): for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
paragraph_model_list.append(paragraph) paragraph_model_list.append(paragraph)
for problem in document_paragraph_dict_model.get('problem_model_list'): for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
problem_model_list.append(problem) problem_paragraph_object_list.append(problem_paragraph_object)
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
problem_paragraph_mapping_list.append(problem_paragraph_mapping) problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
problem_model_list, problem_paragraph_mapping_list = DocumentSerializers.Create.reset_problem_model( dataset_id)
problem_model_list, problem_paragraph_mapping_list) .to_problem_model_list())
# 插入知识库 # 插入知识库
dataset.save() dataset.save()
# 插入文档 # 插入文档

View File

@ -41,7 +41,7 @@ from common.util.file_util import get_file_content
from common.util.fork import Fork from common.util.fork import Fork
from common.util.split_model import get_split_model from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
@ -380,8 +380,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
document_paragraph_model = DocumentSerializers.Create.get_paragraph_model(document, paragraphs) document_paragraph_model = DocumentSerializers.Create.get_paragraph_model(document, paragraphs)
paragraph_model_list = document_paragraph_model.get('paragraph_model_list') paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
problem_model_list = document_paragraph_model.get('problem_model_list') problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list')
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list') problem_model_list, problem_paragraph_mapping_list = ProblemParagraphManage(
problem_paragraph_object_list, document.dataset_id).to_problem_model_list()
# 批量插入段落 # 批量插入段落
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题 # 批量插入问题
@ -626,11 +627,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id') dataset_id = self.data.get('dataset_id')
document_paragraph_model = self.get_document_paragraph_model(dataset_id, instance) document_paragraph_model = self.get_document_paragraph_model(dataset_id, instance)
document_model = document_paragraph_model.get('document') document_model = document_paragraph_model.get('document')
paragraph_model_list = document_paragraph_model.get('paragraph_model_list') paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
problem_model_list = document_paragraph_model.get('problem_model_list') problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list')
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list') problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
dataset_id)
.to_problem_model_list())
# 插入文档 # 插入文档
document_model.save() document_model.save()
# 批量插入段落 # 批量插入段落
@ -685,35 +688,15 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
dataset_id, document_model.id, paragraph) for paragraph in paragraph_list] dataset_id, document_model.id, paragraph) for paragraph in paragraph_list]
paragraph_model_list = [] paragraph_model_list = []
problem_model_list = [] problem_paragraph_object_list = []
problem_paragraph_mapping_list = []
for paragraphs in paragraph_model_dict_list: for paragraphs in paragraph_model_dict_list:
paragraph = paragraphs.get('paragraph') paragraph = paragraphs.get('paragraph')
for problem_model in paragraphs.get('problem_model_list'): for problem_model in paragraphs.get('problem_paragraph_object_list'):
problem_model_list.append(problem_model) problem_paragraph_object_list.append(problem_model)
for problem_paragraph_mapping in paragraphs.get('problem_paragraph_mapping_list'):
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
paragraph_model_list.append(paragraph) paragraph_model_list.append(paragraph)
problem_model_list, problem_paragraph_mapping_list = DocumentSerializers.Create.reset_problem_model(
problem_model_list, problem_paragraph_mapping_list)
return {'document': document_model, 'paragraph_model_list': paragraph_model_list, return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
'problem_model_list': problem_model_list, 'problem_paragraph_object_list': problem_paragraph_object_list}
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
@staticmethod
def reset_problem_model(problem_model_list, problem_paragraph_mapping_list):
new_problem_model_list = [x for i, x in enumerate(problem_model_list) if
len([item for item in problem_model_list[:i] if item.content == x.content]) <= 0]
for new_problem_model in new_problem_model_list:
old_model_list = [problem.id for problem in problem_model_list if
problem.content == new_problem_model.content]
for problem_paragraph_mapping in problem_paragraph_mapping_list:
if old_model_list.__contains__(problem_paragraph_mapping.problem_id):
problem_paragraph_mapping.problem_id = new_problem_model.id
return new_problem_model_list, problem_paragraph_mapping_list
@staticmethod @staticmethod
def get_document_paragraph_model(dataset_id, instance: Dict): def get_document_paragraph_model(dataset_id, instance: Dict):
@ -834,8 +817,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
dataset_id = self.data.get("dataset_id") dataset_id = self.data.get("dataset_id")
document_model_list = [] document_model_list = []
paragraph_model_list = [] paragraph_model_list = []
problem_model_list = [] problem_paragraph_object_list = []
problem_paragraph_mapping_list = []
# 插入文档 # 插入文档
for document in instance_list: for document in instance_list:
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id, document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
@ -843,11 +825,12 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
document_model_list.append(document_paragraph_dict_model.get('document')) document_model_list.append(document_paragraph_dict_model.get('document'))
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'): for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
paragraph_model_list.append(paragraph) paragraph_model_list.append(paragraph)
for problem in document_paragraph_dict_model.get('problem_model_list'): for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
problem_model_list.append(problem) problem_paragraph_object_list.append(problem_paragraph_object)
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
dataset_id)
.to_problem_model_list())
# 插入文档 # 插入文档
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
# 批量插入段落 # 批量插入段落

View File

@ -21,7 +21,8 @@ from common.mixins.api_mixin import ApiMixin
from common.util.common import post from common.util.common import post
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
ProblemParagraphManage
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
from embedding.models import SourceType from embedding.models import SourceType
@ -567,8 +568,10 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
document_id = self.data.get('document_id') document_id = self.data.get('document_id')
paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance) paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance)
paragraph = paragraph_problem_model.get('paragraph') paragraph = paragraph_problem_model.get('paragraph')
problem_model_list = paragraph_problem_model.get('problem_model_list') problem_paragraph_object_list = paragraph_problem_model.get('problem_paragraph_object_list')
problem_paragraph_mapping_list = paragraph_problem_model.get('problem_paragraph_mapping_list') problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
dataset_id).
to_problem_model_list())
# 插入段落 # 插入段落
paragraph_problem_model.get('paragraph').save() paragraph_problem_model.get('paragraph').save()
# 插入問題 # 插入問題
@ -591,30 +594,12 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
content=instance.get("content"), content=instance.get("content"),
dataset_id=dataset_id, dataset_id=dataset_id,
title=instance.get("title") if 'title' in instance else '') title=instance.get("title") if 'title' in instance else '')
problem_list = instance.get('problem_list') problem_paragraph_object_list = [
exists_problem_list = [] ProblemParagraphObject(dataset_id, document_id, paragraph.id, problem.get('content')) for problem in
if 'problem_list' in instance and len(problem_list) > 0: (instance.get('problem_list') if 'problem_list' in instance else [])]
exists_problem_list = QuerySet(Problem).filter(dataset_id=dataset_id,
content__in=[p.get('content') for p in
problem_list]).all()
problem_model_list = [
ParagraphSerializers.Create.or_get(exists_problem_list, problem.get('content'), dataset_id) for
problem in (
instance.get('problem_list') if 'problem_list' in instance else [])]
# 问题去重
problem_model_list = [x for i, x in enumerate(problem_model_list) if
len([item for item in problem_model_list[:i] if item.content == x.content]) <= 0]
problem_paragraph_mapping_list = [
ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id,
paragraph_id=paragraph.id,
dataset_id=dataset_id) for
problem_model in problem_model_list]
return {'paragraph': paragraph, return {'paragraph': paragraph,
'problem_model_list': [problem_model for problem_model in problem_model_list if 'problem_paragraph_object_list': problem_paragraph_object_list}
not list(exists_problem_list).__contains__(problem_model)],
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
@staticmethod @staticmethod
def or_get(exists_problem_list, content, dataset_id): def or_get(exists_problem_list, content, dataset_id):