fix: web站点 批量同步,批量删除,批量导入接口
This commit is contained in:
parent
91b613c4da
commit
344c336143
@ -9,6 +9,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import django.db.models
|
import django.db.models
|
||||||
from blinker import signal
|
from blinker import signal
|
||||||
@ -18,7 +19,7 @@ from common.config.embedding_config import VectorStore, EmbeddingModel
|
|||||||
from common.db.search import native_search, get_dynamics_model
|
from common.db.search import native_search, get_dynamics_model
|
||||||
from common.event.common import poxy
|
from common.event.common import poxy
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from common.util.fork import ForkManage
|
from common.util.fork import ForkManage, Fork
|
||||||
from common.util.lock import try_lock, un_lock
|
from common.util.lock import try_lock, un_lock
|
||||||
from dataset.models import Paragraph, Status, Document
|
from dataset.models import Paragraph, Status, Document
|
||||||
from embedding.models import SourceType
|
from embedding.models import SourceType
|
||||||
@ -36,6 +37,13 @@ class SyncWebDatasetArgs:
|
|||||||
self.handler = handler
|
self.handler = handler
|
||||||
|
|
||||||
|
|
||||||
|
class SyncWebDocumentArgs:
|
||||||
|
def __init__(self, source_url_list: List[str], selector: str, handler):
|
||||||
|
self.source_url_list = source_url_list
|
||||||
|
self.selector = selector
|
||||||
|
self.handler = handler
|
||||||
|
|
||||||
|
|
||||||
class ListenerManagement:
|
class ListenerManagement:
|
||||||
embedding_by_problem_signal = signal("embedding_by_problem")
|
embedding_by_problem_signal = signal("embedding_by_problem")
|
||||||
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
|
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
|
||||||
@ -49,6 +57,7 @@ class ListenerManagement:
|
|||||||
disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph')
|
disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph')
|
||||||
init_embedding_model_signal = signal('init_embedding_model')
|
init_embedding_model_signal = signal('init_embedding_model')
|
||||||
sync_web_dataset_signal = signal('sync_web_dataset')
|
sync_web_dataset_signal = signal('sync_web_dataset')
|
||||||
|
sync_web_document_signal = signal('sync_web_document')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def embedding_by_problem(args):
|
def embedding_by_problem(args):
|
||||||
@ -155,6 +164,13 @@ class ListenerManagement:
|
|||||||
def enable_embedding_by_paragraph(paragraph_id):
|
def enable_embedding_by_paragraph(paragraph_id):
|
||||||
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True})
|
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@poxy
|
||||||
|
def sync_web_document(args: SyncWebDocumentArgs):
|
||||||
|
for source_url in args.source_url_list:
|
||||||
|
result = Fork(base_fork_url=source_url, selector_list=args.selector.split(' ')).fork()
|
||||||
|
args.handler(source_url, args.selector, result)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@poxy
|
@poxy
|
||||||
def sync_web_dataset(args: SyncWebDatasetArgs):
|
def sync_web_dataset(args: SyncWebDatasetArgs):
|
||||||
@ -200,3 +216,5 @@ class ListenerManagement:
|
|||||||
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)
|
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)
|
||||||
# 同步web站点知识库
|
# 同步web站点知识库
|
||||||
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
|
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
|
||||||
|
# 同步web站点 文档
|
||||||
|
ListenerManagement.sync_web_document_signal.connect(self.sync_web_document)
|
||||||
|
|||||||
@ -10,9 +10,13 @@ import os
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
|
from drf_yasg import openapi
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
from common.db.search import native_search
|
from common.db.search import native_search
|
||||||
from common.db.sql_execute import update_execute
|
from common.db.sql_execute import update_execute
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.mixins.api_mixin import ApiMixin
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from dataset.models import Paragraph
|
from dataset.models import Paragraph
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
@ -29,3 +33,28 @@ def list_paragraph(paragraph_list: List[str]):
|
|||||||
return []
|
return []
|
||||||
return native_search(QuerySet(Paragraph).filter(id__in=paragraph_list), get_file_content(
|
return native_search(QuerySet(Paragraph).filter(id__in=paragraph_list), get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql')))
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql')))
|
||||||
|
|
||||||
|
|
||||||
|
class BatchSerializer(ApiMixin, serializers.Serializer):
|
||||||
|
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
|
||||||
|
|
||||||
|
def is_valid(self, *, model=None, raise_exception=False):
|
||||||
|
super().is_valid(raise_exception=True)
|
||||||
|
if model is not None:
|
||||||
|
id_list = self.data.get('id_list')
|
||||||
|
model_list = QuerySet(model).filter(id__in=id_list)
|
||||||
|
if len(model_list) != len(id_list):
|
||||||
|
model_id_list = [str(m.id) for m in model_list]
|
||||||
|
error_id_list = list(filter(lambda row_id: not model_id_list.__contains__(row_id), id_list))
|
||||||
|
raise AppApiException(500, f"id不正确:{error_id_list}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_request_body_api():
|
||||||
|
return openapi.Schema(
|
||||||
|
type=openapi.TYPE_OBJECT,
|
||||||
|
properties={
|
||||||
|
'id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||||
|
title="主键id列表",
|
||||||
|
description="主键id列表")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from rest_framework import serializers
|
|||||||
|
|
||||||
from common.db.search import native_search, native_page_search
|
from common.db.search import native_search, native_page_search
|
||||||
from common.event.common import work_thread_pool
|
from common.event.common import work_thread_pool
|
||||||
from common.event.listener_manage import ListenerManagement
|
from common.event.listener_manage import ListenerManagement, SyncWebDocumentArgs
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
from common.mixins.api_mixin import ApiMixin
|
from common.mixins.api_mixin import ApiMixin
|
||||||
from common.util.common import post
|
from common.util.common import post
|
||||||
@ -29,10 +29,28 @@ from common.util.file_util import get_file_content
|
|||||||
from common.util.fork import Fork
|
from common.util.fork import Fork
|
||||||
from common.util.split_model import SplitModel, get_split_model
|
from common.util.split_model import SplitModel, get_split_model
|
||||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status
|
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status
|
||||||
|
from dataset.serializers.common_serializers import BatchSerializer
|
||||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentWebInstanceSerializer(ApiMixin, serializers.Serializer):
|
||||||
|
source_url_list = serializers.ListField(required=True, child=serializers.CharField(required=True))
|
||||||
|
selector = serializers.CharField(required=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_request_body_api():
|
||||||
|
return openapi.Schema(
|
||||||
|
type=openapi.TYPE_OBJECT,
|
||||||
|
required=['source_url_list'],
|
||||||
|
properties={
|
||||||
|
'source_url_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表",
|
||||||
|
items=openapi.Schema(type=openapi.TYPE_STRING)),
|
||||||
|
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
|
class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
|
||||||
name = serializers.CharField(required=True,
|
name = serializers.CharField(required=True,
|
||||||
validators=[
|
validators=[
|
||||||
@ -121,6 +139,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
document_id = self.data.get('document_id')
|
document_id = self.data.get('document_id')
|
||||||
document = QuerySet(Document).filter(id=document_id).first()
|
document = QuerySet(Document).filter(id=document_id).first()
|
||||||
|
if document.type != Type.web:
|
||||||
|
return True
|
||||||
try:
|
try:
|
||||||
document.status = Status.embedding
|
document.status = Status.embedding
|
||||||
document.save()
|
document.save()
|
||||||
@ -301,6 +321,38 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
|
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||||
with_valid=True), document_id
|
with_valid=True), document_id
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_sync_handler(dataset_id):
|
||||||
|
def handler(source_url: str, selector, response: Fork.Response):
|
||||||
|
if response.status == 200:
|
||||||
|
try:
|
||||||
|
paragraphs = get_split_model('web.md').parse(response.content)
|
||||||
|
# 插入
|
||||||
|
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
|
||||||
|
{'name': source_url, 'paragraphs': paragraphs,
|
||||||
|
'meta': {'source_url': source_url, 'selector': selector},
|
||||||
|
'type': Type.web}, with_valid=True)
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||||
|
else:
|
||||||
|
Document(name=source_url,
|
||||||
|
meta={'source_url': source_url, 'selector': selector},
|
||||||
|
type=Type.web,
|
||||||
|
char_length=0,
|
||||||
|
status=Status.error).save()
|
||||||
|
|
||||||
|
return handler
|
||||||
|
|
||||||
|
def save_web(self, instance: Dict, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
dataset_id = self.data.get('dataset_id')
|
||||||
|
source_url_list = instance.get('source_url_list')
|
||||||
|
selector = instance.get('selector')
|
||||||
|
args = SyncWebDocumentArgs(source_url_list, selector, self.get_sync_handler(dataset_id))
|
||||||
|
ListenerManagement.sync_web_document_signal.send(args)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_paragraph_model(document_model, paragraph_list: List):
|
def get_paragraph_model(document_model, paragraph_list: List):
|
||||||
dataset_id = document_model.dataset_id
|
dataset_id = document_model.dataset_id
|
||||||
@ -331,8 +383,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
'meta': instance.get('meta') if instance.get('meta') is not None else {},
|
'meta': instance.get('meta') if instance.get('meta') is not None else {},
|
||||||
'type': instance.get('type') if instance.get('type') is not None else Type.base})
|
'type': instance.get('type') if instance.get('type') is not None else Type.base})
|
||||||
|
|
||||||
return DocumentSerializers.Create.get_paragraph_model(document_model, instance.get('paragraphs') if
|
return DocumentSerializers.Create.get_paragraph_model(document_model,
|
||||||
'paragraphs' in instance else [])
|
instance.get('paragraphs') if
|
||||||
|
'paragraphs' in instance else [])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_request_body_api():
|
def get_request_body_api():
|
||||||
@ -451,6 +504,27 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||||||
return native_search(query_set, select_string=get_file_content(
|
return native_search(query_set, select_string=get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False),
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False),
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _batch_sync(document_id_list: List[str]):
|
||||||
|
for document_id in document_id_list:
|
||||||
|
DocumentSerializers.Sync(data={'document_id': document_id}).sync()
|
||||||
|
|
||||||
|
def batch_sync(self, instance: Dict, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
# 异步同步
|
||||||
|
work_thread_pool.submit(self._batch_sync,
|
||||||
|
instance.get('id_list'))
|
||||||
|
return True
|
||||||
|
|
||||||
|
def batch_delete(self, instance: Dict, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
QuerySet(Document).filter(id__in=instance.get('id_list')).delete()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int):
|
def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int):
|
||||||
data = file.read()
|
data = file.read()
|
||||||
|
|||||||
@ -12,6 +12,7 @@ urlpatterns = [
|
|||||||
path('dataset/<str:dataset_id>/sync_web', views.Dataset.SyncWeb.as_view()),
|
path('dataset/<str:dataset_id>/sync_web', views.Dataset.SyncWeb.as_view()),
|
||||||
path('dataset/<str:dataset_id>/hit_test', views.Dataset.HitTest.as_view()),
|
path('dataset/<str:dataset_id>/hit_test', views.Dataset.HitTest.as_view()),
|
||||||
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
|
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
|
||||||
|
path('dataset/<str:dataset_id>/document/web', views.WebDocument.as_view()),
|
||||||
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),
|
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),
|
||||||
path('dataset/<str:dataset_id>/document/<int:current_page>/<int:page_size>', views.Document.Page.as_view()),
|
path('dataset/<str:dataset_id>/document/<int:current_page>/<int:page_size>', views.Document.Page.as_view()),
|
||||||
path('dataset/<str:dataset_id>/document/<str:document_id>', views.Document.Operate.as_view(),
|
path('dataset/<str:dataset_id>/document/<str:document_id>', views.Document.Operate.as_view(),
|
||||||
|
|||||||
@ -14,11 +14,29 @@ from rest_framework.views import APIView
|
|||||||
from rest_framework.views import Request
|
from rest_framework.views import Request
|
||||||
|
|
||||||
from common.auth import TokenAuth, has_permissions
|
from common.auth import TokenAuth, has_permissions
|
||||||
from common.constants.permission_constants import Permission, Group, Operate, PermissionConstants
|
from common.constants.permission_constants import Permission, Group, Operate
|
||||||
from common.event.common import work_thread_pool
|
|
||||||
from common.response import result
|
from common.response import result
|
||||||
from common.util.common import query_params_to_single_dict
|
from common.util.common import query_params_to_single_dict
|
||||||
from dataset.serializers.document_serializers import DocumentSerializers
|
from dataset.serializers.common_serializers import BatchSerializer
|
||||||
|
from dataset.serializers.document_serializers import DocumentSerializers, DocumentWebInstanceSerializer
|
||||||
|
|
||||||
|
|
||||||
|
class WebDocument(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@action(methods=['POST'], detail=False)
|
||||||
|
@swagger_auto_schema(operation_summary="创建Web站点文档",
|
||||||
|
operation_id="创建Web站点文档",
|
||||||
|
request_body=DocumentWebInstanceSerializer.get_request_body_api(),
|
||||||
|
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
|
||||||
|
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()),
|
||||||
|
tags=["知识库/文档"])
|
||||||
|
@has_permissions(
|
||||||
|
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||||
|
dynamic_tag=k.get('dataset_id')))
|
||||||
|
def post(self, request: Request, dataset_id: str):
|
||||||
|
return result.success(
|
||||||
|
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_web(request.data, with_valid=True))
|
||||||
|
|
||||||
|
|
||||||
class Document(APIView):
|
class Document(APIView):
|
||||||
@ -71,6 +89,34 @@ class Document(APIView):
|
|||||||
def post(self, request: Request, dataset_id: str):
|
def post(self, request: Request, dataset_id: str):
|
||||||
return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_save(request.data))
|
return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_save(request.data))
|
||||||
|
|
||||||
|
@action(methods=['POST'], detail=False)
|
||||||
|
@swagger_auto_schema(operation_summary="批量同步文档",
|
||||||
|
operation_id="批量同步文档",
|
||||||
|
request_body=
|
||||||
|
BatchSerializer.get_request_body_api(),
|
||||||
|
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
|
||||||
|
responses=result.get_default_response(),
|
||||||
|
tags=["知识库/文档"])
|
||||||
|
@has_permissions(
|
||||||
|
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||||
|
dynamic_tag=k.get('dataset_id')))
|
||||||
|
def put(self, request: Request, dataset_id: str):
|
||||||
|
return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_sync(request.data))
|
||||||
|
|
||||||
|
@action(methods=['DELETE'], detail=False)
|
||||||
|
@swagger_auto_schema(operation_summary="批量删除文档",
|
||||||
|
operation_id="批量删除文档",
|
||||||
|
request_body=
|
||||||
|
BatchSerializer.get_request_body_api(),
|
||||||
|
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
|
||||||
|
responses=result.get_default_response(),
|
||||||
|
tags=["知识库/文档"])
|
||||||
|
@has_permissions(
|
||||||
|
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||||
|
dynamic_tag=k.get('dataset_id')))
|
||||||
|
def delete(self, request: Request, dataset_id: str):
|
||||||
|
return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_delete(request.data))
|
||||||
|
|
||||||
class Refresh(APIView):
|
class Refresh(APIView):
|
||||||
authentication_classes = [TokenAuth]
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user