diff --git a/apps/knowledge/serializers/common.py b/apps/knowledge/serializers/common.py index 571ab410..1a9ff065 100644 --- a/apps/knowledge/serializers/common.py +++ b/apps/knowledge/serializers/common.py @@ -8,10 +8,10 @@ """ import os import re -import uuid_utils.compat as uuid import zipfile from typing import List +import uuid_utils.compat as uuid from django.db.models import QuerySet from django.utils.translation import gettext_lazy as _ from rest_framework import serializers @@ -27,72 +27,6 @@ from maxkb.conf import PROJECT_DIR from models_provider.tools import get_model -def zip_dir(zip_path, output=None): - output = output or os.path.basename(zip_path) + '.zip' - zip = zipfile.ZipFile(output, 'w', zipfile.ZIP_DEFLATED) - for root, dirs, files in os.walk(zip_path): - relative_root = '' if root == zip_path else root.replace(zip_path, '') + os.sep - for filename in files: - zip.write(os.path.join(root, filename), relative_root + filename) - zip.close() - - -def is_valid_uuid(s): - try: - uuid.UUID(s) - return True - except ValueError: - return False - - -def write_image(zip_path: str, image_list: List[str]): - for image in image_list: - search = re.search("\(.*\)", image) - if search: - text = search.group() - if text.startswith('(/api/file/'): - r = text.replace('(/api/file/', '').replace(')', '') - r = r.strip().split(" ")[0] - if not is_valid_uuid(r): - break - file = QuerySet(File).filter(id=r).first() - if file is None: - break - zip_inner_path = os.path.join('api', 'file', r) - file_path = os.path.join(zip_path, zip_inner_path) - if not os.path.exists(os.path.dirname(file_path)): - os.makedirs(os.path.dirname(file_path)) - with open(os.path.join(zip_path, file_path), 'wb') as f: - f.write(file.get_bytes()) - # else: - # r = text.replace('(/api/image/', '').replace(')', '') - # r = r.strip().split(" ")[0] - # if not is_valid_uuid(r): - # break - # image_model = QuerySet(Image).filter(id=r).first() - # if image_model is None: - # break - # zip_inner_path = os.path.join('api', 'image', r) - # file_path = os.path.join(zip_path, zip_inner_path) - # if not os.path.exists(os.path.dirname(file_path)): - # os.makedirs(os.path.dirname(file_path)) - # with open(file_path, 'wb') as f: - # f.write(image_model.image) - - -def update_document_char_length(document_id: str): - update_execute(get_file_content( - os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_document_char_length.sql')), - (document_id, document_id)) - - -def list_paragraph(paragraph_list: List[str]): - if paragraph_list is None or len(paragraph_list) == 0: - return [] - return native_search(QuerySet(Paragraph).filter(id__in=paragraph_list), get_file_content( - os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_paragraph.sql'))) - - class MetaSerializer(serializers.Serializer): class WebMeta(serializers.Serializer): source_url = serializers.CharField(required=True, label=_('source url')) @@ -133,17 +67,11 @@ class ProblemParagraphObject: self.problem_content = problem_content -def or_get(exists_problem_list, content, knowledge_id, document_id, paragraph_id, problem_content_dict): - if content in problem_content_dict: - return problem_content_dict.get(content)[0], document_id, paragraph_id - exists = [row for row in exists_problem_list if row.content == content] - if len(exists) > 0: - problem_content_dict[content] = exists[0], False - return exists[0], document_id, paragraph_id - else: - problem = Problem(id=uuid.uuid7(), content=content, knowledge_id=knowledge_id) - problem_content_dict[content] = problem, True - return problem, document_id, paragraph_id +class GenerateRelatedSerializer(serializers.Serializer): + model_id = serializers.UUIDField(required=True, label=_('Model id')) + prompt = serializers.CharField(required=True, label=_('Prompt word')) + state_list = serializers.ListField(required=False, child=serializers.CharField(required=True), + label=_("state list")) class ProblemParagraphManage: @@ -216,8 +144,80 @@ def get_embedding_model_id_by_knowledge_id_list(knowledge_id_list: List): return str(knowledge_list[0].embedding_model_id) -class GenerateRelatedSerializer(serializers.Serializer): - model_id = serializers.UUIDField(required=True, label=_('Model id')) - prompt = serializers.CharField(required=True, label=_('Prompt word')) - state_list = serializers.ListField(required=False, child=serializers.CharField(required=True), - label=_("state list")) +def zip_dir(zip_path, output=None): + output = output or os.path.basename(zip_path) + '.zip' + zip = zipfile.ZipFile(output, 'w', zipfile.ZIP_DEFLATED) + for root, dirs, files in os.walk(zip_path): + relative_root = '' if root == zip_path else root.replace(zip_path, '') + os.sep + for filename in files: + zip.write(os.path.join(root, filename), relative_root + filename) + zip.close() + + +def is_valid_uuid(s): + try: + uuid.UUID(s) + return True + except ValueError: + return False + + +def write_image(zip_path: str, image_list: List[str]): + for image in image_list: + search = re.search("\(.*\)", image) + if search: + text = search.group() + if text.startswith('(/api/file/'): + r = text.replace('(/api/file/', '').replace(')', '') + r = r.strip().split(" ")[0] + if not is_valid_uuid(r): + break + file = QuerySet(File).filter(id=r).first() + if file is None: + break + zip_inner_path = os.path.join('api', 'file', r) + file_path = os.path.join(zip_path, zip_inner_path) + if not os.path.exists(os.path.dirname(file_path)): + os.makedirs(os.path.dirname(file_path)) + with open(os.path.join(zip_path, file_path), 'wb') as f: + f.write(file.get_bytes()) + # else: + # r = text.replace('(/api/image/', '').replace(')', '') + # r = r.strip().split(" ")[0] + # if not is_valid_uuid(r): + # break + # image_model = QuerySet(Image).filter(id=r).first() + # if image_model is None: + # break + # zip_inner_path = os.path.join('api', 'image', r) + # file_path = os.path.join(zip_path, zip_inner_path) + # if not os.path.exists(os.path.dirname(file_path)): + # os.makedirs(os.path.dirname(file_path)) + # with open(file_path, 'wb') as f: + # f.write(image_model.image) + + +def update_document_char_length(document_id: str): + update_execute(get_file_content( + os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_document_char_length.sql')), + (document_id, document_id)) + + +def list_paragraph(paragraph_list: List[str]): + if paragraph_list is None or len(paragraph_list) == 0: + return [] + return native_search(QuerySet(Paragraph).filter(id__in=paragraph_list), get_file_content( + os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_paragraph.sql'))) + + +def or_get(exists_problem_list, content, knowledge_id, document_id, paragraph_id, problem_content_dict): + if content in problem_content_dict: + return problem_content_dict.get(content)[0], document_id, paragraph_id + exists = [row for row in exists_problem_list if row.content == content] + if len(exists) > 0: + problem_content_dict[content] = exists[0], False + return exists[0], document_id, paragraph_id + else: + problem = Problem(id=uuid.uuid7(), content=content, knowledge_id=knowledge_id) + problem_content_dict[content] = problem, True + return problem, document_id, paragraph_id diff --git a/apps/knowledge/serializers/problem.py b/apps/knowledge/serializers/problem.py index 530716ab..ab95d7a6 100644 --- a/apps/knowledge/serializers/problem.py +++ b/apps/knowledge/serializers/problem.py @@ -26,6 +26,7 @@ class ProblemInstanceSerializer(serializers.Serializer): id = serializers.CharField(required=False, label=_('problem id')) content = serializers.CharField(required=True, max_length=256, label=_('content')) + class ProblemEditSerializer(serializers.Serializer): content = serializers.CharField(required=True, max_length=256, label=_('content')) @@ -56,25 +57,6 @@ class BatchAssociation(serializers.Serializer): paragraph_list = AssociationParagraph(many=True) -def is_exits(exits_problem_paragraph_mapping_list, new_paragraph_mapping): - filter_list = [exits_problem_paragraph_mapping for exits_problem_paragraph_mapping in - exits_problem_paragraph_mapping_list if - str(exits_problem_paragraph_mapping.paragraph_id) == new_paragraph_mapping.paragraph_id - and str(exits_problem_paragraph_mapping.problem_id) == new_paragraph_mapping.problem_id - and str(exits_problem_paragraph_mapping.knowledge_id) == new_paragraph_mapping.knowledge_id] - return len(filter_list) > 0 - - -def to_problem_paragraph_mapping(problem, document_id: str, paragraph_id: str, knowledge_id: str): - return ProblemParagraphMapping( - id=uuid.uuid7(), - document_id=document_id, - paragraph_id=paragraph_id, - knowledge_id=knowledge_id, - problem_id=str(problem.id) - ), problem - - class ProblemSerializers(serializers.Serializer): class BatchOperate(serializers.Serializer): knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) @@ -241,3 +223,22 @@ class ProblemSerializers(serializers.Serializer): query_set = self.get_query_set() return native_page_search(current_page, page_size, query_set, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_problem.sql'))) + + +def is_exits(exits_problem_paragraph_mapping_list, new_paragraph_mapping): + filter_list = [exits_problem_paragraph_mapping for exits_problem_paragraph_mapping in + exits_problem_paragraph_mapping_list if + str(exits_problem_paragraph_mapping.paragraph_id) == new_paragraph_mapping.paragraph_id + and str(exits_problem_paragraph_mapping.problem_id) == new_paragraph_mapping.problem_id + and str(exits_problem_paragraph_mapping.knowledge_id) == new_paragraph_mapping.knowledge_id] + return len(filter_list) > 0 + + +def to_problem_paragraph_mapping(problem, document_id: str, paragraph_id: str, knowledge_id: str): + return ProblemParagraphMapping( + id=uuid.uuid7(), + document_id=document_id, + paragraph_id=paragraph_id, + knowledge_id=knowledge_id, + problem_id=str(problem.id) + ), problem