feat: 添加问题管理相关接口,兼容历史版本
This commit is contained in:
parent
1691e56da5
commit
b470b1b6e5
@ -21,7 +21,7 @@ from common.event.common import poxy, embedding_poxy
|
|||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from common.util.fork import ForkManage, Fork
|
from common.util.fork import ForkManage, Fork
|
||||||
from common.util.lock import try_lock, un_lock
|
from common.util.lock import try_lock, un_lock
|
||||||
from dataset.models import Paragraph, Status, Document
|
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping
|
||||||
from embedding.models import SourceType
|
from embedding.models import SourceType
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
@ -44,6 +44,12 @@ class SyncWebDocumentArgs:
|
|||||||
self.handler = handler
|
self.handler = handler
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateProblemArgs:
|
||||||
|
def __init__(self, problem_id: str, problem_content: str):
|
||||||
|
self.problem_id = problem_id
|
||||||
|
self.problem_content = problem_content
|
||||||
|
|
||||||
|
|
||||||
class ListenerManagement:
|
class ListenerManagement:
|
||||||
embedding_by_problem_signal = signal("embedding_by_problem")
|
embedding_by_problem_signal = signal("embedding_by_problem")
|
||||||
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
|
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
|
||||||
@ -59,6 +65,8 @@ class ListenerManagement:
|
|||||||
init_embedding_model_signal = signal('init_embedding_model')
|
init_embedding_model_signal = signal('init_embedding_model')
|
||||||
sync_web_dataset_signal = signal('sync_web_dataset')
|
sync_web_dataset_signal = signal('sync_web_dataset')
|
||||||
sync_web_document_signal = signal('sync_web_document')
|
sync_web_document_signal = signal('sync_web_document')
|
||||||
|
update_problem_signal = signal('update_problem')
|
||||||
|
delete_embedding_by_source_ids_signal = signal('delete_embedding_by_source_ids')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def embedding_by_problem(args):
|
def embedding_by_problem(args):
|
||||||
@ -76,8 +84,8 @@ class ListenerManagement:
|
|||||||
status = Status.success
|
status = Status.success
|
||||||
try:
|
try:
|
||||||
data_list = native_search(
|
data_list = native_search(
|
||||||
{'problem': QuerySet(get_dynamics_model({'problem.paragraph_id': django.db.models.CharField()})).filter(
|
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
|
||||||
**{'problem.paragraph_id': paragraph_id}),
|
**{'paragraph.id': paragraph_id}),
|
||||||
'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)},
|
'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)},
|
||||||
select_string=get_file_content(
|
select_string=get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
|
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
|
||||||
@ -104,8 +112,9 @@ class ListenerManagement:
|
|||||||
status = Status.success
|
status = Status.success
|
||||||
try:
|
try:
|
||||||
data_list = native_search(
|
data_list = native_search(
|
||||||
{'problem': QuerySet(get_dynamics_model({'problem.document_id': django.db.models.CharField()})).filter(
|
{'problem': QuerySet(
|
||||||
**{'problem.document_id': document_id}),
|
get_dynamics_model({'paragraph.document_id': django.db.models.CharField()})).filter(
|
||||||
|
**{'paragraph.document_id': document_id}),
|
||||||
'paragraph': QuerySet(Paragraph).filter(document_id=document_id)},
|
'paragraph': QuerySet(Paragraph).filter(document_id=document_id)},
|
||||||
select_string=get_file_content(
|
select_string=get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
|
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
|
||||||
@ -188,6 +197,17 @@ class ListenerManagement:
|
|||||||
finally:
|
finally:
|
||||||
un_lock('sync_web_dataset' + args.lock_key)
|
un_lock('sync_web_dataset' + args.lock_key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_problem(args: UpdateProblemArgs):
|
||||||
|
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(problem_id=args.problem_id)
|
||||||
|
embed_value = VectorStore.get_embedding_vector().embed_query(args.problem_content)
|
||||||
|
VectorStore.get_embedding_vector().update_by_source_ids([v.id for v in problem_paragraph_mapping_list],
|
||||||
|
{'embedding': embed_value})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_embedding_by_source_ids(source_ids: List[str]):
|
||||||
|
VectorStore.get_embedding_vector().delete_by_source_ids(source_ids, SourceType.PROBLEM)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@poxy
|
@poxy
|
||||||
def init_embedding_model(ags):
|
def init_embedding_model(ags):
|
||||||
@ -225,3 +245,6 @@ class ListenerManagement:
|
|||||||
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
|
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
|
||||||
# 同步web站点 文档
|
# 同步web站点 文档
|
||||||
ListenerManagement.sync_web_document_signal.connect(self.sync_web_document)
|
ListenerManagement.sync_web_document_signal.connect(self.sync_web_document)
|
||||||
|
# 更新问题向量
|
||||||
|
ListenerManagement.update_problem_signal.connect(self.update_problem)
|
||||||
|
ListenerManagement.delete_embedding_by_source_ids_signal.connect(self.delete_embedding_by_source_ids)
|
||||||
|
|||||||
@ -1,14 +1,15 @@
|
|||||||
SELECT
|
SELECT
|
||||||
problem."id" AS "source_id",
|
problem_paragraph_mapping."id" AS "source_id",
|
||||||
problem.document_id AS document_id,
|
paragraph.document_id AS document_id,
|
||||||
problem.paragraph_id AS paragraph_id,
|
paragraph."id" AS paragraph_id,
|
||||||
problem.dataset_id AS dataset_id,
|
problem.dataset_id AS dataset_id,
|
||||||
0 AS source_type,
|
0 AS source_type,
|
||||||
problem."content" AS "text",
|
problem."content" AS "text",
|
||||||
paragraph.is_active AS is_active
|
paragraph.is_active AS is_active
|
||||||
FROM
|
FROM
|
||||||
problem problem
|
problem problem
|
||||||
LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id
|
LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem_paragraph_mapping.problem_id=problem."id"
|
||||||
|
LEFT JOIN paragraph paragraph ON paragraph."id" = problem_paragraph_mapping.paragraph_id
|
||||||
${problem}
|
${problem}
|
||||||
|
|
||||||
UNION
|
UNION
|
||||||
|
|||||||
@ -0,0 +1,59 @@
|
|||||||
|
# Generated by Django 4.1.10 on 2024-03-08 18:29
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.db.models.deletion
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from embedding.models import SourceType
|
||||||
|
|
||||||
|
|
||||||
|
def delete_problem_embedding(apps, schema_editor):
|
||||||
|
Embedding = apps.get_model('embedding', 'Embedding')
|
||||||
|
Embedding.objects.filter(source_type=SourceType.PROBLEM).delete()
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
('dataset', '0004_remove_paragraph_hit_num_remove_paragraph_star_num_and_more'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='problem',
|
||||||
|
name='document',
|
||||||
|
),
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='problem',
|
||||||
|
name='paragraph',
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='paragraph',
|
||||||
|
name='hit_num',
|
||||||
|
field=models.IntegerField(default=0, verbose_name='命中次数'),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='problem',
|
||||||
|
name='hit_num',
|
||||||
|
field=models.IntegerField(default=0, verbose_name='命中次数'),
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name='ProblemParagraphMapping',
|
||||||
|
fields=[
|
||||||
|
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
|
||||||
|
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
|
||||||
|
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False,
|
||||||
|
verbose_name='主键id')),
|
||||||
|
('dataset', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING,
|
||||||
|
to='dataset.dataset')),
|
||||||
|
('document', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document')),
|
||||||
|
('paragraph', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING,
|
||||||
|
to='dataset.paragraph')),
|
||||||
|
('problem', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING,
|
||||||
|
to='dataset.problem')),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
'db_table': 'problem_paragraph_mapping',
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.RunPython(delete_problem_embedding)
|
||||||
|
]
|
||||||
@ -76,6 +76,7 @@ class Paragraph(AppModelMixin):
|
|||||||
title = models.CharField(max_length=256, verbose_name="标题", default="")
|
title = models.CharField(max_length=256, verbose_name="标题", default="")
|
||||||
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
|
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
|
||||||
default=Status.embedding)
|
default=Status.embedding)
|
||||||
|
hit_num = models.IntegerField(verbose_name="命中次数", default=0)
|
||||||
is_active = models.BooleanField(default=True)
|
is_active = models.BooleanField(default=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
@ -87,10 +88,20 @@ class Problem(AppModelMixin):
|
|||||||
问题表
|
问题表
|
||||||
"""
|
"""
|
||||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
|
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
|
||||||
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False)
|
|
||||||
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False)
|
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||||
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
|
|
||||||
content = models.CharField(max_length=256, verbose_name="问题内容")
|
content = models.CharField(max_length=256, verbose_name="问题内容")
|
||||||
|
hit_num = models.IntegerField(verbose_name="命中次数", default=0)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
db_table = "problem"
|
db_table = "problem"
|
||||||
|
|
||||||
|
|
||||||
|
class ProblemParagraphMapping(AppModelMixin):
|
||||||
|
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
|
||||||
|
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||||
|
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING)
|
||||||
|
problem = models.ForeignKey(Problem, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||||
|
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "problem_paragraph_mapping"
|
||||||
|
|||||||
@ -34,7 +34,7 @@ from common.util.field_message import ErrMessage
|
|||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from common.util.fork import ChildLink, Fork
|
from common.util.fork import ChildLink, Fork
|
||||||
from common.util.split_model import get_split_model
|
from common.util.split_model import get_split_model
|
||||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
|
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
|
||||||
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer
|
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer
|
||||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
||||||
from setting.models import AuthOperate
|
from setting.models import AuthOperate
|
||||||
@ -303,6 +303,7 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
document_model_list = []
|
document_model_list = []
|
||||||
paragraph_model_list = []
|
paragraph_model_list = []
|
||||||
problem_model_list = []
|
problem_model_list = []
|
||||||
|
problem_paragraph_mapping_list = []
|
||||||
# 插入文档
|
# 插入文档
|
||||||
for document in instance.get('documents') if 'documents' in instance else []:
|
for document in instance.get('documents') if 'documents' in instance else []:
|
||||||
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
||||||
@ -312,6 +313,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
paragraph_model_list.append(paragraph)
|
paragraph_model_list.append(paragraph)
|
||||||
for problem in document_paragraph_dict_model.get('problem_model_list'):
|
for problem in document_paragraph_dict_model.get('problem_model_list'):
|
||||||
problem_model_list.append(problem)
|
problem_model_list.append(problem)
|
||||||
|
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
|
||||||
|
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
|
||||||
|
|
||||||
# 插入知识库
|
# 插入知识库
|
||||||
dataset.save()
|
dataset.save()
|
||||||
@ -321,6 +324,9 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||||
# 批量插入问题
|
# 批量插入问题
|
||||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||||
|
# 批量插入关联问题
|
||||||
|
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
||||||
|
problem_paragraph_mapping_list) > 0 else None
|
||||||
|
|
||||||
# 响应数据
|
# 响应数据
|
||||||
return {**DataSetSerializers(dataset).data,
|
return {**DataSetSerializers(dataset).data,
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from common.util.field_message import ErrMessage
|
|||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from common.util.fork import Fork
|
from common.util.fork import Fork
|
||||||
from common.util.split_model import SplitModel, get_split_model
|
from common.util.split_model import SplitModel, get_split_model
|
||||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status
|
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping
|
||||||
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer
|
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer
|
||||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
@ -179,7 +179,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
# 删除段落
|
# 删除段落
|
||||||
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
|
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
|
||||||
# 删除问题
|
# 删除问题
|
||||||
QuerySet(model=Problem).filter(document_id=document_id).delete()
|
QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete()
|
||||||
# 删除向量库
|
# 删除向量库
|
||||||
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
|
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
|
||||||
paragraphs = get_split_model('web.md').parse(result.content)
|
paragraphs = get_split_model('web.md').parse(result.content)
|
||||||
@ -191,10 +191,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
|
|
||||||
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
||||||
problem_model_list = document_paragraph_model.get('problem_model_list')
|
problem_model_list = document_paragraph_model.get('problem_model_list')
|
||||||
|
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list')
|
||||||
# 批量插入段落
|
# 批量插入段落
|
||||||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||||
# 批量插入问题
|
# 批量插入问题
|
||||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||||
|
# 插入关联问题
|
||||||
|
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
||||||
|
problem_paragraph_mapping_list) > 0 else None
|
||||||
# 向量化
|
# 向量化
|
||||||
if with_embedding:
|
if with_embedding:
|
||||||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
ListenerManagement.embedding_by_document_signal.send(document_id)
|
||||||
@ -273,7 +277,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
# 删除段落
|
# 删除段落
|
||||||
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
|
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
|
||||||
# 删除问题
|
# 删除问题
|
||||||
QuerySet(model=Problem).filter(document_id=document_id).delete()
|
QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete()
|
||||||
# 删除向量库
|
# 删除向量库
|
||||||
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
|
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
|
||||||
return True
|
return True
|
||||||
@ -344,12 +348,17 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
document_model = document_paragraph_model.get('document')
|
document_model = document_paragraph_model.get('document')
|
||||||
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
||||||
problem_model_list = document_paragraph_model.get('problem_model_list')
|
problem_model_list = document_paragraph_model.get('problem_model_list')
|
||||||
|
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list')
|
||||||
|
|
||||||
# 插入文档
|
# 插入文档
|
||||||
document_model.save()
|
document_model.save()
|
||||||
# 批量插入段落
|
# 批量插入段落
|
||||||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||||
# 批量插入问题
|
# 批量插入问题
|
||||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||||
|
# 批量插入关联问题
|
||||||
|
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
||||||
|
problem_paragraph_mapping_list) > 0 else None
|
||||||
document_id = str(document_model.id)
|
document_id = str(document_model.id)
|
||||||
return DocumentSerializers.Operate(
|
return DocumentSerializers.Operate(
|
||||||
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
|
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||||
@ -396,14 +405,18 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
|
|
||||||
paragraph_model_list = []
|
paragraph_model_list = []
|
||||||
problem_model_list = []
|
problem_model_list = []
|
||||||
|
problem_paragraph_mapping_list = []
|
||||||
for paragraphs in paragraph_model_dict_list:
|
for paragraphs in paragraph_model_dict_list:
|
||||||
paragraph = paragraphs.get('paragraph')
|
paragraph = paragraphs.get('paragraph')
|
||||||
for problem_model in paragraphs.get('problem_model_list'):
|
for problem_model in paragraphs.get('problem_model_list'):
|
||||||
problem_model_list.append(problem_model)
|
problem_model_list.append(problem_model)
|
||||||
|
for problem_paragraph_mapping in paragraphs.get('problem_paragraph_mapping_list'):
|
||||||
|
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
|
||||||
paragraph_model_list.append(paragraph)
|
paragraph_model_list.append(paragraph)
|
||||||
|
|
||||||
return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
|
return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
|
||||||
'problem_model_list': problem_model_list}
|
'problem_model_list': problem_model_list,
|
||||||
|
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_document_paragraph_model(dataset_id, instance: Dict):
|
def get_document_paragraph_model(dataset_id, instance: Dict):
|
||||||
@ -523,6 +536,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
document_model_list = []
|
document_model_list = []
|
||||||
paragraph_model_list = []
|
paragraph_model_list = []
|
||||||
problem_model_list = []
|
problem_model_list = []
|
||||||
|
problem_paragraph_mapping_list = []
|
||||||
# 插入文档
|
# 插入文档
|
||||||
for document in instance_list:
|
for document in instance_list:
|
||||||
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
||||||
@ -532,6 +546,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
paragraph_model_list.append(paragraph)
|
paragraph_model_list.append(paragraph)
|
||||||
for problem in document_paragraph_dict_model.get('problem_model_list'):
|
for problem in document_paragraph_dict_model.get('problem_model_list'):
|
||||||
problem_model_list.append(problem)
|
problem_model_list.append(problem)
|
||||||
|
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
|
||||||
|
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
|
||||||
|
|
||||||
# 插入文档
|
# 插入文档
|
||||||
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
|
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
|
||||||
@ -539,6 +555,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||||
# 批量插入问题
|
# 批量插入问题
|
||||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||||
|
# 批量插入关联问题
|
||||||
|
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
||||||
|
problem_paragraph_mapping_list) > 0 else None
|
||||||
# 查询文档
|
# 查询文档
|
||||||
query_set = QuerySet(model=Document)
|
query_set = QuerySet(model=Document)
|
||||||
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
|
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
|
||||||
|
|||||||
@ -20,9 +20,10 @@ from common.exception.app_exception import AppApiException
|
|||||||
from common.mixins.api_mixin import ApiMixin
|
from common.mixins.api_mixin import ApiMixin
|
||||||
from common.util.common import post
|
from common.util.common import post
|
||||||
from common.util.field_message import ErrMessage
|
from common.util.field_message import ErrMessage
|
||||||
from dataset.models import Paragraph, Problem, Document
|
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
|
||||||
from dataset.serializers.common_serializers import update_document_char_length
|
from dataset.serializers.common_serializers import update_document_char_length
|
||||||
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer
|
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
|
||||||
|
from embedding.models import SourceType
|
||||||
|
|
||||||
|
|
||||||
class ParagraphSerializer(serializers.ModelSerializer):
|
class ParagraphSerializer(serializers.ModelSerializer):
|
||||||
@ -84,6 +85,193 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
content = serializers.CharField(required=True, max_length=4096, error_messages=ErrMessage.char(
|
content = serializers.CharField(required=True, max_length=4096, error_messages=ErrMessage.char(
|
||||||
"分段内容"))
|
"分段内容"))
|
||||||
|
|
||||||
|
class Problem(ApiMixin, serializers.Serializer):
|
||||||
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
||||||
|
|
||||||
|
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
||||||
|
|
||||||
|
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
|
||||||
|
|
||||||
|
def is_valid(self, *, raise_exception=False):
|
||||||
|
super().is_valid(raise_exception=True)
|
||||||
|
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
|
||||||
|
raise AppApiException(500, "段落id不存在")
|
||||||
|
|
||||||
|
def list(self, with_valid=False):
|
||||||
|
"""
|
||||||
|
获取问题列表
|
||||||
|
:param with_valid: 是否校验
|
||||||
|
:return: 问题列表
|
||||||
|
"""
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("dataset_id"),
|
||||||
|
paragraph_id=self.data.get(
|
||||||
|
'paragraph_id'))
|
||||||
|
return [ProblemSerializer(row).data for row in
|
||||||
|
QuerySet(Problem).filter(id__in=[row.problem_id for row in problem_paragraph_mapping])]
|
||||||
|
|
||||||
|
@transaction.atomic
|
||||||
|
def save(self, instance: Dict, with_valid=True, with_embedding=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid()
|
||||||
|
ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
||||||
|
problem = QuerySet(Problem).filter(dataset_id=self.data.get('dataset_id'),
|
||||||
|
content=instance.get('content')).first()
|
||||||
|
if problem is None:
|
||||||
|
problem = Problem(id=uuid.uuid1(), dataset_id=self.data.get('dataset_id'),
|
||||||
|
content=instance.get('content'))
|
||||||
|
problem.save()
|
||||||
|
if QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get('dataset_id'), problem_id=problem.id,
|
||||||
|
paragraph_id=self.data.get('paragraph_id')).exists():
|
||||||
|
raise AppApiException(500, "已经关联,请勿重复关联")
|
||||||
|
problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(),
|
||||||
|
problem_id=problem.id,
|
||||||
|
document_id=self.data.get('document_id'),
|
||||||
|
paragraph_id=self.data.get('paragraph_id'),
|
||||||
|
dataset_id=self.data.get('dataset_id'))
|
||||||
|
problem_paragraph_mapping.save()
|
||||||
|
if with_embedding:
|
||||||
|
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
||||||
|
'is_active': True,
|
||||||
|
'source_type': SourceType.PROBLEM,
|
||||||
|
'source_id': problem_paragraph_mapping.id,
|
||||||
|
'document_id': self.data.get('document_id'),
|
||||||
|
'paragraph_id': self.data.get('paragraph_id'),
|
||||||
|
'dataset_id': self.data.get('dataset_id'),
|
||||||
|
})
|
||||||
|
|
||||||
|
return ProblemSerializers.Operate(
|
||||||
|
data={'dataset_id': self.data.get('dataset_id'),
|
||||||
|
'problem_id': problem.id}).one(with_valid=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_request_params_api():
|
||||||
|
return [openapi.Parameter(name='dataset_id',
|
||||||
|
in_=openapi.IN_PATH,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
required=True,
|
||||||
|
description='知识库id'),
|
||||||
|
openapi.Parameter(name='document_id',
|
||||||
|
in_=openapi.IN_PATH,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
required=True,
|
||||||
|
description='文档id'),
|
||||||
|
openapi.Parameter(name='paragraph_id',
|
||||||
|
in_=openapi.IN_PATH,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
required=True,
|
||||||
|
description='段落id')]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_request_body_api():
|
||||||
|
return openapi.Schema(type=openapi.TYPE_OBJECT,
|
||||||
|
required=["content"],
|
||||||
|
properties={
|
||||||
|
'content': openapi.Schema(
|
||||||
|
type=openapi.TYPE_STRING, title="内容")
|
||||||
|
})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_response_body_api():
|
||||||
|
return openapi.Schema(
|
||||||
|
type=openapi.TYPE_OBJECT,
|
||||||
|
required=['id', 'content', 'hit_num', 'dataset_id', 'create_time', 'update_time'],
|
||||||
|
properties={
|
||||||
|
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
||||||
|
description="id", default="xx"),
|
||||||
|
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
|
||||||
|
description="问题内容", default='问题内容'),
|
||||||
|
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
|
||||||
|
default=1),
|
||||||
|
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id",
|
||||||
|
description="知识库id", default='xxx'),
|
||||||
|
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
||||||
|
description="修改时间",
|
||||||
|
default="1970-01-01 00:00:00"),
|
||||||
|
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
||||||
|
description="创建时间",
|
||||||
|
default="1970-01-01 00:00:00"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
class Association(ApiMixin, serializers.Serializer):
|
||||||
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
||||||
|
|
||||||
|
problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("问题id"))
|
||||||
|
|
||||||
|
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
||||||
|
|
||||||
|
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
|
||||||
|
|
||||||
|
def is_valid(self, *, raise_exception=True):
|
||||||
|
super().is_valid(raise_exception=True)
|
||||||
|
dataset_id = self.data.get('dataset_id')
|
||||||
|
paragraph_id = self.data.get('paragraph_id')
|
||||||
|
problem_id = self.data.get("problem_id")
|
||||||
|
if not QuerySet(Paragraph).filter(dataset_id=dataset_id, id=paragraph_id).exists():
|
||||||
|
raise AppApiException(500, "段落不存在")
|
||||||
|
if not QuerySet(Problem).filter(dataset_id=dataset_id, id=problem_id).exists():
|
||||||
|
raise AppApiException(500, "问题不存在")
|
||||||
|
|
||||||
|
def association(self, with_valid=True, with_embedding=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
problem = QuerySet(Problem).filter(id=self.data.get("problem_id"))
|
||||||
|
problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(),
|
||||||
|
document_id=self.data.get('document_id'),
|
||||||
|
paragraph_id=self.data.get('paragraph_id'),
|
||||||
|
dataset_id=self.data.get('dataset_id'),
|
||||||
|
problem_id=problem.id)
|
||||||
|
problem_paragraph_mapping.save()
|
||||||
|
if with_embedding:
|
||||||
|
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
||||||
|
'is_active': True,
|
||||||
|
'source_type': SourceType.PROBLEM,
|
||||||
|
'source_id': problem_paragraph_mapping.id,
|
||||||
|
'document_id': self.data.get('document_id'),
|
||||||
|
'paragraph_id': self.data.get('paragraph_id'),
|
||||||
|
'dataset_id': self.data.get('dataset_id'),
|
||||||
|
})
|
||||||
|
|
||||||
|
def un_association(self, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
|
||||||
|
paragraph_id=self.data.get('paragraph_id'),
|
||||||
|
dataset_id=self.data.get('dataset_id'),
|
||||||
|
problem_id=self.data.get(
|
||||||
|
'problem_id')).first()
|
||||||
|
problem_paragraph_mapping_id = problem_paragraph_mapping.id
|
||||||
|
problem_paragraph_mapping.delete()
|
||||||
|
ListenerManagement.delete_embedding_by_source_signal.send(problem_paragraph_mapping_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_request_params_api():
|
||||||
|
return [openapi.Parameter(name='dataset_id',
|
||||||
|
in_=openapi.IN_PATH,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
required=True,
|
||||||
|
description='知识库id'),
|
||||||
|
openapi.Parameter(name='document_id',
|
||||||
|
in_=openapi.IN_PATH,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
required=True,
|
||||||
|
description='文档id')
|
||||||
|
, openapi.Parameter(name='paragraph_id',
|
||||||
|
in_=openapi.IN_PATH,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
required=True,
|
||||||
|
description='段落id'),
|
||||||
|
openapi.Parameter(name='problem_id',
|
||||||
|
in_=openapi.IN_PATH,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
required=True,
|
||||||
|
description='问题id')
|
||||||
|
]
|
||||||
|
|
||||||
class Operate(ApiMixin, serializers.Serializer):
|
class Operate(ApiMixin, serializers.Serializer):
|
||||||
# 段落id
|
# 段落id
|
||||||
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
||||||
@ -158,8 +346,13 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
return self.one(), instance
|
return self.one(), instance
|
||||||
|
|
||||||
def get_problem_list(self):
|
def get_problem_list(self):
|
||||||
return [ProblemSerializer(problem).data for problem in
|
ProblemParagraphMapping(ProblemParagraphMapping)
|
||||||
QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))]
|
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
|
||||||
|
paragraph_id=self.data.get("paragraph_id"))
|
||||||
|
if len(problem_paragraph_mapping) > 0:
|
||||||
|
return [ProblemSerializer(problem).data for problem in
|
||||||
|
QuerySet(Problem).filter(id__in=[ppm.problem_id for ppm in problem_paragraph_mapping])]
|
||||||
|
return []
|
||||||
|
|
||||||
def one(self, with_valid=False):
|
def one(self, with_valid=False):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
@ -172,7 +365,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
paragraph_id = self.data.get('paragraph_id')
|
paragraph_id = self.data.get('paragraph_id')
|
||||||
QuerySet(Paragraph).filter(id=paragraph_id).delete()
|
QuerySet(Paragraph).filter(id=paragraph_id).delete()
|
||||||
QuerySet(Problem).filter(paragraph_id=paragraph_id).delete()
|
QuerySet(ProblemParagraphMapping).filter(paragraph_id=paragraph_id).delete()
|
||||||
ListenerManagement.delete_embedding_by_paragraph_signal.send(paragraph_id)
|
ListenerManagement.delete_embedding_by_paragraph_signal.send(paragraph_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -210,10 +403,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance)
|
paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance)
|
||||||
paragraph = paragraph_problem_model.get('paragraph')
|
paragraph = paragraph_problem_model.get('paragraph')
|
||||||
problem_model_list = paragraph_problem_model.get('problem_model_list')
|
problem_model_list = paragraph_problem_model.get('problem_model_list')
|
||||||
|
problem_paragraph_mapping_list = paragraph_problem_model.get('problem_paragraph_mapping_list')
|
||||||
# 插入段落
|
# 插入段落
|
||||||
paragraph_problem_model.get('paragraph').save()
|
paragraph_problem_model.get('paragraph').save()
|
||||||
# 插入問題
|
# 插入問題
|
||||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||||
|
# 插入问题关联关系
|
||||||
|
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
||||||
|
problem_paragraph_mapping_list) > 0 else None
|
||||||
# 修改长度
|
# 修改长度
|
||||||
update_document_char_length(document_id)
|
update_document_char_length(document_id)
|
||||||
if with_embedding:
|
if with_embedding:
|
||||||
@ -229,12 +426,35 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||||||
content=instance.get("content"),
|
content=instance.get("content"),
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
title=instance.get("title") if 'title' in instance else '')
|
title=instance.get("title") if 'title' in instance else '')
|
||||||
|
problem_list = instance.get('problem_list')
|
||||||
|
exists_problem_list = []
|
||||||
|
if 'problem_list' in instance and len(problem_list) > 0:
|
||||||
|
exists_problem_list = QuerySet(Problem).filter(dataset_id=dataset_id,
|
||||||
|
content__in=[p.get('content') for p in
|
||||||
|
problem_list]).all()
|
||||||
|
|
||||||
problem_model_list = [Problem(id=uuid.uuid1(), content=problem.get('content'), paragraph_id=paragraph.id,
|
problem_model_list = [
|
||||||
document_id=document_id, dataset_id=dataset_id) for problem in (
|
ParagraphSerializers.Create.or_get(exists_problem_list, problem.get('content'), dataset_id) for
|
||||||
instance.get('problem_list') if 'problem_list' in instance else [])]
|
problem in (
|
||||||
|
instance.get('problem_list') if 'problem_list' in instance else [])]
|
||||||
|
|
||||||
return {'paragraph': paragraph, 'problem_model_list': problem_model_list}
|
problem_paragraph_mapping_list = [
|
||||||
|
ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id,
|
||||||
|
paragraph_id=paragraph.id,
|
||||||
|
dataset_id=dataset_id) for
|
||||||
|
problem_model in problem_model_list]
|
||||||
|
return {'paragraph': paragraph,
|
||||||
|
'problem_model_list': [problem_model for problem_model in problem_model_list if
|
||||||
|
not list(exists_problem_list).__contains__(problem_model)],
|
||||||
|
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def or_get(exists_problem_list, content, dataset_id):
|
||||||
|
exists = [row for row in exists_problem_list if row.content == content]
|
||||||
|
if len(exists) > 0:
|
||||||
|
return exists[0]
|
||||||
|
else:
|
||||||
|
return Problem(id=uuid.uuid1(), content=content, dataset_id=dataset_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_request_body_api():
|
def get_request_body_api():
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
@date:2023/10/23 13:55
|
@date:2023/10/23 13:55
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
@ -14,19 +15,19 @@ from django.db.models import QuerySet
|
|||||||
from drf_yasg import openapi
|
from drf_yasg import openapi
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from common.event.listener_manage import ListenerManagement
|
from common.db.search import native_search, native_page_search
|
||||||
from common.exception.app_exception import AppApiException
|
from common.event import ListenerManagement, UpdateProblemArgs
|
||||||
from common.mixins.api_mixin import ApiMixin
|
from common.mixins.api_mixin import ApiMixin
|
||||||
from common.util.field_message import ErrMessage
|
from common.util.field_message import ErrMessage
|
||||||
from dataset.models import Problem, Paragraph
|
from common.util.file_util import get_file_content
|
||||||
from embedding.models import SourceType
|
from dataset.models import Problem, Paragraph, ProblemParagraphMapping
|
||||||
from embedding.vector.pg_vector import PGVector
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
class ProblemSerializer(serializers.ModelSerializer):
|
class ProblemSerializer(serializers.ModelSerializer):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Problem
|
model = Problem
|
||||||
fields = ['id', 'content', 'dataset_id', 'document_id',
|
fields = ['id', 'content', 'dataset_id',
|
||||||
'create_time', 'update_time']
|
'create_time', 'update_time']
|
||||||
|
|
||||||
|
|
||||||
@ -49,186 +50,92 @@ class ProblemInstanceSerializer(ApiMixin, serializers.Serializer):
|
|||||||
|
|
||||||
|
|
||||||
class ProblemSerializers(ApiMixin, serializers.Serializer):
|
class ProblemSerializers(ApiMixin, serializers.Serializer):
|
||||||
class Create(ApiMixin, serializers.Serializer):
|
class Create(serializers.Serializer):
|
||||||
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
||||||
|
problem_list = serializers.ListField(required=True, error_messages=ErrMessage.list("问题列表"),
|
||||||
|
child=serializers.CharField(required=True,
|
||||||
|
error_messages=ErrMessage.char("问题")))
|
||||||
|
|
||||||
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
def batch(self, with_valid=True):
|
||||||
|
|
||||||
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
|
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=False):
|
|
||||||
super().is_valid(raise_exception=True)
|
|
||||||
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id'),
|
|
||||||
document_id=self.data.get('document_id'),
|
|
||||||
dataset_id=self.data.get('dataset_id')).exists():
|
|
||||||
raise AppApiException(500, "段落id不正确")
|
|
||||||
|
|
||||||
@transaction.atomic
|
|
||||||
def save(self, instance: Dict, with_valid=True, with_embedding=True):
|
|
||||||
if with_valid:
|
|
||||||
self.is_valid()
|
|
||||||
ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
|
||||||
problem = Problem(id=uuid.uuid1(), paragraph_id=self.data.get('paragraph_id'),
|
|
||||||
document_id=self.data.get('document_id'), dataset_id=self.data.get('dataset_id'),
|
|
||||||
content=instance.get('content'))
|
|
||||||
problem.save()
|
|
||||||
if with_embedding:
|
|
||||||
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
|
||||||
'is_active': True,
|
|
||||||
'source_type': SourceType.PROBLEM,
|
|
||||||
'source_id': problem.id,
|
|
||||||
'document_id': self.data.get('document_id'),
|
|
||||||
'paragraph_id': self.data.get('paragraph_id'),
|
|
||||||
'dataset_id': self.data.get('dataset_id'),
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
return ProblemSerializers.Operate(
|
|
||||||
data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'),
|
|
||||||
'paragraph_id': self.data.get('paragraph_id'), 'problem_id': problem.id}).one(with_valid=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_request_body_api():
|
|
||||||
return ProblemInstanceSerializer.get_request_body_api()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_request_params_api():
|
|
||||||
return [openapi.Parameter(name='dataset_id',
|
|
||||||
in_=openapi.IN_PATH,
|
|
||||||
type=openapi.TYPE_STRING,
|
|
||||||
required=True,
|
|
||||||
description='知识库id'),
|
|
||||||
openapi.Parameter(name='document_id',
|
|
||||||
in_=openapi.IN_PATH,
|
|
||||||
type=openapi.TYPE_STRING,
|
|
||||||
required=True,
|
|
||||||
description='文档id'),
|
|
||||||
openapi.Parameter(name='paragraph_id',
|
|
||||||
in_=openapi.IN_PATH,
|
|
||||||
type=openapi.TYPE_STRING,
|
|
||||||
required=True,
|
|
||||||
description='段落id')]
|
|
||||||
|
|
||||||
class Query(ApiMixin, serializers.Serializer):
|
|
||||||
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
|
||||||
|
|
||||||
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
|
||||||
|
|
||||||
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
|
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=True):
|
|
||||||
super().is_valid(raise_exception=True)
|
|
||||||
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
|
|
||||||
raise AppApiException(500, "段落id不存在")
|
|
||||||
|
|
||||||
def get_query_set(self):
|
|
||||||
dataset_id = self.data.get('dataset_id')
|
|
||||||
document_id = self.data.get('document_id')
|
|
||||||
paragraph_id = self.data.get("paragraph_id")
|
|
||||||
return QuerySet(Problem).filter(
|
|
||||||
**{'paragraph_id': paragraph_id, 'dataset_id': dataset_id, 'document_id': document_id})
|
|
||||||
|
|
||||||
def list(self, with_valid=False):
|
|
||||||
"""
|
|
||||||
获取问题列表
|
|
||||||
:param with_valid: 是否校验
|
|
||||||
:return: 问题列表
|
|
||||||
"""
|
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
query_set = self.get_query_set()
|
problem_list = self.data.get('problem_list')
|
||||||
return [ProblemSerializer(p).data for p in query_set]
|
dataset_id = self.data.get('dataset_id')
|
||||||
|
exists_problem_content_list = [problem.content for problem in
|
||||||
|
QuerySet(Problem).filter(dataset_id=dataset_id,
|
||||||
|
content__in=problem_list)]
|
||||||
|
problem_instance_list = [Problem(id=uuid.uuid1(), dataset_id=dataset_id, content=problem_content) for
|
||||||
|
problem_content in
|
||||||
|
self.data.get('problem_list') if
|
||||||
|
(not exists_problem_content_list.__contains__(problem_content) if
|
||||||
|
len(exists_problem_content_list) > 0 else True)]
|
||||||
|
|
||||||
@staticmethod
|
QuerySet(Problem).bulk_create(problem_instance_list) if len(problem_instance_list) > 0 else None
|
||||||
def get_request_params_api():
|
return [ProblemSerializer(problem_instance).data for problem_instance in problem_instance_list]
|
||||||
return [openapi.Parameter(name='dataset_id',
|
|
||||||
in_=openapi.IN_PATH,
|
|
||||||
type=openapi.TYPE_STRING,
|
|
||||||
required=True,
|
|
||||||
description='知识库id'),
|
|
||||||
openapi.Parameter(name='document_id',
|
|
||||||
in_=openapi.IN_PATH,
|
|
||||||
type=openapi.TYPE_STRING,
|
|
||||||
required=True,
|
|
||||||
description='文档id')
|
|
||||||
, openapi.Parameter(name='paragraph_id',
|
|
||||||
in_=openapi.IN_PATH,
|
|
||||||
type=openapi.TYPE_STRING,
|
|
||||||
required=True,
|
|
||||||
description='段落id')]
|
|
||||||
|
|
||||||
class Operate(ApiMixin, serializers.Serializer):
|
class Query(serializers.Serializer):
|
||||||
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
||||||
|
content = serializers.CharField(required=False, error_messages=ErrMessage.char("问题"))
|
||||||
|
|
||||||
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
def get_query_set(self):
|
||||||
|
query_set = QuerySet(model=Problem)
|
||||||
|
query_set = query_set.filter(
|
||||||
|
**{'dataset_id': self.data.get('dataset_id')})
|
||||||
|
if 'content' in self.data:
|
||||||
|
query_set = query_set.filter(**{'content__contains': self.data.get('content')})
|
||||||
|
return query_set
|
||||||
|
|
||||||
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
|
def list(self):
|
||||||
|
query_set = self.get_query_set()
|
||||||
|
return native_search(query_set, select_string=get_file_content(
|
||||||
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem.sql')))
|
||||||
|
|
||||||
|
def page(self, current_page, page_size):
|
||||||
|
query_set = self.get_query_set()
|
||||||
|
return native_page_search(current_page, page_size, query_set, select_string=get_file_content(
|
||||||
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem.sql')))
|
||||||
|
|
||||||
|
class Operate(serializers.Serializer):
|
||||||
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
||||||
|
|
||||||
problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("问题id"))
|
problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("问题id"))
|
||||||
|
|
||||||
def delete(self, with_valid=False):
|
def list_paragraph(self, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
QuerySet(Problem).filter(**{'id': self.data.get('problem_id')}).delete()
|
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("dataset_id"),
|
||||||
PGVector().delete_by_source_id(self.data.get('problem_id'), SourceType.PROBLEM)
|
problem_id=self.data.get("problem_id"))
|
||||||
ListenerManagement.delete_embedding_by_source_signal.send(self.data.get('problem_id'))
|
return native_search(
|
||||||
return True
|
QuerySet(Paragraph).filter(id__in=[row.paragraph_id for row in problem_paragraph_mapping]),
|
||||||
|
select_string=get_file_content(
|
||||||
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql')))
|
||||||
|
|
||||||
def one(self, with_valid=False):
|
def one(self, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
return ProblemInstanceSerializer(QuerySet(Problem).get(**{'id': self.data.get('problem_id')})).data
|
return ProblemInstanceSerializer(QuerySet(Problem).get(**{'id': self.data.get('problem_id')})).data
|
||||||
|
|
||||||
@staticmethod
|
@transaction.atomic
|
||||||
def get_request_params_api():
|
def delete(self, with_valid=True):
|
||||||
return [openapi.Parameter(name='dataset_id',
|
if with_valid:
|
||||||
in_=openapi.IN_PATH,
|
self.is_valid(raise_exception=True)
|
||||||
type=openapi.TYPE_STRING,
|
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(
|
||||||
required=True,
|
dataset_id=self.data.get('dataset_id'),
|
||||||
description='知识库id'),
|
problem_id=self.data.get('problem_id'))
|
||||||
openapi.Parameter(name='document_id',
|
source_ids = [row.id for row in problem_paragraph_mapping_list]
|
||||||
in_=openapi.IN_PATH,
|
QuerySet(Problem).filter(id=self.data.get('problem_id')).delete()
|
||||||
type=openapi.TYPE_STRING,
|
ListenerManagement.delete_embedding_by_source_ids_signal.send(source_ids)
|
||||||
required=True,
|
return True
|
||||||
description='文档id')
|
|
||||||
, openapi.Parameter(name='paragraph_id',
|
|
||||||
in_=openapi.IN_PATH,
|
|
||||||
type=openapi.TYPE_STRING,
|
|
||||||
required=True,
|
|
||||||
description='段落id'),
|
|
||||||
openapi.Parameter(name='problem_id',
|
|
||||||
in_=openapi.IN_PATH,
|
|
||||||
type=openapi.TYPE_STRING,
|
|
||||||
required=True,
|
|
||||||
description='问题id')
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
@transaction.atomic
|
||||||
def get_response_body_api():
|
def edit(self, instance: Dict, with_valid=True):
|
||||||
return openapi.Schema(
|
if with_valid:
|
||||||
type=openapi.TYPE_OBJECT,
|
self.is_valid(raise_exception=True)
|
||||||
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id',
|
problem_id = self.data.get('problem_id')
|
||||||
'document_id',
|
dataset_id = self.data.get('dataset_id')
|
||||||
'create_time', 'update_time'],
|
content = instance.get('content')
|
||||||
properties={
|
problem = QuerySet(Problem).filter(id=problem_id,
|
||||||
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
dataset_id=dataset_id).first()
|
||||||
description="id", default="xx"),
|
problem.content = content
|
||||||
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
|
problem.save()
|
||||||
description="问题内容", default='问题内容'),
|
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content))
|
||||||
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
|
|
||||||
default=1),
|
|
||||||
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
|
|
||||||
description="点赞数量", default=1),
|
|
||||||
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
|
|
||||||
description="点踩数", default=1),
|
|
||||||
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
|
|
||||||
description="文档id", default='xxx'),
|
|
||||||
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
|
||||||
description="修改时间",
|
|
||||||
default="1970-01-01 00:00:00"),
|
|
||||||
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
|
||||||
description="创建时间",
|
|
||||||
default="1970-01-01 00:00:00"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,10 +1,5 @@
|
|||||||
SELECT
|
SELECT
|
||||||
problem."id",
|
problem.*,
|
||||||
problem."content",
|
(SELECT "count"("id") FROM "problem_paragraph_mapping" WHERE problem_id="problem"."id") as "paragraph_count"
|
||||||
problem_paragraph_mapping.hit_num,
|
|
||||||
problem_paragraph_mapping.star_num,
|
|
||||||
problem_paragraph_mapping.trample_num,
|
|
||||||
problem_paragraph_mapping.paragraph_id
|
|
||||||
FROM
|
FROM
|
||||||
problem problem
|
problem problem
|
||||||
LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem."id" = problem_paragraph_mapping.problem_id
|
|
||||||
|
|||||||
127
apps/dataset/swagger_api/problem_api.py
Normal file
127
apps/dataset/swagger_api/problem_api.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: problem_api.py
|
||||||
|
@date:2024/3/11 10:49
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from drf_yasg import openapi
|
||||||
|
|
||||||
|
from common.mixins.api_mixin import ApiMixin
|
||||||
|
|
||||||
|
|
||||||
|
class ProblemApi(ApiMixin):
|
||||||
|
@staticmethod
|
||||||
|
def get_response_body_api():
|
||||||
|
return openapi.Schema(
|
||||||
|
type=openapi.TYPE_OBJECT,
|
||||||
|
required=['id', 'content', 'hit_num', 'dataset_id', 'create_time', 'update_time'],
|
||||||
|
properties={
|
||||||
|
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
||||||
|
description="id", default="xx"),
|
||||||
|
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
|
||||||
|
description="问题内容", default='问题内容'),
|
||||||
|
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
|
||||||
|
default=1),
|
||||||
|
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id",
|
||||||
|
description="知识库id", default='xxx'),
|
||||||
|
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
||||||
|
description="修改时间",
|
||||||
|
default="1970-01-01 00:00:00"),
|
||||||
|
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
||||||
|
description="创建时间",
|
||||||
|
default="1970-01-01 00:00:00"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
class Operate(ApiMixin):
|
||||||
|
@staticmethod
|
||||||
|
def get_request_params_api():
|
||||||
|
return [openapi.Parameter(name='dataset_id',
|
||||||
|
in_=openapi.IN_PATH,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
required=True,
|
||||||
|
description='知识库id'),
|
||||||
|
openapi.Parameter(name='problem_id',
|
||||||
|
in_=openapi.IN_PATH,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
required=True,
|
||||||
|
description='问题id')]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_request_body_api():
|
||||||
|
return openapi.Schema(
|
||||||
|
type=openapi.TYPE_OBJECT,
|
||||||
|
required=['content'],
|
||||||
|
properties={
|
||||||
|
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
|
||||||
|
description="问题内容"),
|
||||||
|
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
class Paragraph(ApiMixin):
|
||||||
|
@staticmethod
|
||||||
|
def get_request_params_api():
|
||||||
|
return ProblemApi.Operate.get_request_params_api()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_response_body_api():
|
||||||
|
return openapi.Schema(
|
||||||
|
type=openapi.TYPE_OBJECT,
|
||||||
|
required=['content'],
|
||||||
|
properties={
|
||||||
|
'content': openapi.Schema(type=openapi.TYPE_STRING, max_length=4096, title="分段内容",
|
||||||
|
description="分段内容"),
|
||||||
|
'title': openapi.Schema(type=openapi.TYPE_STRING, max_length=256, title="分段标题",
|
||||||
|
description="分段标题"),
|
||||||
|
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
|
||||||
|
'hit_num': openapi.Schema(type=openapi.TYPE_NUMBER, title="命中次数", description="命中次数"),
|
||||||
|
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
||||||
|
description="修改时间",
|
||||||
|
default="1970-01-01 00:00:00"),
|
||||||
|
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
||||||
|
description="创建时间",
|
||||||
|
default="1970-01-01 00:00:00"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
class Query(ApiMixin):
|
||||||
|
@staticmethod
|
||||||
|
def get_request_params_api():
|
||||||
|
return [openapi.Parameter(name='dataset_id',
|
||||||
|
in_=openapi.IN_PATH,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
required=True,
|
||||||
|
description='知识库id'),
|
||||||
|
openapi.Parameter(name='content',
|
||||||
|
in_=openapi.IN_QUERY,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
required=False,
|
||||||
|
description='问题')]
|
||||||
|
|
||||||
|
class BatchCreate(ApiMixin):
|
||||||
|
@staticmethod
|
||||||
|
def get_request_body_api():
|
||||||
|
return openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||||
|
items=ProblemApi.Create.get_request_body_api())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_request_params_api():
|
||||||
|
return ProblemApi.Create.get_request_params_api()
|
||||||
|
|
||||||
|
class Create(ApiMixin):
|
||||||
|
@staticmethod
|
||||||
|
def get_request_body_api():
|
||||||
|
return openapi.Schema(type=openapi.TYPE_STRING, description="问题文本")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_request_params_api():
|
||||||
|
return [openapi.Parameter(name='dataset_id',
|
||||||
|
in_=openapi.IN_PATH,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
required=True,
|
||||||
|
description='知识库id')]
|
||||||
@ -28,7 +28,16 @@ urlpatterns = [
|
|||||||
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<paragraph_id>',
|
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<paragraph_id>',
|
||||||
views.Paragraph.Operate.as_view()),
|
views.Paragraph.Operate.as_view()),
|
||||||
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem',
|
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem',
|
||||||
views.Problem.as_view()),
|
views.Paragraph.Problem.as_view()),
|
||||||
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem/<str:problem_id>',
|
path(
|
||||||
views.Problem.Operate.as_view())
|
'dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem/<str:problem_id>/un_association',
|
||||||
|
views.Paragraph.Problem.UnAssociation.as_view()),
|
||||||
|
path(
|
||||||
|
'dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem/<str:problem_id>/association',
|
||||||
|
views.Paragraph.Problem.Association.as_view()),
|
||||||
|
path('dataset/<str:dataset_id>/problem', views.Problem.as_view()),
|
||||||
|
path('dataset/<str:dataset_id>/problem/<int:current_page>/<int:page_size>', views.Problem.Page.as_view()),
|
||||||
|
path('dataset/<str:dataset_id>/problem/<str:problem_id>', views.Problem.Operate.as_view()),
|
||||||
|
path('dataset/<str:dataset_id>/problem/<str:problem_id>/paragraph', views.Problem.Paragraph.as_view()),
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|||||||
@ -52,6 +52,73 @@ class Paragraph(APIView):
|
|||||||
return result.success(
|
return result.success(
|
||||||
ParagraphSerializers.Create(data={'dataset_id': dataset_id, 'document_id': document_id}).save(request.data))
|
ParagraphSerializers.Create(data={'dataset_id': dataset_id, 'document_id': document_id}).save(request.data))
|
||||||
|
|
||||||
|
class Problem(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@action(methods=['POST'], detail=False)
|
||||||
|
@swagger_auto_schema(operation_summary="添加关联问题",
|
||||||
|
operation_id="添加段落关联问题",
|
||||||
|
manual_parameters=ParagraphSerializers.Problem.get_request_params_api(),
|
||||||
|
request_body=ParagraphSerializers.Problem.get_request_body_api(),
|
||||||
|
responses=result.get_api_response(ParagraphSerializers.Problem.get_response_body_api()),
|
||||||
|
tags=["知识库/文档/段落"])
|
||||||
|
@has_permissions(
|
||||||
|
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||||
|
dynamic_tag=k.get('dataset_id')))
|
||||||
|
def post(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
|
||||||
|
return result.success(ParagraphSerializers.Problem(
|
||||||
|
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).save(
|
||||||
|
request.data, with_valid=True))
|
||||||
|
|
||||||
|
@action(methods=['GET'], detail=False)
|
||||||
|
@swagger_auto_schema(operation_summary="获取段落问题列表",
|
||||||
|
operation_id="获取段落问题列表",
|
||||||
|
manual_parameters=ParagraphSerializers.Problem.get_request_params_api(),
|
||||||
|
responses=result.get_api_array_response(
|
||||||
|
ParagraphSerializers.Problem.get_response_body_api()),
|
||||||
|
tags=["知识库/文档/段落"])
|
||||||
|
@has_permissions(
|
||||||
|
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||||
|
dynamic_tag=k.get('dataset_id')))
|
||||||
|
def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
|
||||||
|
return result.success(ParagraphSerializers.Problem(
|
||||||
|
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).list(
|
||||||
|
with_valid=True))
|
||||||
|
|
||||||
|
class UnAssociation(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@action(methods=['PUT'], detail=False)
|
||||||
|
@swagger_auto_schema(operation_summary="解除关联问题",
|
||||||
|
operation_id="解除关联问题",
|
||||||
|
manual_parameters=ParagraphSerializers.Association.get_request_params_api(),
|
||||||
|
responses=result.get_default_response(),
|
||||||
|
tags=["知识库/文档/段落"])
|
||||||
|
@has_permissions(
|
||||||
|
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||||
|
dynamic_tag=k.get('dataset_id')))
|
||||||
|
def put(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str):
|
||||||
|
return result.success(ParagraphSerializers.Association(
|
||||||
|
data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id,
|
||||||
|
'problem_id': problem_id}).un_association())
|
||||||
|
|
||||||
|
class Association(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@action(methods=['PUT'], detail=False)
|
||||||
|
@swagger_auto_schema(operation_summary="关联问题",
|
||||||
|
operation_id="关联问题",
|
||||||
|
manual_parameters=ParagraphSerializers.Association.get_request_params_api(),
|
||||||
|
responses=result.get_default_response(),
|
||||||
|
tags=["知识库/文档/段落"])
|
||||||
|
@has_permissions(
|
||||||
|
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||||
|
dynamic_tag=k.get('dataset_id')))
|
||||||
|
def put(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str):
|
||||||
|
return result.success(ParagraphSerializers.Association(
|
||||||
|
data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id,
|
||||||
|
'problem_id': problem_id}).association())
|
||||||
|
|
||||||
class Operate(APIView):
|
class Operate(APIView):
|
||||||
authentication_classes = [TokenAuth]
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
@ -61,7 +128,7 @@ class Paragraph(APIView):
|
|||||||
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
|
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
|
||||||
request_body=ParagraphSerializers.Operate.get_request_body_api(),
|
request_body=ParagraphSerializers.Operate.get_request_body_api(),
|
||||||
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api())
|
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api())
|
||||||
,tags=["知识库/文档/段落"])
|
, tags=["知识库/文档/段落"])
|
||||||
@has_permissions(
|
@has_permissions(
|
||||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||||
dynamic_tag=k.get('dataset_id')))
|
dynamic_tag=k.get('dataset_id')))
|
||||||
|
|||||||
@ -8,61 +8,115 @@
|
|||||||
"""
|
"""
|
||||||
from drf_yasg.utils import swagger_auto_schema
|
from drf_yasg.utils import swagger_auto_schema
|
||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
from rest_framework.request import Request
|
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
|
from rest_framework.views import Request
|
||||||
|
|
||||||
from common.auth import TokenAuth, has_permissions
|
from common.auth import TokenAuth, has_permissions
|
||||||
from common.constants.permission_constants import Permission, Group, Operate
|
from common.constants.permission_constants import Permission, Group, Operate
|
||||||
from common.response import result
|
from common.response import result
|
||||||
|
from common.util.common import query_params_to_single_dict
|
||||||
from dataset.serializers.problem_serializers import ProblemSerializers
|
from dataset.serializers.problem_serializers import ProblemSerializers
|
||||||
|
from dataset.swagger_api.problem_api import ProblemApi
|
||||||
|
|
||||||
|
|
||||||
class Problem(APIView):
|
class Problem(APIView):
|
||||||
authentication_classes = [TokenAuth]
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@action(methods=['GET'], detail=False)
|
||||||
|
@swagger_auto_schema(operation_summary="问题列表",
|
||||||
|
operation_id="问题列表",
|
||||||
|
manual_parameters=ProblemApi.Query.get_request_params_api(),
|
||||||
|
responses=result.get_api_array_response(ProblemApi.get_response_body_api()),
|
||||||
|
tags=["知识库/文档/段落/问题"]
|
||||||
|
)
|
||||||
|
@has_permissions(
|
||||||
|
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||||
|
dynamic_tag=k.get('dataset_id')))
|
||||||
|
def get(self, request: Request, dataset_id: str):
|
||||||
|
q = ProblemSerializers.Query(
|
||||||
|
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id})
|
||||||
|
q.is_valid(raise_exception=True)
|
||||||
|
return result.success(q.list())
|
||||||
|
|
||||||
@action(methods=['POST'], detail=False)
|
@action(methods=['POST'], detail=False)
|
||||||
@swagger_auto_schema(operation_summary="添加关联问题",
|
@swagger_auto_schema(operation_summary="创建问题",
|
||||||
operation_id="添加段落关联问题",
|
operation_id="创建问题",
|
||||||
manual_parameters=ProblemSerializers.Create.get_request_params_api(),
|
manual_parameters=ProblemApi.BatchCreate.get_request_params_api(),
|
||||||
request_body=ProblemSerializers.Create.get_request_body_api(),
|
request_body=ProblemApi.BatchCreate.get_request_body_api(),
|
||||||
responses=result.get_api_response(ProblemSerializers.Operate.get_response_body_api()),
|
responses=result.get_api_response(ProblemApi.Query.get_response_body_api()),
|
||||||
tags=["知识库/文档/段落/问题"])
|
tags=["知识库/文档/段落/问题"])
|
||||||
@has_permissions(
|
@has_permissions(
|
||||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||||
dynamic_tag=k.get('dataset_id')))
|
dynamic_tag=k.get('dataset_id')))
|
||||||
def post(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
|
def post(self, request: Request, dataset_id: str):
|
||||||
return result.success(ProblemSerializers.Create(
|
return result.success(
|
||||||
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).save(
|
ProblemSerializers.Create(
|
||||||
request.data, with_valid=True))
|
data={'dataset_id': dataset_id, 'problem_list': request.query_params.get('problem_list')}).save())
|
||||||
|
|
||||||
@action(methods=['GET'], detail=False)
|
class Paragraph(APIView):
|
||||||
@swagger_auto_schema(operation_summary="获取段落问题列表",
|
authentication_classes = [TokenAuth]
|
||||||
operation_id="获取段落问题列表",
|
|
||||||
manual_parameters=ProblemSerializers.Query.get_request_params_api(),
|
@action(methods=['GET'], detail=False)
|
||||||
responses=result.get_api_array_response(ProblemSerializers.Operate.get_response_body_api()),
|
@swagger_auto_schema(operation_summary="获取关联段落列表",
|
||||||
tags=["知识库/文档/段落/问题"])
|
operation_id="获取关联段落列表",
|
||||||
@has_permissions(
|
manual_parameters=ProblemApi.Paragraph.get_request_params_api(),
|
||||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
responses=result.get_api_array_response(ProblemApi.Paragraph.get_response_body_api()),
|
||||||
dynamic_tag=k.get('dataset_id')))
|
tags=["知识库/文档/段落/问题"])
|
||||||
def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
|
@has_permissions(
|
||||||
return result.success(ProblemSerializers.Query(
|
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||||
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).list(
|
dynamic_tag=k.get('dataset_id')))
|
||||||
with_valid=True))
|
def get(self, request: Request, dataset_id: str, problem_id: str):
|
||||||
|
return result.success(ProblemSerializers.Operate(
|
||||||
|
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id,
|
||||||
|
'problem_id': problem_id}).list_paragraph())
|
||||||
|
|
||||||
class Operate(APIView):
|
class Operate(APIView):
|
||||||
authentication_classes = [TokenAuth]
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
@action(methods=['DELETE'], detail=False)
|
@action(methods=['DELETE'], detail=False)
|
||||||
@swagger_auto_schema(operation_summary="删除段落问题",
|
@swagger_auto_schema(operation_summary="删除问题",
|
||||||
operation_id="删除段落问题",
|
operation_id="删除问题",
|
||||||
manual_parameters=ProblemSerializers.Query.get_request_params_api(),
|
manual_parameters=ProblemApi.Operate.get_request_params_api(),
|
||||||
responses=result.get_default_response(),
|
responses=result.get_default_response(),
|
||||||
tags=["知识库/文档/段落/问题"])
|
tags=["知识库/文档/段落/问题"])
|
||||||
@has_permissions(
|
@has_permissions(
|
||||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||||
dynamic_tag=k.get('dataset_id')))
|
dynamic_tag=k.get('dataset_id')))
|
||||||
def delete(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str):
|
def delete(self, request: Request, dataset_id: str, problem_id: str):
|
||||||
o = ProblemSerializers.Operate(
|
return result.success(ProblemSerializers.Operate(
|
||||||
data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id,
|
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id,
|
||||||
'problem_id': problem_id})
|
'problem_id': problem_id}).delete())
|
||||||
return result.success(o.delete(with_valid=True))
|
|
||||||
|
@action(methods=['PUT'], detail=False)
|
||||||
|
@swagger_auto_schema(operation_summary="修改问题",
|
||||||
|
operation_id="修改问题",
|
||||||
|
manual_parameters=ProblemApi.Operate.get_request_params_api(),
|
||||||
|
request_body=ProblemApi.Operate.get_request_body_api(),
|
||||||
|
responses=result.get_api_response(ProblemApi.get_response_body_api()),
|
||||||
|
tags=["知识库/文档/段落/问题"])
|
||||||
|
@has_permissions(
|
||||||
|
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||||
|
dynamic_tag=k.get('dataset_id')))
|
||||||
|
def put(self, request: Request, dataset_id: str, problem_id: str):
|
||||||
|
return result.success(ProblemSerializers.Operate(
|
||||||
|
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id,
|
||||||
|
'problem_id': problem_id}).edit(request.data))
|
||||||
|
|
||||||
|
class Page(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@action(methods=['GET'], detail=False)
|
||||||
|
@swagger_auto_schema(operation_summary="分页获取问题列表",
|
||||||
|
operation_id="分页获取问题列表",
|
||||||
|
manual_parameters=result.get_page_request_params(
|
||||||
|
ProblemApi.Query.get_request_params_api()),
|
||||||
|
responses=result.get_page_api_response(ProblemApi.get_response_body_api()),
|
||||||
|
tags=["知识库/文档/段落/问题"])
|
||||||
|
@has_permissions(
|
||||||
|
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||||
|
dynamic_tag=k.get('dataset_id')))
|
||||||
|
def get(self, request: Request, dataset_id: str, current_page, page_size):
|
||||||
|
d = ProblemSerializers.Query(
|
||||||
|
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id})
|
||||||
|
d.is_valid(raise_exception=True)
|
||||||
|
return result.success(d.page(current_page, page_size))
|
||||||
|
|||||||
@ -131,6 +131,18 @@ class BaseVectorStore(ABC):
|
|||||||
def update_by_source_id(self, source_id: str, instance: Dict):
|
def update_by_source_id(self, source_id: str, instance: Dict):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embed_documents(self, text_list: List[str]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embed_query(self, text: str):
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete_by_dataset_id(self, dataset_id: str):
|
def delete_by_dataset_id(self, dataset_id: str):
|
||||||
pass
|
pass
|
||||||
@ -147,6 +159,10 @@ class BaseVectorStore(ABC):
|
|||||||
def delete_by_source_id(self, source_id: str, source_type: str):
|
def delete_by_source_id(self, source_id: str, source_type: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete_by_paragraph_id(self, paragraph_id: str):
|
def delete_by_paragraph_id(self, paragraph_id: str):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from typing import Dict, List
|
|||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
|
from common.config.embedding_config import EmbeddingModel
|
||||||
from common.db.search import native_search, generate_sql_by_query_dict
|
from common.db.search import native_search, generate_sql_by_query_dict
|
||||||
from common.db.sql_execute import select_one, select_list
|
from common.db.sql_execute import select_one, select_list
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
@ -24,6 +25,20 @@ from smartdoc.conf import PROJECT_DIR
|
|||||||
|
|
||||||
class PGVector(BaseVectorStore):
|
class PGVector(BaseVectorStore):
|
||||||
|
|
||||||
|
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
|
||||||
|
QuerySet(Embedding).filter(source_id__in=source_ids, source_type=source_type).delete()
|
||||||
|
|
||||||
|
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
||||||
|
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
|
||||||
|
|
||||||
|
def embed_documents(self, text_list: List[str]):
|
||||||
|
embedding = EmbeddingModel.get_embedding_model()
|
||||||
|
return embedding.embed_documents(text_list)
|
||||||
|
|
||||||
|
def embed_query(self, text: str):
|
||||||
|
embedding = EmbeddingModel.get_embedding_model()
|
||||||
|
return embedding.embed_query(text)
|
||||||
|
|
||||||
def vector_is_create(self) -> bool:
|
def vector_is_create(self) -> bool:
|
||||||
# 项目启动默认是创建好的 不需要再创建
|
# 项目启动默认是创建好的 不需要再创建
|
||||||
return True
|
return True
|
||||||
|
|||||||
@ -137,7 +137,7 @@ const postProblem: (
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* 删除问题
|
* 解除关联问题
|
||||||
* @param 参数 dataset_id, document_id, paragraph_id,problem_id
|
* @param 参数 dataset_id, document_id, paragraph_id,problem_id
|
||||||
*/
|
*/
|
||||||
const delProblem: (
|
const delProblem: (
|
||||||
@ -146,8 +146,8 @@ const delProblem: (
|
|||||||
paragraph_id: string,
|
paragraph_id: string,
|
||||||
problem_id: string
|
problem_id: string
|
||||||
) => Promise<Result<boolean>> = (dataset_id, document_id, paragraph_id, problem_id) => {
|
) => Promise<Result<boolean>> = (dataset_id, document_id, paragraph_id, problem_id) => {
|
||||||
return del(
|
return put(
|
||||||
`${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem/${problem_id}`
|
`${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem/${problem_id}/un_association`
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user