diff --git a/apps/common/chunk/__init__.py b/apps/common/chunk/__init__.py new file mode 100644 index 00000000..a4babde7 --- /dev/null +++ b/apps/common/chunk/__init__.py @@ -0,0 +1,18 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/7/23 17:03 + @desc: +""" +from common.chunk.impl.mark_chunk_handle import MarkChunkHandle + +handles = [MarkChunkHandle()] + + +def text_to_chunk(text: str): + chunk_list = [text] + for handle in handles: + chunk_list = handle.handle(chunk_list) + return chunk_list diff --git a/apps/common/chunk/i_chunk_handle.py b/apps/common/chunk/i_chunk_handle.py new file mode 100644 index 00000000..d53575d1 --- /dev/null +++ b/apps/common/chunk/i_chunk_handle.py @@ -0,0 +1,16 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: i_chunk_handle.py + @date:2024/7/23 16:51 + @desc: +""" +from abc import ABC, abstractmethod +from typing import List + + +class IChunkHandle(ABC): + @abstractmethod + def handle(self, chunk_list: List[str]): + pass diff --git a/apps/common/chunk/impl/mark_chunk_handle.py b/apps/common/chunk/impl/mark_chunk_handle.py new file mode 100644 index 00000000..5bca2f44 --- /dev/null +++ b/apps/common/chunk/impl/mark_chunk_handle.py @@ -0,0 +1,40 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: mark_chunk_handle.py + @date:2024/7/23 16:52 + @desc: +""" +import re +from typing import List + +from common.chunk.i_chunk_handle import IChunkHandle + +max_chunk_len = 256 +split_chunk_pattern = r'.{1,%d}[。| |\\.|!|;|;|!|\n]' % max_chunk_len +max_chunk_pattern = r'.{1,%d}' % max_chunk_len + + +class MarkChunkHandle(IChunkHandle): + def handle(self, chunk_list: List[str]): + result = [] + for chunk in chunk_list: + chunk_result = re.findall(split_chunk_pattern, chunk, flags=re.DOTALL) + for c_r in chunk_result: + if len(c_r.strip()) > 0: + result.append(c_r.strip()) + + other_chunk_list = re.split(split_chunk_pattern, chunk, flags=re.DOTALL) + for other_chunk in other_chunk_list: + if len(other_chunk) > 0: + if len(other_chunk) < max_chunk_len: + if len(other_chunk.strip()) > 0: + result.append(other_chunk.strip()) + else: + max_chunk_list = re.findall(max_chunk_pattern, other_chunk, flags=re.DOTALL) + for m_c in max_chunk_list: + if len(m_c.strip()) > 0: + result.append(m_c.strip()) + + return result diff --git a/apps/common/config/embedding_config.py b/apps/common/config/embedding_config.py index 6b66569d..50ed5e23 100644 --- a/apps/common/config/embedding_config.py +++ b/apps/common/config/embedding_config.py @@ -47,20 +47,20 @@ class ModelManage: ModelManage.cache.delete(_id) -# class VectorStore: -# from embedding.vector.pg_vector import PGVector -# from embedding.vector.base_vector import BaseVectorStore -# instance_map = { -# 'pg_vector': PGVector, -# } -# instance = None -# -# @staticmethod -# def get_embedding_vector() -> BaseVectorStore: -# from embedding.vector.pg_vector import PGVector -# if VectorStore.instance is None: -# from maxkb.const import CONFIG -# vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"), -# PGVector) -# VectorStore.instance = vector_store_class() -# return VectorStore.instance +class VectorStore: + from knowledge.vector.pg_vector import PGVector + from knowledge.vector.base_vector import BaseVectorStore + instance_map = { + 'pg_vector': PGVector, + } + instance = None + + @staticmethod + def get_embedding_vector() -> BaseVectorStore: + from knowledge.vector.pg_vector import PGVector + if VectorStore.instance is None: + from maxkb.const import CONFIG + vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"), + PGVector) + VectorStore.instance = vector_store_class() + return VectorStore.instance diff --git a/apps/common/event/__init__.py b/apps/common/event/__init__.py new file mode 100644 index 00000000..c23c4537 --- /dev/null +++ b/apps/common/event/__init__.py @@ -0,0 +1,30 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/11/10 10:43 + @desc: +""" +from models_provider.models import Model, Status +from .listener_manage import * +from django.utils.translation import gettext as _ + +from ..db.sql_execute import update_execute +from common.lock.impl.file_lock import FileLock + +lock = FileLock() +update_document_status_sql = """ +UPDATE "public"."document" +SET status ="replace"("replace"("replace"(status, '1', '3'), '0', '3'), '4', '3') +WHERE status ~ '1|0|4' +""" + + +def run(): + if lock.try_lock('event_init', 30 * 30): + try: + QuerySet(Model).filter(status=Status.DOWNLOAD).update(status=Status.ERROR, meta={'message': _( 'The download process was interrupted, please try again')}) + update_execute(update_document_status_sql, []) + finally: + lock.un_lock('event_init') diff --git a/apps/common/event/common.py b/apps/common/event/common.py new file mode 100644 index 00000000..a54d24df --- /dev/null +++ b/apps/common/event/common.py @@ -0,0 +1,50 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: common.py + @date:2023/11/10 10:41 + @desc: +""" +from concurrent.futures import ThreadPoolExecutor + +from django.core.cache.backends.locmem import LocMemCache + +work_thread_pool = ThreadPoolExecutor(5) + +embedding_thread_pool = ThreadPoolExecutor(3) + +memory_cache = LocMemCache('task', {"OPTIONS": {"MAX_ENTRIES": 1000}}) + + +def poxy(poxy_function): + def inner(args, **keywords): + work_thread_pool.submit(poxy_function, args, **keywords) + + return inner + + +def get_cache_key(poxy_function, args): + return poxy_function.__name__ + str(args) + + +def get_cache_poxy_function(poxy_function, cache_key): + def fun(args, **keywords): + try: + poxy_function(args, **keywords) + finally: + memory_cache.delete(cache_key) + + return fun + + +def embedding_poxy(poxy_function): + def inner(*args, **keywords): + key = get_cache_key(poxy_function, args) + if memory_cache.has_key(key): + return + memory_cache.add(key, None) + f = get_cache_poxy_function(poxy_function, key) + embedding_thread_pool.submit(f, args, **keywords) + + return inner diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py new file mode 100644 index 00000000..9c52e00e --- /dev/null +++ b/apps/common/event/listener_manage.py @@ -0,0 +1,385 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: listener_manage.py + @date:2023/10/20 14:01 + @desc: +""" +import logging +import os +import threading +import datetime +import traceback +from typing import List + +import django.db.models +from django.db.models import QuerySet +from django.db.models.functions import Substr, Reverse +from langchain_core.embeddings import Embeddings + +from common.config.embedding_config import VectorStore +from common.db.search import native_search, get_dynamics_model, native_update +from common.utils.common import get_file_content +from common.utils.lock import try_lock, un_lock +from common.utils.page_utils import page_desc +from knowledge.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State,SourceType, SearchMode +from maxkb.conf import (PROJECT_DIR) +from django.utils.translation import gettext_lazy as _ + +max_kb_error = logging.getLogger(__file__) +max_kb = logging.getLogger(__file__) +lock = threading.Lock() + + +class SyncWebKnowledgeArgs: + def __init__(self, lock_key: str, url: str, selector: str, handler): + self.lock_key = lock_key + self.url = url + self.selector = selector + self.handler = handler + + +class SyncWebDocumentArgs: + def __init__(self, source_url_list: List[str], selector: str, handler): + self.source_url_list = source_url_list + self.selector = selector + self.handler = handler + + +class UpdateProblemArgs: + def __init__(self, problem_id: str, problem_content: str, embedding_model: Embeddings): + self.problem_id = problem_id + self.problem_content = problem_content + self.embedding_model = embedding_model + + +class UpdateEmbeddingKnowledgeIdArgs: + def __init__(self, paragraph_id_list: List[str], target_knowledge_id: str): + self.paragraph_id_list = paragraph_id_list + self.target_knowledge_id = target_knowledge_id + + +class UpdateEmbeddingDocumentIdArgs: + def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_knowledge_id: str, + target_embedding_model: Embeddings = None): + self.paragraph_id_list = paragraph_id_list + self.target_document_id = target_document_id + self.target_knowledge_id = target_knowledge_id + self.target_embedding_model = target_embedding_model + + +class ListenerManagement: + + @staticmethod + def embedding_by_problem(args, embedding_model: Embeddings): + VectorStore.get_embedding_vector().save(**args, embedding=embedding_model) + + @staticmethod + def embedding_by_paragraph_list(paragraph_id_list, embedding_model: Embeddings): + try: + data_list = native_search( + {'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter( + **{'paragraph.id__in': paragraph_id_list}), + 'paragraph': QuerySet(Paragraph).filter(id__in=paragraph_id_list)}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) + ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list=paragraph_id_list, + embedding_model=embedding_model) + except Exception as e: + max_kb_error.error(_('Query vector data: {paragraph_id_list} error {error} {traceback}').format( + paragraph_id_list=paragraph_id_list, error=str(e), traceback=traceback.format_exc())) + + @staticmethod + def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings): + max_kb.info(_('Start--->Embedding paragraph: {paragraph_id_list}').format(paragraph_id_list=paragraph_id_list)) + status = Status.success + try: + # 删除段落 + VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_id_list) + + def is_save_function(): + return QuerySet(Paragraph).filter(id__in=paragraph_id_list).exists() + + # 批量向量化 + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) + except Exception as e: + max_kb_error.error(_('Vectorized paragraph: {paragraph_id_list} error {error} {traceback}').format( + paragraph_id_list=paragraph_id_list, error=str(e), traceback=traceback.format_exc())) + status = Status.error + finally: + QuerySet(Paragraph).filter(id__in=paragraph_id_list).update(**{'status': status}) + max_kb.info( + _('End--->Embedding paragraph: {paragraph_id_list}').format(paragraph_id_list=paragraph_id_list)) + + @staticmethod + def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings): + """ + 向量化段落 根据段落id + @param paragraph_id: 段落id + @param embedding_model: 向量模型 + """ + max_kb.info(_('Start--->Embedding paragraph: {paragraph_id}').format(paragraph_id=paragraph_id)) + # 更新到开始状态 + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, State.STARTED) + try: + data_list = native_search( + {'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter( + **{'paragraph.id': paragraph_id}), + 'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) + # 删除段落 + VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id) + + def is_the_task_interrupted(): + _paragraph = QuerySet(Paragraph).filter(id=paragraph_id).first() + if _paragraph is None or Status(_paragraph.status)[TaskType.EMBEDDING] == State.REVOKE: + return True + return False + + # 批量向量化 + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_the_task_interrupted) + # 更新到开始状态 + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, + State.SUCCESS) + except Exception as e: + max_kb_error.error(_('Vectorized paragraph: {paragraph_id} error {error} {traceback}').format( + paragraph_id=paragraph_id, error=str(e), traceback=traceback.format_exc())) + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, + State.FAILURE) + finally: + max_kb.info(_('End--->Embedding paragraph: {paragraph_id}').format(paragraph_id=paragraph_id)) + + @staticmethod + def embedding_by_data_list(data_list: List, embedding_model: Embeddings): + # 批量向量化 + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True) + + @staticmethod + def get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, post_apply=lambda: None): + def embedding_paragraph_apply(paragraph_list): + for paragraph in paragraph_list: + if is_the_task_interrupted(): + break + ListenerManagement.embedding_by_paragraph(str(paragraph.get('id')), embedding_model) + post_apply() + + return embedding_paragraph_apply + + @staticmethod + def get_aggregation_document_status(document_id): + def aggregation_document_status(): + pass + sql = get_file_content( + os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_document_status_meta.sql')) + native_update({'document_custom_sql': QuerySet(Document).filter(id=document_id)}, sql, with_table_name=True) + + return aggregation_document_status + + @staticmethod + def get_aggregation_document_status_by_knowledge_id(knowledge_id): + def aggregation_document_status(): + sql = get_file_content( + os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_document_status_meta.sql')) + native_update({'document_custom_sql': QuerySet(Document).filter(knowledge_id=knowledge_id)}, sql, + with_table_name=True) + + return aggregation_document_status + + @staticmethod + def get_aggregation_document_status_by_query_set(queryset): + def aggregation_document_status(): + sql = get_file_content( + os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_document_status_meta.sql')) + native_update({'document_custom_sql': queryset}, sql, with_table_name=True) + + return aggregation_document_status + + @staticmethod + def post_update_document_status(document_id, task_type: TaskType): + _document = QuerySet(Document).filter(id=document_id).first() + + status = Status(_document.status) + if status[task_type] == State.REVOKE: + status[task_type] = State.REVOKED + else: + status[task_type] = State.SUCCESS + for item in _document.status_meta.get('aggs', []): + agg_status = item.get('status') + agg_count = item.get('count') + if Status(agg_status)[task_type] == State.FAILURE and agg_count > 0: + status[task_type] = State.FAILURE + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), task_type, status[task_type]) + + ListenerManagement.update_status(QuerySet(Paragraph).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', task_type.value, + task_type.value), + ).filter(task_type_status=State.REVOKE.value).filter(document_id=document_id).values('id'), + task_type, + State.REVOKED) + + @staticmethod + def update_status(query_set: QuerySet, taskType: TaskType, state: State): + exec_sql = get_file_content( + os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_paragraph_status.sql')) + bit_number = len(TaskType) + up_index = taskType.value - 1 + next_index = taskType.value + 1 + current_index = taskType.value + status_number = state.value + current_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') + '+00' + params_dict = {'${bit_number}': bit_number, '${up_index}': up_index, + '${status_number}': status_number, '${next_index}': next_index, + '${table_name}': query_set.model._meta.db_table, '${current_index}': current_index, + '${current_time}': current_time} + for key in params_dict: + _value_ = params_dict[key] + exec_sql = exec_sql.replace(key, str(_value_)) + lock.acquire() + try: + native_update(query_set, exec_sql) + finally: + lock.release() + + @staticmethod + def embedding_by_document(document_id, embedding_model: Embeddings, state_list=None): + """ + 向量化文档 + @param state_list: + @param document_id: 文档id + @param embedding_model 向量模型 + :return: None + """ + if state_list is None: + state_list = [State.PENDING, State.SUCCESS, State.FAILURE, State.REVOKE, State.REVOKED] + if not try_lock('embedding' + str(document_id)): + return + try: + def is_the_task_interrupted(): + document = QuerySet(Document).filter(id=document_id).first() + if document is None or Status(document.status)[TaskType.EMBEDDING] == State.REVOKE: + return True + return False + + if is_the_task_interrupted(): + return + max_kb.info(_('Start--->Embedding document: {document_id}').format(document_id=document_id) + ) + # 批量修改状态为PADDING + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, + State.STARTED) + + + # 根据段落进行向量化处理 + page_desc(QuerySet(Paragraph) + .annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value, + 1), + ).filter(task_type_status__in=state_list, document_id=document_id) + .values('id'), 5, + ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, + ListenerManagement.get_aggregation_document_status( + document_id)), + is_the_task_interrupted) + except Exception as e: + max_kb_error.error(_('Vectorized document: {document_id} error {error} {traceback}').format( + document_id=document_id, error=str(e), traceback=traceback.format_exc())) + finally: + ListenerManagement.post_update_document_status(document_id, TaskType.EMBEDDING) + ListenerManagement.get_aggregation_document_status(document_id)() + max_kb.info(_('End--->Embedding document: {document_id}').format(document_id=document_id)) + un_lock('embedding' + str(document_id)) + + @staticmethod + def embedding_by_knowledge(knowledge_id, embedding_model: Embeddings): + """ + 向量化知识库 + @param knowledge_id: 知识库id + @param embedding_model 向量模型 + :return: None + """ + max_kb.info(_('Start--->Embedding knowledge: {knowledge_id}').format(knowledge_id=knowledge_id)) + try: + ListenerManagement.delete_embedding_by_knowledge(knowledge_id) + document_list = QuerySet(Document).filter(knowledge_id=knowledge_id) + max_kb.info(_('Start--->Embedding document: {document_list}').format(document_list=document_list)) + for document in document_list: + ListenerManagement.embedding_by_document(document.id, embedding_model=embedding_model) + except Exception as e: + max_kb_error.error(_('Vectorized knowledge: {knowledge_id} error {error} {traceback}').format( + knowledge_id=knowledge_id, error=str(e), traceback=traceback.format_exc())) + finally: + max_kb.info(_('End--->Embedding knowledge: {knowledge_id}').format(knowledge_id=knowledge_id)) + + @staticmethod + def delete_embedding_by_document(document_id): + VectorStore.get_embedding_vector().delete_by_document_id(document_id) + + @staticmethod + def delete_embedding_by_document_list(document_id_list: List[str]): + VectorStore.get_embedding_vector().delete_by_document_id_list(document_id_list) + + @staticmethod + def delete_embedding_by_knowledge(knowledge_id): + VectorStore.get_embedding_vector().delete_by_knowledge_id(knowledge_id) + + @staticmethod + def delete_embedding_by_paragraph(paragraph_id): + VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id) + + @staticmethod + def delete_embedding_by_source(source_id): + VectorStore.get_embedding_vector().delete_by_source_id(source_id, SourceType.PROBLEM) + + @staticmethod + def disable_embedding_by_paragraph(paragraph_id): + VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': False}) + + @staticmethod + def enable_embedding_by_paragraph(paragraph_id): + VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True}) + + @staticmethod + def update_problem(args: UpdateProblemArgs): + problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(problem_id=args.problem_id) + embed_value = args.embedding_model.embed_query(args.problem_content) + VectorStore.get_embedding_vector().update_by_source_ids([v.id for v in problem_paragraph_mapping_list], + {'embedding': embed_value}) + + @staticmethod + def update_embedding_knowledge_id(args: UpdateEmbeddingKnowledgeIdArgs): + VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, + {'knowledge_id': args.target_knowledge_id}) + + @staticmethod + def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs): + if args.target_embedding_model is None: + VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, + {'document_id': args.target_document_id, + 'knowledge_id': args.target_knowledge_id}) + else: + ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list, + embedding_model=args.target_embedding_model) + + @staticmethod + def delete_embedding_by_source_ids(source_ids: List[str]): + VectorStore.get_embedding_vector().delete_by_source_ids(source_ids, SourceType.PROBLEM) + + @staticmethod + def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]): + VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_ids) + + @staticmethod + def delete_embedding_by_knowledge_id_list(source_ids: List[str]): + VectorStore.get_embedding_vector().delete_by_knowledge_id_list(source_ids) + + @staticmethod + def hit_test(query_text, knowledge_id: list[str], exclude_document_id_list: list[str], top_number: int, + similarity: float, + search_mode: SearchMode, + embedding: Embeddings): + return VectorStore.get_embedding_vector().hit_test(query_text, knowledge_id, exclude_document_id_list, top_number, + similarity, search_mode, embedding) diff --git a/apps/common/lock/base_lock.py b/apps/common/lock/base_lock.py new file mode 100644 index 00000000..2ca5b21d --- /dev/null +++ b/apps/common/lock/base_lock.py @@ -0,0 +1,20 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_lock.py + @date:2024/8/20 10:33 + @desc: +""" + +from abc import ABC, abstractmethod + + +class BaseLock(ABC): + @abstractmethod + def try_lock(self, key, timeout): + pass + + @abstractmethod + def un_lock(self, key): + pass diff --git a/apps/common/lock/impl/file_lock.py b/apps/common/lock/impl/file_lock.py new file mode 100644 index 00000000..492728f9 --- /dev/null +++ b/apps/common/lock/impl/file_lock.py @@ -0,0 +1,77 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: file_lock.py + @date:2024/8/20 10:48 + @desc: +""" +import errno +import hashlib +import os +import time + +import six + +from common.lock.base_lock import BaseLock +from maxkb.const import PROJECT_DIR + + +def key_to_lock_name(key): + """ + Combine part of a key with its hash to prevent very long filenames + """ + MAX_LENGTH = 50 + key_hash = hashlib.md5(six.b(key)).hexdigest() + lock_name = key[:MAX_LENGTH - len(key_hash) - 1] + '_' + key_hash + return lock_name + + +class FileLock(BaseLock): + """ + File locking backend. + """ + + def __init__(self, settings=None): + if settings is None: + settings = {} + self.location = settings.get('location') + if self.location is None: + self.location = os.path.join(PROJECT_DIR, 'data', 'lock') + try: + os.makedirs(self.location) + except OSError as error: + # Directory exists? + if error.errno != errno.EEXIST: + # Re-raise unexpected OSError + raise + + def _get_lock_path(self, key): + lock_name = key_to_lock_name(key) + return os.path.join(self.location, lock_name) + + def try_lock(self, key, timeout): + lock_path = self._get_lock_path(key) + try: + # 创建锁文件,如果没创建成功则拿不到 + fd = os.open(lock_path, os.O_CREAT | os.O_EXCL) + except OSError as error: + if error.errno == errno.EEXIST: + # File already exists, check its modification time + mtime = os.path.getmtime(lock_path) + ttl = mtime + timeout - time.time() + if ttl > 0: + return False + else: + # 如果超时时间已到,直接上锁成功继续执行 + os.utime(lock_path, None) + return True + else: + return False + else: + os.close(fd) + return True + + def un_lock(self, key): + lock_path = self._get_lock_path(key) + os.remove(lock_path) diff --git a/apps/common/utils/common.py b/apps/common/utils/common.py index 48acd33d..e4c6c344 100644 --- a/apps/common/utils/common.py +++ b/apps/common/utils/common.py @@ -124,6 +124,18 @@ def get_file_content(path): content = file.read() return content +def sub_array(array: List, item_num=10): + result = [] + temp = [] + for item in array: + temp.append(item) + if len(temp) >= item_num: + result.append(temp) + temp = [] + if len(temp) > 0: + result.append(temp) + return result + def bytes_to_uploaded_file(file_bytes, file_name="file.txt"): content_type, _ = mimetypes.guess_type(file_name) @@ -233,3 +245,15 @@ def valid_license(model=None, count=None, message=None): return run return inner + + +def post(post_function): + def inner(func): + def run(*args, **kwargs): + result = func(*args, **kwargs) + return post_function(*result) + + return run + + return inner + diff --git a/apps/common/utils/lock.py b/apps/common/utils/lock.py new file mode 100644 index 00000000..4276f1c6 --- /dev/null +++ b/apps/common/utils/lock.py @@ -0,0 +1,53 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: lock.py + @date:2023/9/11 11:45 + @desc: +""" +from datetime import timedelta + +from django.core.cache import caches + +memory_cache = caches['default'] + + +def try_lock(key: str, timeout=None): + """ + 获取锁 + :param key: 获取锁 key + :param timeout 超时时间 + :return: 是否获取到锁 + """ + return memory_cache.add(key, 'lock', timeout=timedelta(hours=1).total_seconds() if timeout is not None else timeout) + + +def un_lock(key: str): + """ + 解锁 + :param key: 解锁 key + :return: 是否解锁成功 + """ + return memory_cache.delete(key) + + +def lock(lock_key): + """ + 给一个函数上锁 + :param lock_key: 上锁key 字符串|函数 函数返回值为字符串 + :return: 装饰器函数 当前装饰器主要限制一个key只能一个线程去调用 相同key只能阻塞等待上一个任务执行完毕 不同key不需要等待 + """ + + def inner(func): + def run(*args, **kwargs): + key = lock_key(*args, **kwargs) if callable(lock_key) else lock_key + try: + if try_lock(key=key): + return func(*args, **kwargs) + finally: + un_lock(key=key) + + return run + + return inner diff --git a/apps/common/utils/page_utils.py b/apps/common/utils/page_utils.py new file mode 100644 index 00000000..61c52920 --- /dev/null +++ b/apps/common/utils/page_utils.py @@ -0,0 +1,47 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: page_utils.py + @date:2024/11/21 10:32 + @desc: +""" +from math import ceil + + +def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False): + """ + + @param query_set: 查询query_set + @param page_size: 每次查询大小 + @param handler: 数据处理器 + @param is_the_task_interrupted: 任务是否被中断 + @return: + """ + query = query_set.order_by("id") + count = query_set.count() + for i in range(0, ceil(count / page_size)): + if is_the_task_interrupted(): + return + offset = i * page_size + paragraph_list = query.all()[offset: offset + page_size] + handler(paragraph_list) + + +def page_desc(query_set, page_size, handler, is_the_task_interrupted=lambda: False): + """ + + @param query_set: 查询query_set + @param page_size: 每次查询大小 + @param handler: 数据处理器 + @param is_the_task_interrupted: 任务是否被中断 + @return: + """ + query = query_set.order_by("id") + count = query_set.count() + for i in sorted(range(0, ceil(count / page_size)), reverse=True): + if is_the_task_interrupted(): + return + offset = i * page_size + paragraph_list = query.all()[offset: offset + page_size] + handler(paragraph_list) diff --git a/apps/common/utils/tool_code.py b/apps/common/utils/tool_code.py index 6eb44c6f..d0995868 100644 --- a/apps/common/utils/tool_code.py +++ b/apps/common/utils/tool_code.py @@ -3,7 +3,7 @@ import os import subprocess import sys -import uuid_utils as uuid +import uuid_utils.compat as uuid from textwrap import dedent from diskcache import Cache diff --git a/apps/common/utils/ts_vecto_util.py b/apps/common/utils/ts_vecto_util.py new file mode 100644 index 00000000..ce09306b --- /dev/null +++ b/apps/common/utils/ts_vecto_util.py @@ -0,0 +1,88 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: ts_vecto_util.py + @date:2024/4/16 15:26 + @desc: +""" +import re +import uuid_utils.compat as uuid +from typing import List + +import jieba +import jieba.posseg + +jieba_word_list_cache = [chr(item) for item in range(38, 84)] + +for jieba_word in jieba_word_list_cache: + jieba.add_word('#' + jieba_word + '#') +# r"(?i)\b(?:https?|ftp|tcp|file)://[^\s]+\b", +# 某些不分词数据 +# r'"([^"]*)"' +word_pattern_list = [r"v\d+.\d+.\d+", + r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}"] + +remove_chars = '\n , :\'<>!@#¥%……&*()!@#$%^&*(): ;,/"./' + +jieba_remove_flag_list = ['x', 'w'] + + +def get_word_list(text: str): + result = [] + for pattern in word_pattern_list: + word_list = re.findall(pattern, text) + for child_list in word_list: + for word in child_list if isinstance(child_list, tuple) else [child_list]: + # 不能有: 所以再使用: 进行分割 + if word.__contains__(':'): + item_list = word.split(":") + for w in item_list: + result.append(w) + else: + result.append(word) + return result + + +def replace_word(word_dict, text: str): + for key in word_dict: + pattern = '(? index else 'n') + self.task_status[_type] = _state + + @staticmethod + def of(status: str): + return Status(status) + + def __str__(self): + result = [] + for _type in sorted(self.type_cls, key=lambda item: item.value, reverse=True): + result.insert(len(self.type_cls) - _type.value, self.task_status[_type].value) + return ''.join(result) + + def __setitem__(self, key, value): + self.task_status[key] = value + + def __getitem__(self, item): + return self.task_status[item] + + def update_status(self, task_type: TaskType, state: State): + self.task_status[task_type] = state + + +def default_status_meta(): + return {"state_time": {}} + + def default_model(): # todo : 这里需要从数据库中获取默认的模型 return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab') @@ -33,7 +103,7 @@ class KnowledgeFolder(MPTTModel, AppModelMixin): name = models.CharField(max_length=64, verbose_name="文件夹名称") user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="用户id") workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True) - parent = TreeForeignKey('self', on_delete=models.CASCADE, null=True, blank=True, related_name='children') + parent = TreeForeignKey('self', on_delete=models.DO_NOTHING, null=True, blank=True, related_name='children') class Meta: db_table = "knowledge_folder" @@ -42,24 +112,127 @@ class KnowledgeFolder(MPTTModel, AppModelMixin): order_insertion_by = ['name'] - class Knowledge(AppModelMixin): + """ + 知识库表 + """ id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id") name = models.CharField(max_length=150, verbose_name="知识库名称") workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True) desc = models.CharField(max_length=256, verbose_name="描述") user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户") type = models.IntegerField(verbose_name='类型', choices=KnowledgeType.choices, default=KnowledgeType.BASE) - scope = models.CharField(max_length=20, verbose_name='可用范围', choices=KnowledgeScope.choices, default=KnowledgeScope.WORKSPACE) - folder = models.ForeignKey(KnowledgeFolder, on_delete=models.CASCADE, verbose_name="文件夹id", default='root') + scope = models.CharField(max_length=20, verbose_name='可用范围', choices=KnowledgeScope.choices, + default=KnowledgeScope.WORKSPACE) + folder = models.ForeignKey(KnowledgeFolder, on_delete=models.DO_NOTHING, verbose_name="文件夹id", default='root') embedding_model = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型", - default=default_model) + default=default_model) meta = models.JSONField(verbose_name="元数据", default=dict) class Meta: db_table = "knowledge" +class Document(AppModelMixin): + """ + 文档表 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id") + knowledge = models.ForeignKey(Knowledge, on_delete=models.DO_NOTHING, verbose_name="知识库id") + name = models.CharField(max_length=150, verbose_name="文档名称") + char_length = models.IntegerField(verbose_name="文档字符数 冗余字段") + status = models.CharField(verbose_name='状态', max_length=20, default=Status('').__str__) + status_meta = models.JSONField(verbose_name="状态统计数据", default=default_status_meta) + is_active = models.BooleanField(default=True) + type = models.IntegerField(verbose_name='类型', choices=KnowledgeType.choices, default=KnowledgeType.BASE) + hit_handling_method = models.CharField(verbose_name='命中处理方式', max_length=20, + choices=HitHandlingMethod.choices, + default=HitHandlingMethod.optimization) + directly_return_similarity = models.FloatField(verbose_name='直接回答相似度', default=0.9) + + meta = models.JSONField(verbose_name="元数据", default=dict) + + class Meta: + db_table = "document" + + +class Paragraph(AppModelMixin): + """ + 段落表 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id") + document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False) + knowledge = models.ForeignKey(Knowledge, on_delete=models.DO_NOTHING) + content = models.CharField(max_length=102400, verbose_name="段落内容") + title = models.CharField(max_length=256, verbose_name="标题", default="") + status = models.CharField(verbose_name='状态', max_length=20, default=Status('').__str__) + status_meta = models.JSONField(verbose_name="状态数据", default=default_status_meta) + hit_num = models.IntegerField(verbose_name="命中次数", default=0) + is_active = models.BooleanField(default=True) + + class Meta: + db_table = "paragraph" + + +class Problem(AppModelMixin): + """ + 问题表 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id") + knowledge = models.ForeignKey(Knowledge, on_delete=models.DO_NOTHING, db_constraint=False) + content = models.CharField(max_length=256, verbose_name="问题内容") + hit_num = models.IntegerField(verbose_name="命中次数", default=0) + + class Meta: + db_table = "problem" + + +class ProblemParagraphMapping(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id") + knowledge = models.ForeignKey(Knowledge, on_delete=models.DO_NOTHING, db_constraint=False) + document = models.ForeignKey(Document, on_delete=models.DO_NOTHING) + problem = models.ForeignKey(Problem, on_delete=models.DO_NOTHING, db_constraint=False) + paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False) + + class Meta: + db_table = "problem_paragraph_mapping" + + +class SourceType(models.IntegerChoices): + """订单类型""" + PROBLEM = 0, '问题' + PARAGRAPH = 1, '段落' + TITLE = 2, '标题' + + +class SearchMode(models.TextChoices): + embedding = 'embedding' + keywords = 'keywords' + blend = 'blend' + + +class VectorField(models.Field): + def db_type(self, connection): + return 'vector' + + +class Embedding(models.Model): + id = models.CharField(max_length=128, primary_key=True, verbose_name="主键id") + source_id = models.CharField(max_length=128, verbose_name="资源id") + source_type = models.CharField(verbose_name='资源类型', max_length=5, choices=SourceType.choices, + default=SourceType.PROBLEM) + is_active = models.BooleanField(verbose_name="是否可用", max_length=1, default=True) + knowledge = models.ForeignKey(Knowledge, on_delete=models.DO_NOTHING, verbose_name="文档关联", db_constraint=False) + document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, verbose_name="文档关联", db_constraint=False) + paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, verbose_name="段落关联", db_constraint=False) + embedding = VectorField(verbose_name="向量") + search_vector = SearchVectorField(verbose_name="分词", default="") + meta = models.JSONField(verbose_name="元数据", default=dict) + + class Meta: + db_table = "embedding" + + class File(AppModelMixin): id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id") file_name = models.CharField(max_length=256, verbose_name="文件名称", default="") diff --git a/apps/knowledge/serializers/__init__.py b/apps/knowledge/serializers/__init__.py index 9bad5790..bf893c06 100644 --- a/apps/knowledge/serializers/__init__.py +++ b/apps/knowledge/serializers/__init__.py @@ -1 +1 @@ -# coding=utf-8 +# coding=utf-8 \ No newline at end of file diff --git a/apps/knowledge/serializers/common.py b/apps/knowledge/serializers/common.py new file mode 100644 index 00000000..e3182643 --- /dev/null +++ b/apps/knowledge/serializers/common.py @@ -0,0 +1,215 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: common_serializers.py + @date:2023/11/17 11:00 + @desc: +""" +import os +import re +import uuid_utils.compat as uuid +import zipfile +from typing import List + +from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers + +from common.config.embedding_config import ModelManage +from common.db.search import native_search +from common.db.sql_execute import update_execute +from common.exception.app_exception import AppApiException +from common.utils.common import get_file_content +from common.utils.fork import Fork +from knowledge.models import Paragraph, Problem, ProblemParagraphMapping, Knowledge, File +from maxkb.conf import PROJECT_DIR +from models_provider.tools import get_model + + +def zip_dir(zip_path, output=None): + output = output or os.path.basename(zip_path) + '.zip' + zip = zipfile.ZipFile(output, 'w', zipfile.ZIP_DEFLATED) + for root, dirs, files in os.walk(zip_path): + relative_root = '' if root == zip_path else root.replace(zip_path, '') + os.sep + for filename in files: + zip.write(os.path.join(root, filename), relative_root + filename) + zip.close() + + +def is_valid_uuid(s): + try: + uuid.UUID(s) + return True + except ValueError: + return False + + +def write_image(zip_path: str, image_list: List[str]): + for image in image_list: + search = re.search("\(.*\)", image) + if search: + text = search.group() + if text.startswith('(/api/file/'): + r = text.replace('(/api/file/', '').replace(')', '') + r = r.strip().split(" ")[0] + if not is_valid_uuid(r): + break + file = QuerySet(File).filter(id=r).first() + if file is None: + break + zip_inner_path = os.path.join('api', 'file', r) + file_path = os.path.join(zip_path, zip_inner_path) + if not os.path.exists(os.path.dirname(file_path)): + os.makedirs(os.path.dirname(file_path)) + with open(os.path.join(zip_path, file_path), 'wb') as f: + f.write(file.get_bytes()) + # else: + # r = text.replace('(/api/image/', '').replace(')', '') + # r = r.strip().split(" ")[0] + # if not is_valid_uuid(r): + # break + # image_model = QuerySet(Image).filter(id=r).first() + # if image_model is None: + # break + # zip_inner_path = os.path.join('api', 'image', r) + # file_path = os.path.join(zip_path, zip_inner_path) + # if not os.path.exists(os.path.dirname(file_path)): + # os.makedirs(os.path.dirname(file_path)) + # with open(file_path, 'wb') as f: + # f.write(image_model.image) + + +def update_document_char_length(document_id: str): + update_execute(get_file_content( + os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_document_char_length.sql')), + (document_id, document_id)) + + +def list_paragraph(paragraph_list: List[str]): + if paragraph_list is None or len(paragraph_list) == 0: + return [] + return native_search(QuerySet(Paragraph).filter(id__in=paragraph_list), get_file_content( + os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_paragraph.sql'))) + + +class MetaSerializer(serializers.Serializer): + class WebMeta(serializers.Serializer): + source_url = serializers.CharField(required=True, label=_('source url')) + selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('selector')) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + source_url = self.data.get('source_url') + response = Fork(source_url, []).fork() + if response.status == 500: + raise AppApiException(500, _('URL error, cannot parse [{source_url}]').format(source_url=source_url)) + + class BaseMeta(serializers.Serializer): + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + + +class BatchSerializer(serializers.Serializer): + id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=_('id list')) + + def is_valid(self, *, model=None, raise_exception=False): + super().is_valid(raise_exception=True) + if model is not None: + id_list = self.data.get('id_list') + model_list = QuerySet(model).filter(id__in=id_list) + if len(model_list) != len(id_list): + model_id_list = [str(m.id) for m in model_list] + error_id_list = list(filter(lambda row_id: not model_id_list.__contains__(row_id), id_list)) + raise AppApiException(500, _('The following id does not exist: {error_id_list}').format( + error_id_list=error_id_list)) + + +class ProblemParagraphObject: + def __init__(self, knowledge_id: str, document_id: str, paragraph_id: str, problem_content: str): + self.knowledge_id = knowledge_id + self.document_id = document_id + self.paragraph_id = paragraph_id + self.problem_content = problem_content + + +def or_get(exists_problem_list, content, knowledge_id, document_id, paragraph_id, problem_content_dict): + if content in problem_content_dict: + return problem_content_dict.get(content)[0], document_id, paragraph_id + exists = [row for row in exists_problem_list if row.content == content] + if len(exists) > 0: + problem_content_dict[content] = exists[0], False + return exists[0], document_id, paragraph_id + else: + problem = Problem(id=uuid.uuid7(), content=content, knowledge_id=knowledge_id) + problem_content_dict[content] = problem, True + return problem, document_id, paragraph_id + + +class ProblemParagraphManage: + def __init__(self, problem_paragraph_object_list: List[ProblemParagraphObject], knowledge_id): + self.knowledge_id = knowledge_id + self.problem_paragraph_object_list = problem_paragraph_object_list + + def to_problem_model_list(self): + problem_list = [item.problem_content for item in self.problem_paragraph_object_list] + exists_problem_list = [] + if len(self.problem_paragraph_object_list) > 0: + # 查询到已存在的问题列表 + exists_problem_list = QuerySet(Problem).filter(knowledge_id=self.knowledge_id, + content__in=problem_list).all() + problem_content_dict = {} + problem_model_list = [ + or_get(exists_problem_list, problemParagraphObject.problem_content, problemParagraphObject.knowledge_id, + problemParagraphObject.document_id, problemParagraphObject.paragraph_id, problem_content_dict) for + problemParagraphObject in self.problem_paragraph_object_list] + + problem_paragraph_mapping_list = [ + ProblemParagraphMapping(id=uuid.uuid7(), document_id=document_id, problem_id=problem_model.id, + paragraph_id=paragraph_id, + knowledge_id=self.knowledge_id) for + problem_model, document_id, paragraph_id in problem_model_list] + + result = [problem_model for problem_model, is_create in problem_content_dict.values() if + is_create], problem_paragraph_mapping_list + return result + + +def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List): + knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list) + if len(set([knowledge.embedding_model_id for knowledge in knowledge_list])) > 1: + raise Exception(_('The knowledge base is inconsistent with the vector model')) + if len(knowledge_list) == 0: + raise Exception(_('Knowledge base setting error, please reset the knowledge base')) + return ModelManage.get_model(str(knowledge_list[0].embedding_model_id), + lambda _id: get_model(knowledge_list[0].embedding_model)) + + +def get_embedding_model_by_knowledge_id(knowledge_id: str): + knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first() + return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model)) + + +def get_embedding_model_by_knowledge(knowledge): + return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model)) + + +def get_embedding_model_id_by_knowledge_id(knowledge_id): + knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first() + return str(knowledge.embedding_model_id) + + +def get_embedding_model_id_by_knowledge_id_list(knowledge_id_list: List): + knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list) + if len(set([knowledge.embedding_model_id for knowledge in knowledge_list])) > 1: + raise Exception(_('The knowledge base is inconsistent with the vector model')) + if len(knowledge_list) == 0: + raise Exception(_('Knowledge base setting error, please reset the knowledge base')) + return str(knowledge_list[0].embedding_model_id) + + +class GenerateRelatedSerializer(serializers.Serializer): + model_id = serializers.UUIDField(required=True, label=_('Model id')) + prompt = serializers.CharField(required=True, label=_('Prompt word')) + state_list = serializers.ListField(required=False, child=serializers.CharField(required=True), + label=_("state list")) diff --git a/apps/knowledge/serializers/document.py b/apps/knowledge/serializers/document.py new file mode 100644 index 00000000..76ec4afe --- /dev/null +++ b/apps/knowledge/serializers/document.py @@ -0,0 +1,172 @@ +import os +from functools import reduce +from typing import Dict, List + +import uuid_utils.compat as uuid +from celery_once import AlreadyQueued +from django.db import transaction +from django.db.models import QuerySet, Model +from django.db.models.functions import Substr, Reverse +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers + +from common.db.search import native_search +from common.event import ListenerManagement +from common.exception.app_exception import AppApiException +from common.utils.common import post, get_file_content +from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \ + TaskType +from knowledge.serializers.common import ProblemParagraphManage +from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer +from knowledge.task import embedding_by_document +from maxkb.const import PROJECT_DIR + + +class DocumentInstanceSerializer(serializers.Serializer): + name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1) + paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True) + + +class DocumentCreateRequest(serializers.Serializer): + name = serializers.CharField(required=True, label=_('knowledge name'), max_length=64, min_length=1) + desc = serializers.CharField(required=True, label=_('knowledge description'), max_length=256, min_length=1) + embedding_model_id = serializers.UUIDField(required=True, label=_('embedding model')) + documents = DocumentInstanceSerializer(required=False, many=True) + + +class DocumentSerializers(serializers.Serializer): + class Operate(serializers.Serializer): + document_id = serializers.UUIDField(required=True, label=_('document id')) + knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + document_id = self.data.get('document_id') + if not QuerySet(Document).filter(id=document_id).exists(): + raise AppApiException(500, _('document id not exist')) + + def one(self, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + query_set = QuerySet(model=Document) + query_set = query_set.filter(**{'id': self.data.get("document_id")}) + return native_search({ + 'document_custom_sql': query_set, + 'order_by_query': QuerySet(Document).order_by('-create_time', 'id') + }, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')), with_search_one=True) + + def refresh(self, state_list=None, with_valid=True): + if state_list is None: + state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value, + State.REVOKE.value, + State.REVOKED.value, State.IGNORED.value] + if with_valid: + self.is_valid(raise_exception=True) + knowledge = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id')).first() + embedding_model_id = knowledge.embedding_model_id + knowledge_user_id = knowledge.user_id + embedding_model = QuerySet(Model).filter(id=embedding_model_id).first() + if embedding_model is None: + raise AppApiException(500, _('Model does not exist')) + if embedding_model.permission_type == 'PRIVATE' and knowledge_user_id != embedding_model.user_id: + raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}") + document_id = self.data.get("document_id") + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, + State.PENDING) + ListenerManagement.update_status(QuerySet(Paragraph).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value, 1), + ).filter(task_type_status__in=state_list, document_id=document_id).values('id'), + TaskType.EMBEDDING, State.PENDING) + ListenerManagement.get_aggregation_document_status(document_id)() + + try: + embedding_by_document.delay(document_id, embedding_model_id, state_list) + except AlreadyQueued as e: + raise AppApiException(500, _('The task is being executed, please do not send it repeatedly.')) + + class Create(serializers.Serializer): + knowledge_id = serializers.UUIDField(required=True, label=_('document id')) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(Knowledge).filter(id=self.data.get('knowledge_id')).exists(): + raise AppApiException(10000, _('knowledge id not exist')) + return True + + @staticmethod + def post_embedding(result, document_id, knowledge_id): + DocumentSerializers.Operate( + data={'knowledge_id': knowledge_id, 'document_id': document_id}).refresh() + return result + + @post(post_function=post_embedding) + @transaction.atomic + def save(self, instance: Dict, with_valid=False, **kwargs): + if with_valid: + DocumentCreateRequest(data=instance).is_valid(raise_exception=True) + self.is_valid(raise_exception=True) + knowledge_id = self.data.get('knowledge_id') + document_paragraph_model = self.get_document_paragraph_model(knowledge_id, instance) + + document_model = document_paragraph_model.get('document') + paragraph_model_list = document_paragraph_model.get('paragraph_model_list') + problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list') + problem_model_list, problem_paragraph_mapping_list = ( + ProblemParagraphManage(problem_paragraph_object_list, knowledge_id).to_problem_model_list()) + # 插入文档 + document_model.save() + # 批量插入段落 + QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None + # 批量插入问题 + QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + # 批量插入关联问题 + QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len( + problem_paragraph_mapping_list) > 0 else None + document_id = str(document_model.id) + return (DocumentSerializers.Operate( + data={'knowledge_id': knowledge_id, 'document_id': document_id} + ).one(with_valid=True), document_id, knowledge_id) + + @staticmethod + def get_paragraph_model(document_model, paragraph_list: List): + knowledge_id = document_model.knowledge_id + paragraph_model_dict_list = [ + ParagraphSerializers.Create( + data={ + 'knowledge_id': knowledge_id, 'document_id': str(document_model.id) + }).get_paragraph_problem_model(knowledge_id, document_model.id, paragraph) + for paragraph in paragraph_list] + + paragraph_model_list = [] + problem_paragraph_object_list = [] + for paragraphs in paragraph_model_dict_list: + paragraph = paragraphs.get('paragraph') + for problem_model in paragraphs.get('problem_paragraph_object_list'): + problem_paragraph_object_list.append(problem_model) + paragraph_model_list.append(paragraph) + + return { + 'document': document_model, + 'paragraph_model_list': paragraph_model_list, + 'problem_paragraph_object_list': problem_paragraph_object_list + } + + @staticmethod + def get_document_paragraph_model(knowledge_id, instance: Dict): + document_model = Document( + **{ + 'knowledge_id': knowledge_id, + 'id': uuid.uuid7(), + 'name': instance.get('name'), + 'char_length': reduce(lambda x, y: x + y, + [len(p.get('content')) for p in instance.get('paragraphs', [])], + 0), + 'meta': instance.get('meta') if instance.get('meta') is not None else {}, + 'type': instance.get('type') if instance.get('type') is not None else KnowledgeType.BASE + }) + + return DocumentSerializers.Create.get_paragraph_model(document_model, + instance.get('paragraphs') if + 'paragraphs' in instance else []) diff --git a/apps/knowledge/serializers/knowledge.py b/apps/knowledge/serializers/knowledge.py index 75e6e45f..ebceb522 100644 --- a/apps/knowledge/serializers/knowledge.py +++ b/apps/knowledge/serializers/knowledge.py @@ -1,14 +1,19 @@ +from functools import reduce from typing import Dict -import uuid_utils as uuid +import uuid_utils.compat as uuid from django.db import transaction from django.db.models import QuerySet from django.utils.translation import gettext_lazy as _ from rest_framework import serializers from common.exception.app_exception import AppApiException -from common.utils.common import valid_license -from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType +from common.utils.common import valid_license, post +from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType, Document, Paragraph, Problem, \ + ProblemParagraphMapping +from knowledge.serializers.common import ProblemParagraphManage, get_embedding_model_id_by_knowledge_id +from knowledge.serializers.document import DocumentSerializers +from knowledge.task import sync_web_knowledge, embedding_by_knowledge class KnowledgeModelSerializer(serializers.ModelSerializer): @@ -38,10 +43,17 @@ class KnowledgeSerializer(serializers.Serializer): user_id = serializers.UUIDField(required=True, label=_('user id')) workspace_id = serializers.CharField(required=True, label=_('workspace id')) + @staticmethod + def post_embedding_knowledge(document_list, knowledge_id): + # todo 发送向量化事件 + model_id = get_embedding_model_id_by_knowledge_id(knowledge_id) + embedding_by_knowledge.delay(knowledge_id, model_id) + return document_list + @valid_license(model=Knowledge, count=50, message=_( 'The community version supports up to 50 knowledge bases. If you need more knowledge bases, please contact us (https://fit2cloud.com/).')) - # @post(post_function=post_embedding_dataset) + @post(post_function=post_embedding_knowledge) @transaction.atomic def save_base(self, instance, with_valid=True): if with_valid: @@ -51,8 +63,9 @@ class KnowledgeSerializer(serializers.Serializer): name=instance.get('name')).exists(): raise AppApiException(500, _('Knowledge base name duplicate!')) + knowledge_id = uuid.uuid7() knowledge = Knowledge( - id=uuid.uuid7(), + id=knowledge_id, name=instance.get('name'), workspace_id=self.data.get('workspace_id'), desc=instance.get('desc'), @@ -63,8 +76,42 @@ class KnowledgeSerializer(serializers.Serializer): embedding_model_id=instance.get('embedding'), meta=instance.get('meta', {}), ) + + document_model_list = [] + paragraph_model_list = [] + problem_paragraph_object_list = [] + # 插入文档 + for document in instance.get('documents') if 'documents' in instance else []: + document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(knowledge_id, + document) + document_model_list.append(document_paragraph_dict_model.get('document')) + for paragraph in document_paragraph_dict_model.get('paragraph_model_list'): + paragraph_model_list.append(paragraph) + for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'): + problem_paragraph_object_list.append(problem_paragraph_object) + + problem_model_list, problem_paragraph_mapping_list = ( + ProblemParagraphManage(problem_paragraph_object_list, knowledge_id) + .to_problem_model_list()) + # 插入知识库 knowledge.save() - return KnowledgeModelSerializer(knowledge).data + # 插入文档 + QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None + # 批量插入段落 + QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None + # 批量插入问题 + QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + # 批量插入关联问题 + QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len( + problem_paragraph_mapping_list) > 0 else None + + return { + **KnowledgeModelSerializer(knowledge).data, + 'user_id': self.data.get('user_id'), + 'document_list': document_model_list, + "document_count": len(document_model_list), + "char_length": reduce(lambda x, y: x + y, [d.char_length for d in document_model_list], 0) + }, knowledge_id def save_web(self, instance: Dict, with_valid=True): if with_valid: @@ -92,9 +139,8 @@ class KnowledgeSerializer(serializers.Serializer): }, ) knowledge.save() - # sync_web_knowledge.delay(str(knowledge_id), instance.get('source_url'), instance.get('selector')) - return {**KnowledgeModelSerializer(knowledge).data, - 'document_list': []} + sync_web_knowledge.delay(str(knowledge_id), instance.get('source_url'), instance.get('selector')) + return {**KnowledgeModelSerializer(knowledge).data, 'document_list': []} class KnowledgeTreeSerializer(serializers.Serializer): diff --git a/apps/knowledge/serializers/paragraph.py b/apps/knowledge/serializers/paragraph.py new file mode 100644 index 00000000..a297e07d --- /dev/null +++ b/apps/knowledge/serializers/paragraph.py @@ -0,0 +1,221 @@ +# coding=utf-8 + +from typing import Dict + +import uuid_utils.compat as uuid +from django.db import transaction +from django.db.models import QuerySet, Count +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers + +from common.exception.app_exception import AppApiException +from common.utils.common import post +from knowledge.models import Paragraph, Problem, Document, ProblemParagraphMapping +from knowledge.serializers.common import ProblemParagraphObject, ProblemParagraphManage, \ + get_embedding_model_id_by_knowledge_id, update_document_char_length +from knowledge.serializers.problem import ProblemInstanceSerializer +from knowledge.task import embedding_by_paragraph, enable_embedding_by_paragraph, disable_embedding_by_paragraph, \ + delete_embedding_by_paragraph + + +class ParagraphSerializer(serializers.ModelSerializer): + class Meta: + model = Paragraph + fields = ['id', 'content', 'is_active', 'document_id', 'title', 'create_time', 'update_time'] + + +class ParagraphInstanceSerializer(serializers.Serializer): + """ + 段落实例对象 + """ + content = serializers.CharField(required=True, label=_('content'), max_length=102400, min_length=1, allow_null=True, + allow_blank=True) + title = serializers.CharField(required=False, max_length=256, label=_('section title'), allow_null=True, + allow_blank=True) + problem_list = ProblemInstanceSerializer(required=False, many=True) + is_active = serializers.BooleanField(required=False, label=_('Is active')) + + +class EditParagraphSerializers(serializers.Serializer): + title = serializers.CharField(required=False, max_length=256, label=_('section title'), allow_null=True, + allow_blank=True) + content = serializers.CharField(required=False, max_length=102400, allow_null=True, allow_blank=True, + label=_('section title')) + problem_list = ProblemInstanceSerializer(required=False, many=True) + + +class ParagraphSerializers(serializers.Serializer): + title = serializers.CharField(required=False, max_length=256, label=_('section title'), allow_null=True, + allow_blank=True) + content = serializers.CharField(required=True, max_length=102400, label=_('section title')) + + class Operate(serializers.Serializer): + # 段落id + paragraph_id = serializers.UUIDField(required=True, label=_('paragraph id')) + # 知识库id + dataset_id = serializers.UUIDField(required=True, label=_('dataset id')) + # 文档id + document_id = serializers.UUIDField(required=True, label=_('document id')) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists(): + raise AppApiException(500, _('Paragraph id does not exist')) + + @staticmethod + def post_embedding(paragraph, instance, knowledge_id): + if 'is_active' in instance and instance.get('is_active') is not None: + (enable_embedding_by_paragraph if instance.get( + 'is_active') else disable_embedding_by_paragraph)(paragraph.get('id')) + + else: + model_id = get_embedding_model_id_by_knowledge_id(knowledge_id) + embedding_by_paragraph(paragraph.get('id'), model_id) + return paragraph + + @post(post_embedding) + @transaction.atomic + def edit(self, instance: Dict): + self.is_valid() + EditParagraphSerializers(data=instance).is_valid(raise_exception=True) + _paragraph = QuerySet(Paragraph).get(id=self.data.get("paragraph_id")) + update_keys = ['title', 'content', 'is_active'] + for update_key in update_keys: + if update_key in instance and instance.get(update_key) is not None: + _paragraph.__setattr__(update_key, instance.get(update_key)) + + if 'problem_list' in instance: + update_problem_list = list( + filter(lambda row: 'id' in row and row.get('id') is not None, instance.get('problem_list'))) + + create_problem_list = list(filter(lambda row: row.get('id') is None, instance.get('problem_list'))) + + # 问题集合 + problem_list = QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id")) + + # 校验前端 携带过来的id + for update_problem in update_problem_list: + if not set([str(row.id) for row in problem_list]).__contains__(update_problem.get('id')): + raise AppApiException(500, _('Problem id does not exist')) + # 对比需要删除的问题 + delete_problem_list = list(filter( + lambda row: not [str(update_row.get('id')) for update_row in update_problem_list].__contains__( + str(row.id)), problem_list)) if len(update_problem_list) > 0 else [] + # 删除问题 + QuerySet(Problem).filter(id__in=[row.id for row in delete_problem_list]).delete() if len( + delete_problem_list) > 0 else None + # 插入新的问题 + QuerySet(Problem).bulk_create( + [Problem(id=uuid.uuid1(), content=p.get('content'), paragraph_id=self.data.get('paragraph_id'), + dataset_id=self.data.get('dataset_id'), document_id=self.data.get('document_id')) for + p in create_problem_list]) if len(create_problem_list) else None + + # 修改问题集合 + QuerySet(Problem).bulk_update( + [Problem(id=row.get('id'), content=row.get('content')) for row in update_problem_list], + ['content']) if len( + update_problem_list) > 0 else None + + _paragraph.save() + update_document_char_length(self.data.get('document_id')) + return self.one(), instance, self.data.get('dataset_id') + + def get_problem_list(self): + ProblemParagraphMapping(ProblemParagraphMapping) + problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter( + paragraph_id=self.data.get("paragraph_id")) + if len(problem_paragraph_mapping) > 0: + return [ProblemSerializer(problem).data for problem in + QuerySet(Problem).filter(id__in=[ppm.problem_id for ppm in problem_paragraph_mapping])] + return [] + + def one(self, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + return {**ParagraphSerializer(QuerySet(model=Paragraph).get(id=self.data.get('paragraph_id'))).data, + 'problem_list': self.get_problem_list()} + + def delete(self, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + paragraph_id = self.data.get('paragraph_id') + Paragraph.objects.filter(id=paragraph_id).delete() + delete_problems_and_mappings([paragraph_id]) + + update_document_char_length(self.data.get('document_id')) + delete_embedding_by_paragraph(paragraph_id) + + class Create(serializers.Serializer): + knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) + document_id = serializers.UUIDField(required=True, label=_('document id')) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(Document).filter(id=self.data.get('document_id'), + knowledge_id=self.data.get('knowledge_id')).exists(): + raise AppApiException(500, _('The document id is incorrect')) + + def save(self, instance: Dict, with_valid=True, with_embedding=True): + if with_valid: + ParagraphSerializers(data=instance).is_valid(raise_exception=True) + self.is_valid() + knowledge_id = self.data.get("knowledge_id") + document_id = self.data.get('document_id') + paragraph_problem_model = self.get_paragraph_problem_model(knowledge_id, document_id, instance) + paragraph = paragraph_problem_model.get('paragraph') + problem_paragraph_object_list = paragraph_problem_model.get('problem_paragraph_object_list') + problem_model_list, problem_paragraph_mapping_list = ( + ProblemParagraphManage(problem_paragraph_object_list, knowledge_id) + .to_problem_model_list()) + # 插入段落 + paragraph_problem_model.get('paragraph').save() + # 插入問題 + QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + # 插入问题关联关系 + QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len( + problem_paragraph_mapping_list) > 0 else None + # 修改长度 + update_document_char_length(document_id) + if with_embedding: + model_id = get_embedding_model_id_by_knowledge_id(knowledge_id) + embedding_by_paragraph(str(paragraph.id), model_id) + return ParagraphSerializers.Operate( + data={'paragraph_id': str(paragraph.id), 'knowledge_id': knowledge_id, 'document_id': document_id} + ).one(with_valid=True) + + @staticmethod + def get_paragraph_problem_model(knowledge_id: str, document_id: str, instance: Dict): + paragraph = Paragraph(id=uuid.uuid7(), + document_id=document_id, + content=instance.get("content"), + knowledge_id=knowledge_id, + title=instance.get("title") if 'title' in instance else '') + problem_paragraph_object_list = [ + ProblemParagraphObject(knowledge_id, document_id, paragraph.id, problem.get('content')) for problem in + (instance.get('problem_list') if 'problem_list' in instance else [])] + + return {'paragraph': paragraph, + 'problem_paragraph_object_list': problem_paragraph_object_list} + + @staticmethod + def or_get(exists_problem_list, content, knowledge_id): + exists = [row for row in exists_problem_list if row.content == content] + if len(exists) > 0: + return exists[0] + else: + return Problem(id=uuid.uuid7(), content=content, knowledge_id=knowledge_id) + + +def delete_problems_and_mappings(paragraph_ids): + problem_paragraph_mappings = ProblemParagraphMapping.objects.filter(paragraph_id__in=paragraph_ids) + problem_ids = set(problem_paragraph_mappings.values_list('problem_id', flat=True)) + + if problem_ids: + problem_paragraph_mappings.delete() + remaining_problem_counts = ProblemParagraphMapping.objects.filter(problem_id__in=problem_ids).values( + 'problem_id').annotate(count=Count('problem_id')) + remaining_problem_ids = {pc['problem_id'] for pc in remaining_problem_counts} + problem_ids_to_delete = problem_ids - remaining_problem_ids + Problem.objects.filter(id__in=problem_ids_to_delete).delete() + else: + problem_paragraph_mappings.delete() diff --git a/apps/knowledge/serializers/problem.py b/apps/knowledge/serializers/problem.py new file mode 100644 index 00000000..4b98e848 --- /dev/null +++ b/apps/knowledge/serializers/problem.py @@ -0,0 +1,15 @@ +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers + +from knowledge.models import Problem + + +class ProblemSerializer(serializers.ModelSerializer): + class Meta: + model = Problem + fields = ['id', 'content', 'knowledge_id', 'create_time', 'update_time'] + + +class ProblemInstanceSerializer(serializers.Serializer): + id = serializers.CharField(required=False, label=_('problem id')) + content = serializers.CharField(required=True, max_length=256, label=_('content')) diff --git a/apps/knowledge/sql/list_document.sql b/apps/knowledge/sql/list_document.sql new file mode 100644 index 00000000..8b7891bf --- /dev/null +++ b/apps/knowledge/sql/list_document.sql @@ -0,0 +1,11 @@ +SELECT * from ( +SELECT + "document".* , + to_json("document"."meta") as meta, + to_json("document"."status_meta") as status_meta, + (SELECT "count"("id") FROM "paragraph" WHERE document_id="document"."id") as "paragraph_count" +FROM + "document" "document" +${document_custom_sql} +) temp +${order_by_query} \ No newline at end of file diff --git a/apps/knowledge/sql/list_knowledge.sql b/apps/knowledge/sql/list_knowledge.sql new file mode 100644 index 00000000..61b48eab --- /dev/null +++ b/apps/knowledge/sql/list_knowledge.sql @@ -0,0 +1,35 @@ +SELECT + *, + to_json(meta) as meta +FROM + ( + SELECT + "temp_knowledge".*, + "document_temp"."char_length", + CASE + WHEN + "app_knowledge_temp"."count" IS NULL THEN 0 ELSE "app_knowledge_temp"."count" END AS application_mapping_count, + "document_temp".document_count FROM ( + SELECT knowledge.* + FROM + knowledge knowledge + ${knowledge_custom_sql} + UNION + SELECT + * + FROM + knowledge + WHERE + knowledge."id" IN ( + SELECT + team_member_permission.target + FROM + team_member team_member + LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id" + ${team_member_permission_custom_sql} + ) + ) temp_knowledge + LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", knowledge_id FROM "document" GROUP BY knowledge_id ) "document_temp" ON temp_knowledge."id" = "document_temp".knowledge_id + LEFT JOIN (SELECT "count"("id"),knowledge_id FROM application_knowledge_mapping GROUP BY knowledge_id) app_knowledge_temp ON temp_knowledge."id" = "app_knowledge_temp".knowledge_id + ) temp + ${default_sql} \ No newline at end of file diff --git a/apps/knowledge/sql/list_knowledge_application.sql b/apps/knowledge/sql/list_knowledge_application.sql new file mode 100644 index 00000000..9da36a3c --- /dev/null +++ b/apps/knowledge/sql/list_knowledge_application.sql @@ -0,0 +1,20 @@ +SELECT + * +FROM + application +WHERE + user_id = %s UNION +SELECT + * +FROM + application +WHERE + "id" IN ( + SELECT + team_member_permission.target + FROM + team_member team_member + LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id" + WHERE + ( "team_member_permission"."auth_target_type" = 'APPLICATION' AND "team_member_permission"."operate"::text[] @> ARRAY['USE'] AND team_member.team_id = %s AND team_member.user_id =%s ) + ) \ No newline at end of file diff --git a/apps/knowledge/sql/list_paragraph.sql b/apps/knowledge/sql/list_paragraph.sql new file mode 100644 index 00000000..366b7fe4 --- /dev/null +++ b/apps/knowledge/sql/list_paragraph.sql @@ -0,0 +1,6 @@ +SELECT + (SELECT "name" FROM "document" WHERE "id"=document_id) as document_name, + (SELECT "name" FROM "knowledge" WHERE "id"=knowledge_id) as knowledge_name, + * +FROM + "paragraph" diff --git a/apps/knowledge/sql/list_paragraph_document_name.sql b/apps/knowledge/sql/list_paragraph_document_name.sql new file mode 100644 index 00000000..a95209bf --- /dev/null +++ b/apps/knowledge/sql/list_paragraph_document_name.sql @@ -0,0 +1,5 @@ +SELECT + (SELECT "name" FROM "document" WHERE "id"=document_id) as document_name, + * +FROM + "paragraph" diff --git a/apps/knowledge/sql/list_problem.sql b/apps/knowledge/sql/list_problem.sql new file mode 100644 index 00000000..affb5133 --- /dev/null +++ b/apps/knowledge/sql/list_problem.sql @@ -0,0 +1,5 @@ +SELECT + problem.*, + (SELECT "count"("id") FROM "problem_paragraph_mapping" WHERE problem_id="problem"."id") as "paragraph_count" + FROM + problem problem diff --git a/apps/knowledge/sql/list_problem_mapping.sql b/apps/knowledge/sql/list_problem_mapping.sql new file mode 100644 index 00000000..8c8ac3c3 --- /dev/null +++ b/apps/knowledge/sql/list_problem_mapping.sql @@ -0,0 +1,2 @@ +SELECT "problem"."content",problem_paragraph_mapping.paragraph_id FROM problem problem +LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem_paragraph_mapping.problem_id=problem."id" \ No newline at end of file diff --git a/apps/knowledge/sql/update_document_char_length.sql b/apps/knowledge/sql/update_document_char_length.sql new file mode 100644 index 00000000..4a4060cd --- /dev/null +++ b/apps/knowledge/sql/update_document_char_length.sql @@ -0,0 +1,7 @@ +UPDATE "document" +SET "char_length" = ( SELECT CASE WHEN + "sum" ( "char_length" ( "content" ) ) IS NULL THEN + 0 ELSE "sum" ( "char_length" ( "content" ) ) + END FROM paragraph WHERE "document_id" = %s ) +WHERE + "id" = %s \ No newline at end of file diff --git a/apps/knowledge/sql/update_document_status_meta.sql b/apps/knowledge/sql/update_document_status_meta.sql new file mode 100644 index 00000000..6065931f --- /dev/null +++ b/apps/knowledge/sql/update_document_status_meta.sql @@ -0,0 +1,25 @@ +UPDATE "document" "document" +SET status_meta = jsonb_set ( "document".status_meta, '{aggs}', tmp.status_meta ) +FROM + ( + SELECT COALESCE + ( jsonb_agg ( jsonb_delete ( ( row_to_json ( record ) :: JSONB ), 'document_id' ) ), '[]' :: JSONB ) AS status_meta, + document_id AS document_id + FROM + ( + SELECT + "paragraph".status, + "count" ( "paragraph"."id" ), + "document"."id" AS document_id + FROM + "document" "document" + LEFT JOIN "paragraph" "paragraph" ON "document"."id" = paragraph.document_id + ${document_custom_sql} + GROUP BY + "paragraph".status, + "document"."id" + ) record + GROUP BY + document_id + ) tmp +WHERE "document".id="tmp".document_id \ No newline at end of file diff --git a/apps/knowledge/sql/update_paragraph_status.sql b/apps/knowledge/sql/update_paragraph_status.sql new file mode 100644 index 00000000..1e2fc6f0 --- /dev/null +++ b/apps/knowledge/sql/update_paragraph_status.sql @@ -0,0 +1,13 @@ +UPDATE "${table_name}" +SET status = reverse ( + SUBSTRING ( reverse ( LPAD( status, ${bit_number}, 'n' ) ) :: TEXT FROM 1 FOR ${up_index} ) || ${status_number} || SUBSTRING ( reverse ( LPAD( status, ${bit_number}, 'n' ) ) :: TEXT FROM ${next_index} ) +), +status_meta = jsonb_set ( + "${table_name}".status_meta, + '{state_time,${current_index}}', + jsonb_set ( + COALESCE ( "${table_name}".status_meta #> '{state_time,${current_index}}', jsonb_build_object ( '${status_number}', '${current_time}' ) ), + '{${status_number}}', + CONCAT ( '"', '${current_time}', '"' ) :: JSONB + ) + ) \ No newline at end of file diff --git a/apps/knowledge/task/__init__.py b/apps/knowledge/task/__init__.py index 4cc48c90..5fe428b0 100644 --- a/apps/knowledge/task/__init__.py +++ b/apps/knowledge/task/__init__.py @@ -1 +1,2 @@ from .sync import * +from .embedding import * diff --git a/apps/knowledge/task/embedding.py b/apps/knowledge/task/embedding.py new file mode 100644 index 00000000..448c5091 --- /dev/null +++ b/apps/knowledge/task/embedding.py @@ -0,0 +1,255 @@ +# coding=utf-8 + +import logging +import traceback +from typing import List + +from celery_once import QueueOnce +from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ + +from common.config.embedding_config import ModelManage +from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingKnowledgeIdArgs, \ + UpdateEmbeddingDocumentIdArgs +from knowledge.models import Document, TaskType, State +from models_provider.tools import get_model +from models_provider.models import Model +from ops import celery_app + +max_kb_error = logging.getLogger("max_kb_error") +max_kb = logging.getLogger("max_kb") + + +def get_embedding_model(model_id, exception_handler=lambda e: max_kb_error.error( + _('Failed to obtain vector model: {error} {traceback}').format( + error=str(e), + traceback=traceback.format_exc() + ))): + try: + model = QuerySet(Model).filter(id=model_id).first() + embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model)) + except Exception as e: + exception_handler(e) + raise e + return embedding_model + + +@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id']}, name='celery:embedding_by_paragraph') +def embedding_by_paragraph(paragraph_id, model_id): + embedding_model = get_embedding_model(model_id) + ListenerManagement.embedding_by_paragraph(paragraph_id, embedding_model) + + +@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:embedding_by_paragraph_data_list') +def embedding_by_paragraph_data_list(data_list, paragraph_id_list, model_id): + embedding_model = get_embedding_model(model_id) + ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model) + + +@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:embedding_by_paragraph_list') +def embedding_by_paragraph_list(paragraph_id_list, model_id): + embedding_model = get_embedding_model(model_id) + ListenerManagement.embedding_by_paragraph_list(paragraph_id_list, embedding_model) + + +@celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:embedding_by_document') +def embedding_by_document(document_id, model_id, state_list=None): + """ + 向量化文档 + @param state_list: + @param document_id: 文档id + @param model_id 向量模型 + :return: None + """ + + if state_list is None: + state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value, + State.REVOKE.value, + State.REVOKED.value, State.IGNORED.value] + + def exception_handler(e): + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, + State.FAILURE) + max_kb_error.error( + _('Failed to obtain vector model: {error} {traceback}').format( + error=str(e), + traceback=traceback.format_exc() + )) + + embedding_model = get_embedding_model(model_id, exception_handler) + ListenerManagement.embedding_by_document(document_id, embedding_model, state_list) + + +@celery_app.task(name='celery:embedding_by_document_list') +def embedding_by_document_list(document_id_list, model_id): + """ + 向量化文档 + @param document_id_list: 文档id列表 + @param model_id 向量模型 + :return: None + """ + for document_id in document_id_list: + embedding_by_document.delay(document_id, model_id) + + +@celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, name='celery:embedding_by_knowledge') +def embedding_by_knowledge(knowledge_id, model_id): + """ + 向量化知识库 + @param knowledge_id: 知识库id + @param model_id 向量模型 + :return: None + """ + max_kb.info(_('Start--->Vectorized knowledge: {knowledge_id}').format(knowledge_id=knowledge_id)) + try: + ListenerManagement.delete_embedding_by_knowledge(knowledge_id) + document_list = QuerySet(Document).filter(knowledge_id=knowledge_id) + max_kb.info(_('Knowledge documentation: {document_names}').format( + document_names=", ".join([d.name for d in document_list]))) + for document in document_list: + try: + embedding_by_document.delay(document.id, model_id) + except Exception as e: + pass + except Exception as e: + max_kb_error.error( + _('Vectorized knowledge: {knowledge_id} error {error} {traceback}'.format(knowledge_id=knowledge_id, + error=str(e), + traceback=traceback.format_exc()))) + finally: + max_kb.info(_('End--->Vectorized knowledge: {knowledge_id}').format(knowledge_id=knowledge_id)) + + +def embedding_by_problem(args, model_id): + """ + 向量话问题 + @param args: 问题对象 + @param model_id: 模型id + @return: + """ + embedding_model = get_embedding_model(model_id) + ListenerManagement.embedding_by_problem(args, embedding_model) + + +def embedding_by_data_list(args: List, model_id): + embedding_model = get_embedding_model(model_id) + ListenerManagement.embedding_by_data_list(args, embedding_model) + + +def delete_embedding_by_document(document_id): + """ + 删除指定文档id的向量 + @param document_id: 文档id + @return: None + """ + + ListenerManagement.delete_embedding_by_document(document_id) + + +def delete_embedding_by_document_list(document_id_list: List[str]): + """ + 删除指定文档列表的向量数据 + @param document_id_list: 文档id列表 + @return: None + """ + ListenerManagement.delete_embedding_by_document_list(document_id_list) + + +def delete_embedding_by_knowledge(knowledge_id): + """ + 删除指定数据集向量数据 + @param knowledge_id: 数据集id + @return: None + """ + ListenerManagement.delete_embedding_by_knowledge(knowledge_id) + + +def delete_embedding_by_paragraph(paragraph_id): + """ + 删除指定段落的向量数据 + @param paragraph_id: 段落id + @return: None + """ + ListenerManagement.delete_embedding_by_paragraph(paragraph_id) + + +def delete_embedding_by_source(source_id): + """ + 删除指定资源id的向量数据 + @param source_id: 资源id + @return: None + """ + ListenerManagement.delete_embedding_by_source(source_id) + + +def disable_embedding_by_paragraph(paragraph_id): + """ + 禁用某个段落id的向量 + @param paragraph_id: 段落id + @return: None + """ + ListenerManagement.disable_embedding_by_paragraph(paragraph_id) + + +def enable_embedding_by_paragraph(paragraph_id): + """ + 开启某个段落id的向量数据 + @param paragraph_id: 段落id + @return: None + """ + ListenerManagement.enable_embedding_by_paragraph(paragraph_id) + + +def delete_embedding_by_source_ids(source_ids: List[str]): + """ + 删除向量根据source_id_list + @param source_ids: + @return: + """ + ListenerManagement.delete_embedding_by_source_ids(source_ids) + + +def update_problem_embedding(problem_id: str, problem_content: str, model_id): + """ + 更新问题 + @param problem_id: + @param problem_content: + @param model_id: + @return: + """ + model = get_embedding_model(model_id) + ListenerManagement.update_problem(UpdateProblemArgs(problem_id, problem_content, model)) + + +def update_embedding_knowledge_id(paragraph_id_list, target_knowledge_id): + """ + 修改向量数据到指定知识库 + @param paragraph_id_list: 指定段落的向量数据 + @param target_knowledge_id: 知识库id + @return: + """ + + ListenerManagement.update_embedding_knowledge_id( + UpdateEmbeddingKnowledgeIdArgs(paragraph_id_list, target_knowledge_id)) + + +def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]): + """ + 删除指定段落列表的向量数据 + @param paragraph_ids: 段落列表 + @return: None + """ + ListenerManagement.delete_embedding_by_paragraph_ids(paragraph_ids) + + +def update_embedding_document_id(paragraph_id_list, target_document_id, target_knowledge_id, + target_embedding_model_id=None): + target_embedding_model = get_embedding_model( + target_embedding_model_id) if target_embedding_model_id is not None else None + ListenerManagement.update_embedding_document_id( + UpdateEmbeddingDocumentIdArgs(paragraph_id_list, target_document_id, target_knowledge_id, + target_embedding_model)) + + +def delete_embedding_by_knowledge_id_list(knowledge_id_list): + ListenerManagement.delete_embedding_by_knowledge_id_list(knowledge_id_list) diff --git a/apps/knowledge/task/tools.py b/apps/knowledge/task/handler.py similarity index 73% rename from apps/knowledge/task/tools.py rename to apps/knowledge/task/handler.py index 6d624c24..9e1d77ea 100644 --- a/apps/knowledge/task/tools.py +++ b/apps/knowledge/task/handler.py @@ -1,29 +1,23 @@ # coding=utf-8 -""" - @project: MaxKB - @Author:虎 - @file: tools.py - @date:2024/8/20 21:48 - @desc: -""" + import logging import re import traceback from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ from common.utils.fork import ChildLink, Fork from common.utils.split_model import get_split_model -from knowledge.models.knowledge import KnowledgeType, Document, DataSet, Status -from django.utils.translation import gettext_lazy as _ +from knowledge.models.knowledge import KnowledgeType, Document, Knowledge, Status max_kb_error = logging.getLogger("max_kb_error") max_kb = logging.getLogger("max_kb") -def get_save_handler(dataset_id, selector): - from knowledge.serializers.document_serializers import DocumentSerializers +def get_save_handler(knowledge_id, selector): + from knowledge.serializers import DocumentSerializers def handler(child_link: ChildLink, response: Fork.Response): if response.status == 200: @@ -31,7 +25,7 @@ def get_save_handler(dataset_id, selector): document_name = child_link.tag.text if child_link.tag is not None and len( child_link.tag.text.strip()) > 0 else child_link.url paragraphs = get_split_model('web.md').parse(response.content) - DocumentSerializers.Create(data={'dataset_id': dataset_id}).save( + DocumentSerializers.Create(data={'knowledge_id': knowledge_id}).save( {'name': document_name, 'paragraphs': paragraphs, 'meta': {'source_url': child_link.url, 'selector': selector}, 'type': KnowledgeType.WEB}, with_valid=True) @@ -41,9 +35,9 @@ def get_save_handler(dataset_id, selector): return handler -def get_sync_handler(dataset_id): - from knowledge.serializers.document_serializers import DocumentSerializers - dataset = QuerySet(DataSet).filter(id=dataset_id).first() +def get_sync_handler(knowledge_id): + from knowledge.serializers import DocumentSerializers + knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first() def handler(child_link: ChildLink, response: Fork.Response): if response.status == 200: @@ -52,32 +46,31 @@ def get_sync_handler(dataset_id): document_name = child_link.tag.text if child_link.tag is not None and len( child_link.tag.text.strip()) > 0 else child_link.url paragraphs = get_split_model('web.md').parse(response.content) - first = QuerySet(Document).filter(meta__source_url=child_link.url.strip(), - dataset=dataset).first() + first = QuerySet(Document).filter(meta__source_url=child_link.url.strip(), knowledge=knowledge).first() if first is not None: # 如果存在,使用文档同步 DocumentSerializers.Sync(data={'document_id': first.id}).sync() else: # 插入 - DocumentSerializers.Create(data={'dataset_id': dataset.id}).save( + DocumentSerializers.Create(data={'knowledge_id': knowledge.id}).save( {'name': document_name, 'paragraphs': paragraphs, - 'meta': {'source_url': child_link.url.strip(), 'selector': dataset.meta.get('selector')}, - 'type': Type.web}, with_valid=True) + 'meta': {'source_url': child_link.url.strip(), 'selector': knowledge.meta.get('selector')}, + 'type': KnowledgeType.WEB}, with_valid=True) except Exception as e: logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') return handler -def get_sync_web_document_handler(dataset_id): - from knowledge.serializers.document_serializers import DocumentSerializers +def get_sync_web_document_handler(knowledge_id): + from knowledge.serializers import DocumentSerializers def handler(source_url: str, selector, response: Fork.Response): if response.status == 200: try: paragraphs = get_split_model('web.md').parse(response.content) # 插入 - DocumentSerializers.Create(data={'dataset_id': dataset_id}).save( + DocumentSerializers.Create(data={'knowledge_id': knowledge_id}).save( {'name': source_url[0:128], 'paragraphs': paragraphs, 'meta': {'source_url': source_url, 'selector': selector}, 'type': KnowledgeType.WEB}, with_valid=True) @@ -85,7 +78,7 @@ def get_sync_web_document_handler(dataset_id): logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') else: Document(name=source_url[0:128], - dataset_id=dataset_id, + knowledge_id=knowledge_id, meta={'source_url': source_url, 'selector': selector}, type=KnowledgeType.WEB, char_length=0, @@ -94,9 +87,9 @@ def get_sync_web_document_handler(dataset_id): return handler -def save_problem(dataset_id, document_id, paragraph_id, problem): - from knowledge.serializers.paragraph_serializers import ParagraphSerializers - # print(f"dataset_id: {dataset_id}") +def save_problem(knowledge_id, document_id, paragraph_id, problem): + from knowledge.serializers import ParagraphSerializers + # print(f"knowledge_id: {knowledge_id}") # print(f"document_id: {document_id}") # print(f"paragraph_id: {paragraph_id}") # print(f"problem: {problem}") @@ -108,7 +101,7 @@ def save_problem(dataset_id, document_id, paragraph_id, problem): return try: ParagraphSerializers.Problem( - data={"dataset_id": dataset_id, 'document_id': document_id, + data={"knowledge_id": knowledge_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).save(instance={"content": problem}, with_valid=True) except Exception as e: max_kb_error.error(_('Association problem failed {error}').format(error=str(e))) diff --git a/apps/knowledge/task/sync.py b/apps/knowledge/task/sync.py index 99ad15c9..b3bc8bb1 100644 --- a/apps/knowledge/task/sync.py +++ b/apps/knowledge/task/sync.py @@ -12,12 +12,11 @@ import traceback from typing import List from celery_once import QueueOnce +from django.utils.translation import gettext_lazy as _ from common.utils.fork import ForkManage, Fork -from .tools import get_save_handler, get_sync_web_document_handler, get_sync_handler - from ops import celery_app -from django.utils.translation import gettext_lazy as _ +from .handler import get_save_handler, get_sync_web_document_handler, get_sync_handler max_kb_error = logging.getLogger("max_kb_error") max_kb = logging.getLogger("max_kb") diff --git a/apps/knowledge/vector/__init__.py b/apps/knowledge/vector/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/knowledge/vector/base_vector.py b/apps/knowledge/vector/base_vector.py new file mode 100644 index 00000000..07f68e4c --- /dev/null +++ b/apps/knowledge/vector/base_vector.py @@ -0,0 +1,187 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_vector.py + @date:2023/10/18 19:16 + @desc: +""" +import threading +from abc import ABC, abstractmethod +from functools import reduce +from typing import List, Dict + +from langchain_core.embeddings import Embeddings + +from common.chunk import text_to_chunk +from common.utils.common import sub_array +from knowledge.models import SourceType, SearchMode + +lock = threading.Lock() + + +def chunk_data(data: Dict): + if str(data.get('source_type')) == SourceType.PARAGRAPH.value: + text = data.get('text') + chunk_list = text_to_chunk(text) + return [{**data, 'text': chunk} for chunk in chunk_list] + return [data] + + +def chunk_data_list(data_list: List[Dict]): + result = [chunk_data(data) for data in data_list] + return reduce(lambda x, y: [*x, *y], result, []) + + +class BaseVectorStore(ABC): + vector_exists = False + + @abstractmethod + def vector_is_create(self) -> bool: + """ + 判断向量库是否创建 + :return: 是否创建向量库 + """ + pass + + @abstractmethod + def vector_create(self): + """ + 创建 向量库 + :return: + """ + pass + + def save_pre_handler(self): + """ + 插入前置处理器 主要是判断向量库是否创建 + :return: True + """ + if not BaseVectorStore.vector_exists: + if not self.vector_is_create(): + self.vector_create() + BaseVectorStore.vector_exists = True + return True + + def save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str, + is_active: bool, + embedding: Embeddings): + """ + 插入向量数据 + :param source_id: 资源id + :param knowledge_id: 知识库id + :param text: 文本 + :param source_type: 资源类型 + :param document_id: 文档id + :param is_active: 是否禁用 + :param embedding: 向量化处理器 + :param paragraph_id 段落id + :return: bool + """ + self.save_pre_handler() + data = {'document_id': document_id, 'paragraph_id': paragraph_id, 'knowledge_id': knowledge_id, + 'is_active': is_active, 'source_id': source_id, 'source_type': source_type, 'text': text} + chunk_list = chunk_data(data) + result = sub_array(chunk_list) + for child_array in result: + self._batch_save(child_array, embedding, lambda: False) + + def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_the_task_interrupted): + """ + 批量插入 + @param data_list: 数据列表 + @param embedding: 向量化处理器 + @param is_the_task_interrupted: 判断是否中断任务 + :return: bool + """ + self.save_pre_handler() + chunk_list = chunk_data_list(data_list) + result = sub_array(chunk_list) + for child_array in result: + if not is_the_task_interrupted(): + self._batch_save(child_array, embedding, is_the_task_interrupted) + else: + break + + @abstractmethod + def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str, + is_active: bool, + embedding: Embeddings): + pass + + @abstractmethod + def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted): + pass + + def search(self, query_text, knowledge_id_list: list[str], exclude_document_id_list: list[str], + exclude_paragraph_list: list[str], + is_active: bool, + embedding: Embeddings): + if knowledge_id_list is None or len(knowledge_id_list) == 0: + return [] + embedding_query = embedding.embed_query(query_text) + result = self.query(embedding_query, knowledge_id_list, exclude_document_id_list, exclude_paragraph_list, + is_active, 1, 3, 0.65) + return result[0] + + @abstractmethod + def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str], + exclude_document_id_list: list[str], + exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float, + search_mode: SearchMode): + pass + + @abstractmethod + def hit_test(self, query_text, knowledge_id: list[str], exclude_document_id_list: list[str], top_number: int, + similarity: float, + search_mode: SearchMode, + embedding: Embeddings): + pass + + @abstractmethod + def update_by_paragraph_id(self, paragraph_id: str, instance: Dict): + pass + + @abstractmethod + def update_by_paragraph_ids(self, paragraph_ids: str, instance: Dict): + pass + + @abstractmethod + def update_by_source_id(self, source_id: str, instance: Dict): + pass + + @abstractmethod + def update_by_source_ids(self, source_ids: List[str], instance: Dict): + pass + + @abstractmethod + def delete_by_knowledge_id(self, knowledge_id: str): + pass + + @abstractmethod + def delete_by_document_id(self, document_id: str): + pass + + @abstractmethod + def delete_by_document_id_list(self, document_id_list: List[str]): + pass + + @abstractmethod + def delete_by_knowledge_id_list(self, knowledge_id_list: List[str]): + pass + + @abstractmethod + def delete_by_source_id(self, source_id: str, source_type: str): + pass + + @abstractmethod + def delete_by_source_ids(self, source_ids: List[str], source_type: str): + pass + + @abstractmethod + def delete_by_paragraph_id(self, paragraph_id: str): + pass + + @abstractmethod + def delete_by_paragraph_ids(self, paragraph_ids: List[str]): + pass diff --git a/apps/knowledge/vector/pg_vector.py b/apps/knowledge/vector/pg_vector.py new file mode 100644 index 00000000..f6572c66 --- /dev/null +++ b/apps/knowledge/vector/pg_vector.py @@ -0,0 +1,222 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: pg_vector.py + @date:2023/10/19 15:28 + @desc: +""" +import json +import os +from abc import ABC, abstractmethod +from typing import Dict, List + +import uuid_utils.compat as uuid +from common.utils.ts_vecto_util import to_ts_vector, to_query +from django.contrib.postgres.search import SearchVector +from django.db.models import QuerySet, Value +from langchain_core.embeddings import Embeddings + +from common.db.search import generate_sql_by_query_dict +from common.db.sql_execute import select_list +from common.utils.common import get_file_content +from knowledge.models import Embedding, SearchMode, SourceType +from knowledge.vector.base_vector import BaseVectorStore +from maxkb.conf import PROJECT_DIR + + +class PGVector(BaseVectorStore): + + def delete_by_source_ids(self, source_ids: List[str], source_type: str): + if len(source_ids) == 0: + return + QuerySet(Embedding).filter(source_id__in=source_ids, source_type=source_type).delete() + + def update_by_source_ids(self, source_ids: List[str], instance: Dict): + QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance) + + def vector_is_create(self) -> bool: + # 项目启动默认是创建好的 不需要再创建 + return True + + def vector_create(self): + return True + + def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str, + is_active: bool, + embedding: Embeddings): + text_embedding = embedding.embed_query(text) + embedding = Embedding(id=uuid.uuid7(), + knowledge_id=knowledge_id, + document_id=document_id, + is_active=is_active, + paragraph_id=paragraph_id, + source_id=source_id, + embedding=text_embedding, + source_type=source_type, + search_vector=to_ts_vector(text)) + embedding.save() + return True + + def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted): + texts = [row.get('text') for row in text_list] + embeddings = embedding.embed_documents(texts) + embedding_list = [Embedding(id=uuid.uuid7(), + document_id=text_list[index].get('document_id'), + paragraph_id=text_list[index].get('paragraph_id'), + knowledge_id=text_list[index].get('knowledge_id'), + is_active=text_list[index].get('is_active', True), + source_id=text_list[index].get('source_id'), + source_type=text_list[index].get('source_type'), + embedding=embeddings[index], + search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))) for + index in + range(0, len(texts))] + if not is_the_task_interrupted(): + QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None + return True + + def hit_test(self, query_text, knowledge_id_list: list[str], exclude_document_id_list: list[str], top_number: int, + similarity: float, + search_mode: SearchMode, + embedding: Embeddings): + if knowledge_id_list is None or len(knowledge_id_list) == 0: + return [] + exclude_dict = {} + embedding_query = embedding.embed_query(query_text) + query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=True) + if exclude_document_id_list is not None and len(exclude_document_id_list) > 0: + exclude_dict.__setitem__('document_id__in', exclude_document_id_list) + query_set = query_set.exclude(**exclude_dict) + for search_handle in search_handle_list: + if search_handle.support(search_mode): + return search_handle.handle(query_set, query_text, embedding_query, top_number, similarity, search_mode) + + def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str], + exclude_document_id_list: list[str], + exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float, + search_mode: SearchMode): + exclude_dict = {} + if knowledge_id_list is None or len(knowledge_id_list) == 0: + return [] + query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=is_active) + if exclude_document_id_list is not None and len(exclude_document_id_list) > 0: + query_set = query_set.exclude(document_id__in=exclude_document_id_list) + if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0: + query_set = query_set.exclude(paragraph_id__in=exclude_paragraph_list) + query_set = query_set.exclude(**exclude_dict) + for search_handle in search_handle_list: + if search_handle.support(search_mode): + return search_handle.handle(query_set, query_text, query_embedding, top_n, similarity, search_mode) + + def update_by_source_id(self, source_id: str, instance: Dict): + QuerySet(Embedding).filter(source_id=source_id).update(**instance) + + def update_by_paragraph_id(self, paragraph_id: str, instance: Dict): + QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance) + + def update_by_paragraph_ids(self, paragraph_id: str, instance: Dict): + QuerySet(Embedding).filter(paragraph_id__in=paragraph_id).update(**instance) + + def delete_by_knowledge_id(self, knowledge_id: str): + QuerySet(Embedding).filter(knowledge_id=knowledge_id).delete() + + def delete_by_knowledge_id_list(self, knowledge_id_list: List[str]): + QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list).delete() + + def delete_by_document_id(self, document_id: str): + QuerySet(Embedding).filter(document_id=document_id).delete() + return True + + def delete_by_document_id_list(self, document_id_list: List[str]): + if len(document_id_list) == 0: + return True + return QuerySet(Embedding).filter(document_id__in=document_id_list).delete() + + def delete_by_source_id(self, source_id: str, source_type: str): + QuerySet(Embedding).filter(source_id=source_id, source_type=source_type).delete() + return True + + def delete_by_paragraph_id(self, paragraph_id: str): + QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete() + + def delete_by_paragraph_ids(self, paragraph_ids: List[str]): + QuerySet(Embedding).filter(paragraph_id__in=paragraph_ids).delete() + + +class ISearch(ABC): + @abstractmethod + def support(self, search_mode: SearchMode): + pass + + @abstractmethod + def handle(self, query_set, query_text, query_embedding, top_number: int, + similarity: float, search_mode: SearchMode): + pass + + +class EmbeddingSearch(ISearch): + def handle(self, + query_set, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', + 'embedding_search.sql')), + with_table_name=True) + embedding_model = select_list(exec_sql, + [json.dumps(query_embedding), *exec_params, similarity, top_number]) + return embedding_model + + def support(self, search_mode: SearchMode): + return search_mode.value == SearchMode.embedding.value + + +class KeywordsSearch(ISearch): + def handle(self, + query_set, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + exec_sql, exec_params = generate_sql_by_query_dict({'keywords_query': query_set}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', + 'keywords_search.sql')), + with_table_name=True) + embedding_model = select_list(exec_sql, + [to_query(query_text), *exec_params, similarity, top_number]) + return embedding_model + + def support(self, search_mode: SearchMode): + return search_mode.value == SearchMode.keywords.value + + +class BlendSearch(ISearch): + def handle(self, + query_set, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', + 'blend_search.sql')), + with_table_name=True) + embedding_model = select_list(exec_sql, + [json.dumps(query_embedding), to_query(query_text), *exec_params, similarity, + top_number]) + return embedding_model + + def support(self, search_mode: SearchMode): + return search_mode.value == SearchMode.blend.value + + +search_handle_list = [EmbeddingSearch(), KeywordsSearch(), BlendSearch()] diff --git a/apps/tools/migrations/0001_initial.py b/apps/tools/migrations/0001_initial.py index eaf20949..e6887994 100644 --- a/apps/tools/migrations/0001_initial.py +++ b/apps/tools/migrations/0001_initial.py @@ -36,7 +36,7 @@ class Migration(migrations.Migration): ('tree_id', models.PositiveIntegerField(db_index=True, editable=False)), ('level', models.PositiveIntegerField(editable=False)), ('parent', - mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, + mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING, related_name='children', to='tools.toolfolder')), ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user', verbose_name='用户id')), @@ -73,7 +73,7 @@ class Migration(migrations.Migration): ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user', verbose_name='用户id')), ('folder', - models.ForeignKey(default='root', on_delete=django.db.models.deletion.CASCADE, to='tools.toolfolder', + models.ForeignKey(default='root', on_delete=django.db.models.deletion.DO_NOTHING, to='tools.toolfolder', verbose_name='文件夹id')), ], options={ diff --git a/apps/tools/models/tool.py b/apps/tools/models/tool.py index 5cd20404..8ae633c1 100644 --- a/apps/tools/models/tool.py +++ b/apps/tools/models/tool.py @@ -12,7 +12,7 @@ class ToolFolder(MPTTModel, AppModelMixin): name = models.CharField(max_length=64, verbose_name="文件夹名称") user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="用户id") workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True) - parent = TreeForeignKey('self', on_delete=models.CASCADE, null=True, blank=True, related_name='children') + parent = TreeForeignKey('self', on_delete=models.DO_NOTHING, null=True, blank=True, related_name='children') class Meta: db_table = "tool_folder" @@ -46,7 +46,7 @@ class Tool(AppModelMixin): tool_type = models.CharField(max_length=20, verbose_name='工具类型', choices=ToolType.choices, default=ToolType.CUSTOM, db_index=True) template_id = models.UUIDField(max_length=128, verbose_name="模版id", null=True, default=None) - folder = models.ForeignKey(ToolFolder, on_delete=models.CASCADE, verbose_name="文件夹id", default='root') + folder = models.ForeignKey(ToolFolder, on_delete=models.DO_NOTHING, verbose_name="文件夹id", default='root') workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True) init_params = models.CharField(max_length=102400, verbose_name="初始化参数", null=True)