feat: add initial implementation of document and paragraph models with serializers

This commit is contained in:
CaptainB 2025-04-28 16:31:46 +08:00
parent 8c362b0f99
commit 770089e432
43 changed files with 2580 additions and 70 deletions

View File

@ -0,0 +1,18 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py
@date2024/7/23 17:03
@desc:
"""
from common.chunk.impl.mark_chunk_handle import MarkChunkHandle
handles = [MarkChunkHandle()]
def text_to_chunk(text: str):
chunk_list = [text]
for handle in handles:
chunk_list = handle.handle(chunk_list)
return chunk_list

View File

@ -0,0 +1,16 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file i_chunk_handle.py
@date2024/7/23 16:51
@desc:
"""
from abc import ABC, abstractmethod
from typing import List
class IChunkHandle(ABC):
@abstractmethod
def handle(self, chunk_list: List[str]):
pass

View File

@ -0,0 +1,40 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file mark_chunk_handle.py
@date2024/7/23 16:52
@desc:
"""
import re
from typing import List
from common.chunk.i_chunk_handle import IChunkHandle
max_chunk_len = 256
split_chunk_pattern = r'.{1,%d}[。| |\\.||;||!|\n]' % max_chunk_len
max_chunk_pattern = r'.{1,%d}' % max_chunk_len
class MarkChunkHandle(IChunkHandle):
def handle(self, chunk_list: List[str]):
result = []
for chunk in chunk_list:
chunk_result = re.findall(split_chunk_pattern, chunk, flags=re.DOTALL)
for c_r in chunk_result:
if len(c_r.strip()) > 0:
result.append(c_r.strip())
other_chunk_list = re.split(split_chunk_pattern, chunk, flags=re.DOTALL)
for other_chunk in other_chunk_list:
if len(other_chunk) > 0:
if len(other_chunk) < max_chunk_len:
if len(other_chunk.strip()) > 0:
result.append(other_chunk.strip())
else:
max_chunk_list = re.findall(max_chunk_pattern, other_chunk, flags=re.DOTALL)
for m_c in max_chunk_list:
if len(m_c.strip()) > 0:
result.append(m_c.strip())
return result

View File

@ -47,20 +47,20 @@ class ModelManage:
ModelManage.cache.delete(_id)
# class VectorStore:
# from embedding.vector.pg_vector import PGVector
# from embedding.vector.base_vector import BaseVectorStore
# instance_map = {
# 'pg_vector': PGVector,
# }
# instance = None
#
# @staticmethod
# def get_embedding_vector() -> BaseVectorStore:
# from embedding.vector.pg_vector import PGVector
# if VectorStore.instance is None:
# from maxkb.const import CONFIG
# vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"),
# PGVector)
# VectorStore.instance = vector_store_class()
# return VectorStore.instance
class VectorStore:
from knowledge.vector.pg_vector import PGVector
from knowledge.vector.base_vector import BaseVectorStore
instance_map = {
'pg_vector': PGVector,
}
instance = None
@staticmethod
def get_embedding_vector() -> BaseVectorStore:
from knowledge.vector.pg_vector import PGVector
if VectorStore.instance is None:
from maxkb.const import CONFIG
vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"),
PGVector)
VectorStore.instance = vector_store_class()
return VectorStore.instance

View File

@ -0,0 +1,30 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2023/11/10 10:43
@desc:
"""
from models_provider.models import Model, Status
from .listener_manage import *
from django.utils.translation import gettext as _
from ..db.sql_execute import update_execute
from common.lock.impl.file_lock import FileLock
lock = FileLock()
update_document_status_sql = """
UPDATE "public"."document"
SET status ="replace"("replace"("replace"(status, '1', '3'), '0', '3'), '4', '3')
WHERE status ~ '1|0|4'
"""
def run():
if lock.try_lock('event_init', 30 * 30):
try:
QuerySet(Model).filter(status=Status.DOWNLOAD).update(status=Status.ERROR, meta={'message': _( 'The download process was interrupted, please try again')})
update_execute(update_document_status_sql, [])
finally:
lock.un_lock('event_init')

View File

