feat: 优化创建数据集文档段落
This commit is contained in:
parent
5488804ea8
commit
019e133f0f
@ -212,12 +212,31 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
dataset_id = uuid.uuid1()
|
dataset_id = uuid.uuid1()
|
||||||
dataset = DataSet(
|
dataset = DataSet(
|
||||||
**{'id': dataset_id, 'name': self.data.get("name"), 'desc': self.data.get('desc'), 'user': user})
|
**{'id': dataset_id, 'name': self.data.get("name"), 'desc': self.data.get('desc'), 'user': user})
|
||||||
|
|
||||||
|
document_model_list = []
|
||||||
|
paragraph_model_list = []
|
||||||
|
problem_model_list = []
|
||||||
|
# 插入文档
|
||||||
|
for document in self.data.get('documents') if 'documents' in self.data else []:
|
||||||
|
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
||||||
|
document)
|
||||||
|
document_model_list.append(document_paragraph_dict_model.get('document'))
|
||||||
|
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
|
||||||
|
paragraph_model_list.append(paragraph)
|
||||||
|
for problem in document_paragraph_dict_model.get('problem_model_list'):
|
||||||
|
problem_model_list.append(problem)
|
||||||
|
|
||||||
# 插入数据集
|
# 插入数据集
|
||||||
dataset.save()
|
dataset.save()
|
||||||
for document in self.data.get('documents') if 'documents' in self.data else []:
|
# 插入文档
|
||||||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(document, with_valid=True,
|
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
|
||||||
with_embedding=False)
|
# 批量插入段落
|
||||||
|
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||||
|
# 批量插入问题
|
||||||
|
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||||
|
# 发送向量化事件
|
||||||
ListenerManagement.embedding_by_dataset_signal.send(str(dataset.id))
|
ListenerManagement.embedding_by_dataset_signal.send(str(dataset.id))
|
||||||
|
# 响应数据
|
||||||
return {**DataSetSerializers(dataset).data,
|
return {**DataSetSerializers(dataset).data,
|
||||||
'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(with_valid=True)}
|
'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(with_valid=True)}
|
||||||
|
|
||||||
|
|||||||
@ -37,7 +37,7 @@ class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
|
|||||||
message="数据集名称在1-128个字符之间")
|
message="数据集名称在1-128个字符之间")
|
||||||
])
|
])
|
||||||
|
|
||||||
paragraphs = ParagraphInstanceSerializer(required=False, many=True)
|
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_request_body_api():
|
def get_request_body_api():
|
||||||
@ -204,7 +204,24 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
dataset_id = self.data.get('dataset_id')
|
dataset_id = self.data.get('dataset_id')
|
||||||
|
document_paragraph_model = self.get_document_paragraph_model(dataset_id, instance)
|
||||||
|
document_model = document_paragraph_model.get('document')
|
||||||
|
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
||||||
|
problem_model_list = document_paragraph_model.get('problem_model_list')
|
||||||
|
# 插入文档
|
||||||
|
document_model.save()
|
||||||
|
# 批量插入段落
|
||||||
|
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||||
|
# 批量插入问题
|
||||||
|
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||||
|
if with_embedding:
|
||||||
|
ListenerManagement.embedding_by_document_signal.send(str(document_model.id))
|
||||||
|
return DocumentSerializers.Operate(
|
||||||
|
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).one(
|
||||||
|
with_valid=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_document_paragraph_model(dataset_id, instance: Dict):
|
||||||
document_model = Document(
|
document_model = Document(
|
||||||
**{'dataset_id': dataset_id,
|
**{'dataset_id': dataset_id,
|
||||||
'id': uuid.uuid1(),
|
'id': uuid.uuid1(),
|
||||||
@ -212,19 +229,22 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
'char_length': reduce(lambda x, y: x + y,
|
'char_length': reduce(lambda x, y: x + y,
|
||||||
[len(p.get('content')) for p in instance.get('paragraphs', [])],
|
[len(p.get('content')) for p in instance.get('paragraphs', [])],
|
||||||
0)})
|
0)})
|
||||||
# 插入文档
|
|
||||||
document_model.save()
|
|
||||||
|
|
||||||
for paragraph in instance.get('paragraphs') if 'paragraphs' in instance else []:
|
paragraph_model_dict_list = [ParagraphSerializers.Create(
|
||||||
ParagraphSerializers.Create(
|
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).get_paragraph_problem_model(
|
||||||
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).save(paragraph,
|
dataset_id, document_model.id, paragraph) for paragraph in (instance.get('paragraphs') if
|
||||||
with_valid=True,
|
'paragraphs' in instance else [])]
|
||||||
with_embedding=False)
|
|
||||||
if with_embedding:
|
paragraph_model_list = []
|
||||||
ListenerManagement.embedding_by_document_signal.send(str(document_model.id))
|
problem_model_list = []
|
||||||
return DocumentSerializers.Operate(
|
for paragraphs in paragraph_model_dict_list:
|
||||||
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).one(
|
paragraph = paragraphs.get('paragraph')
|
||||||
with_valid=True)
|
for problem_model in paragraphs.get('problem_model_list'):
|
||||||
|
problem_model_list.append(problem_model)
|
||||||
|
paragraph_model_list.append(paragraph)
|
||||||
|
|
||||||
|
return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
|
||||||
|
'problem_model_list': problem_model_list}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_request_body_api():
|
def get_request_body_api():
|
||||||
|
|||||||
@ -39,8 +39,8 @@ class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer):
|
|||||||
validators.MaxLengthValidator(limit_value=1024,
|
validators.MaxLengthValidator(limit_value=1024,
|
||||||
message="段落在1-1024个字符之间"),
|
message="段落在1-1024个字符之间"),
|
||||||
validators.MinLengthValidator(limit_value=1,
|
validators.MinLengthValidator(limit_value=1,
|
||||||
message="段落在1-1024个字符之间")
|
message="段落在1-1024个字符之间"),
|
||||||
])
|
], allow_null=True, allow_blank=True)
|
||||||
|
|
||||||
title = serializers.CharField(required=False, allow_null=True, allow_blank=True)
|
title = serializers.CharField(required=False, allow_null=True, allow_blank=True)
|
||||||
|
|
||||||
@ -179,17 +179,11 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
self.is_valid()
|
self.is_valid()
|
||||||
dataset_id = self.data.get("dataset_id")
|
dataset_id = self.data.get("dataset_id")
|
||||||
document_id = self.data.get('document_id')
|
document_id = self.data.get('document_id')
|
||||||
|
paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance)
|
||||||
paragraph = Paragraph(id=uuid.uuid1(),
|
paragraph = paragraph_problem_model.get('paragraph')
|
||||||
document_id=document_id,
|
problem_model_list = paragraph_problem_model.get('problem_model_list')
|
||||||
content=instance.get("content"),
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
title=instance.get("title") if 'title' in instance else '')
|
|
||||||
# 插入段落
|
# 插入段落
|
||||||
paragraph.save()
|
paragraph_problem_model.get('paragraph').save()
|
||||||
problem_model_list = [Problem(id=uuid.uuid1(), content=problem.get('content'), paragraph_id=paragraph.id,
|
|
||||||
document_id=document_id, dataset_id=dataset_id) for problem in (
|
|
||||||
instance.get('problem_list') if 'problem_list' in instance else [])]
|
|
||||||
# 插入問題
|
# 插入問題
|
||||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||||
# 修改长度
|
# 修改长度
|
||||||
@ -200,6 +194,20 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
|
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||||
with_valid=True)
|
with_valid=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_paragraph_problem_model(dataset_id: str, document_id: str, instance: Dict):
|
||||||
|
paragraph = Paragraph(id=uuid.uuid1(),
|
||||||
|
document_id=document_id,
|
||||||
|
content=instance.get("content"),
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
title=instance.get("title") if 'title' in instance else '')
|
||||||
|
|
||||||
|
problem_model_list = [Problem(id=uuid.uuid1(), content=problem.get('content'), paragraph_id=paragraph.id,
|
||||||
|
document_id=document_id, dataset_id=dataset_id) for problem in (
|
||||||
|
instance.get('problem_list') if 'problem_list' in instance else [])]
|
||||||
|
|
||||||
|
return {'paragraph': paragraph, 'problem_model_list': problem_model_list}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_request_body_api():
|
def get_request_body_api():
|
||||||
return ParagraphInstanceSerializer.get_request_body_api()
|
return ParagraphInstanceSerializer.get_request_body_api()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user