# coding=utf-8 """ @project: maxkb @Author:虎 @file: document_serializers.py @date:2023/9/22 13:43 @desc: """ import logging import os import traceback import uuid from functools import reduce from typing import List, Dict from django.core import validators from django.db import transaction from django.db.models import QuerySet from drf_yasg import openapi from rest_framework import serializers from common.db.search import native_search, native_page_search from common.event.listener_manage import ListenerManagement from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.common import post from common.util.file_util import get_file_content from common.util.fork import Fork from common.util.split_model import SplitModel, get_split_model from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from smartdoc.conf import PROJECT_DIR class DocumentInstanceSerializer(ApiMixin, serializers.Serializer): name = serializers.CharField(required=True, validators=[ validators.MaxLengthValidator(limit_value=128, message="文档名称在1-128个字符之间"), validators.MinLengthValidator(limit_value=1, message="知识库名称在1-128个字符之间") ]) paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True) @staticmethod def get_request_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, required=['name', 'paragraphs'], properties={ 'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"), 'paragraphs': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表", items=ParagraphSerializers.Create.get_request_body_api()) } ) class DocumentSerializers(ApiMixin, serializers.Serializer): class Query(ApiMixin, serializers.Serializer): # 知识库id dataset_id = serializers.UUIDField(required=True) name = serializers.CharField(required=False, validators=[ validators.MaxLengthValidator(limit_value=128, message="文档名称在1-128个字符之间"), validators.MinLengthValidator(limit_value=1, message="知识库名称在1-128个字符之间") ]) def get_query_set(self): query_set = QuerySet(model=Document) query_set = query_set.filter(**{'dataset_id': self.data.get("dataset_id")}) if 'name' in self.data and self.data.get('name') is not None: query_set = query_set.filter(**{'name__contains': self.data.get('name')}) query_set = query_set.order_by('-create_time') return query_set def list(self, with_valid=False): if with_valid: self.is_valid(raise_exception=True) query_set = self.get_query_set() return native_search(query_set, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql'))) def page(self, current_page, page_size): query_set = self.get_query_set() return native_page_search(current_page, page_size, query_set, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql'))) @staticmethod def get_request_params_api(): return [openapi.Parameter(name='name', in_=openapi.IN_QUERY, type=openapi.TYPE_STRING, required=False, description='文档名称')] @staticmethod def get_response_body_api(): return openapi.Schema(type=openapi.TYPE_ARRAY, title="文档列表", description="文档列表", items=DocumentSerializers.Operate.get_response_body_api()) class Sync(ApiMixin, serializers.Serializer): document_id = serializers.UUIDField(required=True) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) document_id = self.data.get('document_id') first = QuerySet(Document).filter(id=document_id).first() if first is None: raise AppApiException(500, "文档id不存在") if first.type != Type.web: raise AppApiException(500, "只有web站点类型才支持同步") def sync(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) document_id = self.data.get('document_id') document = QuerySet(Document).filter(id=document_id).first() try: document.status = Status.embedding document.save() source_url = document.meta.get('source_url') selector_list = document.meta.get('selector').split(" ") if 'selector' in document.meta else [] result = Fork(source_url, selector_list).fork() if result.status == 200: # 删除段落 QuerySet(model=Paragraph).filter(document_id=document_id).delete() # 删除问题 QuerySet(model=Problem).filter(document_id=document_id).delete() # 删除向量库 ListenerManagement.delete_embedding_by_document_signal.send(document_id) paragraphs = get_split_model('web.md').parse(result.content) document.char_length = reduce(lambda x, y: x + y, [len(p.get('content')) for p in paragraphs], 0) document.save() document_paragraph_model = DocumentSerializers.Create.get_paragraph_model(document, paragraphs) paragraph_model_list = document_paragraph_model.get('paragraph_model_list') problem_model_list = document_paragraph_model.get('problem_model_list') # 批量插入段落 QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None # 批量插入问题 QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None else: document.status = Status.error document.save() except Exception as e: logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') document.status = Status.error document.save() return True class Operate(ApiMixin, serializers.Serializer): document_id = serializers.UUIDField(required=True) @staticmethod def get_request_params_api(): return [openapi.Parameter(name='dataset_id', in_=openapi.IN_PATH, type=openapi.TYPE_STRING, required=True, description='知识库id'), openapi.Parameter(name='document_id', in_=openapi.IN_PATH, type=openapi.TYPE_STRING, required=True, description='文档id') ] def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) document_id = self.data.get('document_id') if not QuerySet(Document).filter(id=document_id).exists(): raise AppApiException(500, "文档id不存在") def one(self, with_valid=False): if with_valid: self.is_valid(raise_exception=True) query_set = QuerySet(model=Document) query_set = query_set.filter(**{'id': self.data.get("document_id")}) return native_search(query_set, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=True) def edit(self, instance: Dict, with_valid=False): if with_valid: self.is_valid() _document = QuerySet(Document).get(id=self.data.get("document_id")) update_keys = ['name', 'is_active'] for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: _document.__setattr__(update_key, instance.get(update_key)) _document.save() return self.one() def refresh(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) document_id = self.data.get("document_id") document = QuerySet(Document).filter(id=document_id).first() if document.type == Type.web: # 如果是web站点,就是先同步 DocumentSerializers.Sync(data={'document_id': document_id}).sync() ListenerManagement.embedding_by_document_signal.send(document_id) return True @transaction.atomic def delete(self): document_id = self.data.get("document_id") QuerySet(model=Document).filter(id=document_id).delete() # 删除段落 QuerySet(model=Paragraph).filter(document_id=document_id).delete() # 删除问题 QuerySet(model=Problem).filter(document_id=document_id).delete() # 删除向量库 ListenerManagement.delete_embedding_by_document_signal.send(document_id) return True @staticmethod def get_response_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, required=['id', 'name', 'char_length', 'user_id', 'paragraph_count', 'is_active' 'update_time', 'create_time'], properties={ 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", description="id", default="xx"), 'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称", description="名称", default="测试知识库"), 'char_length': openapi.Schema(type=openapi.TYPE_INTEGER, title="字符数", description="字符数", default=10), 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="用户id", description="用户id"), 'paragraph_count': openapi.Schema(type=openapi.TYPE_INTEGER, title="文档数量", description="文档数量", default=1), 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用", default=True), 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description="修改时间", default="1970-01-01 00:00:00"), 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description="创建时间", default="1970-01-01 00:00:00" ) } ) @staticmethod def get_request_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, properties={ 'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"), 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"), } ) class Create(ApiMixin, serializers.Serializer): dataset_id = serializers.UUIDField(required=True) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) if not QuerySet(DataSet).filter(id=self.data.get('dataset_id')).exists(): raise AppApiException(10000, "知识库id不存在") return True @staticmethod def post_embedding(result, document_id): ListenerManagement.embedding_by_document_signal.send(document_id) return result @post(post_function=post_embedding) @transaction.atomic def save(self, instance: Dict, with_valid=False, **kwargs): if with_valid: DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True) self.is_valid(raise_exception=True) dataset_id = self.data.get('dataset_id') document_paragraph_model = self.get_document_paragraph_model(dataset_id, instance) document_model = document_paragraph_model.get('document') paragraph_model_list = document_paragraph_model.get('paragraph_model_list') problem_model_list = document_paragraph_model.get('problem_model_list') # 插入文档 document_model.save() # 批量插入段落 QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None # 批量插入问题 QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None document_id = str(document_model.id) return DocumentSerializers.Operate( data={'dataset_id': dataset_id, 'document_id': document_id}).one( with_valid=True), document_id @staticmethod def get_paragraph_model(document_model, paragraph_list: List): dataset_id = document_model.dataset_id paragraph_model_dict_list = [ParagraphSerializers.Create( data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).get_paragraph_problem_model( dataset_id, document_model.id, paragraph) for paragraph in paragraph_list] paragraph_model_list = [] problem_model_list = [] for paragraphs in paragraph_model_dict_list: paragraph = paragraphs.get('paragraph') for problem_model in paragraphs.get('problem_model_list'): problem_model_list.append(problem_model) paragraph_model_list.append(paragraph) return {'document': document_model, 'paragraph_model_list': paragraph_model_list, 'problem_model_list': problem_model_list} @staticmethod def get_document_paragraph_model(dataset_id, instance: Dict): document_model = Document( **{'dataset_id': dataset_id, 'id': uuid.uuid1(), 'name': instance.get('name'), 'char_length': reduce(lambda x, y: x + y, [len(p.get('content')) for p in instance.get('paragraphs', [])], 0), 'meta': instance.get('meta') if instance.get('meta') is not None else {}, 'type': instance.get('type') if instance.get('type') is not None else Type.base}) return DocumentSerializers.Create.get_paragraph_model(document_model, instance.get('paragraphs') if 'paragraphs' in instance else []) @staticmethod def get_request_body_api(): return DocumentInstanceSerializer.get_request_body_api() @staticmethod def get_request_params_api(): return [openapi.Parameter(name='dataset_id', in_=openapi.IN_PATH, type=openapi.TYPE_STRING, required=True, description='知识库id') ] class Split(ApiMixin, serializers.Serializer): file = serializers.ListField(required=True) limit = serializers.IntegerField(required=False) patterns = serializers.ListField(required=False, child=serializers.CharField(required=True)) with_filter = serializers.BooleanField(required=False) def is_valid(self, *, raise_exception=True): super().is_valid(raise_exception=True) files = self.data.get('file') for f in files: if f.size > 1024 * 1024 * 10: raise AppApiException(500, "上传文件最大不能超过10m") @staticmethod def get_request_params_api(): return [ openapi.Parameter(name='file', in_=openapi.IN_FORM, type=openapi.TYPE_ARRAY, items=openapi.Items(type=openapi.TYPE_FILE), required=True, description='上传文件'), openapi.Parameter(name='limit', in_=openapi.IN_FORM, required=False, type=openapi.TYPE_INTEGER, title="分段长度", description="分段长度"), openapi.Parameter(name='patterns', in_=openapi.IN_FORM, required=False, type=openapi.TYPE_ARRAY, items=openapi.Items(type=openapi.TYPE_STRING), title="分段正则列表", description="分段正则列表"), openapi.Parameter(name='with_filter', in_=openapi.IN_FORM, required=False, type=openapi.TYPE_BOOLEAN, title="是否清除特殊字符", description="是否清除特殊字符"), ] def parse(self): file_list = self.data.get("file") return list( map(lambda f: file_to_paragraph(f, self.data.get("patterns", None), self.data.get("with_filter", None), self.data.get("limit", None)), file_list)) class SplitPattern(ApiMixin, serializers.Serializer): @staticmethod def list(): return [{'key': "#", 'value': '^# .*'}, {'key': '##', 'value': '(? 0 else None # 批量插入段落 QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None # 批量插入问题 QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None # 查询文档 query_set = QuerySet(model=Document) query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]}) return native_search(query_set, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False), def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int): data = file.read() if pattern_list is not None and len(pattern_list) > 0: split_model = SplitModel(pattern_list, with_filter, limit) else: split_model = get_split_model(file.name, with_filter=with_filter, limit=limit) try: content = data.decode('utf-8') except BaseException as e: return {'name': file.name, 'content': []} return {'name': file.name, 'content': split_model.parse(content) }