fix: 【知识库】知识库上传 有关联问题的会阻塞 (#676)
This commit is contained in:
parent
60181d6f83
commit
fc6da6a484
@ -7,6 +7,7 @@
|
|||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
@ -20,7 +21,7 @@ from common.mixins.api_mixin import ApiMixin
|
|||||||
from common.util.field_message import ErrMessage
|
from common.util.field_message import ErrMessage
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from common.util.fork import Fork
|
from common.util.fork import Fork
|
||||||
from dataset.models import Paragraph
|
from dataset.models import Paragraph, Problem, ProblemParagraphMapping
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
@ -79,3 +80,53 @@ class BatchSerializer(ApiMixin, serializers.Serializer):
|
|||||||
description="主键id列表")
|
description="主键id列表")
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProblemParagraphObject:
|
||||||
|
def __init__(self, dataset_id: str, document_id: str, paragraph_id: str, problem_content: str):
|
||||||
|
self.dataset_id = dataset_id
|
||||||
|
self.document_id = document_id
|
||||||
|
self.paragraph_id = paragraph_id
|
||||||
|
self.problem_content = problem_content
|
||||||
|
|
||||||
|
|
||||||
|
def or_get(exists_problem_list, content, dataset_id, document_id, paragraph_id, problem_content_dict):
|
||||||
|
if content in problem_content_dict:
|
||||||
|
return problem_content_dict.get(content)[0], document_id, paragraph_id
|
||||||
|
exists = [row for row in exists_problem_list if row.content == content]
|
||||||
|
if len(exists) > 0:
|
||||||
|
problem_content_dict[content] = exists[0], False
|
||||||
|
return exists[0], document_id, paragraph_id
|
||||||
|
else:
|
||||||
|
problem = Problem(id=uuid.uuid1(), content=content, dataset_id=dataset_id)
|
||||||
|
problem_content_dict[content] = problem, True
|
||||||
|
return problem, document_id, paragraph_id
|
||||||
|
|
||||||
|
|
||||||
|
class ProblemParagraphManage:
|
||||||
|
def __init__(self, problemParagraphObjectList: [ProblemParagraphObject], dataset_id):
|
||||||
|
self.dataset_id = dataset_id
|
||||||
|
self.problemParagraphObjectList = problemParagraphObjectList
|
||||||
|
|
||||||
|
def to_problem_model_list(self):
|
||||||
|
problem_list = [item.problem_content for item in self.problemParagraphObjectList]
|
||||||
|
exists_problem_list = []
|
||||||
|
if len(self.problemParagraphObjectList) > 0:
|
||||||
|
# 查询到已存在的问题列表
|
||||||
|
exists_problem_list = QuerySet(Problem).filter(dataset_id=self.dataset_id,
|
||||||
|
content__in=problem_list).all()
|
||||||
|
problem_content_dict = {}
|
||||||
|
problem_model_list = [
|
||||||
|
or_get(exists_problem_list, problemParagraphObject.problem_content, problemParagraphObject.dataset_id,
|
||||||
|
problemParagraphObject.document_id, problemParagraphObject.paragraph_id, problem_content_dict) for
|
||||||
|
problemParagraphObject in self.problemParagraphObjectList]
|
||||||
|
|
||||||
|
problem_paragraph_mapping_list = [
|
||||||
|
ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id,
|
||||||
|
paragraph_id=paragraph_id,
|
||||||
|
dataset_id=self.dataset_id) for
|
||||||
|
problem_model, document_id, paragraph_id in problem_model_list]
|
||||||
|
|
||||||
|
result = [problem_model for problem_model, is_create in problem_content_dict.values() if
|
||||||
|
is_create], problem_paragraph_mapping_list
|
||||||
|
return result
|
||||||
|
|||||||
@ -37,7 +37,7 @@ from common.util.file_util import get_file_content
|
|||||||
from common.util.fork import ChildLink, Fork
|
from common.util.fork import ChildLink, Fork
|
||||||
from common.util.split_model import get_split_model
|
from common.util.split_model import get_split_model
|
||||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
|
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
|
||||||
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer
|
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage
|
||||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
||||||
from embedding.models import SearchMode
|
from embedding.models import SearchMode
|
||||||
from setting.models import AuthOperate
|
from setting.models import AuthOperate
|
||||||
@ -383,8 +383,7 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
|
|
||||||
document_model_list = []
|
document_model_list = []
|
||||||
paragraph_model_list = []
|
paragraph_model_list = []
|
||||||
problem_model_list = []
|
problem_paragraph_object_list = []
|
||||||
problem_paragraph_mapping_list = []
|
|
||||||
# 插入文档
|
# 插入文档
|
||||||
for document in instance.get('documents') if 'documents' in instance else []:
|
for document in instance.get('documents') if 'documents' in instance else []:
|
||||||
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
||||||
@ -392,12 +391,12 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
document_model_list.append(document_paragraph_dict_model.get('document'))
|
document_model_list.append(document_paragraph_dict_model.get('document'))
|
||||||
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
|
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
|
||||||
paragraph_model_list.append(paragraph)
|
paragraph_model_list.append(paragraph)
|
||||||
for problem in document_paragraph_dict_model.get('problem_model_list'):
|
for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
|
||||||
problem_model_list.append(problem)
|
problem_paragraph_object_list.append(problem_paragraph_object)
|
||||||
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
|
|
||||||
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
|
problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
|
||||||
problem_model_list, problem_paragraph_mapping_list = DocumentSerializers.Create.reset_problem_model(
|
dataset_id)
|
||||||
problem_model_list, problem_paragraph_mapping_list)
|
.to_problem_model_list())
|
||||||
# 插入知识库
|
# 插入知识库
|
||||||
dataset.save()
|
dataset.save()
|
||||||
# 插入文档
|
# 插入文档
|
||||||
|
|||||||
@ -41,7 +41,7 @@ from common.util.file_util import get_file_content
|
|||||||
from common.util.fork import Fork
|
from common.util.fork import Fork
|
||||||
from common.util.split_model import get_split_model
|
from common.util.split_model import get_split_model
|
||||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
|
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
|
||||||
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer
|
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage
|
||||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
@ -380,8 +380,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
document_paragraph_model = DocumentSerializers.Create.get_paragraph_model(document, paragraphs)
|
document_paragraph_model = DocumentSerializers.Create.get_paragraph_model(document, paragraphs)
|
||||||
|
|
||||||
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
||||||
problem_model_list = document_paragraph_model.get('problem_model_list')
|
problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list')
|
||||||
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list')
|
problem_model_list, problem_paragraph_mapping_list = ProblemParagraphManage(
|
||||||
|
problem_paragraph_object_list, document.dataset_id).to_problem_model_list()
|
||||||
# 批量插入段落
|
# 批量插入段落
|
||||||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||||
# 批量插入问题
|
# 批量插入问题
|
||||||
@ -626,11 +627,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
dataset_id = self.data.get('dataset_id')
|
dataset_id = self.data.get('dataset_id')
|
||||||
document_paragraph_model = self.get_document_paragraph_model(dataset_id, instance)
|
document_paragraph_model = self.get_document_paragraph_model(dataset_id, instance)
|
||||||
|
|
||||||
document_model = document_paragraph_model.get('document')
|
document_model = document_paragraph_model.get('document')
|
||||||
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
||||||
problem_model_list = document_paragraph_model.get('problem_model_list')
|
problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list')
|
||||||
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list')
|
problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
|
||||||
|
dataset_id)
|
||||||
|
.to_problem_model_list())
|
||||||
# 插入文档
|
# 插入文档
|
||||||
document_model.save()
|
document_model.save()
|
||||||
# 批量插入段落
|
# 批量插入段落
|
||||||
@ -685,35 +688,15 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
dataset_id, document_model.id, paragraph) for paragraph in paragraph_list]
|
dataset_id, document_model.id, paragraph) for paragraph in paragraph_list]
|
||||||
|
|
||||||
paragraph_model_list = []
|
paragraph_model_list = []
|
||||||
problem_model_list = []
|
problem_paragraph_object_list = []
|
||||||
problem_paragraph_mapping_list = []
|
|
||||||
for paragraphs in paragraph_model_dict_list:
|
for paragraphs in paragraph_model_dict_list:
|
||||||
paragraph = paragraphs.get('paragraph')
|
paragraph = paragraphs.get('paragraph')
|
||||||
for problem_model in paragraphs.get('problem_model_list'):
|
for problem_model in paragraphs.get('problem_paragraph_object_list'):
|
||||||
problem_model_list.append(problem_model)
|
problem_paragraph_object_list.append(problem_model)
|
||||||
for problem_paragraph_mapping in paragraphs.get('problem_paragraph_mapping_list'):
|
|
||||||
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
|
|
||||||
paragraph_model_list.append(paragraph)
|
paragraph_model_list.append(paragraph)
|
||||||
|
|
||||||
problem_model_list, problem_paragraph_mapping_list = DocumentSerializers.Create.reset_problem_model(
|
|
||||||
problem_model_list, problem_paragraph_mapping_list)
|
|
||||||
|
|
||||||
return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
|
return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
|
||||||
'problem_model_list': problem_model_list,
|
'problem_paragraph_object_list': problem_paragraph_object_list}
|
||||||
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def reset_problem_model(problem_model_list, problem_paragraph_mapping_list):
|
|
||||||
new_problem_model_list = [x for i, x in enumerate(problem_model_list) if
|
|
||||||
len([item for item in problem_model_list[:i] if item.content == x.content]) <= 0]
|
|
||||||
|
|
||||||
for new_problem_model in new_problem_model_list:
|
|
||||||
old_model_list = [problem.id for problem in problem_model_list if
|
|
||||||
problem.content == new_problem_model.content]
|
|
||||||
for problem_paragraph_mapping in problem_paragraph_mapping_list:
|
|
||||||
if old_model_list.__contains__(problem_paragraph_mapping.problem_id):
|
|
||||||
problem_paragraph_mapping.problem_id = new_problem_model.id
|
|
||||||
return new_problem_model_list, problem_paragraph_mapping_list
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_document_paragraph_model(dataset_id, instance: Dict):
|
def get_document_paragraph_model(dataset_id, instance: Dict):
|
||||||
@ -834,8 +817,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
dataset_id = self.data.get("dataset_id")
|
dataset_id = self.data.get("dataset_id")
|
||||||
document_model_list = []
|
document_model_list = []
|
||||||
paragraph_model_list = []
|
paragraph_model_list = []
|
||||||
problem_model_list = []
|
problem_paragraph_object_list = []
|
||||||
problem_paragraph_mapping_list = []
|
|
||||||
# 插入文档
|
# 插入文档
|
||||||
for document in instance_list:
|
for document in instance_list:
|
||||||
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
||||||
@ -843,11 +825,12 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
document_model_list.append(document_paragraph_dict_model.get('document'))
|
document_model_list.append(document_paragraph_dict_model.get('document'))
|
||||||
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
|
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
|
||||||
paragraph_model_list.append(paragraph)
|
paragraph_model_list.append(paragraph)
|
||||||
for problem in document_paragraph_dict_model.get('problem_model_list'):
|
for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
|
||||||
problem_model_list.append(problem)
|
problem_paragraph_object_list.append(problem_paragraph_object)
|
||||||
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
|
|
||||||
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
|
|
||||||
|
|
||||||
|
problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
|
||||||
|
dataset_id)
|
||||||
|
.to_problem_model_list())
|
||||||
# 插入文档
|
# 插入文档
|
||||||
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
|
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
|
||||||
# 批量插入段落
|
# 批量插入段落
|
||||||
|
|||||||
@ -21,7 +21,8 @@ from common.mixins.api_mixin import ApiMixin
|
|||||||
from common.util.common import post
|
from common.util.common import post
|
||||||
from common.util.field_message import ErrMessage
|
from common.util.field_message import ErrMessage
|
||||||
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
|
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
|
||||||
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer
|
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
|
||||||
|
ProblemParagraphManage
|
||||||
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
|
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
|
||||||
from embedding.models import SourceType
|
from embedding.models import SourceType
|
||||||
|
|
||||||
@ -567,8 +568,10 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
document_id = self.data.get('document_id')
|
document_id = self.data.get('document_id')
|
||||||
paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance)
|
paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance)
|
||||||
paragraph = paragraph_problem_model.get('paragraph')
|
paragraph = paragraph_problem_model.get('paragraph')
|
||||||
problem_model_list = paragraph_problem_model.get('problem_model_list')
|
problem_paragraph_object_list = paragraph_problem_model.get('problem_paragraph_object_list')
|
||||||
problem_paragraph_mapping_list = paragraph_problem_model.get('problem_paragraph_mapping_list')
|
problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
|
||||||
|
dataset_id).
|
||||||
|
to_problem_model_list())
|
||||||
# 插入段落
|
# 插入段落
|
||||||
paragraph_problem_model.get('paragraph').save()
|
paragraph_problem_model.get('paragraph').save()
|
||||||
# 插入問題
|
# 插入問題
|
||||||
@ -591,30 +594,12 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
content=instance.get("content"),
|
content=instance.get("content"),
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
title=instance.get("title") if 'title' in instance else '')
|
title=instance.get("title") if 'title' in instance else '')
|
||||||
problem_list = instance.get('problem_list')
|
problem_paragraph_object_list = [
|
||||||
exists_problem_list = []
|
ProblemParagraphObject(dataset_id, document_id, paragraph.id, problem.get('content')) for problem in
|
||||||
if 'problem_list' in instance and len(problem_list) > 0:
|
(instance.get('problem_list') if 'problem_list' in instance else [])]
|
||||||
exists_problem_list = QuerySet(Problem).filter(dataset_id=dataset_id,
|
|
||||||
content__in=[p.get('content') for p in
|
|
||||||
problem_list]).all()
|
|
||||||
|
|
||||||
problem_model_list = [
|
|
||||||
ParagraphSerializers.Create.or_get(exists_problem_list, problem.get('content'), dataset_id) for
|
|
||||||
problem in (
|
|
||||||
instance.get('problem_list') if 'problem_list' in instance else [])]
|
|
||||||
# 问题去重
|
|
||||||
problem_model_list = [x for i, x in enumerate(problem_model_list) if
|
|
||||||
len([item for item in problem_model_list[:i] if item.content == x.content]) <= 0]
|
|
||||||
|
|
||||||
problem_paragraph_mapping_list = [
|
|
||||||
ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id,
|
|
||||||
paragraph_id=paragraph.id,
|
|
||||||
dataset_id=dataset_id) for
|
|
||||||
problem_model in problem_model_list]
|
|
||||||
return {'paragraph': paragraph,
|
return {'paragraph': paragraph,
|
||||||
'problem_model_list': [problem_model for problem_model in problem_model_list if
|
'problem_paragraph_object_list': problem_paragraph_object_list}
|
||||||
not list(exists_problem_list).__contains__(problem_model)],
|
|
||||||
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def or_get(exists_problem_list, content, dataset_id):
|
def or_get(exists_problem_list, content, dataset_id):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user