feat: add initial implementation of document and paragraph models with serializers
This commit is contained in:
parent
8c362b0f99
commit
770089e432
18
apps/common/chunk/__init__.py
Normal file
18
apps/common/chunk/__init__.py
Normal file
@ -0,0 +1,18 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/7/23 17:03
|
||||
@desc:
|
||||
"""
|
||||
from common.chunk.impl.mark_chunk_handle import MarkChunkHandle
|
||||
|
||||
handles = [MarkChunkHandle()]
|
||||
|
||||
|
||||
def text_to_chunk(text: str):
|
||||
chunk_list = [text]
|
||||
for handle in handles:
|
||||
chunk_list = handle.handle(chunk_list)
|
||||
return chunk_list
|
||||
16
apps/common/chunk/i_chunk_handle.py
Normal file
16
apps/common/chunk/i_chunk_handle.py
Normal file
@ -0,0 +1,16 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: i_chunk_handle.py
|
||||
@date:2024/7/23 16:51
|
||||
@desc:
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class IChunkHandle(ABC):
|
||||
@abstractmethod
|
||||
def handle(self, chunk_list: List[str]):
|
||||
pass
|
||||
40
apps/common/chunk/impl/mark_chunk_handle.py
Normal file
40
apps/common/chunk/impl/mark_chunk_handle.py
Normal file
@ -0,0 +1,40 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: mark_chunk_handle.py
|
||||
@date:2024/7/23 16:52
|
||||
@desc:
|
||||
"""
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from common.chunk.i_chunk_handle import IChunkHandle
|
||||
|
||||
max_chunk_len = 256
|
||||
split_chunk_pattern = r'.{1,%d}[。| |\\.|!|;|;|!|\n]' % max_chunk_len
|
||||
max_chunk_pattern = r'.{1,%d}' % max_chunk_len
|
||||
|
||||
|
||||
class MarkChunkHandle(IChunkHandle):
|
||||
def handle(self, chunk_list: List[str]):
|
||||
result = []
|
||||
for chunk in chunk_list:
|
||||
chunk_result = re.findall(split_chunk_pattern, chunk, flags=re.DOTALL)
|
||||
for c_r in chunk_result:
|
||||
if len(c_r.strip()) > 0:
|
||||
result.append(c_r.strip())
|
||||
|
||||
other_chunk_list = re.split(split_chunk_pattern, chunk, flags=re.DOTALL)
|
||||
for other_chunk in other_chunk_list:
|
||||
if len(other_chunk) > 0:
|
||||
if len(other_chunk) < max_chunk_len:
|
||||
if len(other_chunk.strip()) > 0:
|
||||
result.append(other_chunk.strip())
|
||||
else:
|
||||
max_chunk_list = re.findall(max_chunk_pattern, other_chunk, flags=re.DOTALL)
|
||||
for m_c in max_chunk_list:
|
||||
if len(m_c.strip()) > 0:
|
||||
result.append(m_c.strip())
|
||||
|
||||
return result
|
||||
@ -47,20 +47,20 @@ class ModelManage:
|
||||
ModelManage.cache.delete(_id)
|
||||
|
||||
|
||||
# class VectorStore:
|
||||
# from embedding.vector.pg_vector import PGVector
|
||||
# from embedding.vector.base_vector import BaseVectorStore
|
||||
# instance_map = {
|
||||
# 'pg_vector': PGVector,
|
||||
# }
|
||||
# instance = None
|
||||
#
|
||||
# @staticmethod
|
||||
# def get_embedding_vector() -> BaseVectorStore:
|
||||
# from embedding.vector.pg_vector import PGVector
|
||||
# if VectorStore.instance is None:
|
||||
# from maxkb.const import CONFIG
|
||||
# vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"),
|
||||
# PGVector)
|
||||
# VectorStore.instance = vector_store_class()
|
||||
# return VectorStore.instance
|
||||
class VectorStore:
|
||||
from knowledge.vector.pg_vector import PGVector
|
||||
from knowledge.vector.base_vector import BaseVectorStore
|
||||
instance_map = {
|
||||
'pg_vector': PGVector,
|
||||
}
|
||||
instance = None
|
||||
|
||||
@staticmethod
|
||||
def get_embedding_vector() -> BaseVectorStore:
|
||||
from knowledge.vector.pg_vector import PGVector
|
||||
if VectorStore.instance is None:
|
||||
from maxkb.const import CONFIG
|
||||
vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"),
|
||||
PGVector)
|
||||
VectorStore.instance = vector_store_class()
|
||||
return VectorStore.instance
|
||||
|
||||
30
apps/common/event/__init__.py
Normal file
30
apps/common/event/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2023/11/10 10:43
|
||||
@desc:
|
||||
"""
|
||||
from models_provider.models import Model, Status
|
||||
from .listener_manage import *
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from ..db.sql_execute import update_execute
|
||||
from common.lock.impl.file_lock import FileLock
|
||||
|
||||
lock = FileLock()
|
||||
update_document_status_sql = """
|
||||
UPDATE "public"."document"
|
||||
SET status ="replace"("replace"("replace"(status, '1', '3'), '0', '3'), '4', '3')
|
||||
WHERE status ~ '1|0|4'
|
||||
"""
|
||||
|
||||
|
||||
def run():
|
||||
if lock.try_lock('event_init', 30 * 30):
|
||||
try:
|
||||
QuerySet(Model).filter(status=Status.DOWNLOAD).update(status=Status.ERROR, meta={'message': _( 'The download process was interrupted, please try again')})
|
||||
update_execute(update_document_status_sql, [])
|
||||
finally:
|
||||
lock.un_lock('event_init')
|
||||
50
apps/common/event/common.py
Normal file
50
apps/common/event/common.py
Normal file
@ -0,0 +1,50 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: common.py
|
||||
@date:2023/11/10 10:41
|
||||
@desc:
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from django.core.cache.backends.locmem import LocMemCache
|
||||
|
||||
work_thread_pool = ThreadPoolExecutor(5)
|
||||
|
||||
embedding_thread_pool = ThreadPoolExecutor(3)
|
||||
|
||||
memory_cache = LocMemCache('task', {"OPTIONS": {"MAX_ENTRIES": 1000}})
|
||||
|
||||
|
||||
def poxy(poxy_function):
|
||||
def inner(args, **keywords):
|
||||
work_thread_pool.submit(poxy_function, args, **keywords)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def get_cache_key(poxy_function, args):
|
||||
return poxy_function.__name__ + str(args)
|
||||
|
||||
|
||||
def get_cache_poxy_function(poxy_function, cache_key):
|
||||
def fun(args, **keywords):
|
||||
try:
|
||||
poxy_function(args, **keywords)
|
||||
finally:
|
||||
memory_cache.delete(cache_key)
|
||||
|
||||
return fun
|
||||
|
||||
|
||||
def embedding_poxy(poxy_function):
|
||||
def inner(*args, **keywords):
|
||||
key = get_cache_key(poxy_function, args)
|
||||
if memory_cache.has_key(key):
|
||||
return
|
||||
memory_cache.add(key, None)
|
||||
f = get_cache_poxy_function(poxy_function, key)
|
||||
embedding_thread_pool.submit(f, args, **keywords)
|
||||
|
||||
return inner
|
||||
385
apps/common/event/listener_manage.py
Normal file
385
apps/common/event/listener_manage.py
Normal file
@ -0,0 +1,385 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: listener_manage.py
|
||||
@date:2023/10/20 14:01
|
||||
@desc:
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import datetime
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
import django.db.models
|
||||
from django.db.models import QuerySet
|
||||
from django.db.models.functions import Substr, Reverse
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from common.config.embedding_config import VectorStore
|
||||
from common.db.search import native_search, get_dynamics_model, native_update
|
||||
from common.utils.common import get_file_content
|
||||
from common.utils.lock import try_lock, un_lock
|
||||
from common.utils.page_utils import page_desc
|
||||
from knowledge.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State,SourceType, SearchMode
|
||||
from maxkb.conf import (PROJECT_DIR)
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
max_kb_error = logging.getLogger(__file__)
|
||||
max_kb = logging.getLogger(__file__)
|
||||
lock = threading.Lock()
|
||||
|
||||
|
||||
class SyncWebKnowledgeArgs:
|
||||
def __init__(self, lock_key: str, url: str, selector: str, handler):
|
||||
self.lock_key = lock_key
|
||||
self.url = url
|
||||
self.selector = selector
|
||||
self.handler = handler
|
||||
|
||||
|
||||
class SyncWebDocumentArgs:
|
||||
def __init__(self, source_url_list: List[str], selector: str, handler):
|
||||
self.source_url_list = source_url_list
|
||||
self.selector = selector
|
||||
self.handler = handler
|
||||
|
||||
|
||||
class UpdateProblemArgs:
|
||||
def __init__(self, problem_id: str, problem_content: str, embedding_model: Embeddings):
|
||||
self.problem_id = problem_id
|
||||
self.problem_content = problem_content
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
|
||||
class UpdateEmbeddingKnowledgeIdArgs:
|
||||
def __init__(self, paragraph_id_list: List[str], target_knowledge_id: str):
|
||||
self.paragraph_id_list = paragraph_id_list
|
||||
self.target_knowledge_id = target_knowledge_id
|
||||
|
||||
|
||||
class UpdateEmbeddingDocumentIdArgs:
|
||||
def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_knowledge_id: str,
|
||||
target_embedding_model: Embeddings = None):
|
||||
self.paragraph_id_list = paragraph_id_list
|
||||
self.target_document_id = target_document_id
|
||||
self.target_knowledge_id = target_knowledge_id
|
||||
self.target_embedding_model = target_embedding_model
|
||||
|
||||
|
||||
class ListenerManagement:
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_problem(args, embedding_model: Embeddings):
|
||||
VectorStore.get_embedding_vector().save(**args, embedding=embedding_model)
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_paragraph_list(paragraph_id_list, embedding_model: Embeddings):
|
||||
try:
|
||||
data_list = native_search(
|
||||
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
|
||||
**{'paragraph.id__in': paragraph_id_list}),
|
||||
'paragraph': QuerySet(Paragraph).filter(id__in=paragraph_id_list)},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
|
||||
ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list=paragraph_id_list,
|
||||
embedding_model=embedding_model)
|
||||
except Exception as e:
|
||||
max_kb_error.error(_('Query vector data: {paragraph_id_list} error {error} {traceback}').format(
|
||||
paragraph_id_list=paragraph_id_list, error=str(e), traceback=traceback.format_exc()))
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings):
|
||||
max_kb.info(_('Start--->Embedding paragraph: {paragraph_id_list}').format(paragraph_id_list=paragraph_id_list))
|
||||
status = Status.success
|
||||
try:
|
||||
# 删除段落
|
||||
VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_id_list)
|
||||
|
||||
def is_save_function():
|
||||
return QuerySet(Paragraph).filter(id__in=paragraph_id_list).exists()
|
||||
|
||||
# 批量向量化
|
||||
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function)
|
||||
except Exception as e:
|
||||
max_kb_error.error(_('Vectorized paragraph: {paragraph_id_list} error {error} {traceback}').format(
|
||||
paragraph_id_list=paragraph_id_list, error=str(e), traceback=traceback.format_exc()))
|
||||
status = Status.error
|
||||
finally:
|
||||
QuerySet(Paragraph).filter(id__in=paragraph_id_list).update(**{'status': status})
|
||||
max_kb.info(
|
||||
_('End--->Embedding paragraph: {paragraph_id_list}').format(paragraph_id_list=paragraph_id_list))
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
|
||||
"""
|
||||
向量化段落 根据段落id
|
||||
@param paragraph_id: 段落id
|
||||
@param embedding_model: 向量模型
|
||||
"""
|
||||
max_kb.info(_('Start--->Embedding paragraph: {paragraph_id}').format(paragraph_id=paragraph_id))
|
||||
# 更新到开始状态
|
||||
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, State.STARTED)
|
||||
try:
|
||||
data_list = native_search(
|
||||
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
|
||||
**{'paragraph.id': paragraph_id}),
|
||||
'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
|
||||
# 删除段落
|
||||
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
|
||||
|
||||
def is_the_task_interrupted():
|
||||
_paragraph = QuerySet(Paragraph).filter(id=paragraph_id).first()
|
||||
if _paragraph is None or Status(_paragraph.status)[TaskType.EMBEDDING] == State.REVOKE:
|
||||
return True
|
||||
return False
|
||||
|
||||
# 批量向量化
|
||||
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_the_task_interrupted)
|
||||
# 更新到开始状态
|
||||
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING,
|
||||
State.SUCCESS)
|
||||
except Exception as e:
|
||||
max_kb_error.error(_('Vectorized paragraph: {paragraph_id} error {error} {traceback}').format(
|
||||
paragraph_id=paragraph_id, error=str(e), traceback=traceback.format_exc()))
|
||||
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING,
|
||||
State.FAILURE)
|
||||
finally:
|
||||
max_kb.info(_('End--->Embedding paragraph: {paragraph_id}').format(paragraph_id=paragraph_id))
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_data_list(data_list: List, embedding_model: Embeddings):
|
||||
# 批量向量化
|
||||
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True)
|
||||
|
||||
@staticmethod
|
||||
def get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, post_apply=lambda: None):
|
||||
def embedding_paragraph_apply(paragraph_list):
|
||||
for paragraph in paragraph_list:
|
||||
if is_the_task_interrupted():
|
||||
break
|
||||
ListenerManagement.embedding_by_paragraph(str(paragraph.get('id')), embedding_model)
|
||||
post_apply()
|
||||
|
||||
return embedding_paragraph_apply
|
||||
|
||||
@staticmethod
|
||||
def get_aggregation_document_status(document_id):
|
||||
def aggregation_document_status():
|
||||
pass
|
||||
sql = get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_document_status_meta.sql'))
|
||||
native_update({'document_custom_sql': QuerySet(Document).filter(id=document_id)}, sql, with_table_name=True)
|
||||
|
||||
return aggregation_document_status
|
||||
|
||||
@staticmethod
|
||||
def get_aggregation_document_status_by_knowledge_id(knowledge_id):
|
||||
def aggregation_document_status():
|
||||
sql = get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_document_status_meta.sql'))
|
||||
native_update({'document_custom_sql': QuerySet(Document).filter(knowledge_id=knowledge_id)}, sql,
|
||||
with_table_name=True)
|
||||
|
||||
return aggregation_document_status
|
||||
|
||||
@staticmethod
|
||||
def get_aggregation_document_status_by_query_set(queryset):
|
||||
def aggregation_document_status():
|
||||
sql = get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_document_status_meta.sql'))
|
||||
native_update({'document_custom_sql': queryset}, sql, with_table_name=True)
|
||||
|
||||
return aggregation_document_status
|
||||
|
||||
@staticmethod
|
||||
def post_update_document_status(document_id, task_type: TaskType):
|
||||
_document = QuerySet(Document).filter(id=document_id).first()
|
||||
|
||||
status = Status(_document.status)
|
||||
if status[task_type] == State.REVOKE:
|
||||
status[task_type] = State.REVOKED
|
||||
else:
|
||||
status[task_type] = State.SUCCESS
|
||||
for item in _document.status_meta.get('aggs', []):
|
||||
agg_status = item.get('status')
|
||||
agg_count = item.get('count')
|
||||
if Status(agg_status)[task_type] == State.FAILURE and agg_count > 0:
|
||||
status[task_type] = State.FAILURE
|
||||
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), task_type, status[task_type])
|
||||
|
||||
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
|
||||
reversed_status=Reverse('status'),
|
||||
task_type_status=Substr('reversed_status', task_type.value,
|
||||
task_type.value),
|
||||
).filter(task_type_status=State.REVOKE.value).filter(document_id=document_id).values('id'),
|
||||
task_type,
|
||||
State.REVOKED)
|
||||
|
||||
@staticmethod
|
||||
def update_status(query_set: QuerySet, taskType: TaskType, state: State):
|
||||
exec_sql = get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_paragraph_status.sql'))
|
||||
bit_number = len(TaskType)
|
||||
up_index = taskType.value - 1
|
||||
next_index = taskType.value + 1
|
||||
current_index = taskType.value
|
||||
status_number = state.value
|
||||
current_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') + '+00'
|
||||
params_dict = {'${bit_number}': bit_number, '${up_index}': up_index,
|
||||
'${status_number}': status_number, '${next_index}': next_index,
|
||||
'${table_name}': query_set.model._meta.db_table, '${current_index}': current_index,
|
||||
'${current_time}': current_time}
|
||||
for key in params_dict:
|
||||
_value_ = params_dict[key]
|
||||
exec_sql = exec_sql.replace(key, str(_value_))
|
||||
lock.acquire()
|
||||
try:
|
||||
native_update(query_set, exec_sql)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_document(document_id, embedding_model: Embeddings, state_list=None):
|
||||
"""
|
||||
向量化文档
|
||||
@param state_list:
|
||||
@param document_id: 文档id
|
||||
@param embedding_model 向量模型
|
||||
:return: None
|
||||
"""
|
||||
if state_list is None:
|
||||
state_list = [State.PENDING, State.SUCCESS, State.FAILURE, State.REVOKE, State.REVOKED]
|
||||
if not try_lock('embedding' + str(document_id)):
|
||||
return
|
||||
try:
|
||||
def is_the_task_interrupted():
|
||||
document = QuerySet(Document).filter(id=document_id).first()
|
||||
if document is None or Status(document.status)[TaskType.EMBEDDING] == State.REVOKE:
|
||||
return True
|
||||
return False
|
||||
|
||||
if is_the_task_interrupted():
|
||||
return
|
||||
max_kb.info(_('Start--->Embedding document: {document_id}').format(document_id=document_id)
|
||||
)
|
||||
# 批量修改状态为PADDING
|
||||
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
|
||||
State.STARTED)
|
||||
|
||||
|
||||
# 根据段落进行向量化处理
|
||||
page_desc(QuerySet(Paragraph)
|
||||
.annotate(
|
||||
reversed_status=Reverse('status'),
|
||||
task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value,
|
||||
1),
|
||||
).filter(task_type_status__in=state_list, document_id=document_id)
|
||||
.values('id'), 5,
|
||||
ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted,
|
||||
ListenerManagement.get_aggregation_document_status(
|
||||
document_id)),
|
||||
is_the_task_interrupted)
|
||||
except Exception as e:
|
||||
max_kb_error.error(_('Vectorized document: {document_id} error {error} {traceback}').format(
|
||||
document_id=document_id, error=str(e), traceback=traceback.format_exc()))
|
||||
finally:
|
||||
ListenerManagement.post_update_document_status(document_id, TaskType.EMBEDDING)
|
||||
ListenerManagement.get_aggregation_document_status(document_id)()
|
||||
max_kb.info(_('End--->Embedding document: {document_id}').format(document_id=document_id))
|
||||
un_lock('embedding' + str(document_id))
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_knowledge(knowledge_id, embedding_model: Embeddings):
|
||||
"""
|
||||
向量化知识库
|
||||
@param knowledge_id: 知识库id
|
||||
@param embedding_model 向量模型
|
||||
:return: None
|
||||
"""
|
||||
max_kb.info(_('Start--->Embedding knowledge: {knowledge_id}').format(knowledge_id=knowledge_id))
|
||||
try:
|
||||
ListenerManagement.delete_embedding_by_knowledge(knowledge_id)
|
||||
document_list = QuerySet(Document).filter(knowledge_id=knowledge_id)
|
||||
max_kb.info(_('Start--->Embedding document: {document_list}').format(document_list=document_list))
|
||||
for document in document_list:
|
||||
ListenerManagement.embedding_by_document(document.id, embedding_model=embedding_model)
|
||||
except Exception as e:
|
||||
max_kb_error.error(_('Vectorized knowledge: {knowledge_id} error {error} {traceback}').format(
|
||||
knowledge_id=knowledge_id, error=str(e), traceback=traceback.format_exc()))
|
||||
finally:
|
||||
max_kb.info(_('End--->Embedding knowledge: {knowledge_id}').format(knowledge_id=knowledge_id))
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_document(document_id):
|
||||
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_document_list(document_id_list: List[str]):
|
||||
VectorStore.get_embedding_vector().delete_by_document_id_list(document_id_list)
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_knowledge(knowledge_id):
|
||||
VectorStore.get_embedding_vector().delete_by_knowledge_id(knowledge_id)
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_paragraph(paragraph_id):
|
||||
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_source(source_id):
|
||||
VectorStore.get_embedding_vector().delete_by_source_id(source_id, SourceType.PROBLEM)
|
||||
|
||||
@staticmethod
|
||||
def disable_embedding_by_paragraph(paragraph_id):
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': False})
|
||||
|
||||
@staticmethod
|
||||
def enable_embedding_by_paragraph(paragraph_id):
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True})
|
||||
|
||||
@staticmethod
|
||||
def update_problem(args: UpdateProblemArgs):
|
||||
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(problem_id=args.problem_id)
|
||||
embed_value = args.embedding_model.embed_query(args.problem_content)
|
||||
VectorStore.get_embedding_vector().update_by_source_ids([v.id for v in problem_paragraph_mapping_list],
|
||||
{'embedding': embed_value})
|
||||
|
||||
@staticmethod
|
||||
def update_embedding_knowledge_id(args: UpdateEmbeddingKnowledgeIdArgs):
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
|
||||
{'knowledge_id': args.target_knowledge_id})
|
||||
|
||||
@staticmethod
|
||||
def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs):
|
||||
if args.target_embedding_model is None:
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
|
||||
{'document_id': args.target_document_id,
|
||||
'knowledge_id': args.target_knowledge_id})
|
||||
else:
|
||||
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
|
||||
embedding_model=args.target_embedding_model)
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_source_ids(source_ids: List[str]):
|
||||
VectorStore.get_embedding_vector().delete_by_source_ids(source_ids, SourceType.PROBLEM)
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]):
|
||||
VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_ids)
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_knowledge_id_list(source_ids: List[str]):
|
||||
VectorStore.get_embedding_vector().delete_by_knowledge_id_list(source_ids)
|
||||
|
||||
@staticmethod
|
||||
def hit_test(query_text, knowledge_id: list[str], exclude_document_id_list: list[str], top_number: int,
|
||||
similarity: float,
|
||||
search_mode: SearchMode,
|
||||
embedding: Embeddings):
|
||||
return VectorStore.get_embedding_vector().hit_test(query_text, knowledge_id, exclude_document_id_list, top_number,
|
||||
similarity, search_mode, embedding)
|
||||
20
apps/common/lock/base_lock.py
Normal file
20
apps/common/lock/base_lock.py
Normal file
@ -0,0 +1,20 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: base_lock.py
|
||||
@date:2024/8/20 10:33
|
||||
@desc:
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseLock(ABC):
|
||||
@abstractmethod
|
||||
def try_lock(self, key, timeout):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def un_lock(self, key):
|
||||
pass
|
||||
77
apps/common/lock/impl/file_lock.py
Normal file
77
apps/common/lock/impl/file_lock.py
Normal file
@ -0,0 +1,77 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: file_lock.py
|
||||
@date:2024/8/20 10:48
|
||||
@desc:
|
||||
"""
|
||||
import errno
|
||||
import hashlib
|
||||
import os
|
||||
import time
|
||||
|
||||
import six
|
||||
|
||||
from common.lock.base_lock import BaseLock
|
||||
from maxkb.const import PROJECT_DIR
|
||||
|
||||
|
||||
def key_to_lock_name(key):
|
||||
"""
|
||||
Combine part of a key with its hash to prevent very long filenames
|
||||
"""
|
||||
MAX_LENGTH = 50
|
||||
key_hash = hashlib.md5(six.b(key)).hexdigest()
|
||||
lock_name = key[:MAX_LENGTH - len(key_hash) - 1] + '_' + key_hash
|
||||
return lock_name
|
||||
|
||||
|
||||
class FileLock(BaseLock):
|
||||
"""
|
||||
File locking backend.
|
||||
"""
|
||||
|
||||
def __init__(self, settings=None):
|
||||
if settings is None:
|
||||
settings = {}
|
||||
self.location = settings.get('location')
|
||||
if self.location is None:
|
||||
self.location = os.path.join(PROJECT_DIR, 'data', 'lock')
|
||||
try:
|
||||
os.makedirs(self.location)
|
||||
except OSError as error:
|
||||
# Directory exists?
|
||||
if error.errno != errno.EEXIST:
|
||||
# Re-raise unexpected OSError
|
||||
raise
|
||||
|
||||
def _get_lock_path(self, key):
|
||||
lock_name = key_to_lock_name(key)
|
||||
return os.path.join(self.location, lock_name)
|
||||
|
||||
def try_lock(self, key, timeout):
|
||||
lock_path = self._get_lock_path(key)
|
||||
try:
|
||||
# 创建锁文件,如果没创建成功则拿不到
|
||||
fd = os.open(lock_path, os.O_CREAT | os.O_EXCL)
|
||||
except OSError as error:
|
||||
if error.errno == errno.EEXIST:
|
||||
# File already exists, check its modification time
|
||||
mtime = os.path.getmtime(lock_path)
|
||||
ttl = mtime + timeout - time.time()
|
||||
if ttl > 0:
|
||||
return False
|
||||
else:
|
||||
# 如果超时时间已到,直接上锁成功继续执行
|
||||
os.utime(lock_path, None)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
os.close(fd)
|
||||
return True
|
||||
|
||||
def un_lock(self, key):
|
||||
lock_path = self._get_lock_path(key)
|
||||
os.remove(lock_path)
|
||||
@ -124,6 +124,18 @@ def get_file_content(path):
|
||||
content = file.read()
|
||||
return content
|
||||
|
||||
def sub_array(array: List, item_num=10):
|
||||
result = []
|
||||
temp = []
|
||||
for item in array:
|
||||
temp.append(item)
|
||||
if len(temp) >= item_num:
|
||||
result.append(temp)
|
||||
temp = []
|
||||
if len(temp) > 0:
|
||||
result.append(temp)
|
||||
return result
|
||||
|
||||
|
||||
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
|
||||
content_type, _ = mimetypes.guess_type(file_name)
|
||||
@ -233,3 +245,15 @@ def valid_license(model=None, count=None, message=None):
|
||||
return run
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def post(post_function):
|
||||
def inner(func):
|
||||
def run(*args, **kwargs):
|
||||
result = func(*args, **kwargs)
|
||||
return post_function(*result)
|
||||
|
||||
return run
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
53
apps/common/utils/lock.py
Normal file
53
apps/common/utils/lock.py
Normal file
@ -0,0 +1,53 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: qabot
|
||||
@Author:虎
|
||||
@file: lock.py
|
||||
@date:2023/9/11 11:45
|
||||
@desc:
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
from django.core.cache import caches
|
||||
|
||||
memory_cache = caches['default']
|
||||
|
||||
|
||||
def try_lock(key: str, timeout=None):
|
||||
"""
|
||||
获取锁
|
||||
:param key: 获取锁 key
|
||||
:param timeout 超时时间
|
||||
:return: 是否获取到锁
|
||||
"""
|
||||
return memory_cache.add(key, 'lock', timeout=timedelta(hours=1).total_seconds() if timeout is not None else timeout)
|
||||
|
||||
|
||||
def un_lock(key: str):
|
||||
"""
|
||||
解锁
|
||||
:param key: 解锁 key
|
||||
:return: 是否解锁成功
|
||||
"""
|
||||
return memory_cache.delete(key)
|
||||
|
||||
|
||||
def lock(lock_key):
|
||||
"""
|
||||
给一个函数上锁
|
||||
:param lock_key: 上锁key 字符串|函数 函数返回值为字符串
|
||||
:return: 装饰器函数 当前装饰器主要限制一个key只能一个线程去调用 相同key只能阻塞等待上一个任务执行完毕 不同key不需要等待
|
||||
"""
|
||||
|
||||
def inner(func):
|
||||
def run(*args, **kwargs):
|
||||
key = lock_key(*args, **kwargs) if callable(lock_key) else lock_key
|
||||
try:
|
||||
if try_lock(key=key):
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
un_lock(key=key)
|
||||
|
||||
return run
|
||||
|
||||
return inner
|
||||
47
apps/common/utils/page_utils.py
Normal file
47
apps/common/utils/page_utils.py
Normal file
@ -0,0 +1,47 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: page_utils.py
|
||||
@date:2024/11/21 10:32
|
||||
@desc:
|
||||
"""
|
||||
from math import ceil
|
||||
|
||||
|
||||
def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
|
||||
"""
|
||||
|
||||
@param query_set: 查询query_set
|
||||
@param page_size: 每次查询大小
|
||||
@param handler: 数据处理器
|
||||
@param is_the_task_interrupted: 任务是否被中断
|
||||
@return:
|
||||
"""
|
||||
query = query_set.order_by("id")
|
||||
count = query_set.count()
|
||||
for i in range(0, ceil(count / page_size)):
|
||||
if is_the_task_interrupted():
|
||||
return
|
||||
offset = i * page_size
|
||||
paragraph_list = query.all()[offset: offset + page_size]
|
||||
handler(paragraph_list)
|
||||
|
||||
|
||||
def page_desc(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
|
||||
"""
|
||||
|
||||
@param query_set: 查询query_set
|
||||
@param page_size: 每次查询大小
|
||||
@param handler: 数据处理器
|
||||
@param is_the_task_interrupted: 任务是否被中断
|
||||
@return:
|
||||
"""
|
||||
query = query_set.order_by("id")
|
||||
count = query_set.count()
|
||||
for i in sorted(range(0, ceil(count / page_size)), reverse=True):
|
||||
if is_the_task_interrupted():
|
||||
return
|
||||
offset = i * page_size
|
||||
paragraph_list = query.all()[offset: offset + page_size]
|
||||
handler(paragraph_list)
|
||||
@ -3,7 +3,7 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid_utils as uuid
|
||||
import uuid_utils.compat as uuid
|
||||
from textwrap import dedent
|
||||
|
||||
from diskcache import Cache
|
||||
|
||||
88
apps/common/utils/ts_vecto_util.py
Normal file
88
apps/common/utils/ts_vecto_util.py
Normal file
@ -0,0 +1,88 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: ts_vecto_util.py
|
||||
@date:2024/4/16 15:26
|
||||
@desc:
|
||||
"""
|
||||
import re
|
||||
import uuid_utils.compat as uuid
|
||||
from typing import List
|
||||
|
||||
import jieba
|
||||
import jieba.posseg
|
||||
|
||||
jieba_word_list_cache = [chr(item) for item in range(38, 84)]
|
||||
|
||||
for jieba_word in jieba_word_list_cache:
|
||||
jieba.add_word('#' + jieba_word + '#')
|
||||
# r"(?i)\b(?:https?|ftp|tcp|file)://[^\s]+\b",
|
||||
# 某些不分词数据
|
||||
# r'"([^"]*)"'
|
||||
word_pattern_list = [r"v\d+.\d+.\d+",
|
||||
r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}"]
|
||||
|
||||
remove_chars = '\n , :\'<>!@#¥%……&*()!@#$%^&*(): ;,/"./'
|
||||
|
||||
jieba_remove_flag_list = ['x', 'w']
|
||||
|
||||
|
||||
def get_word_list(text: str):
|
||||
result = []
|
||||
for pattern in word_pattern_list:
|
||||
word_list = re.findall(pattern, text)
|
||||
for child_list in word_list:
|
||||
for word in child_list if isinstance(child_list, tuple) else [child_list]:
|
||||
# 不能有: 所以再使用: 进行分割
|
||||
if word.__contains__(':'):
|
||||
item_list = word.split(":")
|
||||
for w in item_list:
|
||||
result.append(w)
|
||||
else:
|
||||
result.append(word)
|
||||
return result
|
||||
|
||||
|
||||
def replace_word(word_dict, text: str):
|
||||
for key in word_dict:
|
||||
pattern = '(?<!#)' + re.escape(word_dict[key]) + '(?!#)'
|
||||
text = re.sub(pattern, key, text)
|
||||
return text
|
||||
|
||||
|
||||
def get_word_key(text: str, use_word_list):
|
||||
j_word = next((j for j in jieba_word_list_cache if j not in text and all(j not in used for used in use_word_list)),
|
||||
None)
|
||||
if j_word:
|
||||
return j_word
|
||||
j_word = str(uuid.uuid7())
|
||||
jieba.add_word(j_word)
|
||||
return j_word
|
||||
|
||||
|
||||
def to_word_dict(word_list: List, text: str):
|
||||
word_dict = {}
|
||||
for word in word_list:
|
||||
key = get_word_key(text, set(word_dict))
|
||||
word_dict['#' + key + '#'] = word
|
||||
return word_dict
|
||||
|
||||
|
||||
def get_key_by_word_dict(key, word_dict):
|
||||
v = word_dict.get(key)
|
||||
if v is None:
|
||||
return key
|
||||
return v
|
||||
|
||||
|
||||
def to_ts_vector(text: str):
|
||||
# 分词
|
||||
result = jieba.lcut(text, cut_all=True)
|
||||
return " ".join(result)
|
||||
|
||||
|
||||
def to_query(text: str):
|
||||
extract_tags = jieba.lcut(text, cut_all=True)
|
||||
result = " ".join(extract_tags)
|
||||
return result
|
||||
34
apps/knowledge/api/document.py
Normal file
34
apps/knowledge/api/document.py
Normal file
@ -0,0 +1,34 @@
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter
|
||||
|
||||
from common.mixins.api_mixin import APIMixin
|
||||
from common.result import DefaultResultSerializer, ResultSerializer
|
||||
from knowledge.serializers.document import DocumentCreateRequest
|
||||
|
||||
|
||||
class DocumentCreateResponse(ResultSerializer):
|
||||
@staticmethod
|
||||
def get_data():
|
||||
return DefaultResultSerializer()
|
||||
|
||||
|
||||
class DocumentCreateAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [
|
||||
OpenApiParameter(
|
||||
name="workspace_id",
|
||||
description="工作空间id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return DocumentCreateRequest
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return DocumentCreateResponse
|
||||
0
apps/knowledge/api/problem.py
Normal file
0
apps/knowledge/api/problem.py
Normal file
@ -56,7 +56,7 @@ class Migration(migrations.Migration):
|
||||
('tree_id', models.PositiveIntegerField(db_index=True, editable=False)),
|
||||
('level', models.PositiveIntegerField(editable=False)),
|
||||
('parent',
|
||||
mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE,
|
||||
mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING,
|
||||
related_name='children', to='knowledge.knowledgefolder')),
|
||||
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user',
|
||||
verbose_name='用户id')),
|
||||
@ -85,7 +85,7 @@ class Migration(migrations.Migration):
|
||||
models.CharField(choices=[('SHARED', '共享'), ('WORKSPACE', '工作空间可用')], default='WORKSPACE',
|
||||
max_length=20, verbose_name='可用范围')),
|
||||
('folder',
|
||||
models.ForeignKey(default='root', on_delete=django.db.models.deletion.CASCADE,
|
||||
models.ForeignKey(default='root', on_delete=django.db.models.deletion.DO_NOTHING,
|
||||
to='knowledge.knowledgefolder',
|
||||
verbose_name='文件夹id')),
|
||||
('embedding_model', models.ForeignKey(default=knowledge.models.knowledge.default_model,
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
from enum import Enum
|
||||
|
||||
import uuid_utils.compat as uuid
|
||||
from django.contrib.postgres.search import SearchVectorField
|
||||
from django.db import models
|
||||
from django.db.models.signals import pre_delete
|
||||
from django.dispatch import receiver
|
||||
@ -18,11 +21,78 @@ class KnowledgeType(models.IntegerChoices):
|
||||
YUQUE = 3, '语雀类型'
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
# 向量
|
||||
EMBEDDING = 1
|
||||
# 生成问题
|
||||
GENERATE_PROBLEM = 2
|
||||
# 同步
|
||||
SYNC = 3
|
||||
|
||||
|
||||
class State(Enum):
|
||||
# 等待
|
||||
PENDING = '0'
|
||||
# 执行中
|
||||
STARTED = '1'
|
||||
# 成功
|
||||
SUCCESS = '2'
|
||||
# 失败
|
||||
FAILURE = '3'
|
||||
# 取消任务
|
||||
REVOKE = '4'
|
||||
# 取消成功
|
||||
REVOKED = '5'
|
||||
# 忽略
|
||||
IGNORED = 'n'
|
||||
|
||||
|
||||
class KnowledgeScope(models.TextChoices):
|
||||
SHARED = "SHARED", '共享'
|
||||
WORKSPACE = "WORKSPACE", "工作空间可用"
|
||||
|
||||
|
||||
class HitHandlingMethod(models.TextChoices):
|
||||
optimization = 'optimization', '模型优化'
|
||||
directly_return = 'directly_return', '直接返回'
|
||||
|
||||
|
||||
class Status:
|
||||
type_cls = TaskType
|
||||
state_cls = State
|
||||
|
||||
def __init__(self, status: str = None):
|
||||
self.task_status = {}
|
||||
status_list = list(status[::-1] if status is not None else '')
|
||||
for _type in self.type_cls:
|
||||
index = _type.value - 1
|
||||
_state = self.state_cls(status_list[index] if len(status_list) > index else 'n')
|
||||
self.task_status[_type] = _state
|
||||
|
||||
@staticmethod
|
||||
def of(status: str):
|
||||
return Status(status)
|
||||
|
||||
def __str__(self):
|
||||
result = []
|
||||
for _type in sorted(self.type_cls, key=lambda item: item.value, reverse=True):
|
||||
result.insert(len(self.type_cls) - _type.value, self.task_status[_type].value)
|
||||
return ''.join(result)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.task_status[key] = value
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.task_status[item]
|
||||
|
||||
def update_status(self, task_type: TaskType, state: State):
|
||||
self.task_status[task_type] = state
|
||||
|
||||
|
||||
def default_status_meta():
|
||||
return {"state_time": {}}
|
||||
|
||||
|
||||
def default_model():
|
||||
# todo : 这里需要从数据库中获取默认的模型
|
||||
return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab')
|
||||
@ -33,7 +103,7 @@ class KnowledgeFolder(MPTTModel, AppModelMixin):
|
||||
name = models.CharField(max_length=64, verbose_name="文件夹名称")
|
||||
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="用户id")
|
||||
workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True)
|
||||
parent = TreeForeignKey('self', on_delete=models.CASCADE, null=True, blank=True, related_name='children')
|
||||
parent = TreeForeignKey('self', on_delete=models.DO_NOTHING, null=True, blank=True, related_name='children')
|
||||
|
||||
class Meta:
|
||||
db_table = "knowledge_folder"
|
||||
@ -42,24 +112,127 @@ class KnowledgeFolder(MPTTModel, AppModelMixin):
|
||||
order_insertion_by = ['name']
|
||||
|
||||
|
||||
|
||||
class Knowledge(AppModelMixin):
|
||||
"""
|
||||
知识库表
|
||||
"""
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
|
||||
name = models.CharField(max_length=150, verbose_name="知识库名称")
|
||||
workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True)
|
||||
desc = models.CharField(max_length=256, verbose_name="描述")
|
||||
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
|
||||
type = models.IntegerField(verbose_name='类型', choices=KnowledgeType.choices, default=KnowledgeType.BASE)
|
||||
scope = models.CharField(max_length=20, verbose_name='可用范围', choices=KnowledgeScope.choices, default=KnowledgeScope.WORKSPACE)
|
||||
folder = models.ForeignKey(KnowledgeFolder, on_delete=models.CASCADE, verbose_name="文件夹id", default='root')
|
||||
scope = models.CharField(max_length=20, verbose_name='可用范围', choices=KnowledgeScope.choices,
|
||||
default=KnowledgeScope.WORKSPACE)
|
||||
folder = models.ForeignKey(KnowledgeFolder, on_delete=models.DO_NOTHING, verbose_name="文件夹id", default='root')
|
||||
embedding_model = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
|
||||
default=default_model)
|
||||
default=default_model)
|
||||
meta = models.JSONField(verbose_name="元数据", default=dict)
|
||||
|
||||
class Meta:
|
||||
db_table = "knowledge"
|
||||
|
||||
|
||||
class Document(AppModelMixin):
|
||||
"""
|
||||
文档表
|
||||
"""
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
|
||||
knowledge = models.ForeignKey(Knowledge, on_delete=models.DO_NOTHING, verbose_name="知识库id")
|
||||
name = models.CharField(max_length=150, verbose_name="文档名称")
|
||||
char_length = models.IntegerField(verbose_name="文档字符数 冗余字段")
|
||||
status = models.CharField(verbose_name='状态', max_length=20, default=Status('').__str__)
|
||||
status_meta = models.JSONField(verbose_name="状态统计数据", default=default_status_meta)
|
||||
is_active = models.BooleanField(default=True)
|
||||
type = models.IntegerField(verbose_name='类型', choices=KnowledgeType.choices, default=KnowledgeType.BASE)
|
||||
hit_handling_method = models.CharField(verbose_name='命中处理方式', max_length=20,
|
||||
choices=HitHandlingMethod.choices,
|
||||
default=HitHandlingMethod.optimization)
|
||||
directly_return_similarity = models.FloatField(verbose_name='直接回答相似度', default=0.9)
|
||||
|
||||
meta = models.JSONField(verbose_name="元数据", default=dict)
|
||||
|
||||
class Meta:
|
||||
db_table = "document"
|
||||
|
||||
|
||||
class Paragraph(AppModelMixin):
|
||||
"""
|
||||
段落表
|
||||
"""
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
|
||||
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
knowledge = models.ForeignKey(Knowledge, on_delete=models.DO_NOTHING)
|
||||
content = models.CharField(max_length=102400, verbose_name="段落内容")
|
||||
title = models.CharField(max_length=256, verbose_name="标题", default="")
|
||||
status = models.CharField(verbose_name='状态', max_length=20, default=Status('').__str__)
|
||||
status_meta = models.JSONField(verbose_name="状态数据", default=default_status_meta)
|
||||
hit_num = models.IntegerField(verbose_name="命中次数", default=0)
|
||||
is_active = models.BooleanField(default=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "paragraph"
|
||||
|
||||
|
||||
class Problem(AppModelMixin):
|
||||
"""
|
||||
问题表
|
||||
"""
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
|
||||
knowledge = models.ForeignKey(Knowledge, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
content = models.CharField(max_length=256, verbose_name="问题内容")
|
||||
hit_num = models.IntegerField(verbose_name="命中次数", default=0)
|
||||
|
||||
class Meta:
|
||||
db_table = "problem"
|
||||
|
||||
|
||||
class ProblemParagraphMapping(AppModelMixin):
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
|
||||
knowledge = models.ForeignKey(Knowledge, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING)
|
||||
problem = models.ForeignKey(Problem, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
|
||||
class Meta:
|
||||
db_table = "problem_paragraph_mapping"
|
||||
|
||||
|
||||
class SourceType(models.IntegerChoices):
|
||||
"""订单类型"""
|
||||
PROBLEM = 0, '问题'
|
||||
PARAGRAPH = 1, '段落'
|
||||
TITLE = 2, '标题'
|
||||
|
||||
|
||||
class SearchMode(models.TextChoices):
|
||||
embedding = 'embedding'
|
||||
keywords = 'keywords'
|
||||
blend = 'blend'
|
||||
|
||||
|
||||
class VectorField(models.Field):
|
||||
def db_type(self, connection):
|
||||
return 'vector'
|
||||
|
||||
|
||||
class Embedding(models.Model):
|
||||
id = models.CharField(max_length=128, primary_key=True, verbose_name="主键id")
|
||||
source_id = models.CharField(max_length=128, verbose_name="资源id")
|
||||
source_type = models.CharField(verbose_name='资源类型', max_length=5, choices=SourceType.choices,
|
||||
default=SourceType.PROBLEM)
|
||||
is_active = models.BooleanField(verbose_name="是否可用", max_length=1, default=True)
|
||||
knowledge = models.ForeignKey(Knowledge, on_delete=models.DO_NOTHING, verbose_name="文档关联", db_constraint=False)
|
||||
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, verbose_name="文档关联", db_constraint=False)
|
||||
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, verbose_name="段落关联", db_constraint=False)
|
||||
embedding = VectorField(verbose_name="向量")
|
||||
search_vector = SearchVectorField(verbose_name="分词", default="")
|
||||
meta = models.JSONField(verbose_name="元数据", default=dict)
|
||||
|
||||
class Meta:
|
||||
db_table = "embedding"
|
||||
|
||||
|
||||
class File(AppModelMixin):
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
|
||||
file_name = models.CharField(max_length=256, verbose_name="文件名称", default="")
|
||||
|
||||
215
apps/knowledge/serializers/common.py
Normal file
215
apps/knowledge/serializers/common.py
Normal file
@ -0,0 +1,215 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: common_serializers.py
|
||||
@date:2023/11/17 11:00
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import uuid_utils.compat as uuid
|
||||
import zipfile
|
||||
from typing import List
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.config.embedding_config import ModelManage
|
||||
from common.db.search import native_search
|
||||
from common.db.sql_execute import update_execute
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.utils.common import get_file_content
|
||||
from common.utils.fork import Fork
|
||||
from knowledge.models import Paragraph, Problem, ProblemParagraphMapping, Knowledge, File
|
||||
from maxkb.conf import PROJECT_DIR
|
||||
from models_provider.tools import get_model
|
||||
|
||||
|
||||
def zip_dir(zip_path, output=None):
|
||||
output = output or os.path.basename(zip_path) + '.zip'
|
||||
zip = zipfile.ZipFile(output, 'w', zipfile.ZIP_DEFLATED)
|
||||
for root, dirs, files in os.walk(zip_path):
|
||||
relative_root = '' if root == zip_path else root.replace(zip_path, '') + os.sep
|
||||
for filename in files:
|
||||
zip.write(os.path.join(root, filename), relative_root + filename)
|
||||
zip.close()
|
||||
|
||||
|
||||
def is_valid_uuid(s):
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def write_image(zip_path: str, image_list: List[str]):
|
||||
for image in image_list:
|
||||
search = re.search("\(.*\)", image)
|
||||
if search:
|
||||
text = search.group()
|
||||
if text.startswith('(/api/file/'):
|
||||
r = text.replace('(/api/file/', '').replace(')', '')
|
||||
r = r.strip().split(" ")[0]
|
||||
if not is_valid_uuid(r):
|
||||
break
|
||||
file = QuerySet(File).filter(id=r).first()
|
||||
if file is None:
|
||||
break
|
||||
zip_inner_path = os.path.join('api', 'file', r)
|
||||
file_path = os.path.join(zip_path, zip_inner_path)
|
||||
if not os.path.exists(os.path.dirname(file_path)):
|
||||
os.makedirs(os.path.dirname(file_path))
|
||||
with open(os.path.join(zip_path, file_path), 'wb') as f:
|
||||
f.write(file.get_bytes())
|
||||
# else:
|
||||
# r = text.replace('(/api/image/', '').replace(')', '')
|
||||
# r = r.strip().split(" ")[0]
|
||||
# if not is_valid_uuid(r):
|
||||
# break
|
||||
# image_model = QuerySet(Image).filter(id=r).first()
|
||||
# if image_model is None:
|
||||
# break
|
||||
# zip_inner_path = os.path.join('api', 'image', r)
|
||||
# file_path = os.path.join(zip_path, zip_inner_path)
|
||||
# if not os.path.exists(os.path.dirname(file_path)):
|
||||
# os.makedirs(os.path.dirname(file_path))
|
||||
# with open(file_path, 'wb') as f:
|
||||
# f.write(image_model.image)
|
||||
|
||||
|
||||
def update_document_char_length(document_id: str):
|
||||
update_execute(get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_document_char_length.sql')),
|
||||
(document_id, document_id))
|
||||
|
||||
|
||||
def list_paragraph(paragraph_list: List[str]):
|
||||
if paragraph_list is None or len(paragraph_list) == 0:
|
||||
return []
|
||||
return native_search(QuerySet(Paragraph).filter(id__in=paragraph_list), get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_paragraph.sql')))
|
||||
|
||||
|
||||
class MetaSerializer(serializers.Serializer):
|
||||
class WebMeta(serializers.Serializer):
|
||||
source_url = serializers.CharField(required=True, label=_('source url'))
|
||||
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('selector'))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
source_url = self.data.get('source_url')
|
||||
response = Fork(source_url, []).fork()
|
||||
if response.status == 500:
|
||||
raise AppApiException(500, _('URL error, cannot parse [{source_url}]').format(source_url=source_url))
|
||||
|
||||
class BaseMeta(serializers.Serializer):
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
||||
|
||||
class BatchSerializer(serializers.Serializer):
|
||||
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=_('id list'))
|
||||
|
||||
def is_valid(self, *, model=None, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if model is not None:
|
||||
id_list = self.data.get('id_list')
|
||||
model_list = QuerySet(model).filter(id__in=id_list)
|
||||
if len(model_list) != len(id_list):
|
||||
model_id_list = [str(m.id) for m in model_list]
|
||||
error_id_list = list(filter(lambda row_id: not model_id_list.__contains__(row_id), id_list))
|
||||
raise AppApiException(500, _('The following id does not exist: {error_id_list}').format(
|
||||
error_id_list=error_id_list))
|
||||
|
||||
|
||||
class ProblemParagraphObject:
|
||||
def __init__(self, knowledge_id: str, document_id: str, paragraph_id: str, problem_content: str):
|
||||
self.knowledge_id = knowledge_id
|
||||
self.document_id = document_id
|
||||
self.paragraph_id = paragraph_id
|
||||
self.problem_content = problem_content
|
||||
|
||||
|
||||
def or_get(exists_problem_list, content, knowledge_id, document_id, paragraph_id, problem_content_dict):
|
||||
if content in problem_content_dict:
|
||||
return problem_content_dict.get(content)[0], document_id, paragraph_id
|
||||
exists = [row for row in exists_problem_list if row.content == content]
|
||||
if len(exists) > 0:
|
||||
problem_content_dict[content] = exists[0], False
|
||||
return exists[0], document_id, paragraph_id
|
||||
else:
|
||||
problem = Problem(id=uuid.uuid7(), content=content, knowledge_id=knowledge_id)
|
||||
problem_content_dict[content] = problem, True
|
||||
return problem, document_id, paragraph_id
|
||||
|
||||
|
||||
class ProblemParagraphManage:
|
||||
def __init__(self, problem_paragraph_object_list: List[ProblemParagraphObject], knowledge_id):
|
||||
self.knowledge_id = knowledge_id
|
||||
self.problem_paragraph_object_list = problem_paragraph_object_list
|
||||
|
||||
def to_problem_model_list(self):
|
||||
problem_list = [item.problem_content for item in self.problem_paragraph_object_list]
|
||||
exists_problem_list = []
|
||||
if len(self.problem_paragraph_object_list) > 0:
|
||||
# 查询到已存在的问题列表
|
||||
exists_problem_list = QuerySet(Problem).filter(knowledge_id=self.knowledge_id,
|
||||
content__in=problem_list).all()
|
||||
problem_content_dict = {}
|
||||
problem_model_list = [
|
||||
or_get(exists_problem_list, problemParagraphObject.problem_content, problemParagraphObject.knowledge_id,
|
||||
problemParagraphObject.document_id, problemParagraphObject.paragraph_id, problem_content_dict) for
|
||||
problemParagraphObject in self.problem_paragraph_object_list]
|
||||
|
||||
problem_paragraph_mapping_list = [
|
||||
ProblemParagraphMapping(id=uuid.uuid7(), document_id=document_id, problem_id=problem_model.id,
|
||||
paragraph_id=paragraph_id,
|
||||
knowledge_id=self.knowledge_id) for
|
||||
problem_model, document_id, paragraph_id in problem_model_list]
|
||||
|
||||
result = [problem_model for problem_model, is_create in problem_content_dict.values() if
|
||||
is_create], problem_paragraph_mapping_list
|
||||
return result
|
||||
|
||||
|
||||
def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
|
||||
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
|
||||
if len(set([knowledge.embedding_model_id for knowledge in knowledge_list])) > 1:
|
||||
raise Exception(_('The knowledge base is inconsistent with the vector model'))
|
||||
if len(knowledge_list) == 0:
|
||||
raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
|
||||
return ModelManage.get_model(str(knowledge_list[0].embedding_model_id),
|
||||
lambda _id: get_model(knowledge_list[0].embedding_model))
|
||||
|
||||
|
||||
def get_embedding_model_by_knowledge_id(knowledge_id: str):
|
||||
knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first()
|
||||
return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model))
|
||||
|
||||
|
||||
def get_embedding_model_by_knowledge(knowledge):
|
||||
return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model))
|
||||
|
||||
|
||||
def get_embedding_model_id_by_knowledge_id(knowledge_id):
|
||||
knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first()
|
||||
return str(knowledge.embedding_model_id)
|
||||
|
||||
|
||||
def get_embedding_model_id_by_knowledge_id_list(knowledge_id_list: List):
|
||||
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
|
||||
if len(set([knowledge.embedding_model_id for knowledge in knowledge_list])) > 1:
|
||||
raise Exception(_('The knowledge base is inconsistent with the vector model'))
|
||||
if len(knowledge_list) == 0:
|
||||
raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
|
||||
return str(knowledge_list[0].embedding_model_id)
|
||||
|
||||
|
||||
class GenerateRelatedSerializer(serializers.Serializer):
|
||||
model_id = serializers.UUIDField(required=True, label=_('Model id'))
|
||||
prompt = serializers.CharField(required=True, label=_('Prompt word'))
|
||||
state_list = serializers.ListField(required=False, child=serializers.CharField(required=True),
|
||||
label=_("state list"))
|
||||
172
apps/knowledge/serializers/document.py
Normal file
172
apps/knowledge/serializers/document.py
Normal file
@ -0,0 +1,172 @@
|
||||
import os
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
import uuid_utils.compat as uuid
|
||||
from celery_once import AlreadyQueued
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet, Model
|
||||
from django.db.models.functions import Substr, Reverse
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.db.search import native_search
|
||||
from common.event import ListenerManagement
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.utils.common import post, get_file_content
|
||||
from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \
|
||||
TaskType
|
||||
from knowledge.serializers.common import ProblemParagraphManage
|
||||
from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer
|
||||
from knowledge.task import embedding_by_document
|
||||
from maxkb.const import PROJECT_DIR
|
||||
|
||||
|
||||
class DocumentInstanceSerializer(serializers.Serializer):
|
||||
name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1)
|
||||
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
|
||||
|
||||
|
||||
class DocumentCreateRequest(serializers.Serializer):
|
||||
name = serializers.CharField(required=True, label=_('knowledge name'), max_length=64, min_length=1)
|
||||
desc = serializers.CharField(required=True, label=_('knowledge description'), max_length=256, min_length=1)
|
||||
embedding_model_id = serializers.UUIDField(required=True, label=_('embedding model'))
|
||||
documents = DocumentInstanceSerializer(required=False, many=True)
|
||||
|
||||
|
||||
class DocumentSerializers(serializers.Serializer):
|
||||
class Operate(serializers.Serializer):
|
||||
document_id = serializers.UUIDField(required=True, label=_('document id'))
|
||||
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
document_id = self.data.get('document_id')
|
||||
if not QuerySet(Document).filter(id=document_id).exists():
|
||||
raise AppApiException(500, _('document id not exist'))
|
||||
|
||||
def one(self, with_valid=False):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
query_set = QuerySet(model=Document)
|
||||
query_set = query_set.filter(**{'id': self.data.get("document_id")})
|
||||
return native_search({
|
||||
'document_custom_sql': query_set,
|
||||
'order_by_query': QuerySet(Document).order_by('-create_time', 'id')
|
||||
}, select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')), with_search_one=True)
|
||||
|
||||
def refresh(self, state_list=None, with_valid=True):
|
||||
if state_list is None:
|
||||
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
|
||||
State.REVOKE.value,
|
||||
State.REVOKED.value, State.IGNORED.value]
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
knowledge = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id')).first()
|
||||
embedding_model_id = knowledge.embedding_model_id
|
||||
knowledge_user_id = knowledge.user_id
|
||||
embedding_model = QuerySet(Model).filter(id=embedding_model_id).first()
|
||||
if embedding_model is None:
|
||||
raise AppApiException(500, _('Model does not exist'))
|
||||
if embedding_model.permission_type == 'PRIVATE' and knowledge_user_id != embedding_model.user_id:
|
||||
raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}")
|
||||
document_id = self.data.get("document_id")
|
||||
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
|
||||
State.PENDING)
|
||||
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
|
||||
reversed_status=Reverse('status'),
|
||||
task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value, 1),
|
||||
).filter(task_type_status__in=state_list, document_id=document_id).values('id'),
|
||||
TaskType.EMBEDDING, State.PENDING)
|
||||
ListenerManagement.get_aggregation_document_status(document_id)()
|
||||
|
||||
try:
|
||||
embedding_by_document.delay(document_id, embedding_model_id, state_list)
|
||||
except AlreadyQueued as e:
|
||||
raise AppApiException(500, _('The task is being executed, please do not send it repeatedly.'))
|
||||
|
||||
class Create(serializers.Serializer):
|
||||
knowledge_id = serializers.UUIDField(required=True, label=_('document id'))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if not QuerySet(Knowledge).filter(id=self.data.get('knowledge_id')).exists():
|
||||
raise AppApiException(10000, _('knowledge id not exist'))
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def post_embedding(result, document_id, knowledge_id):
|
||||
DocumentSerializers.Operate(
|
||||
data={'knowledge_id': knowledge_id, 'document_id': document_id}).refresh()
|
||||
return result
|
||||
|
||||
@post(post_function=post_embedding)
|
||||
@transaction.atomic
|
||||
def save(self, instance: Dict, with_valid=False, **kwargs):
|
||||
if with_valid:
|
||||
DocumentCreateRequest(data=instance).is_valid(raise_exception=True)
|
||||
self.is_valid(raise_exception=True)
|
||||
knowledge_id = self.data.get('knowledge_id')
|
||||
document_paragraph_model = self.get_document_paragraph_model(knowledge_id, instance)
|
||||
|
||||
document_model = document_paragraph_model.get('document')
|
||||
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
||||
problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list')
|
||||
problem_model_list, problem_paragraph_mapping_list = (
|
||||
ProblemParagraphManage(problem_paragraph_object_list, knowledge_id).to_problem_model_list())
|
||||
# 插入文档
|
||||
document_model.save()
|
||||
# 批量插入段落
|
||||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||
# 批量插入问题
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
# 批量插入关联问题
|
||||
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
||||
problem_paragraph_mapping_list) > 0 else None
|
||||
document_id = str(document_model.id)
|
||||
return (DocumentSerializers.Operate(
|
||||
data={'knowledge_id': knowledge_id, 'document_id': document_id}
|
||||
).one(with_valid=True), document_id, knowledge_id)
|
||||
|
||||
@staticmethod
|
||||
def get_paragraph_model(document_model, paragraph_list: List):
|
||||
knowledge_id = document_model.knowledge_id
|
||||
paragraph_model_dict_list = [
|
||||
ParagraphSerializers.Create(
|
||||
data={
|
||||
'knowledge_id': knowledge_id, 'document_id': str(document_model.id)
|
||||
}).get_paragraph_problem_model(knowledge_id, document_model.id, paragraph)
|
||||
for paragraph in paragraph_list]
|
||||
|
||||
paragraph_model_list = []
|
||||
problem_paragraph_object_list = []
|
||||
for paragraphs in paragraph_model_dict_list:
|
||||
paragraph = paragraphs.get('paragraph')
|
||||
for problem_model in paragraphs.get('problem_paragraph_object_list'):
|
||||
problem_paragraph_object_list.append(problem_model)
|
||||
paragraph_model_list.append(paragraph)
|
||||
|
||||
return {
|
||||
'document': document_model,
|
||||
'paragraph_model_list': paragraph_model_list,
|
||||
'problem_paragraph_object_list': problem_paragraph_object_list
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_document_paragraph_model(knowledge_id, instance: Dict):
|
||||
document_model = Document(
|
||||
**{
|
||||
'knowledge_id': knowledge_id,
|
||||
'id': uuid.uuid7(),
|
||||
'name': instance.get('name'),
|
||||
'char_length': reduce(lambda x, y: x + y,
|
||||
[len(p.get('content')) for p in instance.get('paragraphs', [])],
|
||||
0),
|
||||
'meta': instance.get('meta') if instance.get('meta') is not None else {},
|
||||
'type': instance.get('type') if instance.get('type') is not None else KnowledgeType.BASE
|
||||
})
|
||||
|
||||
return DocumentSerializers.Create.get_paragraph_model(document_model,
|
||||
instance.get('paragraphs') if
|
||||
'paragraphs' in instance else [])
|
||||
@ -1,14 +1,19 @@
|
||||
from functools import reduce
|
||||
from typing import Dict
|
||||
|
||||
import uuid_utils as uuid
|
||||
import uuid_utils.compat as uuid
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.utils.common import valid_license
|
||||
from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType
|
||||
from common.utils.common import valid_license, post
|
||||
from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType, Document, Paragraph, Problem, \
|
||||
ProblemParagraphMapping
|
||||
from knowledge.serializers.common import ProblemParagraphManage, get_embedding_model_id_by_knowledge_id
|
||||
from knowledge.serializers.document import DocumentSerializers
|
||||
from knowledge.task import sync_web_knowledge, embedding_by_knowledge
|
||||
|
||||
|
||||
class KnowledgeModelSerializer(serializers.ModelSerializer):
|
||||
@ -38,10 +43,17 @@ class KnowledgeSerializer(serializers.Serializer):
|
||||
user_id = serializers.UUIDField(required=True, label=_('user id'))
|
||||
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
|
||||
|
||||
@staticmethod
|
||||
def post_embedding_knowledge(document_list, knowledge_id):
|
||||
# todo 发送向量化事件
|
||||
model_id = get_embedding_model_id_by_knowledge_id(knowledge_id)
|
||||
embedding_by_knowledge.delay(knowledge_id, model_id)
|
||||
return document_list
|
||||
|
||||
@valid_license(model=Knowledge, count=50,
|
||||
message=_(
|
||||
'The community version supports up to 50 knowledge bases. If you need more knowledge bases, please contact us (https://fit2cloud.com/).'))
|
||||
# @post(post_function=post_embedding_dataset)
|
||||
@post(post_function=post_embedding_knowledge)
|
||||
@transaction.atomic
|
||||
def save_base(self, instance, with_valid=True):
|
||||
if with_valid:
|
||||
@ -51,8 +63,9 @@ class KnowledgeSerializer(serializers.Serializer):
|
||||
name=instance.get('name')).exists():
|
||||
raise AppApiException(500, _('Knowledge base name duplicate!'))
|
||||
|
||||
knowledge_id = uuid.uuid7()
|
||||
knowledge = Knowledge(
|
||||
id=uuid.uuid7(),
|
||||
id=knowledge_id,
|
||||
name=instance.get('name'),
|
||||
workspace_id=self.data.get('workspace_id'),
|
||||
desc=instance.get('desc'),
|
||||
@ -63,8 +76,42 @@ class KnowledgeSerializer(serializers.Serializer):
|
||||
embedding_model_id=instance.get('embedding'),
|
||||
meta=instance.get('meta', {}),
|
||||
)
|
||||
|
||||
document_model_list = []
|
||||
paragraph_model_list = []
|
||||
problem_paragraph_object_list = []
|
||||
# 插入文档
|
||||
for document in instance.get('documents') if 'documents' in instance else []:
|
||||
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(knowledge_id,
|
||||
document)
|
||||
document_model_list.append(document_paragraph_dict_model.get('document'))
|
||||
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
|
||||
paragraph_model_list.append(paragraph)
|
||||
for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
|
||||
problem_paragraph_object_list.append(problem_paragraph_object)
|
||||
|
||||
problem_model_list, problem_paragraph_mapping_list = (
|
||||
ProblemParagraphManage(problem_paragraph_object_list, knowledge_id)
|
||||
.to_problem_model_list())
|
||||
# 插入知识库
|
||||
knowledge.save()
|
||||
return KnowledgeModelSerializer(knowledge).data
|
||||
# 插入文档
|
||||
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
|
||||
# 批量插入段落
|
||||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||
# 批量插入问题
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
# 批量插入关联问题
|
||||
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
||||
problem_paragraph_mapping_list) > 0 else None
|
||||
|
||||
return {
|
||||
**KnowledgeModelSerializer(knowledge).data,
|
||||
'user_id': self.data.get('user_id'),
|
||||
'document_list': document_model_list,
|
||||
"document_count": len(document_model_list),
|
||||
"char_length": reduce(lambda x, y: x + y, [d.char_length for d in document_model_list], 0)
|
||||
}, knowledge_id
|
||||
|
||||
def save_web(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
@ -92,9 +139,8 @@ class KnowledgeSerializer(serializers.Serializer):
|
||||
},
|
||||
)
|
||||
knowledge.save()
|
||||
# sync_web_knowledge.delay(str(knowledge_id), instance.get('source_url'), instance.get('selector'))
|
||||
return {**KnowledgeModelSerializer(knowledge).data,
|
||||
'document_list': []}
|
||||
sync_web_knowledge.delay(str(knowledge_id), instance.get('source_url'), instance.get('selector'))
|
||||
return {**KnowledgeModelSerializer(knowledge).data, 'document_list': []}
|
||||
|
||||
|
||||
class KnowledgeTreeSerializer(serializers.Serializer):
|
||||
|
||||
221
apps/knowledge/serializers/paragraph.py
Normal file
221
apps/knowledge/serializers/paragraph.py
Normal file
@ -0,0 +1,221 @@
|
||||
# coding=utf-8
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import uuid_utils.compat as uuid
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet, Count
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.utils.common import post
|
||||
from knowledge.models import Paragraph, Problem, Document, ProblemParagraphMapping
|
||||
from knowledge.serializers.common import ProblemParagraphObject, ProblemParagraphManage, \
|
||||
get_embedding_model_id_by_knowledge_id, update_document_char_length
|
||||
from knowledge.serializers.problem import ProblemInstanceSerializer
|
||||
from knowledge.task import embedding_by_paragraph, enable_embedding_by_paragraph, disable_embedding_by_paragraph, \
|
||||
delete_embedding_by_paragraph
|
||||
|
||||
|
||||
class ParagraphSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Paragraph
|
||||
fields = ['id', 'content', 'is_active', 'document_id', 'title', 'create_time', 'update_time']
|
||||
|
||||
|
||||
class ParagraphInstanceSerializer(serializers.Serializer):
|
||||
"""
|
||||
段落实例对象
|
||||
"""
|
||||
content = serializers.CharField(required=True, label=_('content'), max_length=102400, min_length=1, allow_null=True,
|
||||
allow_blank=True)
|
||||
title = serializers.CharField(required=False, max_length=256, label=_('section title'), allow_null=True,
|
||||
allow_blank=True)
|
||||
problem_list = ProblemInstanceSerializer(required=False, many=True)
|
||||
is_active = serializers.BooleanField(required=False, label=_('Is active'))
|
||||
|
||||
|
||||
class EditParagraphSerializers(serializers.Serializer):
|
||||
title = serializers.CharField(required=False, max_length=256, label=_('section title'), allow_null=True,
|
||||
allow_blank=True)
|
||||
content = serializers.CharField(required=False, max_length=102400, allow_null=True, allow_blank=True,
|
||||
label=_('section title'))
|
||||
problem_list = ProblemInstanceSerializer(required=False, many=True)
|
||||
|
||||
|
||||
class ParagraphSerializers(serializers.Serializer):
|
||||
title = serializers.CharField(required=False, max_length=256, label=_('section title'), allow_null=True,
|
||||
allow_blank=True)
|
||||
content = serializers.CharField(required=True, max_length=102400, label=_('section title'))
|
||||
|
||||
class Operate(serializers.Serializer):
|
||||
# 段落id
|
||||
paragraph_id = serializers.UUIDField(required=True, label=_('paragraph id'))
|
||||
# 知识库id
|
||||
dataset_id = serializers.UUIDField(required=True, label=_('dataset id'))
|
||||
# 文档id
|
||||
document_id = serializers.UUIDField(required=True, label=_('document id'))
|
||||
|
||||
def is_valid(self, *, raise_exception=True):
|
||||
super().is_valid(raise_exception=True)
|
||||
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
|
||||
raise AppApiException(500, _('Paragraph id does not exist'))
|
||||
|
||||
@staticmethod
|
||||
def post_embedding(paragraph, instance, knowledge_id):
|
||||
if 'is_active' in instance and instance.get('is_active') is not None:
|
||||
(enable_embedding_by_paragraph if instance.get(
|
||||
'is_active') else disable_embedding_by_paragraph)(paragraph.get('id'))
|
||||
|
||||
else:
|
||||
model_id = get_embedding_model_id_by_knowledge_id(knowledge_id)
|
||||
embedding_by_paragraph(paragraph.get('id'), model_id)
|
||||
return paragraph
|
||||
|
||||
@post(post_embedding)
|
||||
@transaction.atomic
|
||||
def edit(self, instance: Dict):
|
||||
self.is_valid()
|
||||
EditParagraphSerializers(data=instance).is_valid(raise_exception=True)
|
||||
_paragraph = QuerySet(Paragraph).get(id=self.data.get("paragraph_id"))
|
||||
update_keys = ['title', 'content', 'is_active']
|
||||
for update_key in update_keys:
|
||||
if update_key in instance and instance.get(update_key) is not None:
|
||||
_paragraph.__setattr__(update_key, instance.get(update_key))
|
||||
|
||||
if 'problem_list' in instance:
|
||||
update_problem_list = list(
|
||||
filter(lambda row: 'id' in row and row.get('id') is not None, instance.get('problem_list')))
|
||||
|
||||
create_problem_list = list(filter(lambda row: row.get('id') is None, instance.get('problem_list')))
|
||||
|
||||
# 问题集合
|
||||
problem_list = QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))
|
||||
|
||||
# 校验前端 携带过来的id
|
||||
for update_problem in update_problem_list:
|
||||
if not set([str(row.id) for row in problem_list]).__contains__(update_problem.get('id')):
|
||||
raise AppApiException(500, _('Problem id does not exist'))
|
||||
# 对比需要删除的问题
|
||||
delete_problem_list = list(filter(
|
||||
lambda row: not [str(update_row.get('id')) for update_row in update_problem_list].__contains__(
|
||||
str(row.id)), problem_list)) if len(update_problem_list) > 0 else []
|
||||
# 删除问题
|
||||
QuerySet(Problem).filter(id__in=[row.id for row in delete_problem_list]).delete() if len(
|
||||
delete_problem_list) > 0 else None
|
||||
# 插入新的问题
|
||||
QuerySet(Problem).bulk_create(
|
||||
[Problem(id=uuid.uuid1(), content=p.get('content'), paragraph_id=self.data.get('paragraph_id'),
|
||||
dataset_id=self.data.get('dataset_id'), document_id=self.data.get('document_id')) for
|
||||
p in create_problem_list]) if len(create_problem_list) else None
|
||||
|
||||
# 修改问题集合
|
||||
QuerySet(Problem).bulk_update(
|
||||
[Problem(id=row.get('id'), content=row.get('content')) for row in update_problem_list],
|
||||
['content']) if len(
|
||||
update_problem_list) > 0 else None
|
||||
|
||||
_paragraph.save()
|
||||
update_document_char_length(self.data.get('document_id'))
|
||||
return self.one(), instance, self.data.get('dataset_id')
|
||||
|
||||
def get_problem_list(self):
|
||||
ProblemParagraphMapping(ProblemParagraphMapping)
|
||||
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
|
||||
paragraph_id=self.data.get("paragraph_id"))
|
||||
if len(problem_paragraph_mapping) > 0:
|
||||
return [ProblemSerializer(problem).data for problem in
|
||||
QuerySet(Problem).filter(id__in=[ppm.problem_id for ppm in problem_paragraph_mapping])]
|
||||
return []
|
||||
|
||||
def one(self, with_valid=False):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
return {**ParagraphSerializer(QuerySet(model=Paragraph).get(id=self.data.get('paragraph_id'))).data,
|
||||
'problem_list': self.get_problem_list()}
|
||||
|
||||
def delete(self, with_valid=False):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
paragraph_id = self.data.get('paragraph_id')
|
||||
Paragraph.objects.filter(id=paragraph_id).delete()
|
||||
delete_problems_and_mappings([paragraph_id])
|
||||
|
||||
update_document_char_length(self.data.get('document_id'))
|
||||
delete_embedding_by_paragraph(paragraph_id)
|
||||
|
||||
class Create(serializers.Serializer):
|
||||
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
|
||||
document_id = serializers.UUIDField(required=True, label=_('document id'))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if not QuerySet(Document).filter(id=self.data.get('document_id'),
|
||||
knowledge_id=self.data.get('knowledge_id')).exists():
|
||||
raise AppApiException(500, _('The document id is incorrect'))
|
||||
|
||||
def save(self, instance: Dict, with_valid=True, with_embedding=True):
|
||||
if with_valid:
|
||||
ParagraphSerializers(data=instance).is_valid(raise_exception=True)
|
||||
self.is_valid()
|
||||
knowledge_id = self.data.get("knowledge_id")
|
||||
document_id = self.data.get('document_id')
|
||||
paragraph_problem_model = self.get_paragraph_problem_model(knowledge_id, document_id, instance)
|
||||
paragraph = paragraph_problem_model.get('paragraph')
|
||||
problem_paragraph_object_list = paragraph_problem_model.get('problem_paragraph_object_list')
|
||||
problem_model_list, problem_paragraph_mapping_list = (
|
||||
ProblemParagraphManage(problem_paragraph_object_list, knowledge_id)
|
||||
.to_problem_model_list())
|
||||
# 插入段落
|
||||
paragraph_problem_model.get('paragraph').save()
|
||||
# 插入問題
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
# 插入问题关联关系
|
||||
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
||||
problem_paragraph_mapping_list) > 0 else None
|
||||
# 修改长度
|
||||
update_document_char_length(document_id)
|
||||
if with_embedding:
|
||||
model_id = get_embedding_model_id_by_knowledge_id(knowledge_id)
|
||||
embedding_by_paragraph(str(paragraph.id), model_id)
|
||||
return ParagraphSerializers.Operate(
|
||||
data={'paragraph_id': str(paragraph.id), 'knowledge_id': knowledge_id, 'document_id': document_id}
|
||||
).one(with_valid=True)
|
||||
|
||||
@staticmethod
|
||||
def get_paragraph_problem_model(knowledge_id: str, document_id: str, instance: Dict):
|
||||
paragraph = Paragraph(id=uuid.uuid7(),
|
||||
document_id=document_id,
|
||||
content=instance.get("content"),
|
||||
knowledge_id=knowledge_id,
|
||||
title=instance.get("title") if 'title' in instance else '')
|
||||
problem_paragraph_object_list = [
|
||||
ProblemParagraphObject(knowledge_id, document_id, paragraph.id, problem.get('content')) for problem in
|
||||
(instance.get('problem_list') if 'problem_list' in instance else [])]
|
||||
|
||||
return {'paragraph': paragraph,
|
||||
'problem_paragraph_object_list': problem_paragraph_object_list}
|
||||
|
||||
@staticmethod
|
||||
def or_get(exists_problem_list, content, knowledge_id):
|
||||
exists = [row for row in exists_problem_list if row.content == content]
|
||||
if len(exists) > 0:
|
||||
return exists[0]
|
||||
else:
|
||||
return Problem(id=uuid.uuid7(), content=content, knowledge_id=knowledge_id)
|
||||
|
||||
|
||||
def delete_problems_and_mappings(paragraph_ids):
|
||||
problem_paragraph_mappings = ProblemParagraphMapping.objects.filter(paragraph_id__in=paragraph_ids)
|
||||
problem_ids = set(problem_paragraph_mappings.values_list('problem_id', flat=True))
|
||||
|
||||
if problem_ids:
|
||||
problem_paragraph_mappings.delete()
|
||||
remaining_problem_counts = ProblemParagraphMapping.objects.filter(problem_id__in=problem_ids).values(
|
||||
'problem_id').annotate(count=Count('problem_id'))
|
||||
remaining_problem_ids = {pc['problem_id'] for pc in remaining_problem_counts}
|
||||
problem_ids_to_delete = problem_ids - remaining_problem_ids
|
||||
Problem.objects.filter(id__in=problem_ids_to_delete).delete()
|
||||
else:
|
||||
problem_paragraph_mappings.delete()
|
||||
15
apps/knowledge/serializers/problem.py
Normal file
15
apps/knowledge/serializers/problem.py
Normal file
@ -0,0 +1,15 @@
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from knowledge.models import Problem
|
||||
|
||||
|
||||
class ProblemSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Problem
|
||||
fields = ['id', 'content', 'knowledge_id', 'create_time', 'update_time']
|
||||
|
||||
|
||||
class ProblemInstanceSerializer(serializers.Serializer):
|
||||
id = serializers.CharField(required=False, label=_('problem id'))
|
||||
content = serializers.CharField(required=True, max_length=256, label=_('content'))
|
||||
11
apps/knowledge/sql/list_document.sql
Normal file
11
apps/knowledge/sql/list_document.sql
Normal file
@ -0,0 +1,11 @@
|
||||
SELECT * from (
|
||||
SELECT
|
||||
"document".* ,
|
||||
to_json("document"."meta") as meta,
|
||||
to_json("document"."status_meta") as status_meta,
|
||||
(SELECT "count"("id") FROM "paragraph" WHERE document_id="document"."id") as "paragraph_count"
|
||||
FROM
|
||||
"document" "document"
|
||||
${document_custom_sql}
|
||||
) temp
|
||||
${order_by_query}
|
||||
35
apps/knowledge/sql/list_knowledge.sql
Normal file
35
apps/knowledge/sql/list_knowledge.sql
Normal file
@ -0,0 +1,35 @@
|
||||
SELECT
|
||||
*,
|
||||
to_json(meta) as meta
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
"temp_knowledge".*,
|
||||
"document_temp"."char_length",
|
||||
CASE
|
||||
WHEN
|
||||
"app_knowledge_temp"."count" IS NULL THEN 0 ELSE "app_knowledge_temp"."count" END AS application_mapping_count,
|
||||
"document_temp".document_count FROM (
|
||||
SELECT knowledge.*
|
||||
FROM
|
||||
knowledge knowledge
|
||||
${knowledge_custom_sql}
|
||||
UNION
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
knowledge
|
||||
WHERE
|
||||
knowledge."id" IN (
|
||||
SELECT
|
||||
team_member_permission.target
|
||||
FROM
|
||||
team_member team_member
|
||||
LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id"
|
||||
${team_member_permission_custom_sql}
|
||||
)
|
||||
) temp_knowledge
|
||||
LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", knowledge_id FROM "document" GROUP BY knowledge_id ) "document_temp" ON temp_knowledge."id" = "document_temp".knowledge_id
|
||||
LEFT JOIN (SELECT "count"("id"),knowledge_id FROM application_knowledge_mapping GROUP BY knowledge_id) app_knowledge_temp ON temp_knowledge."id" = "app_knowledge_temp".knowledge_id
|
||||
) temp
|
||||
${default_sql}
|
||||
20
apps/knowledge/sql/list_knowledge_application.sql
Normal file
20
apps/knowledge/sql/list_knowledge_application.sql
Normal file
@ -0,0 +1,20 @@
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
application
|
||||
WHERE
|
||||
user_id = %s UNION
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
application
|
||||
WHERE
|
||||
"id" IN (
|
||||
SELECT
|
||||
team_member_permission.target
|
||||
FROM
|
||||
team_member team_member
|
||||
LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id"
|
||||
WHERE
|
||||
( "team_member_permission"."auth_target_type" = 'APPLICATION' AND "team_member_permission"."operate"::text[] @> ARRAY['USE'] AND team_member.team_id = %s AND team_member.user_id =%s )
|
||||
)
|
||||
6
apps/knowledge/sql/list_paragraph.sql
Normal file
6
apps/knowledge/sql/list_paragraph.sql
Normal file
@ -0,0 +1,6 @@
|
||||
SELECT
|
||||
(SELECT "name" FROM "document" WHERE "id"=document_id) as document_name,
|
||||
(SELECT "name" FROM "knowledge" WHERE "id"=knowledge_id) as knowledge_name,
|
||||
*
|
||||
FROM
|
||||
"paragraph"
|
||||
5
apps/knowledge/sql/list_paragraph_document_name.sql
Normal file
5
apps/knowledge/sql/list_paragraph_document_name.sql
Normal file
@ -0,0 +1,5 @@
|
||||
SELECT
|
||||
(SELECT "name" FROM "document" WHERE "id"=document_id) as document_name,
|
||||
*
|
||||
FROM
|
||||
"paragraph"
|
||||
5
apps/knowledge/sql/list_problem.sql
Normal file
5
apps/knowledge/sql/list_problem.sql
Normal file
@ -0,0 +1,5 @@
|
||||
SELECT
|
||||
problem.*,
|
||||
(SELECT "count"("id") FROM "problem_paragraph_mapping" WHERE problem_id="problem"."id") as "paragraph_count"
|
||||
FROM
|
||||
problem problem
|
||||
2
apps/knowledge/sql/list_problem_mapping.sql
Normal file
2
apps/knowledge/sql/list_problem_mapping.sql
Normal file
@ -0,0 +1,2 @@
|
||||
SELECT "problem"."content",problem_paragraph_mapping.paragraph_id FROM problem problem
|
||||
LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem_paragraph_mapping.problem_id=problem."id"
|
||||
7
apps/knowledge/sql/update_document_char_length.sql
Normal file
7
apps/knowledge/sql/update_document_char_length.sql
Normal file
@ -0,0 +1,7 @@
|
||||
UPDATE "document"
|
||||
SET "char_length" = ( SELECT CASE WHEN
|
||||
"sum" ( "char_length" ( "content" ) ) IS NULL THEN
|
||||
0 ELSE "sum" ( "char_length" ( "content" ) )
|
||||
END FROM paragraph WHERE "document_id" = %s )
|
||||
WHERE
|
||||
"id" = %s
|
||||
25
apps/knowledge/sql/update_document_status_meta.sql
Normal file
25
apps/knowledge/sql/update_document_status_meta.sql
Normal file
@ -0,0 +1,25 @@
|
||||
UPDATE "document" "document"
|
||||
SET status_meta = jsonb_set ( "document".status_meta, '{aggs}', tmp.status_meta )
|
||||
FROM
|
||||
(
|
||||
SELECT COALESCE
|
||||
( jsonb_agg ( jsonb_delete ( ( row_to_json ( record ) :: JSONB ), 'document_id' ) ), '[]' :: JSONB ) AS status_meta,
|
||||
document_id AS document_id
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
"paragraph".status,
|
||||
"count" ( "paragraph"."id" ),
|
||||
"document"."id" AS document_id
|
||||
FROM
|
||||
"document" "document"
|
||||
LEFT JOIN "paragraph" "paragraph" ON "document"."id" = paragraph.document_id
|
||||
${document_custom_sql}
|
||||
GROUP BY
|
||||
"paragraph".status,
|
||||
"document"."id"
|
||||
) record
|
||||
GROUP BY
|
||||
document_id
|
||||
) tmp
|
||||
WHERE "document".id="tmp".document_id
|
||||
13
apps/knowledge/sql/update_paragraph_status.sql
Normal file
13
apps/knowledge/sql/update_paragraph_status.sql
Normal file
@ -0,0 +1,13 @@
|
||||
UPDATE "${table_name}"
|
||||
SET status = reverse (
|
||||
SUBSTRING ( reverse ( LPAD( status, ${bit_number}, 'n' ) ) :: TEXT FROM 1 FOR ${up_index} ) || ${status_number} || SUBSTRING ( reverse ( LPAD( status, ${bit_number}, 'n' ) ) :: TEXT FROM ${next_index} )
|
||||
),
|
||||
status_meta = jsonb_set (
|
||||
"${table_name}".status_meta,
|
||||
'{state_time,${current_index}}',
|
||||
jsonb_set (
|
||||
COALESCE ( "${table_name}".status_meta #> '{state_time,${current_index}}', jsonb_build_object ( '${status_number}', '${current_time}' ) ),
|
||||
'{${status_number}}',
|
||||
CONCAT ( '"', '${current_time}', '"' ) :: JSONB
|
||||
)
|
||||
)
|
||||
@ -1 +1,2 @@
|
||||
from .sync import *
|
||||
from .embedding import *
|
||||
|
||||
255
apps/knowledge/task/embedding.py
Normal file
255
apps/knowledge/task/embedding.py
Normal file
@ -0,0 +1,255 @@
|
||||
# coding=utf-8
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
from celery_once import QueueOnce
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from common.config.embedding_config import ModelManage
|
||||
from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingKnowledgeIdArgs, \
|
||||
UpdateEmbeddingDocumentIdArgs
|
||||
from knowledge.models import Document, TaskType, State
|
||||
from models_provider.tools import get_model
|
||||
from models_provider.models import Model
|
||||
from ops import celery_app
|
||||
|
||||
max_kb_error = logging.getLogger("max_kb_error")
|
||||
max_kb = logging.getLogger("max_kb")
|
||||
|
||||
|
||||
def get_embedding_model(model_id, exception_handler=lambda e: max_kb_error.error(
|
||||
_('Failed to obtain vector model: {error} {traceback}').format(
|
||||
error=str(e),
|
||||
traceback=traceback.format_exc()
|
||||
))):
|
||||
try:
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||
except Exception as e:
|
||||
exception_handler(e)
|
||||
raise e
|
||||
return embedding_model
|
||||
|
||||
|
||||
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id']}, name='celery:embedding_by_paragraph')
|
||||
def embedding_by_paragraph(paragraph_id, model_id):
|
||||
embedding_model = get_embedding_model(model_id)
|
||||
ListenerManagement.embedding_by_paragraph(paragraph_id, embedding_model)
|
||||
|
||||
|
||||
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:embedding_by_paragraph_data_list')
|
||||
def embedding_by_paragraph_data_list(data_list, paragraph_id_list, model_id):
|
||||
embedding_model = get_embedding_model(model_id)
|
||||
ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model)
|
||||
|
||||
|
||||
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:embedding_by_paragraph_list')
|
||||
def embedding_by_paragraph_list(paragraph_id_list, model_id):
|
||||
embedding_model = get_embedding_model(model_id)
|
||||
ListenerManagement.embedding_by_paragraph_list(paragraph_id_list, embedding_model)
|
||||
|
||||
|
||||
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:embedding_by_document')
|
||||
def embedding_by_document(document_id, model_id, state_list=None):
|
||||
"""
|
||||
向量化文档
|
||||
@param state_list:
|
||||
@param document_id: 文档id
|
||||
@param model_id 向量模型
|
||||
:return: None
|
||||
"""
|
||||
|
||||
if state_list is None:
|
||||
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
|
||||
State.REVOKE.value,
|
||||
State.REVOKED.value, State.IGNORED.value]
|
||||
|
||||
def exception_handler(e):
|
||||
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
|
||||
State.FAILURE)
|
||||
max_kb_error.error(
|
||||
_('Failed to obtain vector model: {error} {traceback}').format(
|
||||
error=str(e),
|
||||
traceback=traceback.format_exc()
|
||||
))
|
||||
|
||||
embedding_model = get_embedding_model(model_id, exception_handler)
|
||||
ListenerManagement.embedding_by_document(document_id, embedding_model, state_list)
|
||||
|
||||
|
||||
@celery_app.task(name='celery:embedding_by_document_list')
|
||||
def embedding_by_document_list(document_id_list, model_id):
|
||||
"""
|
||||
向量化文档
|
||||
@param document_id_list: 文档id列表
|
||||
@param model_id 向量模型
|
||||
:return: None
|
||||
"""
|
||||
for document_id in document_id_list:
|
||||
embedding_by_document.delay(document_id, model_id)
|
||||
|
||||
|
||||
@celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, name='celery:embedding_by_knowledge')
|
||||
def embedding_by_knowledge(knowledge_id, model_id):
|
||||
"""
|
||||
向量化知识库
|
||||
@param knowledge_id: 知识库id
|
||||
@param model_id 向量模型
|
||||
:return: None
|
||||
"""
|
||||
max_kb.info(_('Start--->Vectorized knowledge: {knowledge_id}').format(knowledge_id=knowledge_id))
|
||||
try:
|
||||
ListenerManagement.delete_embedding_by_knowledge(knowledge_id)
|
||||
document_list = QuerySet(Document).filter(knowledge_id=knowledge_id)
|
||||
max_kb.info(_('Knowledge documentation: {document_names}').format(
|
||||
document_names=", ".join([d.name for d in document_list])))
|
||||
for document in document_list:
|
||||
try:
|
||||
embedding_by_document.delay(document.id, model_id)
|
||||
except Exception as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
max_kb_error.error(
|
||||
_('Vectorized knowledge: {knowledge_id} error {error} {traceback}'.format(knowledge_id=knowledge_id,
|
||||
error=str(e),
|
||||
traceback=traceback.format_exc())))
|
||||
finally:
|
||||
max_kb.info(_('End--->Vectorized knowledge: {knowledge_id}').format(knowledge_id=knowledge_id))
|
||||
|
||||
|
||||
def embedding_by_problem(args, model_id):
|
||||
"""
|
||||
向量话问题
|
||||
@param args: 问题对象
|
||||
@param model_id: 模型id
|
||||
@return:
|
||||
"""
|
||||
embedding_model = get_embedding_model(model_id)
|
||||
ListenerManagement.embedding_by_problem(args, embedding_model)
|
||||
|
||||
|
||||
def embedding_by_data_list(args: List, model_id):
|
||||
embedding_model = get_embedding_model(model_id)
|
||||
ListenerManagement.embedding_by_data_list(args, embedding_model)
|
||||
|
||||
|
||||
def delete_embedding_by_document(document_id):
|
||||
"""
|
||||
删除指定文档id的向量
|
||||
@param document_id: 文档id
|
||||
@return: None
|
||||
"""
|
||||
|
||||
ListenerManagement.delete_embedding_by_document(document_id)
|
||||
|
||||
|
||||
def delete_embedding_by_document_list(document_id_list: List[str]):
|
||||
"""
|
||||
删除指定文档列表的向量数据
|
||||
@param document_id_list: 文档id列表
|
||||
@return: None
|
||||
"""
|
||||
ListenerManagement.delete_embedding_by_document_list(document_id_list)
|
||||
|
||||
|
||||
def delete_embedding_by_knowledge(knowledge_id):
|
||||
"""
|
||||
删除指定数据集向量数据
|
||||
@param knowledge_id: 数据集id
|
||||
@return: None
|
||||
"""
|
||||
ListenerManagement.delete_embedding_by_knowledge(knowledge_id)
|
||||
|
||||
|
||||
def delete_embedding_by_paragraph(paragraph_id):
|
||||
"""
|
||||
删除指定段落的向量数据
|
||||
@param paragraph_id: 段落id
|
||||
@return: None
|
||||
"""
|
||||
ListenerManagement.delete_embedding_by_paragraph(paragraph_id)
|
||||
|
||||
|
||||
def delete_embedding_by_source(source_id):
|
||||
"""
|
||||
删除指定资源id的向量数据
|
||||
@param source_id: 资源id
|
||||
@return: None
|
||||
"""
|
||||
ListenerManagement.delete_embedding_by_source(source_id)
|
||||
|
||||
|
||||
def disable_embedding_by_paragraph(paragraph_id):
|
||||
"""
|
||||
禁用某个段落id的向量
|
||||
@param paragraph_id: 段落id
|
||||
@return: None
|
||||
"""
|
||||
ListenerManagement.disable_embedding_by_paragraph(paragraph_id)
|
||||
|
||||
|
||||
def enable_embedding_by_paragraph(paragraph_id):
|
||||
"""
|
||||
开启某个段落id的向量数据
|
||||
@param paragraph_id: 段落id
|
||||
@return: None
|
||||
"""
|
||||
ListenerManagement.enable_embedding_by_paragraph(paragraph_id)
|
||||
|
||||
|
||||
def delete_embedding_by_source_ids(source_ids: List[str]):
|
||||
"""
|
||||
删除向量根据source_id_list
|
||||
@param source_ids:
|
||||
@return:
|
||||
"""
|
||||
ListenerManagement.delete_embedding_by_source_ids(source_ids)
|
||||
|
||||
|
||||
def update_problem_embedding(problem_id: str, problem_content: str, model_id):
|
||||
"""
|
||||
更新问题
|
||||
@param problem_id:
|
||||
@param problem_content:
|
||||
@param model_id:
|
||||
@return:
|
||||
"""
|
||||
model = get_embedding_model(model_id)
|
||||
ListenerManagement.update_problem(UpdateProblemArgs(problem_id, problem_content, model))
|
||||
|
||||
|
||||
def update_embedding_knowledge_id(paragraph_id_list, target_knowledge_id):
|
||||
"""
|
||||
修改向量数据到指定知识库
|
||||
@param paragraph_id_list: 指定段落的向量数据
|
||||
@param target_knowledge_id: 知识库id
|
||||
@return:
|
||||
"""
|
||||
|
||||
ListenerManagement.update_embedding_knowledge_id(
|
||||
UpdateEmbeddingKnowledgeIdArgs(paragraph_id_list, target_knowledge_id))
|
||||
|
||||
|
||||
def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]):
|
||||
"""
|
||||
删除指定段落列表的向量数据
|
||||
@param paragraph_ids: 段落列表
|
||||
@return: None
|
||||
"""
|
||||
ListenerManagement.delete_embedding_by_paragraph_ids(paragraph_ids)
|
||||
|
||||
|
||||
def update_embedding_document_id(paragraph_id_list, target_document_id, target_knowledge_id,
|
||||
target_embedding_model_id=None):
|
||||
target_embedding_model = get_embedding_model(
|
||||
target_embedding_model_id) if target_embedding_model_id is not None else None
|
||||
ListenerManagement.update_embedding_document_id(
|
||||
UpdateEmbeddingDocumentIdArgs(paragraph_id_list, target_document_id, target_knowledge_id,
|
||||
target_embedding_model))
|
||||
|
||||
|
||||
def delete_embedding_by_knowledge_id_list(knowledge_id_list):
|
||||
ListenerManagement.delete_embedding_by_knowledge_id_list(knowledge_id_list)
|
||||
@ -1,29 +1,23 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: tools.py
|
||||
@date:2024/8/20 21:48
|
||||
@desc:
|
||||
"""
|
||||
|
||||
|
||||
import logging
|
||||
import re
|
||||
import traceback
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from common.utils.fork import ChildLink, Fork
|
||||
from common.utils.split_model import get_split_model
|
||||
from knowledge.models.knowledge import KnowledgeType, Document, DataSet, Status
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from knowledge.models.knowledge import KnowledgeType, Document, Knowledge, Status
|
||||
|
||||
max_kb_error = logging.getLogger("max_kb_error")
|
||||
max_kb = logging.getLogger("max_kb")
|
||||
|
||||
|
||||
def get_save_handler(dataset_id, selector):
|
||||
from knowledge.serializers.document_serializers import DocumentSerializers
|
||||
def get_save_handler(knowledge_id, selector):
|
||||
from knowledge.serializers import DocumentSerializers
|
||||
|
||||
def handler(child_link: ChildLink, response: Fork.Response):
|
||||
if response.status == 200:
|
||||
@ -31,7 +25,7 @@ def get_save_handler(dataset_id, selector):
|
||||
document_name = child_link.tag.text if child_link.tag is not None and len(
|
||||
child_link.tag.text.strip()) > 0 else child_link.url
|
||||
paragraphs = get_split_model('web.md').parse(response.content)
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
|
||||
DocumentSerializers.Create(data={'knowledge_id': knowledge_id}).save(
|
||||
{'name': document_name, 'paragraphs': paragraphs,
|
||||
'meta': {'source_url': child_link.url, 'selector': selector},
|
||||
'type': KnowledgeType.WEB}, with_valid=True)
|
||||
@ -41,9 +35,9 @@ def get_save_handler(dataset_id, selector):
|
||||
return handler
|
||||
|
||||
|
||||
def get_sync_handler(dataset_id):
|
||||
from knowledge.serializers.document_serializers import DocumentSerializers
|
||||
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
|
||||
def get_sync_handler(knowledge_id):
|
||||
from knowledge.serializers import DocumentSerializers
|
||||
knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first()
|
||||
|
||||
def handler(child_link: ChildLink, response: Fork.Response):
|
||||
if response.status == 200:
|
||||
@ -52,32 +46,31 @@ def get_sync_handler(dataset_id):
|
||||
document_name = child_link.tag.text if child_link.tag is not None and len(
|
||||
child_link.tag.text.strip()) > 0 else child_link.url
|
||||
paragraphs = get_split_model('web.md').parse(response.content)
|
||||
first = QuerySet(Document).filter(meta__source_url=child_link.url.strip(),
|
||||
dataset=dataset).first()
|
||||
first = QuerySet(Document).filter(meta__source_url=child_link.url.strip(), knowledge=knowledge).first()
|
||||
if first is not None:
|
||||
# 如果存在,使用文档同步
|
||||
DocumentSerializers.Sync(data={'document_id': first.id}).sync()
|
||||
else:
|
||||
# 插入
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset.id}).save(
|
||||
DocumentSerializers.Create(data={'knowledge_id': knowledge.id}).save(
|
||||
{'name': document_name, 'paragraphs': paragraphs,
|
||||
'meta': {'source_url': child_link.url.strip(), 'selector': dataset.meta.get('selector')},
|
||||
'type': Type.web}, with_valid=True)
|
||||
'meta': {'source_url': child_link.url.strip(), 'selector': knowledge.meta.get('selector')},
|
||||
'type': KnowledgeType.WEB}, with_valid=True)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def get_sync_web_document_handler(dataset_id):
|
||||
from knowledge.serializers.document_serializers import DocumentSerializers
|
||||
def get_sync_web_document_handler(knowledge_id):
|
||||
from knowledge.serializers import DocumentSerializers
|
||||
|
||||
def handler(source_url: str, selector, response: Fork.Response):
|
||||
if response.status == 200:
|
||||
try:
|
||||
paragraphs = get_split_model('web.md').parse(response.content)
|
||||
# 插入
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
|
||||
DocumentSerializers.Create(data={'knowledge_id': knowledge_id}).save(
|
||||
{'name': source_url[0:128], 'paragraphs': paragraphs,
|
||||
'meta': {'source_url': source_url, 'selector': selector},
|
||||
'type': KnowledgeType.WEB}, with_valid=True)
|
||||
@ -85,7 +78,7 @@ def get_sync_web_document_handler(dataset_id):
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
else:
|
||||
Document(name=source_url[0:128],
|
||||
dataset_id=dataset_id,
|
||||
knowledge_id=knowledge_id,
|
||||
meta={'source_url': source_url, 'selector': selector},
|
||||
type=KnowledgeType.WEB,
|
||||
char_length=0,
|
||||
@ -94,9 +87,9 @@ def get_sync_web_document_handler(dataset_id):
|
||||
return handler
|
||||
|
||||
|
||||
def save_problem(dataset_id, document_id, paragraph_id, problem):
|
||||
from knowledge.serializers.paragraph_serializers import ParagraphSerializers
|
||||
# print(f"dataset_id: {dataset_id}")
|
||||
def save_problem(knowledge_id, document_id, paragraph_id, problem):
|
||||
from knowledge.serializers import ParagraphSerializers
|
||||
# print(f"knowledge_id: {knowledge_id}")
|
||||
# print(f"document_id: {document_id}")
|
||||
# print(f"paragraph_id: {paragraph_id}")
|
||||
# print(f"problem: {problem}")
|
||||
@ -108,7 +101,7 @@ def save_problem(dataset_id, document_id, paragraph_id, problem):
|
||||
return
|
||||
try:
|
||||
ParagraphSerializers.Problem(
|
||||
data={"dataset_id": dataset_id, 'document_id': document_id,
|
||||
data={"knowledge_id": knowledge_id, 'document_id': document_id,
|
||||
'paragraph_id': paragraph_id}).save(instance={"content": problem}, with_valid=True)
|
||||
except Exception as e:
|
||||
max_kb_error.error(_('Association problem failed {error}').format(error=str(e)))
|
||||
@ -12,12 +12,11 @@ import traceback
|
||||
from typing import List
|
||||
|
||||
from celery_once import QueueOnce
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from common.utils.fork import ForkManage, Fork
|
||||
from .tools import get_save_handler, get_sync_web_document_handler, get_sync_handler
|
||||
|
||||
from ops import celery_app
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from .handler import get_save_handler, get_sync_web_document_handler, get_sync_handler
|
||||
|
||||
max_kb_error = logging.getLogger("max_kb_error")
|
||||
max_kb = logging.getLogger("max_kb")
|
||||
|
||||
0
apps/knowledge/vector/__init__.py
Normal file
0
apps/knowledge/vector/__init__.py
Normal file
187
apps/knowledge/vector/base_vector.py
Normal file
187
apps/knowledge/vector/base_vector.py
Normal file
@ -0,0 +1,187 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_vector.py
|
||||
@date:2023/10/18 19:16
|
||||
@desc:
|
||||
"""
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from common.chunk import text_to_chunk
|
||||
from common.utils.common import sub_array
|
||||
from knowledge.models import SourceType, SearchMode
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
|
||||
def chunk_data(data: Dict):
|
||||
if str(data.get('source_type')) == SourceType.PARAGRAPH.value:
|
||||
text = data.get('text')
|
||||
chunk_list = text_to_chunk(text)
|
||||
return [{**data, 'text': chunk} for chunk in chunk_list]
|
||||
return [data]
|
||||
|
||||
|
||||
def chunk_data_list(data_list: List[Dict]):
|
||||
result = [chunk_data(data) for data in data_list]
|
||||
return reduce(lambda x, y: [*x, *y], result, [])
|
||||
|
||||
|
||||
class BaseVectorStore(ABC):
|
||||
vector_exists = False
|
||||
|
||||
@abstractmethod
|
||||
def vector_is_create(self) -> bool:
|
||||
"""
|
||||
判断向量库是否创建
|
||||
:return: 是否创建向量库
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def vector_create(self):
|
||||
"""
|
||||
创建 向量库
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_pre_handler(self):
|
||||
"""
|
||||
插入前置处理器 主要是判断向量库是否创建
|
||||
:return: True
|
||||
"""
|
||||
if not BaseVectorStore.vector_exists:
|
||||
if not self.vector_is_create():
|
||||
self.vector_create()
|
||||
BaseVectorStore.vector_exists = True
|
||||
return True
|
||||
|
||||
def save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
embedding: Embeddings):
|
||||
"""
|
||||
插入向量数据
|
||||
:param source_id: 资源id
|
||||
:param knowledge_id: 知识库id
|
||||
:param text: 文本
|
||||
:param source_type: 资源类型
|
||||
:param document_id: 文档id
|
||||
:param is_active: 是否禁用
|
||||
:param embedding: 向量化处理器
|
||||
:param paragraph_id 段落id
|
||||
:return: bool
|
||||
"""
|
||||
self.save_pre_handler()
|
||||
data = {'document_id': document_id, 'paragraph_id': paragraph_id, 'knowledge_id': knowledge_id,
|
||||
'is_active': is_active, 'source_id': source_id, 'source_type': source_type, 'text': text}
|
||||
chunk_list = chunk_data(data)
|
||||
result = sub_array(chunk_list)
|
||||
for child_array in result:
|
||||
self._batch_save(child_array, embedding, lambda: False)
|
||||
|
||||
def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
|
||||
"""
|
||||
批量插入
|
||||
@param data_list: 数据列表
|
||||
@param embedding: 向量化处理器
|
||||
@param is_the_task_interrupted: 判断是否中断任务
|
||||
:return: bool
|
||||
"""
|
||||
self.save_pre_handler()
|
||||
chunk_list = chunk_data_list(data_list)
|
||||
result = sub_array(chunk_list)
|
||||
for child_array in result:
|
||||
if not is_the_task_interrupted():
|
||||
self._batch_save(child_array, embedding, is_the_task_interrupted)
|
||||
else:
|
||||
break
|
||||
|
||||
@abstractmethod
|
||||
def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
embedding: Embeddings):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
|
||||
pass
|
||||
|
||||
def search(self, query_text, knowledge_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_list: list[str],
|
||||
is_active: bool,
|
||||
embedding: Embeddings):
|
||||
if knowledge_id_list is None or len(knowledge_id_list) == 0:
|
||||
return []
|
||||
embedding_query = embedding.embed_query(query_text)
|
||||
result = self.query(embedding_query, knowledge_id_list, exclude_document_id_list, exclude_paragraph_list,
|
||||
is_active, 1, 3, 0.65)
|
||||
return result[0]
|
||||
|
||||
@abstractmethod
|
||||
def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str],
|
||||
exclude_document_id_list: list[str],
|
||||
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
|
||||
search_mode: SearchMode):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def hit_test(self, query_text, knowledge_id: list[str], exclude_document_id_list: list[str], top_number: int,
|
||||
similarity: float,
|
||||
search_mode: SearchMode,
|
||||
embedding: Embeddings):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_by_paragraph_ids(self, paragraph_ids: str, instance: Dict):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_by_source_id(self, source_id: str, instance: Dict):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_knowledge_id(self, knowledge_id: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_document_id_list(self, document_id_list: List[str]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_knowledge_id_list(self, knowledge_id_list: List[str]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_source_id(self, source_id: str, source_type: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_paragraph_id(self, paragraph_id: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_paragraph_ids(self, paragraph_ids: List[str]):
|
||||
pass
|
||||
222
apps/knowledge/vector/pg_vector.py
Normal file
222
apps/knowledge/vector/pg_vector.py
Normal file
@ -0,0 +1,222 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: pg_vector.py
|
||||
@date:2023/10/19 15:28
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
|
||||
import uuid_utils.compat as uuid
|
||||
from common.utils.ts_vecto_util import to_ts_vector, to_query
|
||||
from django.contrib.postgres.search import SearchVector
|
||||
from django.db.models import QuerySet, Value
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from common.db.search import generate_sql_by_query_dict
|
||||
from common.db.sql_execute import select_list
|
||||
from common.utils.common import get_file_content
|
||||
from knowledge.models import Embedding, SearchMode, SourceType
|
||||
from knowledge.vector.base_vector import BaseVectorStore
|
||||
from maxkb.conf import PROJECT_DIR
|
||||
|
||||
|
||||
class PGVector(BaseVectorStore):
|
||||
|
||||
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
|
||||
if len(source_ids) == 0:
|
||||
return
|
||||
QuerySet(Embedding).filter(source_id__in=source_ids, source_type=source_type).delete()
|
||||
|
||||
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
||||
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
|
||||
|
||||
def vector_is_create(self) -> bool:
|
||||
# 项目启动默认是创建好的 不需要再创建
|
||||
return True
|
||||
|
||||
def vector_create(self):
|
||||
return True
|
||||
|
||||
def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
embedding: Embeddings):
|
||||
text_embedding = embedding.embed_query(text)
|
||||
embedding = Embedding(id=uuid.uuid7(),
|
||||
knowledge_id=knowledge_id,
|
||||
document_id=document_id,
|
||||
is_active=is_active,
|
||||
paragraph_id=paragraph_id,
|
||||
source_id=source_id,
|
||||
embedding=text_embedding,
|
||||
source_type=source_type,
|
||||
search_vector=to_ts_vector(text))
|
||||
embedding.save()
|
||||
return True
|
||||
|
||||
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
|
||||
texts = [row.get('text') for row in text_list]
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
embedding_list = [Embedding(id=uuid.uuid7(),
|
||||
document_id=text_list[index].get('document_id'),
|
||||
paragraph_id=text_list[index].get('paragraph_id'),
|
||||
knowledge_id=text_list[index].get('knowledge_id'),
|
||||
is_active=text_list[index].get('is_active', True),
|
||||
source_id=text_list[index].get('source_id'),
|
||||
source_type=text_list[index].get('source_type'),
|
||||
embedding=embeddings[index],
|
||||
search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))) for
|
||||
index in
|
||||
range(0, len(texts))]
|
||||
if not is_the_task_interrupted():
|
||||
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
|
||||
return True
|
||||
|
||||
def hit_test(self, query_text, knowledge_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
|
||||
similarity: float,
|
||||
search_mode: SearchMode,
|
||||
embedding: Embeddings):
|
||||
if knowledge_id_list is None or len(knowledge_id_list) == 0:
|
||||
return []
|
||||
exclude_dict = {}
|
||||
embedding_query = embedding.embed_query(query_text)
|
||||
query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=True)
|
||||
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
|
||||
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
|
||||
query_set = query_set.exclude(**exclude_dict)
|
||||
for search_handle in search_handle_list:
|
||||
if search_handle.support(search_mode):
|
||||
return search_handle.handle(query_set, query_text, embedding_query, top_number, similarity, search_mode)
|
||||
|
||||
def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str],
|
||||
exclude_document_id_list: list[str],
|
||||
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
|
||||
search_mode: SearchMode):
|
||||
exclude_dict = {}
|
||||
if knowledge_id_list is None or len(knowledge_id_list) == 0:
|
||||
return []
|
||||
query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=is_active)
|
||||
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
|
||||
query_set = query_set.exclude(document_id__in=exclude_document_id_list)
|
||||
if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:
|
||||
query_set = query_set.exclude(paragraph_id__in=exclude_paragraph_list)
|
||||
query_set = query_set.exclude(**exclude_dict)
|
||||
for search_handle in search_handle_list:
|
||||
if search_handle.support(search_mode):
|
||||
return search_handle.handle(query_set, query_text, query_embedding, top_n, similarity, search_mode)
|
||||
|
||||
def update_by_source_id(self, source_id: str, instance: Dict):
|
||||
QuerySet(Embedding).filter(source_id=source_id).update(**instance)
|
||||
|
||||
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
|
||||
QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance)
|
||||
|
||||
def update_by_paragraph_ids(self, paragraph_id: str, instance: Dict):
|
||||
QuerySet(Embedding).filter(paragraph_id__in=paragraph_id).update(**instance)
|
||||
|
||||
def delete_by_knowledge_id(self, knowledge_id: str):
|
||||
QuerySet(Embedding).filter(knowledge_id=knowledge_id).delete()
|
||||
|
||||
def delete_by_knowledge_id_list(self, knowledge_id_list: List[str]):
|
||||
QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list).delete()
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
QuerySet(Embedding).filter(document_id=document_id).delete()
|
||||
return True
|
||||
|
||||
def delete_by_document_id_list(self, document_id_list: List[str]):
|
||||
if len(document_id_list) == 0:
|
||||
return True
|
||||
return QuerySet(Embedding).filter(document_id__in=document_id_list).delete()
|
||||
|
||||
def delete_by_source_id(self, source_id: str, source_type: str):
|
||||
QuerySet(Embedding).filter(source_id=source_id, source_type=source_type).delete()
|
||||
return True
|
||||
|
||||
def delete_by_paragraph_id(self, paragraph_id: str):
|
||||
QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete()
|
||||
|
||||
def delete_by_paragraph_ids(self, paragraph_ids: List[str]):
|
||||
QuerySet(Embedding).filter(paragraph_id__in=paragraph_ids).delete()
|
||||
|
||||
|
||||
class ISearch(ABC):
|
||||
@abstractmethod
|
||||
def support(self, search_mode: SearchMode):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def handle(self, query_set, query_text, query_embedding, top_number: int,
|
||||
similarity: float, search_mode: SearchMode):
|
||||
pass
|
||||
|
||||
|
||||
class EmbeddingSearch(ISearch):
|
||||
def handle(self,
|
||||
query_set,
|
||||
query_text,
|
||||
query_embedding,
|
||||
top_number: int,
|
||||
similarity: float,
|
||||
search_mode: SearchMode):
|
||||
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
|
||||
'embedding_search.sql')),
|
||||
with_table_name=True)
|
||||
embedding_model = select_list(exec_sql,
|
||||
[json.dumps(query_embedding), *exec_params, similarity, top_number])
|
||||
return embedding_model
|
||||
|
||||
def support(self, search_mode: SearchMode):
|
||||
return search_mode.value == SearchMode.embedding.value
|
||||
|
||||
|
||||
class KeywordsSearch(ISearch):
|
||||
def handle(self,
|
||||
query_set,
|
||||
query_text,
|
||||
query_embedding,
|
||||
top_number: int,
|
||||
similarity: float,
|
||||
search_mode: SearchMode):
|
||||
exec_sql, exec_params = generate_sql_by_query_dict({'keywords_query': query_set},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
|
||||
'keywords_search.sql')),
|
||||
with_table_name=True)
|
||||
embedding_model = select_list(exec_sql,
|
||||
[to_query(query_text), *exec_params, similarity, top_number])
|
||||
return embedding_model
|
||||
|
||||
def support(self, search_mode: SearchMode):
|
||||
return search_mode.value == SearchMode.keywords.value
|
||||
|
||||
|
||||
class BlendSearch(ISearch):
|
||||
def handle(self,
|
||||
query_set,
|
||||
query_text,
|
||||
query_embedding,
|
||||
top_number: int,
|
||||
similarity: float,
|
||||
search_mode: SearchMode):
|
||||
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
|
||||
'blend_search.sql')),
|
||||
with_table_name=True)
|
||||
embedding_model = select_list(exec_sql,
|
||||
[json.dumps(query_embedding), to_query(query_text), *exec_params, similarity,
|
||||
top_number])
|
||||
return embedding_model
|
||||
|
||||
def support(self, search_mode: SearchMode):
|
||||
return search_mode.value == SearchMode.blend.value
|
||||
|
||||
|
||||
search_handle_list = [EmbeddingSearch(), KeywordsSearch(), BlendSearch()]
|
||||
@ -36,7 +36,7 @@ class Migration(migrations.Migration):
|
||||
('tree_id', models.PositiveIntegerField(db_index=True, editable=False)),
|
||||
('level', models.PositiveIntegerField(editable=False)),
|
||||
('parent',
|
||||
mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE,
|
||||
mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING,
|
||||
related_name='children', to='tools.toolfolder')),
|
||||
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user',
|
||||
verbose_name='用户id')),
|
||||
@ -73,7 +73,7 @@ class Migration(migrations.Migration):
|
||||
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user',
|
||||
verbose_name='用户id')),
|
||||
('folder',
|
||||
models.ForeignKey(default='root', on_delete=django.db.models.deletion.CASCADE, to='tools.toolfolder',
|
||||
models.ForeignKey(default='root', on_delete=django.db.models.deletion.DO_NOTHING, to='tools.toolfolder',
|
||||
verbose_name='文件夹id')),
|
||||
],
|
||||
options={
|
||||
|
||||
@ -12,7 +12,7 @@ class ToolFolder(MPTTModel, AppModelMixin):
|
||||
name = models.CharField(max_length=64, verbose_name="文件夹名称")
|
||||
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="用户id")
|
||||
workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True)
|
||||
parent = TreeForeignKey('self', on_delete=models.CASCADE, null=True, blank=True, related_name='children')
|
||||
parent = TreeForeignKey('self', on_delete=models.DO_NOTHING, null=True, blank=True, related_name='children')
|
||||
|
||||
class Meta:
|
||||
db_table = "tool_folder"
|
||||
@ -46,7 +46,7 @@ class Tool(AppModelMixin):
|
||||
tool_type = models.CharField(max_length=20, verbose_name='工具类型', choices=ToolType.choices,
|
||||
default=ToolType.CUSTOM, db_index=True)
|
||||
template_id = models.UUIDField(max_length=128, verbose_name="模版id", null=True, default=None)
|
||||
folder = models.ForeignKey(ToolFolder, on_delete=models.CASCADE, verbose_name="文件夹id", default='root')
|
||||
folder = models.ForeignKey(ToolFolder, on_delete=models.DO_NOTHING, verbose_name="文件夹id", default='root')
|
||||
workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True)
|
||||
init_params = models.CharField(max_length=102400, verbose_name="初始化参数", null=True)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user