@ -0,0 +1,50 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file common.py
@date2023/11/10 10:41
@desc:
"""
from concurrent.futures import ThreadPoolExecutor
from django.core.cache.backends.locmem import LocMemCache
work_thread_pool = ThreadPoolExecutor(5)
embedding_thread_pool = ThreadPoolExecutor(3)
memory_cache = LocMemCache('task', {"OPTIONS": {"MAX_ENTRIES": 1000}})
def poxy(poxy_function):
def inner(args, **keywords):
work_thread_pool.submit(poxy_function, args, **keywords)
return inner
def get_cache_key(poxy_function, args):
return poxy_function.__name__ + str(args)
def get_cache_poxy_function(poxy_function, cache_key):
def fun(args, **keywords):
try:
poxy_function(args, **keywords)
finally:
memory_cache.delete(cache_key)
return fun
def embedding_poxy(poxy_function):
def inner(*args, **keywords):
key = get_cache_key(poxy_function, args)
if memory_cache.has_key(key):
return
memory_cache.add(key, None)
f = get_cache_poxy_function(poxy_function, key)
embedding_thread_pool.submit(f, args, **keywords)
return inner

View File

@ -0,0 +1,385 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file listener_manage.py
@date2023/10/20 14:01
@desc:
"""
import logging
import os
import threading
import datetime
import traceback
from typing import List
import django.db.models
from django.db.models import QuerySet
from django.db.models.functions import Substr, Reverse
from langchain_core.embeddings import Embeddings
from common.config.embedding_config import VectorStore
from common.db.search import native_search, get_dynamics_model, native_update
from common.utils.common import get_file_content
from common.utils.lock import try_lock, un_lock
from common.utils.page_utils import page_desc
from knowledge.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State,SourceType, SearchMode
from maxkb.conf import (PROJECT_DIR)
from django.utils.translation import gettext_lazy as _
max_kb_error = logging.getLogger(__file__)
max_kb = logging.getLogger(__file__)
lock = threading.Lock()
class SyncWebKnowledgeArgs:
def __init__(self, lock_key: str, url: str, selector: str, handler):
self.lock_key = lock_key
self.url = url
self.selector = selector
self.handler = handler
class SyncWebDocumentArgs:
def __init__(self, source_url_list: List[str], selector: str, handler):
self.source_url_list = source_url_list
self.selector = selector
self.handler = handler
class UpdateProblemArgs:
def __init__(self, problem_id: str, problem_content: str, embedding_model: Embeddings):
self.problem_id = problem_id
self.problem_content = problem_content
self.embedding_model = embedding_model
class UpdateEmbeddingKnowledgeIdArgs:
def __init__(self, paragraph_id_list: List[str], target_knowledge_id: str):
self.paragraph_id_list = paragraph_id_list
self.target_knowledge_id = target_knowledge_id
class UpdateEmbeddingDocumentIdArgs:
def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_knowledge_id: str,
target_embedding_model: Embeddings = None):
self.paragraph_id_list = paragraph_id_list
self.target_document_id = target_document_id
self.target_knowledge_id = target_knowledge_id
self.target_embedding_model = target_embedding_model
class ListenerManagement:
@staticmethod
def embedding_by_problem(args, embedding_model: Embeddings):
VectorStore.get_embedding_vector().save(**args, embedding=embedding_model)
@staticmethod
def embedding_by_paragraph_list(paragraph_id_list, embedding_model: Embeddings):
try:
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
**{'paragraph.id__in': paragraph_id_list}),
'paragraph': QuerySet(Paragraph).filter(id__in=paragraph_id_list)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list=paragraph_id_list,
embedding_model=embedding_model)
except Exception as e:
max_kb_error.error(_('Query vector data: {paragraph_id_list} error {error} {traceback}').format(
paragraph_id_list=paragraph_id_list, error=str(e), traceback=traceback.format_exc()))
@staticmethod
def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings):
max_kb.info(_('Start--->Embedding paragraph: {paragraph_id_list}').format(paragraph_id_list=paragraph_id_list))
status = Status.success
try:
# 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_id_list)
def is_save_function():
return QuerySet(Paragraph).filter(id__in=paragraph_id_list).exists()
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function)
except Exception as e:
max_kb_error.error(_('Vectorized paragraph: {paragraph_id_list} error {error} {traceback}').format(
paragraph_id_list=paragraph_id_list, error=str(e), traceback=traceback.format_exc()))
status = Status.error
finally:
QuerySet(Paragraph).filter(id__in=paragraph_id_list).update(**{'status': status})
max_kb.info(
_('End--->Embedding paragraph: {paragraph_id_list}').format(paragraph_id_list=paragraph_id_list))
@staticmethod
def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
"""
向量化段落 根据段落id
@param paragraph_id: 段落id
@param embedding_model: 向量模型
"""
max_kb.info(_('Start--->Embedding paragraph: {paragraph_id}').format(paragraph_id=paragraph_id))
# 更新到开始状态
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, State.STARTED)
try:
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
**{'paragraph.id': paragraph_id}),
'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
def is_the_task_interrupted():
_paragraph = QuerySet(Paragraph).filter(id=paragraph_id).first()
if _paragraph is None or Status(_paragraph.status)[TaskType.EMBEDDING] == State.REVOKE:
return True
return False
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_the_task_interrupted)
# 更新到开始状态
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING,
State.SUCCESS)
except Exception as e:
max_kb_error.error(_('Vectorized paragraph: {paragraph_id} error {error} {traceback}').format(
paragraph_id=paragraph_id, error=str(e), traceback=traceback.format_exc()))
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING,
State.FAILURE)
finally:
max_kb.info(_('End--->Embedding paragraph: {paragraph_id}').format(paragraph_id=paragraph_id))
@staticmethod
def embedding_by_data_list(data_list: List, embedding_model: Embeddings):
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True)
@staticmethod
def get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, post_apply=lambda: None):
def embedding_paragraph_apply(paragraph_list):
for paragraph in paragraph_list:
if is_the_task_interrupted():
break
ListenerManagement.embedding_by_paragraph(str(paragraph.get('id')), embedding_model)
post_apply()
return embedding_paragraph_apply
@staticmethod
def get_aggregation_document_status(document_id):
def aggregation_document_status():
pass
sql = get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_document_status_meta.sql'))
native_update({'document_custom_sql': QuerySet(Document).filter(id=document_id)}, sql, with_table_name=True)
return aggregation_document_status
@staticmethod
def get_aggregation_document_status_by_knowledge_id(knowledge_id):
def aggregation_document_status():
sql = get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_document_status_meta.sql'))
native_update({'document_custom_sql': QuerySet(Document).filter(knowledge_id=knowledge_id)}, sql,
with_table_name=True)
return aggregation_document_status
@staticmethod
def get_aggregation_document_status_by_query_set(queryset):
def aggregation_document_status():
sql = get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_document_status_meta.sql'))
native_update({'document_custom_sql': queryset}, sql, with_table_name=True)
return aggregation_document_status
@staticmethod
def post_update_document_status(document_id, task_type: TaskType):
_document = QuerySet(Document).filter(id=document_id).first()
status = Status(_document.status)
if status[task_type] == State.REVOKE:
status[task_type] = State.REVOKED
else:
status[task_type] = State.SUCCESS
for item in _document.status_meta.get('aggs', []):
agg_status = item.get('status')
agg_count = item.get('count')
if Status(agg_status)[task_type] == State.FAILURE and agg_count > 0:
status[task_type] = State.FAILURE
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), task_type, status[task_type])
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', task_type.value,
task_type.value),
).filter(task_type_status=State.REVOKE.value).filter(document_id=document_id).values('id'),
task_type,
State.REVOKED)
@staticmethod
def update_status(query_set: QuerySet, taskType: TaskType, state: State):
exec_sql = get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_paragraph_status.sql'))
bit_number = len(TaskType)
up_index = taskType.value - 1
next_index = taskType.value + 1
current_index = taskType.value
status_number = state.value
current_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') + '+00'
params_dict = {'${bit_number}': bit_number, '${up_index}': up_index,
'${status_number}': status_number, '${next_index}': next_index,
'${table_name}': query_set.model._meta.db_table, '${current_index}': current_index,
'${current_time}': current_time}
for key in params_dict:
_value_ = params_dict[key]
exec_sql = exec_sql.replace(key, str(_value_))
lock.acquire()
try:
native_update(query_set, exec_sql)
finally:
lock.release()
@staticmethod
def embedding_by_document(document_id, embedding_model: Embeddings, state_list=None):
"""
向量化文档
@param state_list:
@param document_id: 文档id
@param embedding_model 向量模型
:return: None
"""
if state_list is None:
state_list = [State.PENDING, State.SUCCESS, State.FAILURE, State.REVOKE, State.REVOKED]
if not try_lock('embedding' + str(document_id)):
return
try:
def is_the_task_interrupted():
document = QuerySet(Document).filter(id=document_id).first()
if document is None or Status(document.status)[TaskType.EMBEDDING] == State.REVOKE:
return True
return False
if is_the_task_interrupted():
return
max_kb.info(_('Start--->Embedding document: {document_id}').format(document_id=document_id)
)
# 批量修改状态为PADDING
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
State.STARTED)
# 根据段落进行向量化处理
page_desc(QuerySet(Paragraph)
.annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value,
1),
).filter(task_type_status__in=state_list, document_id=document_id)
.values('id'), 5,
ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted,
ListenerManagement.get_aggregation_document_status(
document_id)),
is_the_task_interrupted)
except Exception as e:
max_kb_error.error(_('Vectorized document: {document_id} error {error} {traceback}').format(
document_id=document_id, error=str(e), traceback=traceback.format_exc()))
finally:
ListenerManagement.post_update_document_status(document_id, TaskType.EMBEDDING)
ListenerManagement.get_aggregation_document_status(document_id)()
max_kb.info(_('End--->Embedding document: {document_id}').format(document_id=document_id))
un_lock('embedding' + str(document_id))
@staticmethod
def embedding_by_knowledge(knowledge_id, embedding_model: Embeddings):
"""
向量化知识库
@param knowledge_id: 知识库id
@param embedding_model 向量模型
:return: None
"""
max_kb.info(_('Start--->Embedding knowledge: {knowledge_id}').format(knowledge_id=knowledge_id))
try:
ListenerManagement.delete_embedding_by_knowledge(knowledge_id)
document_list = QuerySet(Document).filter(knowledge_id=knowledge_id)
max_kb.info(_('Start--->Embedding document: {document_list}').format(document_list=document_list))
for document in document_list:
ListenerManagement.embedding_by_document(document.id, embedding_model=embedding_model)
except Exception as e:
max_kb_error.error(_('Vectorized knowledge: {knowledge_id} error {error} {traceback}').format(
knowledge_id=knowledge_id, error=str(e), traceback=traceback.format_exc()))
finally:
max_kb.info(_('End--->Embedding knowledge: {knowledge_id}').format(knowledge_id=knowledge_id))
@staticmethod
def delete_embedding_by_document(document_id):
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
@staticmethod
def delete_embedding_by_document_list(document_id_list: List[str]):
VectorStore.get_embedding_vector().delete_by_document_id_list(document_id_list)
@staticmethod
def delete_embedding_by_knowledge(knowledge_id):
VectorStore.get_embedding_vector().delete_by_knowledge_id(knowledge_id)
@staticmethod
def delete_embedding_by_paragraph(paragraph_id):
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
@staticmethod
def delete_embedding_by_source(source_id):
VectorStore.get_embedding_vector().delete_by_source_id(source_id, SourceType.PROBLEM)
@staticmethod
def disable_embedding_by_paragraph(paragraph_id):
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': False})
@staticmethod
def enable_embedding_by_paragraph(paragraph_id):
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True})
@staticmethod
def update_problem(args: UpdateProblemArgs):
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(problem_id=args.problem_id)
embed_value = args.embedding_model.embed_query(args.problem_content)
VectorStore.get_embedding_vector().update_by_source_ids([v.id for v in problem_paragraph_mapping_list],
{'embedding': embed_value})
@staticmethod
def update_embedding_knowledge_id(args: UpdateEmbeddingKnowledgeIdArgs):
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'knowledge_id': args.target_knowledge_id})
@staticmethod
def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs):
if args.target_embedding_model is None:
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'document_id': args.target_document_id,
'knowledge_id': args.target_knowledge_id})
else:
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
embedding_model=args.target_embedding_model)
@staticmethod
def delete_embedding_by_source_ids(source_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_source_ids(source_ids, SourceType.PROBLEM)
@staticmethod
def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_ids)
@staticmethod
def delete_embedding_by_knowledge_id_list(source_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_knowledge_id_list(source_ids)
@staticmethod
def hit_test(query_text, knowledge_id: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
search_mode: SearchMode,
embedding: Embeddings):
return VectorStore.get_embedding_vector().hit_test(query_text, knowledge_id, exclude_document_id_list, top_number,
similarity, search_mode, embedding)

View File

@ -0,0 +1,20 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file base_lock.py
@date2024/8/20 10:33
@desc:
"""
from abc import ABC, abstractmethod
class BaseLock(ABC):
@abstractmethod
def try_lock(self, key, timeout):
pass
@abstractmethod
def un_lock(self, key):
pass

View File

@ -0,0 +1,77 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file file_lock.py
@date2024/8/20 10:48
@desc:
"""
import errno
import hashlib
import os
import time
import six
from common.lock.base_lock import BaseLock
from maxkb.const import PROJECT_DIR
def key_to_lock_name(key):
"""
Combine part of a key with its hash to prevent very long filenames
"""
MAX_LENGTH = 50
key_hash = hashlib.md5(six.b(key)).hexdigest()
lock_name = key[:MAX_LENGTH - len(key_hash) - 1] + '_' + key_hash
return lock_name
class FileLock(BaseLock):
"""
File locking backend.
"""
def __init__(self, settings=None):
if settings is None:
settings = {}
self.location = settings.get('location')
if self.location is None:
self.location = os.path.join(PROJECT_DIR, 'data', 'lock')
try:
os.makedirs(self.location)
except OSError as error:
# Directory exists?
if error.errno != errno.EEXIST:
# Re-raise unexpected OSError
raise
def _get_lock_path(self, key):
lock_name = key_to_lock_name(key)
return os.path.join(self.location, lock_name)
def try_lock(self, key, timeout):
lock_path = self._get_lock_path(key)
try:
# 创建锁文件,如果没创建成功则拿不到
fd = os.open(lock_path, os.O_CREAT | os.O_EXCL)
except OSError as error:
if error.errno == errno.EEXIST:
# File already exists, check its modification time
mtime = os.path.getmtime(lock_path)
ttl = mtime + timeout - time.time()
if ttl > 0:
return False
else:
# 如果超时时间已到,直接上锁成功继续执行
os.utime(lock_path, None)
return True
else:
return False
else:
os.close(fd)
return True
def un_lock(self, key):
lock_path = self._get_lock_path(key)
os.remove(lock_path)

View File

@ -124,6 +124,18 @@ def get_file_content(path):
content = file.read()
return content
def sub_array(array: List, item_num=10):
result = []
temp = []
for item in array:
temp.append(item)
if len(temp) >= item_num:
result.append(temp)
temp = []
if len(temp) > 0:
result.append(temp)
return result
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
content_type, _ = mimetypes.guess_type(file_name)
@ -233,3 +245,15 @@ def valid_license(model=None, count=None, message=None):
return run
return inner
def post(post_function):
def inner(func):
def run(*args, **kwargs):
result = func(*args, **kwargs)
return post_function(*result)
return run
return inner

53
apps/common/utils/lock.py Normal file
View File

@ -0,0 +1,53 @@
# coding=utf-8
"""
@project: qabot
@Author
@file lock.py
@date2023/9/11 11:45
@desc:
"""
from datetime import timedelta
from django.core.cache import caches
memory_cache = caches['default']
def try_lock(key: str, timeout=None):
"""
获取锁
:param key: 获取锁 key
:param timeout 超时时间
:return: 是否获取到锁
"""
return memory_cache.add(key, 'lock', timeout=timedelta(hours=1).total_seconds() if timeout is not None else timeout)
def un_lock(key: str):
"""
解锁
:param key: 解锁 key
:return: 是否解锁成功
"""
return memory_cache.delete(key)
def lock(lock_key):
"""
给一个函数上锁
:param lock_key: 上锁key 字符串|函数 函数返回值为字符串
:return: 装饰器函数 当前装饰器主要限制一个key只能一个线程去调用 相同key只能阻塞等待上一个任务执行完毕 不同key不需要等待
"""
def inner(func):
def run(*args, **kwargs):
key = lock_key(*args, **kwargs) if callable(lock_key) else lock_key
try:
if try_lock(key=key):
return func(*args, **kwargs)
finally:
un_lock(key=key)
return run
return inner

View File

@ -0,0 +1,47 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file page_utils.py
@date2024/11/21 10:32
@desc:
"""
from math import ceil
def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
"""
@param query_set: 查询query_set
@param page_size: 每次查询大小
@param handler: 数据处理器
@param is_the_task_interrupted: 任务是否被中断
@return:
"""
query = query_set.order_by("id")
count = query_set.count()
for i in range(0, ceil(count / page_size)):
if is_the_task_interrupted():
return
offset = i * page_size
paragraph_list = query.all()[offset: offset + page_size]
handler(paragraph_list)
def page_desc(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
"""
@param query_set: 查询query_set
@param page_size: 每次查询大小
@param handler: 数据处理器
@param is_the_task_interrupted: 任务是否被中断
@return:
"""
query = query_set.order_by("id")
count = query_set.count()
for i in sorted(range(0, ceil(count / page_size)), reverse=True):
if is_the_task_interrupted():
return
offset = i * page_size
paragraph_list = query.all()[offset: offset + page_size]
handler(paragraph_list)

View File

@ -3,7 +3,7 @@
import os
import subprocess
import sys
import uuid_utils as uuid
import uuid_utils.compat as uuid
from textwrap import dedent
from diskcache import Cache

View File

@ -0,0 +1,88 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file ts_vecto_util.py
@date2024/4/16 15:26
@desc:
"""
import re
import uuid_utils.compat as uuid
from typing import List
import jieba
import jieba.posseg
jieba_word_list_cache = [chr(item) for item in range(38, 84)]
for jieba_word in jieba_word_list_cache:
jieba.add_word('#' + jieba_word + '#')
# r"(?i)\b(?:https?|ftp|tcp|file)://[^\s]+\b",
# 某些不分词数据
# r'"([^"]*)"'
word_pattern_list = [r"v\d+.\d+.\d+",
r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}"]
remove_chars = '\n , :\'<>@#¥%……&*!@#$%^&*() /"./'
jieba_remove_flag_list = ['x', 'w']
def get_word_list(text: str):
result = []
for pattern in word_pattern_list:
word_list = re.findall(pattern, text)
for child_list in word_list:
for word in child_list if isinstance(child_list, tuple) else [child_list]:
# 不能有: 所以再使用: 进行分割
if word.__contains__(':'):
item_list = word.split(":")
for w in item_list:
result.append(w)
else:
result.append(word)
return result
def replace_word(word_dict, text: str):
for key in word_dict:
pattern = '(?<!#)' + re.escape(word_dict[key]) + '(?!#)'
text = re.sub(pattern, key, text)
return text
def get_word_key(text: str, use_word_list):
j_word = next((j for j in jieba_word_list_cache if j not in text and all(j not in used for used in use_word_list)),
None)
if j_word:
return j_word
j_word = str(uuid.uuid7())
jieba.add_word(j_word)
return j_word
def to_word_dict(word_list: List, text: str):
word_dict = {}
for word in word_list:
key = get_word_key(text, set(word_dict))
word_dict['#' + key + '#'] = word
return word_dict
def get_key_by_word_dict(key, word_dict):
v = word_dict.get(key)
if v is None:
return key
return v
def to_ts_vector(text: str):
# 分词
result = jieba.lcut(text, cut_all=True)
return " ".join(result)
def to_query(text: str):
extract_tags = jieba.lcut(text, cut_all=True)
result = " ".join(extract_tags)
return result

View File

@ -0,0 +1,34 @@
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter
from common.mixins.api_mixin import APIMixin
from common.result import DefaultResultSerializer, ResultSerializer
from knowledge.serializers.document import DocumentCreateRequest
class DocumentCreateResponse(ResultSerializer):
@staticmethod
def get_data():
return DefaultResultSerializer()
class DocumentCreateAPI(APIMixin):
@staticmethod
def get_parameters():
return [
OpenApiParameter(
name="workspace_id",
description="工作空间id",
type=OpenApiTypes.STR,
location='path',
required=True,
)
]
@staticmethod
def get_request():
return DocumentCreateRequest
@staticmethod
def get_response():
return DocumentCreateResponse

View File

View File

@ -56,7 +56,7 @@ class Migration(migrations.Migration):
('tree_id', models.PositiveIntegerField(db_index=True, editable=False)),
('level', models.PositiveIntegerField(editable=False)),
('parent',
mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE,
mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING,
related_name='children', to='knowledge.knowledgefolder')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user',
verbose_name='用户id')),
@ -85,7 +85,7 @@ class Migration(migrations.Migration):
models.CharField(choices=[('SHARED', '共享'), ('WORKSPACE', '工作空间可用')], default='WORKSPACE',
max_length=20, verbose_name='可用范围')),
('folder',
models.ForeignKey(default='root', on_delete=django.db.models.deletion.CASCADE,
models.ForeignKey(default='root', on_delete=django.db.models.deletion.DO_NOTHING,
to='knowledge.knowledgefolder',
verbose_name='文件夹id')),
('embedding_model', models.ForeignKey(default=knowledge.models.knowledge.default_model,

View File

@ -1,4 +1,7 @@
from enum import Enum
import uuid_utils.compat as uuid
from django.contrib.postgres.search import SearchVectorField
from django.db import models
from django.db.models.signals import pre_delete
from django.dispatch import receiver
@ -18,11 +21,78 @@ class KnowledgeType(models.IntegerChoices):
YUQUE = 3, '语雀类型'
class TaskType(Enum):
# 向量
EMBEDDING = 1
# 生成问题
GENERATE_PROBLEM = 2
# 同步
SYNC = 3
class State(Enum):
# 等待
PENDING = '0'
# 执行中
STARTED = '1'
# 成功
SUCCESS = '2'
# 失败
FAILURE = '3'
# 取消任务
REVOKE = '4'
# 取消成功
REVOKED = '5'
# 忽略
IGNORED = 'n'
class KnowledgeScope(models.TextChoices):
SHARED = "SHARED", '共享'
WORKSPACE = "WORKSPACE", "工作空间可用"
class HitHandlingMethod(models.TextChoices):
optimization = 'optimization', '模型优化'
directly_return = 'directly_return', '直接返回'
class Status:
type_cls = TaskType
state_cls = State
def __init__(self, status: str = None):
self.task_status = {}
status_list = list(status[::-1] if status is not None else '')
for _type in self.type_cls:
index = _type.value - 1
_state = self.state_cls(status_list[index] if len(status_list) > index else 'n')
self.task_status[_type] = _state
@staticmethod
def of(status: str):
return Status(status)
def __str__(self):
result = []
for _type in sorted(self.type_cls, key=lambda item: item.value, reverse=True):
result.insert(len(self.type_cls) - _type.value, self.task_status[_type].value)
return ''.join(result)
def __setitem__(self, key, value):
self.task_status[key] = value
def __getitem__(self, item):
return self.task_status[item]
def update_status(self, task_type: TaskType, state: State):
self.task_status[task_type] = state
def default_status_meta():
return {"state_time": {}}
def default_model():
# todo : 这里需要从数据库中获取默认的模型
return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab')
@ -33,7 +103,7 @@ class KnowledgeFolder(MPTTModel, AppModelMixin):
name = models.CharField(max_length=64, verbose_name="文件夹名称")
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="用户id")
workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True)
parent = TreeForeignKey('self', on_delete=models.CASCADE, null=True, blank=True, related_name='children')
parent = TreeForeignKey('self', on_delete=models.DO_NOTHING, null=True, blank=True, related_name='children')
class Meta:
db_table = "knowledge_folder"
@ -42,24 +112,127 @@ class KnowledgeFolder(MPTTModel, AppModelMixin):
order_insertion_by = ['name']
class Knowledge(AppModelMixin):
"""
知识库表
"""
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
name = models.CharField(max_length=150, verbose_name="知识库名称")
workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True)
desc = models.CharField(max_length=256, verbose_name="描述")
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
type = models.IntegerField(verbose_name='类型', choices=KnowledgeType.choices, default=KnowledgeType.BASE)
scope = models.CharField(max_length=20, verbose_name='可用范围', choices=KnowledgeScope.choices, default=KnowledgeScope.WORKSPACE)
folder = models.ForeignKey(KnowledgeFolder, on_delete=models.CASCADE, verbose_name="文件夹id", default='root')
scope = models.CharField(max_length=20, verbose_name='可用范围', choices=KnowledgeScope.choices,
default=KnowledgeScope.WORKSPACE)
folder = models.ForeignKey(KnowledgeFolder, on_delete=models.DO_NOTHING, verbose_name="文件夹id", default='root')
embedding_model = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
default=default_model)
default=default_model)
meta = models.JSONField(verbose_name="元数据", default=dict)
class Meta:
db_table = "knowledge"
class Document(AppModelMixin):
"""
文档表
"""
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
knowledge = models.ForeignKey(Knowledge, on_delete=models.DO_NOTHING, verbose_name="知识库id")
name = models.CharField(max_length=150, verbose_name="文档名称")
char_length = models.IntegerField(verbose_name="文档字符数 冗余字段")
status = models.CharField(verbose_name='状态', max_length=20, default=Status('').__str__)
status_meta = models.JSONField(verbose_name="状态统计数据", default=default_status_meta)
is_active = models.BooleanField(default=True)
type = models.IntegerField(verbose_name='类型', choices=KnowledgeType.choices, default=KnowledgeType.BASE)
hit_handling_method = models.CharField(verbose_name='命中处理方式', max_length=20,
choices=HitHandlingMethod.choices,
default=HitHandlingMethod.optimization)
directly_return_similarity = models.FloatField(verbose_name='直接回答相似度', default=0.9)
meta = models.JSONField(verbose_name="元数据", default=dict)
class Meta:
db_table = "document"
class Paragraph(AppModelMixin):
"""
段落表
"""
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False)
knowledge = models.ForeignKey(Knowledge, on_delete=models.DO_NOTHING)
content = models.CharField(max_length=102400, verbose_name="段落内容")
title = models.CharField(max_length=256, verbose_name="标题", default="")
status = models.CharField(verbose_name='状态', max_length=20, default=Status('').__str__)
status_meta = models.JSONField(verbose_name="状态数据", default=default_status_meta)
hit_num = models.IntegerField(verbose_name="命中次数", default=0)
is_active = models.BooleanField(default=True)
class Meta:
db_table = "paragraph"
class Problem(AppModelMixin):
"""
问题表
"""
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
knowledge = models.ForeignKey(Knowledge, on_delete=models.DO_NOTHING, db_constraint=False)
content = models.CharField(max_length=256, verbose_name="问题内容")
hit_num = models.IntegerField(verbose_name="命中次数", default=0)
class Meta:
db_table = "problem"
class ProblemParagraphMapping(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
knowledge = models.ForeignKey(Knowledge, on_delete=models.DO_NOTHING, db_constraint=False)
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING)
problem = models.ForeignKey(Problem, on_delete=models.DO_NOTHING, db_constraint=False)
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
class Meta:
db_table = "problem_paragraph_mapping"
class SourceType(models.IntegerChoices):
"""订单类型"""
PROBLEM = 0, '问题'
PARAGRAPH = 1, '段落'
TITLE = 2, '标题'
class SearchMode(models.TextChoices):
embedding = 'embedding'
keywords = 'keywords'
blend = 'blend'
class VectorField(models.Field):
def db_type(self, connection):
return 'vector'
class Embedding(models.Model):
id = models.CharField(max_length=128, primary_key=True, verbose_name="主键id")
source_id = models.CharField(max_length=128, verbose_name="资源id")
source_type = models.CharField(verbose_name='资源类型', max_length=5, choices=SourceType.choices,
default=SourceType.PROBLEM)
is_active = models.BooleanField(verbose_name="是否可用", max_length=1, default=True)
knowledge = models.ForeignKey(Knowledge, on_delete=models.DO_NOTHING, verbose_name="文档关联", db_constraint=False)
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, verbose_name="文档关联", db_constraint=False)
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, verbose_name="段落关联", db_constraint=False)
embedding = VectorField(verbose_name="向量")
search_vector = SearchVectorField(verbose_name="分词", default="")
meta = models.JSONField(verbose_name="元数据", default=dict)
class Meta:
db_table = "embedding"
class File(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
file_name = models.CharField(max_length=256, verbose_name="文件名称", default="")

View File

@ -0,0 +1,215 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file common_serializers.py
@date2023/11/17 11:00
@desc:
"""
import os
import re
import uuid_utils.compat as uuid
import zipfile
from typing import List
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.config.embedding_config import ModelManage
from common.db.search import native_search
from common.db.sql_execute import update_execute
from common.exception.app_exception import AppApiException
from common.utils.common import get_file_content
from common.utils.fork import Fork
from knowledge.models import Paragraph, Problem, ProblemParagraphMapping, Knowledge, File
from maxkb.conf import PROJECT_DIR
from models_provider.tools import get_model
def zip_dir(zip_path, output=None):
output = output or os.path.basename(zip_path) + '.zip'
zip = zipfile.ZipFile(output, 'w', zipfile.ZIP_DEFLATED)
for root, dirs, files in os.walk(zip_path):
relative_root = '' if root == zip_path else root.replace(zip_path, '') + os.sep
for filename in files:
zip.write(os.path.join(root, filename), relative_root + filename)
zip.close()
def is_valid_uuid(s):
try:
uuid.UUID(s)
return True
except ValueError:
return False
def write_image(zip_path: str, image_list: List[str]):
for image in image_list:
search = re.search("\(.*\)", image)
if search:
text = search.group()
if text.startswith('(/api/file/'):
r = text.replace('(/api/file/', '').replace(')', '')
r = r.strip().split(" ")[0]
if not is_valid_uuid(r):
break
file = QuerySet(File).filter(id=r).first()
if file is None:
break
zip_inner_path = os.path.join('api', 'file', r)
file_path = os.path.join(zip_path, zip_inner_path)
if not os.path.exists(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(os.path.join(zip_path, file_path), 'wb') as f:
f.write(file.get_bytes())
# else:
# r = text.replace('(/api/image/', '').replace(')', '')
# r = r.strip().split(" ")[0]
# if not is_valid_uuid(r):
# break
# image_model = QuerySet(Image).filter(id=r).first()
# if image_model is None:
# break
# zip_inner_path = os.path.join('api', 'image', r)
# file_path = os.path.join(zip_path, zip_inner_path)
# if not os.path.exists(os.path.dirname(file_path)):
# os.makedirs(os.path.dirname(file_path))
# with open(file_path, 'wb') as f:
# f.write(image_model.image)
def update_document_char_length(document_id: str):
update_execute(get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'update_document_char_length.sql')),
(document_id, document_id))
def list_paragraph(paragraph_list: List[str]):
if paragraph_list is None or len(paragraph_list) == 0:
return []
return native_search(QuerySet(Paragraph).filter(id__in=paragraph_list), get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_paragraph.sql')))
class MetaSerializer(serializers.Serializer):
class WebMeta(serializers.Serializer):
source_url = serializers.CharField(required=True, label=_('source url'))
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('selector'))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
source_url = self.data.get('source_url')
response = Fork(source_url, []).fork()
if response.status == 500:
raise AppApiException(500, _('URL error, cannot parse [{source_url}]').format(source_url=source_url))
class BaseMeta(serializers.Serializer):
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
class BatchSerializer(serializers.Serializer):
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=_('id list'))
def is_valid(self, *, model=None, raise_exception=False):
super().is_valid(raise_exception=True)
if model is not None:
id_list = self.data.get('id_list')
model_list = QuerySet(model).filter(id__in=id_list)
if len(model_list) != len(id_list):
model_id_list = [str(m.id) for m in model_list]
error_id_list = list(filter(lambda row_id: not model_id_list.__contains__(row_id), id_list))
raise AppApiException(500, _('The following id does not exist: {error_id_list}').format(
error_id_list=error_id_list))
class ProblemParagraphObject:
def __init__(self, knowledge_id: str, document_id: str, paragraph_id: str, problem_content: str):
self.knowledge_id = knowledge_id
self.document_id = document_id
self.paragraph_id = paragraph_id
self.problem_content = problem_content
def or_get(exists_problem_list, content, knowledge_id, document_id, paragraph_id, problem_content_dict):
if content in problem_content_dict:
return problem_content_dict.get(content)[0], document_id, paragraph_id
exists = [row for row in exists_problem_list if row.content == content]
if len(exists) > 0:
problem_content_dict[content] = exists[0], False
return exists[0], document_id, paragraph_id
else:
problem = Problem(id=uuid.uuid7(), content=content, knowledge_id=knowledge_id)
problem_content_dict[content] = problem, True
return problem, document_id, paragraph_id
class ProblemParagraphManage:
def __init__(self, problem_paragraph_object_list: List[ProblemParagraphObject], knowledge_id):
self.knowledge_id = knowledge_id
self.problem_paragraph_object_list = problem_paragraph_object_list
def to_problem_model_list(self):
problem_list = [item.problem_content for item in self.problem_paragraph_object_list]
exists_problem_list = []
if len(self.problem_paragraph_object_list) > 0:
# 查询到已存在的问题列表
exists_problem_list = QuerySet(Problem).filter(knowledge_id=self.knowledge_id,
content__in=problem_list).all()
problem_content_dict = {}
problem_model_list = [
or_get(exists_problem_list, problemParagraphObject.problem_content, problemParagraphObject.knowledge_id,
problemParagraphObject.document_id, problemParagraphObject.paragraph_id, problem_content_dict) for
problemParagraphObject in self.problem_paragraph_object_list]
problem_paragraph_mapping_list = [
ProblemParagraphMapping(id=uuid.uuid7(), document_id=document_id, problem_id=problem_model.id,
paragraph_id=paragraph_id,
knowledge_id=self.knowledge_id) for
problem_model, document_id, paragraph_id in problem_model_list]
result = [problem_model for problem_model, is_create in problem_content_dict.values() if
is_create], problem_paragraph_mapping_list
return result
def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
if len(set([knowledge.embedding_model_id for knowledge in knowledge_list])) > 1:
raise Exception(_('The knowledge base is inconsistent with the vector model'))
if len(knowledge_list) == 0:
raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
return ModelManage.get_model(str(knowledge_list[0].embedding_model_id),
lambda _id: get_model(knowledge_list[0].embedding_model))
def get_embedding_model_by_knowledge_id(knowledge_id: str):
knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first()
return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model))
def get_embedding_model_by_knowledge(knowledge):
return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model))
def get_embedding_model_id_by_knowledge_id(knowledge_id):
knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first()
return str(knowledge.embedding_model_id)
def get_embedding_model_id_by_knowledge_id_list(knowledge_id_list: List):
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
if len(set([knowledge.embedding_model_id for knowledge in knowledge_list])) > 1:
raise Exception(_('The knowledge base is inconsistent with the vector model'))
if len(knowledge_list) == 0:
raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
return str(knowledge_list[0].embedding_model_id)
class GenerateRelatedSerializer(serializers.Serializer):
model_id = serializers.UUIDField(required=True, label=_('Model id'))
prompt = serializers.CharField(required=True, label=_('Prompt word'))
state_list = serializers.ListField(required=False, child=serializers.CharField(required=True),
label=_("state list"))

View File

@ -0,0 +1,172 @@
import os
from functools import reduce
from typing import Dict, List
import uuid_utils.compat as uuid
from celery_once import AlreadyQueued
from django.db import transaction
from django.db.models import QuerySet, Model
from django.db.models.functions import Substr, Reverse
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.db.search import native_search
from common.event import ListenerManagement
from common.exception.app_exception import AppApiException
from common.utils.common import post, get_file_content
from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \
TaskType
from knowledge.serializers.common import ProblemParagraphManage
from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer
from knowledge.task import embedding_by_document
from maxkb.const import PROJECT_DIR
class DocumentInstanceSerializer(serializers.Serializer):
name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1)
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
class DocumentCreateRequest(serializers.Serializer):
name = serializers.CharField(required=True, label=_('knowledge name'), max_length=64, min_length=1)
desc = serializers.CharField(required=True, label=_('knowledge description'), max_length=256, min_length=1)
embedding_model_id = serializers.UUIDField(required=True, label=_('embedding model'))
documents = DocumentInstanceSerializer(required=False, many=True)
class DocumentSerializers(serializers.Serializer):
class Operate(serializers.Serializer):
document_id = serializers.UUIDField(required=True, label=_('document id'))
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
document_id = self.data.get('document_id')
if not QuerySet(Document).filter(id=document_id).exists():
raise AppApiException(500, _('document id not exist'))
def one(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
query_set = QuerySet(model=Document)
query_set = query_set.filter(**{'id': self.data.get("document_id")})
return native_search({
'document_custom_sql': query_set,
'order_by_query': QuerySet(Document).order_by('-create_time', 'id')
}, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')), with_search_one=True)
def refresh(self, state_list=None, with_valid=True):
if state_list is None:
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
State.REVOKE.value,
State.REVOKED.value, State.IGNORED.value]
if with_valid:
self.is_valid(raise_exception=True)
knowledge = QuerySet(Knowledge).filter(id=self.data.get('knowledge_id')).first()
embedding_model_id = knowledge.embedding_model_id
knowledge_user_id = knowledge.user_id
embedding_model = QuerySet(Model).filter(id=embedding_model_id).first()
if embedding_model is None:
raise AppApiException(500, _('Model does not exist'))
if embedding_model.permission_type == 'PRIVATE' and knowledge_user_id != embedding_model.user_id:
raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}")
document_id = self.data.get("document_id")
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
State.PENDING)
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value, 1),
).filter(task_type_status__in=state_list, document_id=document_id).values('id'),
TaskType.EMBEDDING, State.PENDING)
ListenerManagement.get_aggregation_document_status(document_id)()
try:
embedding_by_document.delay(document_id, embedding_model_id, state_list)
except AlreadyQueued as e:
raise AppApiException(500, _('The task is being executed, please do not send it repeatedly.'))
class Create(serializers.Serializer):
knowledge_id = serializers.UUIDField(required=True, label=_('document id'))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(Knowledge).filter(id=self.data.get('knowledge_id')).exists():
raise AppApiException(10000, _('knowledge id not exist'))
return True
@staticmethod
def post_embedding(result, document_id, knowledge_id):
DocumentSerializers.Operate(
data={'knowledge_id': knowledge_id, 'document_id': document_id}).refresh()
return result
@post(post_function=post_embedding)
@transaction.atomic
def save(self, instance: Dict, with_valid=False, **kwargs):
if with_valid:
DocumentCreateRequest(data=instance).is_valid(raise_exception=True)
self.is_valid(raise_exception=True)
knowledge_id = self.data.get('knowledge_id')
document_paragraph_model = self.get_document_paragraph_model(knowledge_id, instance)
document_model = document_paragraph_model.get('document')
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list')
problem_model_list, problem_paragraph_mapping_list = (
ProblemParagraphManage(problem_paragraph_object_list, knowledge_id).to_problem_model_list())
# 插入文档
document_model.save()
# 批量插入段落
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
# 批量插入关联问题
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
problem_paragraph_mapping_list) > 0 else None
document_id = str(document_model.id)
return (DocumentSerializers.Operate(
data={'knowledge_id': knowledge_id, 'document_id': document_id}
).one(with_valid=True), document_id, knowledge_id)
@staticmethod
def get_paragraph_model(document_model, paragraph_list: List):
knowledge_id = document_model.knowledge_id
paragraph_model_dict_list = [
ParagraphSerializers.Create(
data={
'knowledge_id': knowledge_id, 'document_id': str(document_model.id)
}).get_paragraph_problem_model(knowledge_id, document_model.id, paragraph)
for paragraph in paragraph_list]
paragraph_model_list = []
problem_paragraph_object_list = []
for paragraphs in paragraph_model_dict_list:
paragraph = paragraphs.get('paragraph')
for problem_model in paragraphs.get('problem_paragraph_object_list'):
problem_paragraph_object_list.append(problem_model)
paragraph_model_list.append(paragraph)
return {
'document': document_model,
'paragraph_model_list': paragraph_model_list,
'problem_paragraph_object_list': problem_paragraph_object_list
}
@staticmethod
def get_document_paragraph_model(knowledge_id, instance: Dict):
document_model = Document(
**{
'knowledge_id': knowledge_id,
'id': uuid.uuid7(),
'name': instance.get('name'),
'char_length': reduce(lambda x, y: x + y,
[len(p.get('content')) for p in instance.get('paragraphs', [])],
0),
'meta': instance.get('meta') if instance.get('meta') is not None else {},
'type': instance.get('type') if instance.get('type') is not None else KnowledgeType.BASE
})
return DocumentSerializers.Create.get_paragraph_model(document_model,
instance.get('paragraphs') if
'paragraphs' in instance else [])

View File

@ -1,14 +1,19 @@
from functools import reduce
from typing import Dict
import uuid_utils as uuid
import uuid_utils.compat as uuid
from django.db import transaction
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.exception.app_exception import AppApiException
from common.utils.common import valid_license
from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType
from common.utils.common import valid_license, post
from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType, Document, Paragraph, Problem, \
ProblemParagraphMapping
from knowledge.serializers.common import ProblemParagraphManage, get_embedding_model_id_by_knowledge_id
from knowledge.serializers.document import DocumentSerializers
from knowledge.task import sync_web_knowledge, embedding_by_knowledge
class KnowledgeModelSerializer(serializers.ModelSerializer):
@ -38,10 +43,17 @@ class KnowledgeSerializer(serializers.Serializer):
user_id = serializers.UUIDField(required=True, label=_('user id'))
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
@staticmethod
def post_embedding_knowledge(document_list, knowledge_id):
# todo 发送向量化事件
model_id = get_embedding_model_id_by_knowledge_id(knowledge_id)
embedding_by_knowledge.delay(knowledge_id, model_id)
return document_list
@valid_license(model=Knowledge, count=50,
message=_(
'The community version supports up to 50 knowledge bases. If you need more knowledge bases, please contact us (https://fit2cloud.com/).'))
# @post(post_function=post_embedding_dataset)
@post(post_function=post_embedding_knowledge)
@transaction.atomic
def save_base(self, instance, with_valid=True):
if with_valid:
@ -51,8 +63,9 @@ class KnowledgeSerializer(serializers.Serializer):
name=instance.get('name')).exists():
raise AppApiException(500, _('Knowledge base name duplicate!'))
knowledge_id = uuid.uuid7()
knowledge = Knowledge(
id=uuid.uuid7(),
id=knowledge_id,
name=instance.get('name'),
workspace_id=self.data.get('workspace_id'),
desc=instance.get('desc'),
@ -63,8 +76,42 @@ class KnowledgeSerializer(serializers.Serializer):
embedding_model_id=instance.get('embedding'),
meta=instance.get('meta', {}),
)
document_model_list = []
paragraph_model_list = []
problem_paragraph_object_list = []
# 插入文档
for document in instance.get('documents') if 'documents' in instance else []:
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(knowledge_id,
document)
document_model_list.append(document_paragraph_dict_model.get('document'))
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
paragraph_model_list.append(paragraph)
for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
problem_paragraph_object_list.append(problem_paragraph_object)
problem_model_list, problem_paragraph_mapping_list = (
ProblemParagraphManage(problem_paragraph_object_list, knowledge_id)
.to_problem_model_list())
# 插入知识库
knowledge.save()
return KnowledgeModelSerializer(knowledge).data
# 插入文档
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
# 批量插入段落
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
# 批量插入关联问题
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
problem_paragraph_mapping_list) > 0 else None
return {
**KnowledgeModelSerializer(knowledge).data,
'user_id': self.data.get('user_id'),
'document_list': document_model_list,
"document_count": len(document_model_list),
"char_length": reduce(lambda x, y: x + y, [d.char_length for d in document_model_list], 0)
}, knowledge_id
def save_web(self, instance: Dict, with_valid=True):
if with_valid:
@ -92,9 +139,8 @@ class KnowledgeSerializer(serializers.Serializer):
},
)
knowledge.save()
# sync_web_knowledge.delay(str(knowledge_id), instance.get('source_url'), instance.get('selector'))
return {**KnowledgeModelSerializer(knowledge).data,
'document_list': []}
sync_web_knowledge.delay(str(knowledge_id), instance.get('source_url'), instance.get('selector'))
return {**KnowledgeModelSerializer(knowledge).data, 'document_list': []}
class KnowledgeTreeSerializer(serializers.Serializer):

View File

@ -0,0 +1,221 @@
# coding=utf-8
from typing import Dict
import uuid_utils.compat as uuid
from django.db import transaction
from django.db.models import QuerySet, Count
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.exception.app_exception import AppApiException
from common.utils.common import post
from knowledge.models import Paragraph, Problem, Document, ProblemParagraphMapping
from knowledge.serializers.common import ProblemParagraphObject, ProblemParagraphManage, \
get_embedding_model_id_by_knowledge_id, update_document_char_length
from knowledge.serializers.problem import ProblemInstanceSerializer
from knowledge.task import embedding_by_paragraph, enable_embedding_by_paragraph, disable_embedding_by_paragraph, \
delete_embedding_by_paragraph
class ParagraphSerializer(serializers.ModelSerializer):
class Meta:
model = Paragraph
fields = ['id', 'content', 'is_active', 'document_id', 'title', 'create_time', 'update_time']
class ParagraphInstanceSerializer(serializers.Serializer):
"""
段落实例对象
"""
content = serializers.CharField(required=True, label=_('content'), max_length=102400, min_length=1, allow_null=True,
allow_blank=True)
title = serializers.CharField(required=False, max_length=256, label=_('section title'), allow_null=True,
allow_blank=True)
problem_list = ProblemInstanceSerializer(required=False, many=True)
is_active = serializers.BooleanField(required=False, label=_('Is active'))
class EditParagraphSerializers(serializers.Serializer):
title = serializers.CharField(required=False, max_length=256, label=_('section title'), allow_null=True,
allow_blank=True)
content = serializers.CharField(required=False, max_length=102400, allow_null=True, allow_blank=True,
label=_('section title'))
problem_list = ProblemInstanceSerializer(required=False, many=True)
class ParagraphSerializers(serializers.Serializer):
title = serializers.CharField(required=False, max_length=256, label=_('section title'), allow_null=True,
allow_blank=True)
content = serializers.CharField(required=True, max_length=102400, label=_('section title'))
class Operate(serializers.Serializer):
# 段落id
paragraph_id = serializers.UUIDField(required=True, label=_('paragraph id'))
# 知识库id
dataset_id = serializers.UUIDField(required=True, label=_('dataset id'))
# 文档id
document_id = serializers.UUIDField(required=True, label=_('document id'))
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
raise AppApiException(500, _('Paragraph id does not exist'))
@staticmethod
def post_embedding(paragraph, instance, knowledge_id):
if 'is_active' in instance and instance.get('is_active') is not None:
(enable_embedding_by_paragraph if instance.get(
'is_active') else disable_embedding_by_paragraph)(paragraph.get('id'))
else:
model_id = get_embedding_model_id_by_knowledge_id(knowledge_id)
embedding_by_paragraph(paragraph.get('id'), model_id)
return paragraph
@post(post_embedding)
@transaction.atomic
def edit(self, instance: Dict):
self.is_valid()
EditParagraphSerializers(data=instance).is_valid(raise_exception=True)
_paragraph = QuerySet(Paragraph).get(id=self.data.get("paragraph_id"))
update_keys = ['title', 'content', 'is_active']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
_paragraph.__setattr__(update_key, instance.get(update_key))
if 'problem_list' in instance:
update_problem_list = list(
filter(lambda row: 'id' in row and row.get('id') is not None, instance.get('problem_list')))
create_problem_list = list(filter(lambda row: row.get('id') is None, instance.get('problem_list')))
# 问题集合
problem_list = QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))
# 校验前端 携带过来的id
for update_problem in update_problem_list:
if not set([str(row.id) for row in problem_list]).__contains__(update_problem.get('id')):
raise AppApiException(500, _('Problem id does not exist'))
# 对比需要删除的问题
delete_problem_list = list(filter(
lambda row: not [str(update_row.get('id')) for update_row in update_problem_list].__contains__(
str(row.id)), problem_list)) if len(update_problem_list) > 0 else []
# 删除问题
QuerySet(Problem).filter(id__in=[row.id for row in delete_problem_list]).delete() if len(
delete_problem_list) > 0 else None
# 插入新的问题
QuerySet(Problem).bulk_create(
[Problem(id=uuid.uuid1(), content=p.get('content'), paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'), document_id=self.data.get('document_id')) for
p in create_problem_list]) if len(create_problem_list) else None
# 修改问题集合
QuerySet(Problem).bulk_update(
[Problem(id=row.get('id'), content=row.get('content')) for row in update_problem_list],
['content']) if len(
update_problem_list) > 0 else None
_paragraph.save()
update_document_char_length(self.data.get('document_id'))
return self.one(), instance, self.data.get('dataset_id')
def get_problem_list(self):
ProblemParagraphMapping(ProblemParagraphMapping)
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
paragraph_id=self.data.get("paragraph_id"))
if len(problem_paragraph_mapping) > 0:
return [ProblemSerializer(problem).data for problem in
QuerySet(Problem).filter(id__in=[ppm.problem_id for ppm in problem_paragraph_mapping])]
return []
def one(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
return {**ParagraphSerializer(QuerySet(model=Paragraph).get(id=self.data.get('paragraph_id'))).data,
'problem_list': self.get_problem_list()}
def delete(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
paragraph_id = self.data.get('paragraph_id')
Paragraph.objects.filter(id=paragraph_id).delete()
delete_problems_and_mappings([paragraph_id])
update_document_char_length(self.data.get('document_id'))
delete_embedding_by_paragraph(paragraph_id)
class Create(serializers.Serializer):
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
document_id = serializers.UUIDField(required=True, label=_('document id'))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(Document).filter(id=self.data.get('document_id'),
knowledge_id=self.data.get('knowledge_id')).exists():
raise AppApiException(500, _('The document id is incorrect'))
def save(self, instance: Dict, with_valid=True, with_embedding=True):
if with_valid:
ParagraphSerializers(data=instance).is_valid(raise_exception=True)
self.is_valid()
knowledge_id = self.data.get("knowledge_id")
document_id = self.data.get('document_id')
paragraph_problem_model = self.get_paragraph_problem_model(knowledge_id, document_id, instance)
paragraph = paragraph_problem_model.get('paragraph')
problem_paragraph_object_list = paragraph_problem_model.get('problem_paragraph_object_list')
problem_model_list, problem_paragraph_mapping_list = (
ProblemParagraphManage(problem_paragraph_object_list, knowledge_id)
.to_problem_model_list())
# 插入段落
paragraph_problem_model.get('paragraph').save()
# 插入問題
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
# 插入问题关联关系
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
problem_paragraph_mapping_list) > 0 else None
# 修改长度
update_document_char_length(document_id)
if with_embedding:
model_id = get_embedding_model_id_by_knowledge_id(knowledge_id)
embedding_by_paragraph(str(paragraph.id), model_id)
return ParagraphSerializers.Operate(
data={'paragraph_id': str(paragraph.id), 'knowledge_id': knowledge_id, 'document_id': document_id}
).one(with_valid=True)
@staticmethod
def get_paragraph_problem_model(knowledge_id: str, document_id: str, instance: Dict):
paragraph = Paragraph(id=uuid.uuid7(),
document_id=document_id,
content=instance.get("content"),
knowledge_id=knowledge_id,
title=instance.get("title") if 'title' in instance else '')
problem_paragraph_object_list = [
ProblemParagraphObject(knowledge_id, document_id, paragraph.id, problem.get('content')) for problem in
(instance.get('problem_list') if 'problem_list' in instance else [])]
return {'paragraph': paragraph,
'problem_paragraph_object_list': problem_paragraph_object_list}
@staticmethod
def or_get(exists_problem_list, content, knowledge_id):
exists = [row for row in exists_problem_list if row.content == content]
if len(exists) > 0:
return exists[0]
else:
return Problem(id=uuid.uuid7(), content=content, knowledge_id=knowledge_id)
def delete_problems_and_mappings(paragraph_ids):
problem_paragraph_mappings = ProblemParagraphMapping.objects.filter(paragraph_id__in=paragraph_ids)
problem_ids = set(problem_paragraph_mappings.values_list('problem_id', flat=True))
if problem_ids:
problem_paragraph_mappings.delete()
remaining_problem_counts = ProblemParagraphMapping.objects.filter(problem_id__in=problem_ids).values(
'problem_id').annotate(count=Count('problem_id'))
remaining_problem_ids = {pc['problem_id'] for pc in remaining_problem_counts}
problem_ids_to_delete = problem_ids - remaining_problem_ids
Problem.objects.filter(id__in=problem_ids_to_delete).delete()
else:
problem_paragraph_mappings.delete()

View File

@ -0,0 +1,15 @@
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from knowledge.models import Problem
class ProblemSerializer(serializers.ModelSerializer):
class Meta:
model = Problem
fields = ['id', 'content', 'knowledge_id', 'create_time', 'update_time']
class ProblemInstanceSerializer(serializers.Serializer):
id = serializers.CharField(required=False, label=_('problem id'))
content = serializers.CharField(required=True, max_length=256, label=_('content'))

View File

@ -0,0 +1,11 @@
SELECT * from (
SELECT
"document".* ,
to_json("document"."meta") as meta,
to_json("document"."status_meta") as status_meta,
(SELECT "count"("id") FROM "paragraph" WHERE document_id="document"."id") as "paragraph_count"
FROM
"document" "document"
${document_custom_sql}
) temp
${order_by_query}

View File

@ -0,0 +1,35 @@
SELECT
*,
to_json(meta) as meta
FROM
(
SELECT
"temp_knowledge".*,
"document_temp"."char_length",
CASE
WHEN
"app_knowledge_temp"."count" IS NULL THEN 0 ELSE "app_knowledge_temp"."count" END AS application_mapping_count,
"document_temp".document_count FROM (
SELECT knowledge.*
FROM
knowledge knowledge
${knowledge_custom_sql}
UNION
SELECT
*
FROM
knowledge
WHERE
knowledge."id" IN (
SELECT
team_member_permission.target
FROM
team_member team_member
LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id"
${team_member_permission_custom_sql}
)
) temp_knowledge
LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", knowledge_id FROM "document" GROUP BY knowledge_id ) "document_temp" ON temp_knowledge."id" = "document_temp".knowledge_id
LEFT JOIN (SELECT "count"("id"),knowledge_id FROM application_knowledge_mapping GROUP BY knowledge_id) app_knowledge_temp ON temp_knowledge."id" = "app_knowledge_temp".knowledge_id
) temp
${default_sql}

View File

@ -0,0 +1,20 @@
SELECT
*
FROM
application
WHERE
user_id = %s UNION
SELECT
*
FROM
application
WHERE
"id" IN (
SELECT
team_member_permission.target
FROM
team_member team_member
LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id"
WHERE
( "team_member_permission"."auth_target_type" = 'APPLICATION' AND "team_member_permission"."operate"::text[] @> ARRAY['USE'] AND team_member.team_id = %s AND team_member.user_id =%s )
)

View File

@ -0,0 +1,6 @@
SELECT
(SELECT "name" FROM "document" WHERE "id"=document_id) as document_name,
(SELECT "name" FROM "knowledge" WHERE "id"=knowledge_id) as knowledge_name,
*
FROM
"paragraph"

View File

@ -0,0 +1,5 @@
SELECT
(SELECT "name" FROM "document" WHERE "id"=document_id) as document_name,
*
FROM
"paragraph"

View File

@ -0,0 +1,5 @@
SELECT
problem.*,
(SELECT "count"("id") FROM "problem_paragraph_mapping" WHERE problem_id="problem"."id") as "paragraph_count"
FROM
problem problem

View File

@ -0,0 +1,2 @@
SELECT "problem"."content",problem_paragraph_mapping.paragraph_id FROM problem problem
LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem_paragraph_mapping.problem_id=problem."id"

View File

@ -0,0 +1,7 @@
UPDATE "document"
SET "char_length" = ( SELECT CASE WHEN
"sum" ( "char_length" ( "content" ) ) IS NULL THEN
0 ELSE "sum" ( "char_length" ( "content" ) )
END FROM paragraph WHERE "document_id" = %s )
WHERE
"id" = %s

View File

@ -0,0 +1,25 @@
UPDATE "document" "document"
SET status_meta = jsonb_set ( "document".status_meta, '{aggs}', tmp.status_meta )
FROM
(
SELECT COALESCE
( jsonb_agg ( jsonb_delete ( ( row_to_json ( record ) :: JSONB ), 'document_id' ) ), '[]' :: JSONB ) AS status_meta,
document_id AS document_id
FROM
(
SELECT
"paragraph".status,
"count" ( "paragraph"."id" ),
"document"."id" AS document_id
FROM
"document" "document"
LEFT JOIN "paragraph" "paragraph" ON "document"."id" = paragraph.document_id
${document_custom_sql}
GROUP BY
"paragraph".status,
"document"."id"
) record
GROUP BY
document_id
) tmp
WHERE "document".id="tmp".document_id

View File

@ -0,0 +1,13 @@
UPDATE "${table_name}"
SET status = reverse (
SUBSTRING ( reverse ( LPAD( status, ${bit_number}, 'n' ) ) :: TEXT FROM 1 FOR ${up_index} ) || ${status_number} || SUBSTRING ( reverse ( LPAD( status, ${bit_number}, 'n' ) ) :: TEXT FROM ${next_index} )
),
status_meta = jsonb_set (
"${table_name}".status_meta,
'{state_time,${current_index}}',
jsonb_set (
COALESCE ( "${table_name}".status_meta #> '{state_time,${current_index}}', jsonb_build_object ( '${status_number}', '${current_time}' ) ),
'{${status_number}}',
CONCAT ( '"', '${current_time}', '"' ) :: JSONB
)
)

View File

@ -1 +1,2 @@
from .sync import *
from .embedding import *

View File

@ -0,0 +1,255 @@
# coding=utf-8
import logging
import traceback
from typing import List
from celery_once import QueueOnce
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from common.config.embedding_config import ModelManage
from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingKnowledgeIdArgs, \
UpdateEmbeddingDocumentIdArgs
from knowledge.models import Document, TaskType, State
from models_provider.tools import get_model
from models_provider.models import Model
from ops import celery_app
max_kb_error = logging.getLogger("max_kb_error")
max_kb = logging.getLogger("max_kb")
def get_embedding_model(model_id, exception_handler=lambda e: max_kb_error.error(
_('Failed to obtain vector model: {error} {traceback}').format(
error=str(e),
traceback=traceback.format_exc()
))):
try:
model = QuerySet(Model).filter(id=model_id).first()
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
except Exception as e:
exception_handler(e)
raise e
return embedding_model
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id']}, name='celery:embedding_by_paragraph')
def embedding_by_paragraph(paragraph_id, model_id):
embedding_model = get_embedding_model(model_id)
ListenerManagement.embedding_by_paragraph(paragraph_id, embedding_model)
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:embedding_by_paragraph_data_list')
def embedding_by_paragraph_data_list(data_list, paragraph_id_list, model_id):
embedding_model = get_embedding_model(model_id)
ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model)
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:embedding_by_paragraph_list')
def embedding_by_paragraph_list(paragraph_id_list, model_id):
embedding_model = get_embedding_model(model_id)
ListenerManagement.embedding_by_paragraph_list(paragraph_id_list, embedding_model)
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:embedding_by_document')
def embedding_by_document(document_id, model_id, state_list=None):
"""
向量化文档
@param state_list:
@param document_id: 文档id
@param model_id 向量模型
:return: None
"""
if state_list is None:
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
State.REVOKE.value,
State.REVOKED.value, State.IGNORED.value]
def exception_handler(e):
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
State.FAILURE)
max_kb_error.error(
_('Failed to obtain vector model: {error} {traceback}').format(
error=str(e),
traceback=traceback.format_exc()
))
embedding_model = get_embedding_model(model_id, exception_handler)
ListenerManagement.embedding_by_document(document_id, embedding_model, state_list)
@celery_app.task(name='celery:embedding_by_document_list')
def embedding_by_document_list(document_id_list, model_id):
"""
向量化文档
@param document_id_list: 文档id列表
@param model_id 向量模型
:return: None
"""
for document_id in document_id_list:
embedding_by_document.delay(document_id, model_id)
@celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, name='celery:embedding_by_knowledge')
def embedding_by_knowledge(knowledge_id, model_id):
"""
向量化知识库
@param knowledge_id: 知识库id
@param model_id 向量模型
:return: None
"""
max_kb.info(_('Start--->Vectorized knowledge: {knowledge_id}').format(knowledge_id=knowledge_id))
try:
ListenerManagement.delete_embedding_by_knowledge(knowledge_id)
document_list = QuerySet(Document).filter(knowledge_id=knowledge_id)
max_kb.info(_('Knowledge documentation: {document_names}').format(
document_names=", ".join([d.name for d in document_list])))
for document in document_list:
try:
embedding_by_document.delay(document.id, model_id)
except Exception as e:
pass
except Exception as e:
max_kb_error.error(
_('Vectorized knowledge: {knowledge_id} error {error} {traceback}'.format(knowledge_id=knowledge_id,
error=str(e),
traceback=traceback.format_exc())))
finally:
max_kb.info(_('End--->Vectorized knowledge: {knowledge_id}').format(knowledge_id=knowledge_id))
def embedding_by_problem(args, model_id):
"""
向量话问题
@param args: 问题对象
@param model_id: 模型id
@return:
"""
embedding_model = get_embedding_model(model_id)
ListenerManagement.embedding_by_problem(args, embedding_model)
def embedding_by_data_list(args: List, model_id):
embedding_model = get_embedding_model(model_id)
ListenerManagement.embedding_by_data_list(args, embedding_model)
def delete_embedding_by_document(document_id):
"""
删除指定文档id的向量
@param document_id: 文档id
@return: None
"""
ListenerManagement.delete_embedding_by_document(document_id)
def delete_embedding_by_document_list(document_id_list: List[str]):
"""
删除指定文档列表的向量数据
@param document_id_list: 文档id列表
@return: None
"""
ListenerManagement.delete_embedding_by_document_list(document_id_list)
def delete_embedding_by_knowledge(knowledge_id):
"""
删除指定数据集向量数据
@param knowledge_id: 数据集id
@return: None
"""
ListenerManagement.delete_embedding_by_knowledge(knowledge_id)
def delete_embedding_by_paragraph(paragraph_id):
"""
删除指定段落的向量数据
@param paragraph_id: 段落id
@return: None
"""
ListenerManagement.delete_embedding_by_paragraph(paragraph_id)
def delete_embedding_by_source(source_id):
"""
删除指定资源id的向量数据
@param source_id: 资源id
@return: None
"""
ListenerManagement.delete_embedding_by_source(source_id)
def disable_embedding_by_paragraph(paragraph_id):
"""
禁用某个段落id的向量
@param paragraph_id: 段落id
@return: None
"""
ListenerManagement.disable_embedding_by_paragraph(paragraph_id)
def enable_embedding_by_paragraph(paragraph_id):
"""
开启某个段落id的向量数据
@param paragraph_id: 段落id
@return: None
"""
ListenerManagement.enable_embedding_by_paragraph(paragraph_id)
def delete_embedding_by_source_ids(source_ids: List[str]):
"""
删除向量根据source_id_list
@param source_ids:
@return:
"""
ListenerManagement.delete_embedding_by_source_ids(source_ids)
def update_problem_embedding(problem_id: str, problem_content: str, model_id):
"""
更新问题
@param problem_id:
@param problem_content:
@param model_id:
@return:
"""
model = get_embedding_model(model_id)
ListenerManagement.update_problem(UpdateProblemArgs(problem_id, problem_content, model))
def update_embedding_knowledge_id(paragraph_id_list, target_knowledge_id):
"""
修改向量数据到指定知识库
@param paragraph_id_list: 指定段落的向量数据
@param target_knowledge_id: 知识库id
@return:
"""
ListenerManagement.update_embedding_knowledge_id(
UpdateEmbeddingKnowledgeIdArgs(paragraph_id_list, target_knowledge_id))
def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]):
"""
删除指定段落列表的向量数据
@param paragraph_ids: 段落列表
@return: None
"""
ListenerManagement.delete_embedding_by_paragraph_ids(paragraph_ids)
def update_embedding_document_id(paragraph_id_list, target_document_id, target_knowledge_id,
target_embedding_model_id=None):
target_embedding_model = get_embedding_model(
target_embedding_model_id) if target_embedding_model_id is not None else None
ListenerManagement.update_embedding_document_id(
UpdateEmbeddingDocumentIdArgs(paragraph_id_list, target_document_id, target_knowledge_id,
target_embedding_model))
def delete_embedding_by_knowledge_id_list(knowledge_id_list):
ListenerManagement.delete_embedding_by_knowledge_id_list(knowledge_id_list)

View File

@ -1,29 +1,23 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file tools.py
@date2024/8/20 21:48
@desc:
"""
import logging
import re
import traceback
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from common.utils.fork import ChildLink, Fork
from common.utils.split_model import get_split_model
from knowledge.models.knowledge import KnowledgeType, Document, DataSet, Status
from django.utils.translation import gettext_lazy as _
from knowledge.models.knowledge import KnowledgeType, Document, Knowledge, Status
max_kb_error = logging.getLogger("max_kb_error")
max_kb = logging.getLogger("max_kb")
def get_save_handler(dataset_id, selector):
from knowledge.serializers.document_serializers import DocumentSerializers
def get_save_handler(knowledge_id, selector):
from knowledge.serializers import DocumentSerializers
def handler(child_link: ChildLink, response: Fork.Response):
if response.status == 200:
@ -31,7 +25,7 @@ def get_save_handler(dataset_id, selector):
document_name = child_link.tag.text if child_link.tag is not None and len(
child_link.tag.text.strip()) > 0 else child_link.url
paragraphs = get_split_model('web.md').parse(response.content)
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
DocumentSerializers.Create(data={'knowledge_id': knowledge_id}).save(
{'name': document_name, 'paragraphs': paragraphs,
'meta': {'source_url': child_link.url, 'selector': selector},
'type': KnowledgeType.WEB}, with_valid=True)
@ -41,9 +35,9 @@ def get_save_handler(dataset_id, selector):
return handler
def get_sync_handler(dataset_id):
from knowledge.serializers.document_serializers import DocumentSerializers
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
def get_sync_handler(knowledge_id):
from knowledge.serializers import DocumentSerializers
knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first()
def handler(child_link: ChildLink, response: Fork.Response):
if response.status == 200:
@ -52,32 +46,31 @@ def get_sync_handler(dataset_id):
document_name = child_link.tag.text if child_link.tag is not None and len(
child_link.tag.text.strip()) > 0 else child_link.url
paragraphs = get_split_model('web.md').parse(response.content)
first = QuerySet(Document).filter(meta__source_url=child_link.url.strip(),
dataset=dataset).first()
first = QuerySet(Document).filter(meta__source_url=child_link.url.strip(), knowledge=knowledge).first()
if first is not None:
# 如果存在,使用文档同步
DocumentSerializers.Sync(data={'document_id': first.id}).sync()
else:
# 插入
DocumentSerializers.Create(data={'dataset_id': dataset.id}).save(
DocumentSerializers.Create(data={'knowledge_id': knowledge.id}).save(
{'name': document_name, 'paragraphs': paragraphs,
'meta': {'source_url': child_link.url.strip(), 'selector': dataset.meta.get('selector')},
'type': Type.web}, with_valid=True)
'meta': {'source_url': child_link.url.strip(), 'selector': knowledge.meta.get('selector')},
'type': KnowledgeType.WEB}, with_valid=True)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
return handler
def get_sync_web_document_handler(dataset_id):
from knowledge.serializers.document_serializers import DocumentSerializers
def get_sync_web_document_handler(knowledge_id):
from knowledge.serializers import DocumentSerializers
def handler(source_url: str, selector, response: Fork.Response):
if response.status == 200:
try:
paragraphs = get_split_model('web.md').parse(response.content)
# 插入
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
DocumentSerializers.Create(data={'knowledge_id': knowledge_id}).save(
{'name': source_url[0:128], 'paragraphs': paragraphs,
'meta': {'source_url': source_url, 'selector': selector},
'type': KnowledgeType.WEB}, with_valid=True)
@ -85,7 +78,7 @@ def get_sync_web_document_handler(dataset_id):
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
else:
Document(name=source_url[0:128],
dataset_id=dataset_id,
knowledge_id=knowledge_id,
meta={'source_url': source_url, 'selector': selector},
type=KnowledgeType.WEB,
char_length=0,
@ -94,9 +87,9 @@ def get_sync_web_document_handler(dataset_id):
return handler
def save_problem(dataset_id, document_id, paragraph_id, problem):
from knowledge.serializers.paragraph_serializers import ParagraphSerializers
# print(f"dataset_id: {dataset_id}")
def save_problem(knowledge_id, document_id, paragraph_id, problem):
from knowledge.serializers import ParagraphSerializers
# print(f"knowledge_id: {knowledge_id}")
# print(f"document_id: {document_id}")
# print(f"paragraph_id: {paragraph_id}")
# print(f"problem: {problem}")
@ -108,7 +101,7 @@ def save_problem(dataset_id, document_id, paragraph_id, problem):
return
try:
ParagraphSerializers.Problem(
data={"dataset_id": dataset_id, 'document_id': document_id,
data={"knowledge_id": knowledge_id, 'document_id': document_id,
'paragraph_id': paragraph_id}).save(instance={"content": problem}, with_valid=True)
except Exception as e:
max_kb_error.error(_('Association problem failed {error}').format(error=str(e)))

View File

@ -12,12 +12,11 @@ import traceback
from typing import List
from celery_once import QueueOnce
from django.utils.translation import gettext_lazy as _
from common.utils.fork import ForkManage, Fork
from .tools import get_save_handler, get_sync_web_document_handler, get_sync_handler
from ops import celery_app
from django.utils.translation import gettext_lazy as _
from .handler import get_save_handler, get_sync_web_document_handler, get_sync_handler
max_kb_error = logging.getLogger("max_kb_error")
max_kb = logging.getLogger("max_kb")

View File

View File

@ -0,0 +1,187 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_vector.py
@date2023/10/18 19:16
@desc:
"""
import threading
from abc import ABC, abstractmethod
from functools import reduce
from typing import List, Dict
from langchain_core.embeddings import Embeddings
from common.chunk import text_to_chunk
from common.utils.common import sub_array
from knowledge.models import SourceType, SearchMode
lock = threading.Lock()
def chunk_data(data: Dict):
if str(data.get('source_type')) == SourceType.PARAGRAPH.value:
text = data.get('text')
chunk_list = text_to_chunk(text)
return [{**data, 'text': chunk} for chunk in chunk_list]
return [data]
def chunk_data_list(data_list: List[Dict]):
result = [chunk_data(data) for data in data_list]
return reduce(lambda x, y: [*x, *y], result, [])
class BaseVectorStore(ABC):
vector_exists = False
@abstractmethod
def vector_is_create(self) -> bool:
"""
判断向量库是否创建
:return: 是否创建向量库
"""
pass
@abstractmethod
def vector_create(self):
"""
创建 向量库
:return:
"""
pass
def save_pre_handler(self):
"""
插入前置处理器 主要是判断向量库是否创建
:return: True
"""
if not BaseVectorStore.vector_exists:
if not self.vector_is_create():
self.vector_create()
BaseVectorStore.vector_exists = True
return True
def save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding: Embeddings):
"""
插入向量数据
:param source_id: 资源id
:param knowledge_id: 知识库id
:param text: 文本
:param source_type: 资源类型
:param document_id: 文档id
:param is_active: 是否禁用
:param embedding: 向量化处理器
:param paragraph_id 段落id
:return: bool
"""
self.save_pre_handler()
data = {'document_id': document_id, 'paragraph_id': paragraph_id, 'knowledge_id': knowledge_id,
'is_active': is_active, 'source_id': source_id, 'source_type': source_type, 'text': text}
chunk_list = chunk_data(data)
result = sub_array(chunk_list)
for child_array in result:
self._batch_save(child_array, embedding, lambda: False)
def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
"""
批量插入
@param data_list: 数据列表
@param embedding: 向量化处理器
@param is_the_task_interrupted: 判断是否中断任务
:return: bool
"""
self.save_pre_handler()
chunk_list = chunk_data_list(data_list)
result = sub_array(chunk_list)
for child_array in result:
if not is_the_task_interrupted():
self._batch_save(child_array, embedding, is_the_task_interrupted)
else:
break
@abstractmethod
def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding: Embeddings):
pass
@abstractmethod
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
pass
def search(self, query_text, knowledge_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_list: list[str],
is_active: bool,
embedding: Embeddings):
if knowledge_id_list is None or len(knowledge_id_list) == 0:
return []
embedding_query = embedding.embed_query(query_text)
result = self.query(embedding_query, knowledge_id_list, exclude_document_id_list, exclude_paragraph_list,
is_active, 1, 3, 0.65)
return result[0]
@abstractmethod
def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str],
exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
search_mode: SearchMode):
pass
@abstractmethod
def hit_test(self, query_text, knowledge_id: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
search_mode: SearchMode,
embedding: Embeddings):
pass
@abstractmethod
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
pass
@abstractmethod
def update_by_paragraph_ids(self, paragraph_ids: str, instance: Dict):
pass
@abstractmethod
def update_by_source_id(self, source_id: str, instance: Dict):
pass
@abstractmethod
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
pass
@abstractmethod
def delete_by_knowledge_id(self, knowledge_id: str):
pass
@abstractmethod
def delete_by_document_id(self, document_id: str):
pass
@abstractmethod
def delete_by_document_id_list(self, document_id_list: List[str]):
pass
@abstractmethod
def delete_by_knowledge_id_list(self, knowledge_id_list: List[str]):
pass
@abstractmethod
def delete_by_source_id(self, source_id: str, source_type: str):
pass
@abstractmethod
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
pass
@abstractmethod
def delete_by_paragraph_id(self, paragraph_id: str):
pass
@abstractmethod
def delete_by_paragraph_ids(self, paragraph_ids: List[str]):
pass

View File

@ -0,0 +1,222 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file pg_vector.py
@date2023/10/19 15:28
@desc:
"""
import json
import os
from abc import ABC, abstractmethod
from typing import Dict, List
import uuid_utils.compat as uuid
from common.utils.ts_vecto_util import to_ts_vector, to_query
from django.contrib.postgres.search import SearchVector
from django.db.models import QuerySet, Value
from langchain_core.embeddings import Embeddings
from common.db.search import generate_sql_by_query_dict
from common.db.sql_execute import select_list
from common.utils.common import get_file_content
from knowledge.models import Embedding, SearchMode, SourceType
from knowledge.vector.base_vector import BaseVectorStore
from maxkb.conf import PROJECT_DIR
class PGVector(BaseVectorStore):
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
if len(source_ids) == 0:
return
QuerySet(Embedding).filter(source_id__in=source_ids, source_type=source_type).delete()
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
def vector_is_create(self) -> bool:
# 项目启动默认是创建好的 不需要再创建
return True
def vector_create(self):
return True
def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding: Embeddings):
text_embedding = embedding.embed_query(text)
embedding = Embedding(id=uuid.uuid7(),
knowledge_id=knowledge_id,
document_id=document_id,
is_active=is_active,
paragraph_id=paragraph_id,
source_id=source_id,
embedding=text_embedding,
source_type=source_type,
search_vector=to_ts_vector(text))
embedding.save()
return True
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
texts = [row.get('text') for row in text_list]
embeddings = embedding.embed_documents(texts)
embedding_list = [Embedding(id=uuid.uuid7(),
document_id=text_list[index].get('document_id'),
paragraph_id=text_list[index].get('paragraph_id'),
knowledge_id=text_list[index].get('knowledge_id'),
is_active=text_list[index].get('is_active', True),
source_id=text_list[index].get('source_id'),
source_type=text_list[index].get('source_type'),
embedding=embeddings[index],
search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))) for
index in
range(0, len(texts))]
if not is_the_task_interrupted():
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
return True
def hit_test(self, query_text, knowledge_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
search_mode: SearchMode,
embedding: Embeddings):
if knowledge_id_list is None or len(knowledge_id_list) == 0:
return []
exclude_dict = {}
embedding_query = embedding.embed_query(query_text)
query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=True)
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
query_set = query_set.exclude(**exclude_dict)
for search_handle in search_handle_list:
if search_handle.support(search_mode):
return search_handle.handle(query_set, query_text, embedding_query, top_number, similarity, search_mode)
def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str],
exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
search_mode: SearchMode):
exclude_dict = {}
if knowledge_id_list is None or len(knowledge_id_list) == 0:
return []
query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=is_active)
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
query_set = query_set.exclude(document_id__in=exclude_document_id_list)
if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:
query_set = query_set.exclude(paragraph_id__in=exclude_paragraph_list)
query_set = query_set.exclude(**exclude_dict)
for search_handle in search_handle_list:
if search_handle.support(search_mode):
return search_handle.handle(query_set, query_text, query_embedding, top_n, similarity, search_mode)
def update_by_source_id(self, source_id: str, instance: Dict):
QuerySet(Embedding).filter(source_id=source_id).update(**instance)
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance)
def update_by_paragraph_ids(self, paragraph_id: str, instance: Dict):
QuerySet(Embedding).filter(paragraph_id__in=paragraph_id).update(**instance)
def delete_by_knowledge_id(self, knowledge_id: str):
QuerySet(Embedding).filter(knowledge_id=knowledge_id).delete()
def delete_by_knowledge_id_list(self, knowledge_id_list: List[str]):
QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list).delete()
def delete_by_document_id(self, document_id: str):
QuerySet(Embedding).filter(document_id=document_id).delete()
return True
def delete_by_document_id_list(self, document_id_list: List[str]):
if len(document_id_list) == 0:
return True
return QuerySet(Embedding).filter(document_id__in=document_id_list).delete()
def delete_by_source_id(self, source_id: str, source_type: str):
QuerySet(Embedding).filter(source_id=source_id, source_type=source_type).delete()
return True
def delete_by_paragraph_id(self, paragraph_id: str):
QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete()
def delete_by_paragraph_ids(self, paragraph_ids: List[str]):
QuerySet(Embedding).filter(paragraph_id__in=paragraph_ids).delete()
class ISearch(ABC):
@abstractmethod
def support(self, search_mode: SearchMode):
pass
@abstractmethod
def handle(self, query_set, query_text, query_embedding, top_number: int,
similarity: float, search_mode: SearchMode):
pass
class EmbeddingSearch(ISearch):
def handle(self,
query_set,
query_text,
query_embedding,
top_number: int,
similarity: float,
search_mode: SearchMode):
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
'embedding_search.sql')),
with_table_name=True)
embedding_model = select_list(exec_sql,
[json.dumps(query_embedding), *exec_params, similarity, top_number])
return embedding_model
def support(self, search_mode: SearchMode):
return search_mode.value == SearchMode.embedding.value
class KeywordsSearch(ISearch):
def handle(self,
query_set,
query_text,
query_embedding,
top_number: int,
similarity: float,
search_mode: SearchMode):
exec_sql, exec_params = generate_sql_by_query_dict({'keywords_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
'keywords_search.sql')),
with_table_name=True)
embedding_model = select_list(exec_sql,
[to_query(query_text), *exec_params, similarity, top_number])
return embedding_model
def support(self, search_mode: SearchMode):
return search_mode.value == SearchMode.keywords.value
class BlendSearch(ISearch):
def handle(self,
query_set,
query_text,
query_embedding,
top_number: int,
similarity: float,
search_mode: SearchMode):
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
'blend_search.sql')),
with_table_name=True)
embedding_model = select_list(exec_sql,
[json.dumps(query_embedding), to_query(query_text), *exec_params, similarity,
top_number])
return embedding_model
def support(self, search_mode: SearchMode):
return search_mode.value == SearchMode.blend.value
search_handle_list = [EmbeddingSearch(), KeywordsSearch(), BlendSearch()]

View File

@ -36,7 +36,7 @@ class Migration(migrations.Migration):
('tree_id', models.PositiveIntegerField(db_index=True, editable=False)),
('level', models.PositiveIntegerField(editable=False)),
('parent',
mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE,
mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING,
related_name='children', to='tools.toolfolder')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user',
verbose_name='用户id')),
@ -73,7 +73,7 @@ class Migration(migrations.Migration):
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user',
verbose_name='用户id')),
('folder',
models.ForeignKey(default='root', on_delete=django.db.models.deletion.CASCADE, to='tools.toolfolder',
models.ForeignKey(default='root', on_delete=django.db.models.deletion.DO_NOTHING, to='tools.toolfolder',
verbose_name='文件夹id')),
],
options={

View File

@ -12,7 +12,7 @@ class ToolFolder(MPTTModel, AppModelMixin):
name = models.CharField(max_length=64, verbose_name="文件夹名称")
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="用户id")
workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True)
parent = TreeForeignKey('self', on_delete=models.CASCADE, null=True, blank=True, related_name='children')
parent = TreeForeignKey('self', on_delete=models.DO_NOTHING, null=True, blank=True, related_name='children')
class Meta:
db_table = "tool_folder"
@ -46,7 +46,7 @@ class Tool(AppModelMixin):
tool_type = models.CharField(max_length=20, verbose_name='工具类型', choices=ToolType.choices,
default=ToolType.CUSTOM, db_index=True)
template_id = models.UUIDField(max_length=128, verbose_name="模版id", null=True, default=None)
folder = models.ForeignKey(ToolFolder, on_delete=models.CASCADE, verbose_name="文件夹id", default='root')
folder = models.ForeignKey(ToolFolder, on_delete=models.DO_NOTHING, verbose_name="文件夹id", default='root')
workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True)
init_params = models.CharField(max_length=102400, verbose_name="初始化参数", null=True)