feat: 数据集,文档,段落,问题,向量化接口

This commit is contained in:
shaohuzhang1 2023-10-24 20:24:32 +08:00
parent 64e93679f8
commit a2de9691fb
36 changed files with 1843 additions and 166 deletions

1
.gitignore vendored
View File

@ -162,6 +162,7 @@ cython_debug/
ui/node_modules ui/node_modules
ui/dist ui/dist
apps/static apps/static
models/
data data
.idea .idea
.dev .dev

View File

@ -1,4 +1,4 @@
# Generated by Django 4.1.10 on 2023-10-09 06:33 # Generated by Django 4.1.10 on 2023-10-24 12:13
import django.contrib.postgres.fields import django.contrib.postgres.fields
from django.db import migrations, models from django.db import migrations, models

View File

@ -0,0 +1,51 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file embedding_config.py
@date2023/10/23 16:03
@desc:
"""
import types
from smartdoc.const import CONFIG
from langchain.embeddings import HuggingFaceEmbeddings
class EmbeddingModel:
instance = None
@staticmethod
def get_embedding_model():
"""
获取向量化模型
:return:
"""
if EmbeddingModel.instance is None:
model_name = CONFIG.get('EMBEDDING_MODEL_NAME')
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH')
device = CONFIG.get('EMBEDDING_DEVICE')
e = HuggingFaceEmbeddings(
model_name=model_name,
cache_folder=cache_folder,
model_kwargs={'device': device})
EmbeddingModel.instance = e
return EmbeddingModel.instance
class VectorStore:
from embedding.vector.pg_vector import PGVector
from embedding.vector.base_vector import BaseVectorStore
instance_map = {
'pg_vector': PGVector,
}
instance = None
@staticmethod
def get_embedding_vector() -> BaseVectorStore:
from embedding.vector.pg_vector import PGVector
if VectorStore.instance is None:
from smartdoc.const import CONFIG
vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"),
PGVector)
VectorStore.instance = vector_store_class()
return VectorStore.instance

View File

@ -1,4 +1,3 @@
# coding=utf-8
""" """
@project: qabot @project: qabot
@Author @Author

View File

@ -0,0 +1,162 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file listener_manage.py
@date2023/10/20 14:01
@desc:
"""
import os
from concurrent.futures import ThreadPoolExecutor
import django.db.models
from blinker import signal
from django.db.models import QuerySet
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.db.search import native_search, get_dynamics_model
from common.util.file_util import get_file_content
from dataset.models import Paragraph, Status, Document
from embedding.models import SourceType
from smartdoc.conf import PROJECT_DIR
def poxy(poxy_function):
def inner(args):
ListenerManagement.work_thread_pool.submit(poxy_function, args)
return inner
class ListenerManagement:
work_thread_pool = ThreadPoolExecutor(5)
embedding_by_problem_signal = signal("embedding_by_problem")
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
embedding_by_dataset_signal = signal("embedding_by_dataset")
embedding_by_document_signal = signal("embedding_by_document")
delete_embedding_by_document_signal = signal("delete_embedding_by_document")
delete_embedding_by_dataset_signal = signal("delete_embedding_by_dataset")
delete_embedding_by_paragraph_signal = signal("delete_embedding_by_paragraph")
delete_embedding_by_source_signal = signal("delete_embedding_by_source")
enable_embedding_by_paragraph_signal = signal('enable_embedding_by_paragraph')
disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph')
init_embedding_model_signal = signal('init_embedding_model')
@staticmethod
def embedding_by_problem(args):
VectorStore.get_embedding_vector().save(**args)
@staticmethod
@poxy
def embedding_by_paragraph(paragraph_id):
"""
向量化段落 根据段落id
:param paragraph_id: 段落id
:return: None
"""
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'problem.paragraph_id': django.db.models.CharField()})).filter(
**{'problem.paragraph_id': paragraph_id}),
'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list)
QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': Status.success.value})
@staticmethod
@poxy
def embedding_by_document(document_id):
"""
向量化文档
:param document_id: 文档id
:return: None
"""
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'problem.document_id': django.db.models.CharField()})).filter(
**{'problem.document_id': document_id}),
'paragraph': QuerySet(Paragraph).filter(document_id=document_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list)
# 修改状态
QuerySet(Document).filter(id=document_id).update(**{'status': Status.success.value})
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.success.value})
@staticmethod
@poxy
def embedding_by_dataset(dataset_id):
"""
向量化数据集
:param dataset_id: 数据集id
:return: None
"""
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'problem.dataset_id': django.db.models.CharField()})).filter(
**{'problem.dataset_id': dataset_id}),
'paragraph': QuerySet(Paragraph).filter(dataset_id=dataset_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list)
# 修改文档 以及段落的状态
QuerySet(Document).filter(dataset_id=dataset_id).update(**{'status': Status.success.value})
QuerySet(Paragraph).filter(dataset_id=dataset_id).update(**{'status': Status.success.value})
@staticmethod
def delete_embedding_by_document(document_id):
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
@staticmethod
def delete_embedding_by_dataset(dataset_id):
VectorStore.get_embedding_vector().delete_by_dataset_id(dataset_id)
@staticmethod
def delete_embedding_by_paragraph(paragraph_id):
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
@staticmethod
def delete_embedding_by_source(source_id):
VectorStore.get_embedding_vector().delete_by_source_id(source_id, SourceType.PROBLEM)
@staticmethod
def disable_embedding_by_paragraph(paragraph_id):
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': False})
@staticmethod
def enable_embedding_by_paragraph(paragraph_id):
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True})
@staticmethod
@poxy
def init_embedding_model(ags):
EmbeddingModel.get_embedding_model()
def run(self):
# 添加向量 根据问题id
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
# 添加向量 根据段落id
ListenerManagement.embedding_by_paragraph_signal.connect(self.embedding_by_paragraph)
# 添加向量 根据数据集id
ListenerManagement.embedding_by_dataset_signal.connect(
self.embedding_by_dataset)
# 添加向量 根据文档id
ListenerManagement.embedding_by_document_signal.connect(
self.embedding_by_document)
# 删除 向量 根据文档
ListenerManagement.delete_embedding_by_document_signal.connect(self.delete_embedding_by_document)
# 删除 向量 根据数据集id
ListenerManagement.delete_embedding_by_dataset_signal.connect(self.delete_embedding_by_dataset)
# 删除向量 根据段落id
ListenerManagement.delete_embedding_by_paragraph_signal.connect(
self.delete_embedding_by_paragraph)
# 删除向量 根据资源id
ListenerManagement.delete_embedding_by_source_signal.connect(self.delete_embedding_by_source)
# 禁用段落
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
# 启动段落向量
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
# 初始化向量化模型
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)

View File

@ -14,7 +14,7 @@ from rest_framework.views import exception_handler
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.response import result from common.response import result
import traceback
def to_result(key, args, parent_key=None): def to_result(key, args, parent_key=None):
""" """
将校验异常 args转换为统一数据 将校验异常 args转换为统一数据
@ -59,6 +59,7 @@ def handle_exception(exc, context):
exception_class = exc.__class__ exception_class = exc.__class__
# 先调用REST framework默认的异常处理方法获得标准错误响应对象 # 先调用REST framework默认的异常处理方法获得标准错误响应对象
response = exception_handler(exc, context) response = exception_handler(exc, context)
traceback.print_exc()
# 在此处补充自定义的异常处理 # 在此处补充自定义的异常处理
if issubclass(exception_class, ValidationError): if issubclass(exception_class, ValidationError):
return validation_error_to_result(exc) return validation_error_to_result(exc)

View File

@ -70,7 +70,9 @@ def get_page_api_response(response_data_schema: openapi.Schema):
title="总条数", title="总条数",
default=1, default=1,
description="数据总条数"), description="数据总条数"),
"records": response_data_schema, "records": openapi.Schema(
type=openapi.TYPE_ARRAY,
items=response_data_schema),
"current": openapi.Schema( "current": openapi.Schema(
type=openapi.TYPE_INTEGER, type=openapi.TYPE_INTEGER,
title="当前页", title="当前页",
@ -115,6 +117,36 @@ def get_api_response(response_data_schema: openapi.Schema):
)}) )})
def get_default_response():
return get_api_response(openapi.Schema(type=openapi.TYPE_BOOLEAN))
def get_api_array_response(response_data_schema: openapi.Schema):
"""
获取统一返回 响应Api
"""
return openapi.Responses(responses={200: openapi.Response(description="响应参数",
schema=openapi.Schema(
type=openapi.TYPE_OBJECT,
properties={
'code': openapi.Schema(
type=openapi.TYPE_INTEGER,
title="响应码",
default=200,
description="成功:200 失败:其他"),
"message": openapi.Schema(
type=openapi.TYPE_STRING,
title="提示",
default='成功',
description="错误提示"),
"data": openapi.Schema(type=openapi.TYPE_ARRAY,
items=response_data_schema)
}
),
)})
def success(data): def success(data):
""" """
获取一个成功的响应对象 获取一个成功的响应对象

View File

@ -0,0 +1,26 @@
SELECT
problem."id" AS "source_id",
problem.document_id AS document_id,
problem.paragraph_id AS paragraph_id,
problem.dataset_id AS dataset_id,
0 AS source_type,
problem."content" AS "text",
paragraph.is_active AS is_active
FROM
problem problem
LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id
${problem}
UNION
SELECT
paragraph."id" AS "source_id",
paragraph.document_id AS document_id,
paragraph."id" AS paragraph_id,
paragraph.dataset_id AS dataset_id,
1 AS source_type,
paragraph."content" AS "text",
paragraph.is_active AS is_active
FROM
paragraph paragraph
${paragraph}

View File

@ -0,0 +1,19 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file common.py
@date2023/10/16 16:42
@desc:
"""
from functools import reduce
from typing import Dict
def query_params_to_single_dict(query_params: Dict):
return reduce(lambda x, y: {**x, y[0]: y[1]}, list(filter(lambda row: row[1] is not None,
list(map(lambda row: (
row[0], row[1][0] if isinstance(row[1][0],
list) and len(
row[1][0]) > 0 else row[1][0]),
query_params.items())))), {})

View File

@ -7,6 +7,7 @@
@desc: @desc:
""" """
import re import re
from functools import reduce
from typing import List from typing import List
import jieba import jieba
@ -25,7 +26,7 @@ def get_level_block(text, level_content_list, level_content_index):
level_content_list) else None level_content_list) else None
start_index = text.index(start_content) start_index = text.index(start_content)
end_index = text.index(next_content) if next_content is not None else len(text) end_index = text.index(next_content) if next_content is not None else len(text)
return text[start_index:end_index] return text[start_index:end_index].replace(level_content_list[level_content_index]['content'], "")
def to_tree_obj(content, state='title'): def to_tree_obj(content, state='title'):
@ -88,7 +89,7 @@ def to_paragraph(obj: dict):
content = obj['content'] content = obj['content']
return {"keywords": get_keyword(content), return {"keywords": get_keyword(content),
'parent_chain': list(map(lambda p: p['content'], obj['parent_chain'])), 'parent_chain': list(map(lambda p: p['content'], obj['parent_chain'])),
'content': content} 'content': ",".join(list(map(lambda p: p['content'], obj['parent_chain']))) + content}
def get_keyword(content: str): def get_keyword(content: str):
@ -109,13 +110,15 @@ def titles_to_paragraph(list_title: List[dict]):
:return: 块段落 :return: 块段落
""" """
if len(list_title) > 0: if len(list_title) > 0:
content = "\n".join( content = "\n,".join(
list(map(lambda d: d['content'].strip("\r\n").strip("\n").strip("\\s"), list_title))) list(map(lambda d: d['content'].strip("\r\n").strip("\n").strip("\\s"), list_title)))
return {'keywords': '', return {'keywords': '',
'parent_chain': list( 'parent_chain': list(
map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"), list_title[0]['parent_chain'])), map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"), list_title[0]['parent_chain'])),
'content': content} 'content': ",".join(list(
map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"),
list_title[0]['parent_chain']))) + content}
return None return None
@ -144,6 +147,15 @@ def to_block_paragraph(tree_data_list: List[dict]):
return list(map(lambda level: parse_group_key(level_group_dict[level]), level_group_dict)) return list(map(lambda level: parse_group_key(level_group_dict[level]), level_group_dict))
def parse_title_level(text, content_level_pattern: List, index):
if len(content_level_pattern) == index:
return []
result = parse_level(text, content_level_pattern[index])
if len(result) == 0 and len(content_level_pattern) > index + 1:
return parse_title_level(text, content_level_pattern, index + 1)
return result
def parse_level(text, pattern: str): def parse_level(text, pattern: str):
""" """
获取正则匹配到的文本 获取正则匹配到的文本
@ -151,10 +163,17 @@ def parse_level(text, pattern: str):
:param pattern: 正则 :param pattern: 正则
:return: 符合正则的文本 :return: 符合正则的文本
""" """
level_content_list = list(map(to_tree_obj, re.findall(pattern, text, flags=0))) level_content_list = list(map(to_tree_obj, re_findall(pattern, text)))
return list(map(filter_special_symbol, level_content_list)) return list(map(filter_special_symbol, level_content_list))
def re_findall(pattern, text):
result = re.findall(pattern, text, flags=0)
return list(filter(lambda r: r is not None and len(r) > 0, reduce(lambda x, y: [*x, *y], list(
map(lambda row: [*(row if isinstance(row, tuple) else [row])], result)),
[])))
def to_flat_obj(parent_chain: List[dict], content: str, state: str): def to_flat_obj(parent_chain: List[dict], content: str, state: str):
""" """
将树形属性转换为扁平对象 将树形属性转换为扁平对象
@ -194,10 +213,79 @@ def group_by(list_source: List, key):
return result return result
def result_tree_to_paragraph(result_tree: List[dict], result, parent_chain):
"""
转换为分段对象
:param result_tree: 解析文本的树
:param result: [] 用于递归
:param parent_chain: [] 用户递归存储数据
:return: List[{'problem':'xx','content':'xx'}]
"""
for item in result_tree:
print(item)
if item.get('state') == 'block':
result.append({'title': " ".join(parent_chain), 'content': item.get("content")})
children = item.get("children")
if children is not None and len(children) > 0:
result_tree_to_paragraph(children, result, [*parent_chain, item.get('content')])
return result
def post_handler_paragraph(content: str, limit: int, with_filter: bool):
"""
根据文本的最大字符分段
:param with_filter: 是否过滤特殊字符
:param content: 需要分段的文本字段
:param limit: 最大分段字符
:return: 分段后数据
"""
split_list = content.split('\n')
result = []
temp_char = ''
for split in split_list:
if len(temp_char + split) > limit:
result.append(temp_char)
temp_char = ''
temp_char = temp_char + split
if len(temp_char) > 0:
result.append(temp_char)
pattern = "[\\S\\s]{1," + str(limit) + '}'
# 如果\n 单段超过限制,则继续拆分
s = list(map(lambda row: filter_special_char(row) if with_filter else row, list(
reduce(lambda x, y: [*x, *y], list(map(lambda row: list(re.findall(pattern, row)), result)), []))))
return s
replace_map = {
re.compile('\n+'): '\n',
re.compile('\\s+'): ' ',
re.compile('#+'): "",
re.compile("\t+"): ''
}
def filter_special_char(content: str):
"""
过滤特殊字段
:param content: 文本
:return: 过滤后字段
"""
items = replace_map.items()
for key, value in items:
content = re.sub(key, value, content)
return content
class SplitModel: class SplitModel:
def __init__(self, content_level_pattern): def __init__(self, content_level_pattern, with_filter=True, limit=1024):
self.content_level_pattern = content_level_pattern self.content_level_pattern = content_level_pattern
self.with_filter = with_filter
if limit is None or limit > 1024:
limit = 1024
if limit < 50:
limit = 50
self.limit = limit
def parse_to_tree(self, text: str, index=0): def parse_to_tree(self, text: str, index=0):
""" """
@ -208,23 +296,27 @@ class SplitModel:
""" """
if len(self.content_level_pattern) == index: if len(self.content_level_pattern) == index:
return return
level_content_list = parse_level(text, pattern=self.content_level_pattern[index]) level_content_list = parse_title_level(text, self.content_level_pattern, index)
for i in range(len(level_content_list)): for i in range(len(level_content_list)):
block = get_level_block(text, level_content_list, i) block = get_level_block(text, level_content_list, i)
children = self.parse_to_tree(text=block.replace(level_content_list[i]['content'][:-1], ""), children = self.parse_to_tree(text=block,
index=index + 1) index=index + 1)
if children is not None and len(children) > 0: if children is not None and len(children) > 0:
level_content_list[i]['children'] = children level_content_list[i]['children'] = children
else: else:
if len(block) > 0: if len(block) > 0:
level_content_list[i]['children'] = [to_tree_obj(block, 'block')] level_content_list[i]['children'] = list(
map(lambda row: to_tree_obj(row, 'block'),
post_handler_paragraph(block, with_filter=self.with_filter, limit=self.limit)))
if len(level_content_list) > 0: if len(level_content_list) > 0:
end_index = text.index(level_content_list[0].get('content')) end_index = text.index(level_content_list[0].get('content'))
if end_index == 0: if end_index == 0:
return level_content_list return level_content_list
other_content = text[0:end_index] other_content = text[0:end_index]
if len(other_content.strip()) > 0: if len(other_content.strip()) > 0:
level_content_list.append(to_tree_obj(other_content, 'block')) level_content_list = [*level_content_list, *list(
map(lambda row: to_tree_obj(row, 'block'),
post_handler_paragraph(other_content, with_filter=self.with_filter, limit=self.limit)))]
return level_content_list return level_content_list
def parse(self, text: str): def parse(self, text: str):
@ -234,4 +326,35 @@ class SplitModel:
:return: 解析后数据 {content:段落数据,keywords:[段落关键词],parent_chain:['段落父级链路']} :return: 解析后数据 {content:段落数据,keywords:[段落关键词],parent_chain:['段落父级链路']}
""" """
result_tree = self.parse_to_tree(text, 0) result_tree = self.parse_to_tree(text, 0)
return flat_map(to_block_paragraph(result_tree)) return result_tree_to_paragraph(result_tree, [], [])
split_model_map = {
'md': SplitModel(
[re.compile("^# .*"), re.compile('(?<!#)## (?!#).*'), re.compile("(?<!#)### (?!#).*"),
re.compile("(?<!#)####(?!#).*"), re.compile("(?<!#)#####(?!#).*"),
re.compile("(?<!#)#####(?!#).*"),
re.compile("(?<! )- .*")]),
'default': SplitModel([re.compile("(?<!\n)\n\n.+")])
}
def get_split_model(filename: str):
"""
根据文件名称获取分段模型
:param filename: 文件名称
:return: 分段模型
"""
if filename.endswith(".md"):
return split_model_map.get('md')
return split_model_map.get("default")
def to_title_tree_string(result_tree: List):
f = flat(result_tree, [], [])
return "\n".join(list(map(lambda r: title_tostring(r), list(filter(lambda row: row.get('state') == 'title', f)))))
def title_tostring(title_obj):
f = "".join(list(map(lambda index: " ", range(0, len(title_obj.get("parent_chain"))))))
return f + "├───" + title_obj.get('content')

View File

@ -1,4 +1,4 @@
# Generated by Django 4.1.10 on 2023-10-09 06:33 # Generated by Django 4.1.10 on 2023-10-24 12:13
from django.db import migrations, models from django.db import migrations, models
import django.db.models.deletion import django.db.models.deletion
@ -36,6 +36,7 @@ class Migration(migrations.Migration):
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
('name', models.CharField(max_length=150, verbose_name='文档名称')), ('name', models.CharField(max_length=150, verbose_name='文档名称')),
('char_length', models.IntegerField(verbose_name='文档字符数 冗余字段')), ('char_length', models.IntegerField(verbose_name='文档字符数 冗余字段')),
('status', models.CharField(choices=[('0', '导入中'), ('1', '已完成'), ('2', '导入失败')], default='0', max_length=1, verbose_name='状态')),
('is_active', models.BooleanField(default=True)), ('is_active', models.BooleanField(default=True)),
('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')), ('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')),
], ],
@ -50,11 +51,14 @@ class Migration(migrations.Migration):
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
('content', models.CharField(max_length=1024, verbose_name='段落内容')), ('content', models.CharField(max_length=1024, verbose_name='段落内容')),
('title', models.CharField(default='', max_length=256, verbose_name='标题')),
('hit_num', models.IntegerField(default=0, verbose_name='命中数量')), ('hit_num', models.IntegerField(default=0, verbose_name='命中数量')),
('star_num', models.IntegerField(default=0, verbose_name='点赞数')), ('star_num', models.IntegerField(default=0, verbose_name='点赞数')),
('trample_num', models.IntegerField(default=0, verbose_name='点踩数')), ('trample_num', models.IntegerField(default=0, verbose_name='点踩数')),
('status', models.CharField(choices=[('0', '导入中'), ('1', '已完成'), ('2', '导入失败')], default='0', max_length=1, verbose_name='状态')),
('is_active', models.BooleanField(default=True)), ('is_active', models.BooleanField(default=True)),
('document', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document')), ('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')),
('document', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document')),
], ],
options={ options={
'db_table': 'paragraph', 'db_table': 'paragraph',
@ -67,25 +71,15 @@ class Migration(migrations.Migration):
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
('content', models.CharField(max_length=256, verbose_name='问题内容')), ('content', models.CharField(max_length=256, verbose_name='问题内容')),
('hit_num', models.IntegerField(default=0, verbose_name='命中数量')),
('star_num', models.IntegerField(default=0, verbose_name='点赞数')),
('trample_num', models.IntegerField(default=0, verbose_name='点踩数')),
('dataset', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')),
('document', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document')),
('paragraph', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph')),
], ],
options={ options={
'db_table': 'problem', 'db_table': 'problem',
}, },
), ),
migrations.CreateModel(
name='ProblemAnswerMapping',
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
('hit_num', models.IntegerField(default=0, verbose_name='命中数量')),
('star_num', models.IntegerField(default=0, verbose_name='点赞数')),
('trample_num', models.IntegerField(default=0, verbose_name='点踩数')),
('paragraph', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph')),
('problem', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.problem')),
],
options={
'db_table': 'problem_paragraph_mapping',
},
),
] ]

View File

@ -14,6 +14,13 @@ from common.mixins.app_model_mixin import AppModelMixin
from users.models import User from users.models import User
class Status(models.TextChoices):
"""订单类型"""
embedding = 0, '导入中'
success = 1, '已完成'
error = 2, '导入失败'
class DataSet(AppModelMixin): class DataSet(AppModelMixin):
""" """
数据集表 数据集表
@ -35,6 +42,8 @@ class Document(AppModelMixin):
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING) dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
name = models.CharField(max_length=150, verbose_name="文档名称") name = models.CharField(max_length=150, verbose_name="文档名称")
char_length = models.IntegerField(verbose_name="文档字符数 冗余字段") char_length = models.IntegerField(verbose_name="文档字符数 冗余字段")
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
default=Status.embedding)
is_active = models.BooleanField(default=True) is_active = models.BooleanField(default=True)
class Meta: class Meta:
@ -46,11 +55,15 @@ class Paragraph(AppModelMixin):
段落表 段落表
""" """
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING) document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False)
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
content = models.CharField(max_length=1024, verbose_name="段落内容") content = models.CharField(max_length=1024, verbose_name="段落内容")
title = models.CharField(max_length=256, verbose_name="标题", default="")
hit_num = models.IntegerField(verbose_name="命中数量", default=0) hit_num = models.IntegerField(verbose_name="命中数量", default=0)
star_num = models.IntegerField(verbose_name="点赞数", default=0) star_num = models.IntegerField(verbose_name="点赞数", default=0)
trample_num = models.IntegerField(verbose_name="点踩数", default=0) trample_num = models.IntegerField(verbose_name="点踩数", default=0)
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
default=Status.embedding)
is_active = models.BooleanField(default=True) is_active = models.BooleanField(default=True)
class Meta: class Meta:
@ -62,23 +75,13 @@ class Problem(AppModelMixin):
问题表 问题表
""" """
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False)
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False)
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
content = models.CharField(max_length=256, verbose_name="问题内容") content = models.CharField(max_length=256, verbose_name="问题内容")
class Meta:
db_table = "problem"
class ProblemAnswerMapping(AppModelMixin):
"""
问题 段落 映射表
"""
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING)
problem = models.ForeignKey(Problem, on_delete=models.DO_NOTHING)
hit_num = models.IntegerField(verbose_name="命中数量", default=0) hit_num = models.IntegerField(verbose_name="命中数量", default=0)
star_num = models.IntegerField(verbose_name="点赞数", default=0) star_num = models.IntegerField(verbose_name="点赞数", default=0)
trample_num = models.IntegerField(verbose_name="点踩数", default=0) trample_num = models.IntegerField(verbose_name="点踩数", default=0)
class Meta: class Meta:
db_table = "problem_paragraph_mapping" db_table = "problem"

View File

@ -8,7 +8,6 @@
""" """
import os.path import os.path
import uuid import uuid
from functools import reduce
from typing import Dict from typing import Dict
from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields import ArrayField
@ -19,11 +18,12 @@ from drf_yasg import openapi
from rest_framework import serializers from rest_framework import serializers
from common.db.search import get_dynamics_model, native_page_search, native_search from common.db.search import get_dynamics_model, native_page_search, native_search
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin from common.mixins.api_mixin import ApiMixin
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from dataset.models.data_set import DataSet, Document, Paragraph from dataset.models.data_set import DataSet, Document, Paragraph, Problem
from dataset.serializers.document_serializers import CreateDocumentSerializers from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from setting.models import AuthOperate from setting.models import AuthOperate
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
from users.models import User from users.models import User
@ -81,12 +81,12 @@ class DataSetSerializers(serializers.ModelSerializer):
user_id = self.data.get("user_id") user_id = self.data.get("user_id")
query_set_dict = {} query_set_dict = {}
query_set = QuerySet(model=get_dynamics_model( query_set = QuerySet(model=get_dynamics_model(
{'dataset.name': models.CharField(), 'dataset.desc': models.CharField(), {'temp.name': models.CharField(), 'temp.desc': models.CharField(),
"document_temp.char_length": models.IntegerField()})) "document_temp.char_length": models.IntegerField()}))
if "desc" in self.data: if "desc" in self.data and self.data.get('desc') is not None:
query_set = query_set.filter(**{'dataset.desc__contains': self.data.get("desc")}) query_set = query_set.filter(**{'temp.desc__contains': self.data.get("desc")})
if "name" in self.data: if "name" in self.data and self.data.get('name') is not None:
query_set = query_set.filter(**{'dataset.name__contains': self.data.get("name")}) query_set = query_set.filter(**{'temp.name__contains': self.data.get("name")})
query_set_dict['default_sql'] = query_set query_set_dict['default_sql'] = query_set
@ -133,9 +133,7 @@ class DataSetSerializers(serializers.ModelSerializer):
@staticmethod @staticmethod
def get_response_body_api(): def get_response_body_api():
return openapi.Schema(type=openapi.TYPE_ARRAY, return DataSetSerializers.Operate.get_response_body_api()
title="数据集列表", description="数据集列表",
items=DataSetSerializers.Operate.get_response_body_api())
class Create(ApiMixin, serializers.Serializer): class Create(ApiMixin, serializers.Serializer):
""" """
@ -157,7 +155,7 @@ class DataSetSerializers(serializers.ModelSerializer):
message="数据集名称在1-256个字符之间") message="数据集名称在1-256个字符之间")
]) ])
documents = CreateDocumentSerializers(required=False, many=True) documents = DocumentInstanceSerializer(required=False, many=True)
def is_valid(self, *, raise_exception=False): def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True) super().is_valid(raise_exception=True)
@ -168,28 +166,46 @@ class DataSetSerializers(serializers.ModelSerializer):
dataset_id = uuid.uuid1() dataset_id = uuid.uuid1()
dataset = DataSet( dataset = DataSet(
**{'id': dataset_id, 'name': self.data.get("name"), 'desc': self.data.get('desc'), 'user': user}) **{'id': dataset_id, 'name': self.data.get("name"), 'desc': self.data.get('desc'), 'user': user})
document_model_list = []
paragraph_model_list = []
if 'documents' in self.data:
documents = self.data.get('documents')
for document in documents:
document_model = Document(**{'dataset': dataset, 'id': uuid.uuid1(), 'name': document.get('name'),
'char_length': reduce(lambda x, y: x + y,
list(
map(lambda p: len(p),
document.get("paragraphs"))), 0)})
document_model_list.append(document_model)
if 'paragraphs' in document:
paragraph_model_list += list(map(lambda p: Paragraph(
**{'document': document_model, 'id': uuid.uuid1(), 'content': p}),
document.get('paragraphs')))
# 插入数据集 # 插入数据集
dataset.save() dataset.save()
# 插入文档 for document in self.data.get('documents') if 'documents' in self.data else []:
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(document, with_valid=True,
# 插入段落 with_embedding=False)
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None ListenerManagement.embedding_by_dataset_signal.send(str(dataset.id))
return True return {**DataSetSerializers(dataset).data,
'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(with_valid=True)}
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count',
'update_time', 'create_time', 'document_list'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称",
description="名称", default="测试数据集"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述",
description="描述", default="测试数据集描述"),
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id",
description="所属用户id", default="user_xxxx"),
'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数",
description="字符数", default=10),
'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量",
description="文档数量", default=1),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
),
'document_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档列表",
description="文档列表",
items=DocumentSerializers.Operate.get_response_body_api())
}
)
@staticmethod @staticmethod
def get_request_body_api(): def get_request_body_api():
@ -200,7 +216,7 @@ class DataSetSerializers(serializers.ModelSerializer):
'name': openapi.Schema(type=openapi.TYPE_STRING, title="数据集名称", description="数据集名称"), 'name': openapi.Schema(type=openapi.TYPE_STRING, title="数据集名称", description="数据集名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="数据集描述", description="数据集描述"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="数据集描述", description="数据集描述"),
'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据", 'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据",
items=CreateDocumentSerializers().get_request_body_api() items=DocumentSerializers().Create.get_request_body_api()
) )
} }
) )
@ -217,10 +233,11 @@ class DataSetSerializers(serializers.ModelSerializer):
def delete(self): def delete(self):
self.is_valid() self.is_valid()
dataset = QuerySet(DataSet).get(id=self.data.get("id")) dataset = QuerySet(DataSet).get(id=self.data.get("id"))
document_list = QuerySet(Document).filter(dataset=dataset) QuerySet(Document).filter(dataset=dataset).delete()
QuerySet(Paragraph).filter(document__in=document_list).delete() QuerySet(Paragraph).filter(dataset=dataset).delete()
document_list.delete() QuerySet(Problem).filter(dataset=dataset).delete()
dataset.delete() dataset.delete()
ListenerManagement.delete_embedding_by_dataset_signal.send(self.data.get('id'))
return True return True
def one(self, user_id, with_valid=True): def one(self, user_id, with_valid=True):
@ -303,9 +320,9 @@ class DataSetSerializers(serializers.ModelSerializer):
@staticmethod @staticmethod
def get_request_params_api(): def get_request_params_api():
return [openapi.Parameter(name='id', return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH, in_=openapi.IN_PATH,
type=openapi.TYPE_STRING, type=openapi.TYPE_STRING,
required=False, required=True,
description='数据集id') description='数据集id')
] ]

View File

@ -6,20 +6,29 @@
@date2023/9/22 13:43 @date2023/9/22 13:43
@desc: @desc:
""" """
import os
import uuid import uuid
from functools import reduce from functools import reduce
from typing import List, Dict
from django.core import validators from django.core import validators
from django.db import transaction
from django.db.models import QuerySet from django.db.models import QuerySet
from drf_yasg import openapi from drf_yasg import openapi
from rest_framework import serializers from rest_framework import serializers
from common.db.search import native_search, native_page_search
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin from common.mixins.api_mixin import ApiMixin
from dataset.models.data_set import DataSet, Document, Paragraph from common.util.file_util import get_file_content
from common.util.split_model import SplitModel, get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR
class CreateDocumentSerializers(ApiMixin, serializers.Serializer): class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
name = serializers.CharField(required=True, name = serializers.CharField(required=True,
validators=[ validators=[
validators.MaxLengthValidator(limit_value=128, validators.MaxLengthValidator(limit_value=128,
@ -28,52 +37,265 @@ class CreateDocumentSerializers(ApiMixin, serializers.Serializer):
message="数据集名称在1-128个字符之间") message="数据集名称在1-128个字符之间")
]) ])
paragraphs = serializers.ListField(required=False, paragraphs = ParagraphInstanceSerializer(required=False, many=True)
child=serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=256,
message="段落在1-256个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="段落在1-256个字符之间")
]))
def is_valid(self, *, dataset_id=None, raise_exception=False): @staticmethod
if not QuerySet(DataSet).filter(id=dataset_id).exists(): def get_request_body_api():
raise AppApiException(10000, "数据集id不存在")
return super().is_valid(raise_exception=True)
def save(self, dataset_id: str, **kwargs):
document_model = Document(
**{'dataset': DataSet(id=dataset_id),
'id': uuid.uuid1(),
'name': self.data.get('name'),
'char_length': reduce(lambda x, y: x + y, list(map(lambda p: len(p), self.data.get("paragraphs"))), 0)})
paragraph_model_list = list(map(lambda p: Paragraph(
**{'document': document_model, 'id': uuid.uuid1(), 'content': p}),
self.data.get('paragraphs')))
# 插入文档
document_model.save()
# 插入段落
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
return True
def get_request_body_api(self):
return openapi.Schema( return openapi.Schema(
type=openapi.TYPE_OBJECT, type=openapi.TYPE_OBJECT,
required=['name', 'paragraph'], required=['name', 'paragraphs'],
properties={ properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"), 'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"),
'paragraphs': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表", 'paragraphs': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表",
items=openapi.Schema(type=openapi.TYPE_STRING, title="段落数据", items=ParagraphSerializers.Create.get_request_body_api())
description="段落数据"))
} }
) )
def get_request_params_api(self):
return [openapi.Parameter(name='dataset_id', class DocumentSerializers(ApiMixin, serializers.Serializer):
in_=openapi.IN_PATH, class Query(ApiMixin, serializers.Serializer):
type=openapi.TYPE_STRING, # 数据集id
dataset_id = serializers.UUIDField(required=True)
name = serializers.CharField(required=False,
validators=[
validators.MaxLengthValidator(limit_value=128,
message="文档名称在1-128个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="数据集名称在1-128个字符之间")
])
def get_query_set(self):
query_set = QuerySet(model=Document)
query_set = query_set.filter(**{'dataset_id': self.data.get("dataset_id")})
if 'name' in self.data and self.data.get('name') is not None:
query_set = query_set.filter(**{'name__contains': self.data.get('name')})
return query_set
def list(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
query_set = self.get_query_set()
return native_search(query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')))
def page(self, current_page, page_size):
query_set = self.get_query_set()
return native_page_search(current_page, page_size, query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')))
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='name',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='文档名称')]
@staticmethod
def get_response_body_api():
return openapi.Schema(type=openapi.TYPE_ARRAY,
title="文档列表", description="文档列表",
items=DocumentSerializers.Operate.get_response_body_api())
class Operate(ApiMixin, serializers.Serializer):
document_id = serializers.UUIDField(required=True)
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='数据集id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id')
]
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
document_id = self.data.get('document_id')
if not QuerySet(Document).filter(id=document_id).exists():
raise AppApiException(500, "文档id不存在")
def one(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
query_set = QuerySet(model=Document)
query_set = query_set.filter(**{'id': self.data.get("document_id")})
return native_search(query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=True)
def edit(self, instance: Dict, with_valid=False):
if with_valid:
self.is_valid()
_document = QuerySet(Document).get(id=self.data.get("document_id"))
update_keys = ['name', 'is_active']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
_document.__setattr__(update_key, instance.get(update_key))
_document.save()
return self.one()
@transaction.atomic
def delete(self):
document_id = self.data.get("document_id")
QuerySet(model=Document).filter(id=document_id).delete()
# 删除段落
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
# 删除问题
QuerySet(model=Problem).filter(document_id=document_id).delete()
# 删除向量库
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
return True
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'name', 'char_length', 'user_id', 'paragraph_count', 'is_active'
'update_time', 'create_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称",
description="名称", default="测试数据集"),
'char_length': openapi.Schema(type=openapi.TYPE_INTEGER, title="字符数",
description="字符数", default=10),
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="用户id", description="用户id"),
'paragraph_count': openapi.Schema(type=openapi.TYPE_INTEGER, title="文档数量",
description="文档数量", default=1),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",
description="是否可用", default=True),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
)
}
)
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
}
)
class Create(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(DataSet).filter(id=self.data.get('dataset_id')).exists():
raise AppApiException(10000, "数据集id不存在")
return True
def save(self, instance: Dict, with_valid=False, with_embedding=True, **kwargs):
if with_valid:
DocumentInstanceSerializer(data=instance).is_valid()
self.is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
document_model = Document(
**{'dataset_id': dataset_id,
'id': uuid.uuid1(),
'name': instance.get('name'),
'char_length': reduce(lambda x, y: x + y,
[len(p.get('content')) for p in instance.get('paragraphs', [])],
0)})
for paragraph in instance.get('paragraphs') if 'paragraphs' in instance else []:
ParagraphSerializers.Create(
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).save(paragraph,
with_valid=True,
with_embedding=False)
# 插入文档
document_model.save()
if with_embedding:
ListenerManagement.embedding_by_document_signal.send(str(document_model.id))
return DocumentSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).one(
with_valid=True)
@staticmethod
def get_request_body_api():
return DocumentInstanceSerializer.get_request_body_api()
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='数据集id')
]
class Split(ApiMixin, serializers.Serializer):
file = serializers.ListField(required=True)
limit = serializers.IntegerField(required=False)
patterns = serializers.ListField(required=False,
child=serializers.CharField(required=True))
with_filter = serializers.BooleanField(required=False)
def is_valid(self, *, raise_exception=True):
super().is_valid()
files = self.data.get('file')
for f in files:
if f.size > 1024 * 1024 * 10:
raise AppApiException(500, "上传文件最大不能超过10m")
@staticmethod
def get_request_params_api():
return [
openapi.Parameter(name='file',
in_=openapi.IN_FORM,
type=openapi.TYPE_ARRAY,
items=openapi.Items(type=openapi.TYPE_FILE),
required=True, required=True,
description='数据集id')] description='上传文件'),
openapi.Parameter(name='limit',
in_=openapi.IN_FORM,
required=False,
type=openapi.TYPE_INTEGER, title="分段长度", description="分段长度"),
openapi.Parameter(name='patterns',
in_=openapi.IN_FORM,
required=False,
type=openapi.TYPE_ARRAY, items=openapi.Items(type=openapi.TYPE_STRING),
title="分段正则列表", description="分段正则列表"),
openapi.Parameter(name='with_filter',
in_=openapi.IN_FORM,
required=False,
type=openapi.TYPE_BOOLEAN, title="是否清除特殊字符", description="是否清除特殊字符"),
]
def parse(self):
file_list = self.data.get("file")
return list(map(lambda f: file_to_paragraph(f, self.data.get("patterns"), self.data.get("with_filter"),
self.data.get("limit")), file_list))
def file_to_paragraph(file, pattern_list: List, with_filter, limit: int):
data = file.read()
if pattern_list is None or len(pattern_list) > 0:
split_model = SplitModel(pattern_list, with_filter, limit)
else:
split_model = get_split_model(file.name)
try:
content = data.decode('utf-8')
except BaseException as e:
return {'name': file.name,
'content': []}
return {'name': file.name,
'content': split_model.parse(content)
}

View File

@ -0,0 +1,278 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file paragraph_serializers.py
@date2023/10/16 15:51
@desc:
"""
import uuid
from typing import Dict
from django.core import validators
from django.db import transaction
from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
from common.db.search import page_search
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from dataset.models import Paragraph, Problem
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer
class ParagraphSerializer(serializers.ModelSerializer):
class Meta:
model = Paragraph
fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title',
'create_time', 'update_time']
class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer):
"""
段落实例对象
"""
content = serializers.CharField(required=True, validators=[
validators.MaxLengthValidator(limit_value=1024,
message="段落在1-1024个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="段落在1-1024个字符之间")
])
title = serializers.CharField(required=False)
problem_list = ProblemInstanceSerializer(required=False, many=True)
is_active = serializers.BooleanField(required=False)
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['content'],
properties={
'content': openapi.Schema(type=openapi.TYPE_STRING, title="分段内容", description="分段内容"),
'title': openapi.Schema(type=openapi.TYPE_STRING, title="分段标题",
description="分段标题"),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
'problem_list': openapi.Schema(type=openapi.TYPE_ARRAY, title='问题列表',
description="问题列表",
items=ProblemInstanceSerializer.get_request_body_api())
}
)
class ParagraphSerializers(ApiMixin, serializers.Serializer):
class Operate(ApiMixin, serializers.Serializer):
# 段落id
paragraph_id = serializers.UUIDField(required=True)
# 数据集id
dataset_id = serializers.UUIDField(required=True)
# 数据集id
document_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
raise AppApiException(500, "段落id不存在")
@transaction.atomic
def edit(self, instance: Dict):
self.is_valid()
_paragraph = QuerySet(Paragraph).get(id=self.data.get("paragraph_id"))
update_keys = ['title', 'content', 'is_active']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
_paragraph.__setattr__(update_key, instance.get(update_key))
if 'problem_list' in instance:
update_problem_list = list(
filter(lambda row: 'id' in row and row.get('id') is not None, instance.get('problem_list')))
create_problem_list = list(filter(lambda row: row.get('id') is None, instance.get('problem_list')))
# 问题集合
problem_list = QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))
# 校验前端 携带过来的id
for update_problem in update_problem_list:
if not set([str(row.id) for row in problem_list]).__contains__(update_problem.get('id')):
raise AppApiException(500, update_problem.get('id') + '问题id不存在')
# 对比需要删除的问题
delete_problem_list = list(filter(
lambda row: not [str(update_row.get('id')) for update_row in update_problem_list].__contains__(
str(row.id)), problem_list)) if len(update_problem_list) > 0 else []
# 删除问题
QuerySet(Problem).filter(id__in=[row.id for row in delete_problem_list]).delete() if len(
delete_problem_list) > 0 else None
# 插入新的问题
QuerySet(Problem).bulk_create(
[Problem(id=uuid.uuid1(), content=p.get('content'), paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'), document_id=self.data.get('document_id')) for
p in create_problem_list]) if len(create_problem_list) else None
# 修改问题集合
QuerySet(Problem).bulk_update(
[Problem(id=row.get('id'), content=row.get('content')) for row in update_problem_list],
['content']) if len(
update_problem_list) > 0 else None
_paragraph.save()
if 'is_active' in instance and instance.get('is_active') is not None:
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
s.send(self.data.get('paragraph_id'))
return self.one()
def get_problem_list(self):
return [ProblemSerializer(problem).data for problem in
QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))]
def one(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
return {**ParagraphSerializer(QuerySet(model=Paragraph).get(id=self.data.get('paragraph_id'))).data,
'problem_list': self.get_problem_list()}
def delete(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
paragraph_id = self.data.get('paragraph_id')
QuerySet(Paragraph).filter(id=paragraph_id).delete()
QuerySet(Problem).filter(paragraph_id=paragraph_id).delete()
ListenerManagement.delete_embedding_by_paragraph_signal.send(paragraph_id)
@staticmethod
def get_request_body_api():
return ParagraphInstanceSerializer.get_request_body_api()
@staticmethod
def get_response_body_api():
return ParagraphInstanceSerializer.get_request_body_api()
@staticmethod
def get_request_params_api():
return [openapi.Parameter(type=openapi.TYPE_STRING, in_=openapi.IN_PATH, name='paragraph_id',
description="段落id")]
class Create(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True)
def save(self, instance: Dict, with_valid=True, with_embedding=True):
if with_valid:
ParagraphSerializers(data=instance).is_valid(raise_exception=True)
self.is_valid()
dataset_id = self.data.get("dataset_id")
document_id = self.data.get('document_id')
paragraph = Paragraph(id=uuid.uuid1(),
document_id=document_id,
content=instance.get("content"),
dataset_id=dataset_id,
title=instance.get("title") if 'title' in instance else '')
# 插入段落
paragraph.save()
problem_model_list = [Problem(id=uuid.uuid1(), content=problem.get('content'), paragraph_id=paragraph.id,
document_id=document_id, dataset_id=dataset_id) for problem in (
instance.get('problem_list') if 'problem_list' in instance else [])]
# 插入問題
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
if with_embedding:
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id))
return ParagraphSerializers.Operate(
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
with_valid=True)
@staticmethod
def get_request_body_api():
return ParagraphInstanceSerializer.get_request_body_api()
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='数据集id'),
openapi.Parameter(name='document_id', in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description="文档id")
]
class Query(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True)
title = serializers.CharField(required=False)
def get_query_set(self):
query_set = QuerySet(model=Paragraph)
query_set = query_set.filter(
**{'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get("document_id")})
if 'title' in self.data:
query_set = query_set.filter(
**{'title__contains': self.data.get('title')})
return query_set
def list(self):
return list(map(lambda row: ParagraphSerializer(row).data, self.get_query_set()))
def page(self, current_page, page_size):
query_set = self.get_query_set()
return page_search(current_page, page_size, query_set, lambda row: ParagraphSerializer(row).data)
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='title',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='标题')
]
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title',
'create_time', 'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容",
description="段落内容", default='段落内容'),
'title': openapi.Schema(type=openapi.TYPE_STRING, title="标题",
description="标题", default="xxx的描述"),
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
default=1),
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
description="点赞数量", default=1),
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
description="点踩数", default=1),
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
description="文档id", default='xxx'),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",
description="是否可用", default=True),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
)
}
)

View File

@ -0,0 +1,222 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file problem_serializers.py
@date2023/10/23 13:55
@desc:
"""
import uuid
from typing import Dict
from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from dataset.models import Problem, Paragraph
from embedding.models import SourceType
from embedding.vector.pg_vector import PGVector
class ProblemSerializer(serializers.ModelSerializer):
class Meta:
model = Problem
fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id', 'document_id',
'create_time', 'update_time']
class ProblemInstanceSerializer(ApiMixin, serializers.Serializer):
id = serializers.CharField(required=False)
content = serializers.CharField(required=True)
@staticmethod
def get_request_body_api():
return openapi.Schema(type=openapi.TYPE_OBJECT,
required=["content"],
properties={
'id': openapi.Schema(
type=openapi.TYPE_STRING,
title="问题id,修改的时候传递,创建的时候不传"),
'content': openapi.Schema(
type=openapi.TYPE_STRING, title="内容")
})
class ProblemSerializers(ApiMixin, serializers.Serializer):
class Create(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True)
paragraph_id = serializers.UUIDField(required=True)
def save(self, instance: Dict, with_valid=True, with_embedding=True):
if with_valid:
self.is_valid()
ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
problem = Problem(id=uuid.uuid1(), paragraph_id=self.data.get('paragraph_id'),
document_id=self.data.get('document_id'), dataset_id=self.data.get('dataset_id'),
content=instance.get('content'))
problem.save()
if with_embedding:
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': problem.id,
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id')})
return ProblemSerializers.Operate(
data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'), 'problem_id': problem.id}).one(with_valid=True)
@staticmethod
def get_request_body_api():
return ProblemInstanceSerializer.get_request_body_api()
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='数据集id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id')]
class Query(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True)
paragraph_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
raise AppApiException(500, "段落id不存在")
def get_query_set(self):
dataset_id = self.data.get('dataset_id')
document_id = self.data.get('document_id')
paragraph_id = self.data.get("paragraph_id")
return QuerySet(Problem).filter(
**{'paragraph_id': paragraph_id, 'dataset_id': dataset_id, 'document_id': document_id})
def list(self, with_valid=False):
"""
获取问题列表
:param with_valid: 是否校验
:return: 问题列表
"""
if with_valid:
self.is_valid(raise_exception=True)
query_set = self.get_query_set()
return [ProblemSerializer(p).data for p in query_set]
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='数据集id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id')
, openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id')]
class Operate(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True)
paragraph_id = serializers.UUIDField(required=True)
problem_id = serializers.UUIDField(required=True)
def delete(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
QuerySet(Problem).filter(**{'id': self.data.get('problem_id')}).delete()
PGVector().delete_by_source_id(self.data.get('problem_id'), SourceType.PROBLEM)
ListenerManagement.delete_embedding_by_source_signal.send(self.data.get('problem_id'))
return True
def one(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
return ProblemInstanceSerializer(QuerySet(Problem).get(**{'id': self.data.get('problem_id')})).data
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='数据集id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id')
, openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id'),
openapi.Parameter(name='problem_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='问题id')
]
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id',
'document_id',
'create_time', 'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
description="问题内容", default='问题内容'),
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
default=1),
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
description="点赞数量", default=1),
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
description="点踩数", default=1),
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
description="文档id", default='xxx'),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
)
}
)

View File

@ -0,0 +1,5 @@
SELECT
"document".* ,
(SELECT "count"("id") FROM "paragraph" WHERE document_id="document"."id") as "paragraph_count"
FROM
"document" "document"

View File

@ -0,0 +1,10 @@
SELECT
problem."id",
problem."content",
problem_paragraph_mapping.hit_num,
problem_paragraph_mapping.star_num,
problem_paragraph_mapping.trample_num,
problem_paragraph_mapping.paragraph_id
FROM
problem problem
LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem."id" = problem_paragraph_mapping.problem_id

View File

@ -7,5 +7,18 @@ urlpatterns = [
path('dataset', views.Dataset.as_view(), name="dataset"), path('dataset', views.Dataset.as_view(), name="dataset"),
path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"), path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"),
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"), path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document') path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
path('dataset/<str:dataset_id>/document/<str:document_id>', views.Document.Operate.as_view(),
name="document_operate"),
path('dataset/document/split', views.Document.Split.as_view(),
name="document_operate"),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph', views.Paragraph.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<int:current_page>/<int:page_size>',
views.Paragraph.Page.as_view(), name='paragraph_page'),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<paragraph_id>',
views.Paragraph.Operate.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem',
views.Problem.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem/<str:problem_id>',
views.Problem.Operate.as_view())
] ]

View File

@ -8,3 +8,5 @@
""" """
from .dataset import * from .dataset import *
from .document import * from .document import *
from .paragraph import *
from .problem import *

View File

@ -26,7 +26,7 @@ class Dataset(APIView):
@swagger_auto_schema(operation_summary="获取数据集列表", @swagger_auto_schema(operation_summary="获取数据集列表",
operation_id="获取数据集列表", operation_id="获取数据集列表",
manual_parameters=DataSetSerializers.Query.get_request_params_api(), manual_parameters=DataSetSerializers.Query.get_request_params_api(),
responses=get_api_response(DataSetSerializers.Query.get_response_body_api())) responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()))
@has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND) @has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
def get(self, request: Request): def get(self, request: Request):
d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)}) d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)})
@ -36,19 +36,21 @@ class Dataset(APIView):
@action(methods=['POST'], detail=False) @action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建数据集", @swagger_auto_schema(operation_summary="创建数据集",
operation_id="创建数据集", operation_id="创建数据集",
request_body=DataSetSerializers.Create.get_request_body_api()) request_body=DataSetSerializers.Create.get_request_body_api(),
responses=get_api_response(DataSetSerializers.Create.get_response_body_api()))
@has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND) @has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND)
def post(self, request: Request): def post(self, request: Request):
s = DataSetSerializers.Create(data=request.data) s = DataSetSerializers.Create(data=request.data)
if s.is_valid(): s.is_valid(raise_exception=True)
s.save(request.user) return result.success(s.save(request.user))
return result.success("ok")
class Operate(APIView): class Operate(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]
@action(methods="DELETE", detail=False) @action(methods="DELETE", detail=False)
@swagger_auto_schema(operation_summary="删除数据集", operation_id="删除数据集") @swagger_auto_schema(operation_summary="删除数据集", operation_id="删除数据集",
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
responses=result.get_default_response())
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=keywords.get('dataset_id')), dynamic_tag=keywords.get('dataset_id')),
lambda r, k: Permission(group=Group.DATASET, operate=Operate.DELETE, lambda r, k: Permission(group=Group.DATASET, operate=Operate.DELETE,
@ -59,6 +61,7 @@ class Dataset(APIView):
@action(methods="GET", detail=False) @action(methods="GET", detail=False)
@swagger_auto_schema(operation_summary="查询数据集详情根据数据集id", operation_id="查询数据集详情根据数据集id", @swagger_auto_schema(operation_summary="查询数据集详情根据数据集id", operation_id="查询数据集详情根据数据集id",
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
responses=get_api_response(DataSetSerializers.Operate.get_response_body_api())) responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()))
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE, @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=keywords.get('dataset_id'))) dynamic_tag=keywords.get('dataset_id')))
@ -67,6 +70,7 @@ class Dataset(APIView):
@action(methods="PUT", detail=False) @action(methods="PUT", detail=False)
@swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息", @swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息",
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
request_body=DataSetSerializers.Operate.get_request_body_api(), request_body=DataSetSerializers.Operate.get_request_body_api(),
responses=get_api_response(DataSetSerializers.Operate.get_response_body_api())) responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()))
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
@ -84,8 +88,10 @@ class Dataset(APIView):
manual_parameters=get_page_request_params( manual_parameters=get_page_request_params(
DataSetSerializers.Query.get_request_params_api()), DataSetSerializers.Query.get_request_params_api()),
responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api())) responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api()))
@has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND) @has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
def get(self, request: Request, current_page, page_size): def get(self, request: Request, current_page, page_size):
d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)}) d = DataSetSerializers.Query(
data={'name': request.query_params.get('name', None), 'desc': request.query_params.get("desc", None),
'user_id': str(request.user.id)})
d.is_valid() d.is_valid()
return result.success(d.page(current_page, page_size)) return result.success(d.page(current_page, page_size))

View File

@ -9,13 +9,15 @@
from drf_yasg.utils import swagger_auto_schema from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.parsers import MultiPartParser
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.views import Request from rest_framework.views import Request
from common.auth import TokenAuth, has_permissions from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import Permission, Group, Operate, PermissionConstants from common.constants.permission_constants import Permission, Group, Operate, PermissionConstants
from common.response import result from common.response import result
from dataset.serializers.dataset_serializers import CreateDocumentSerializers from common.util.common import query_params_to_single_dict
from dataset.serializers.document_serializers import DocumentSerializers
class Document(APIView): class Document(APIView):
@ -24,28 +26,102 @@ class Document(APIView):
@action(methods=['POST'], detail=False) @action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建文档", @swagger_auto_schema(operation_summary="创建文档",
operation_id="创建文档", operation_id="创建文档",
request_body=CreateDocumentSerializers().get_request_body_api(), request_body=DocumentSerializers.Create.get_request_body_api(),
manual_parameters=CreateDocumentSerializers().get_request_params_api()) manual_parameters=DocumentSerializers.Create.get_request_params_api(),
@has_permissions(PermissionConstants.DATASET_CREATE) responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def post(self, request: Request, dataset_id: str): def post(self, request: Request, dataset_id: str):
d = CreateDocumentSerializers(data=request.data) return result.success(
if d.is_valid(dataset_id=dataset_id): DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(request.data, with_valid=True))
d.save(dataset_id)
return result.success("ok")
class DocumentDetails(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False) @action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取文档详情", @swagger_auto_schema(operation_summary="文档列表",
operation_id="获取文档详情", operation_id="文档列表",
request_body=CreateDocumentSerializers().get_request_body_api(), manual_parameters=DocumentSerializers.Query.get_request_params_api(),
manual_parameters=CreateDocumentSerializers().get_request_params_api()) responses=result.get_api_response(DocumentSerializers.Query.get_response_body_api()))
@has_permissions( @has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=k.get('dataset_id'))) lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str): def get(self, request: Request, dataset_id: str):
d = CreateDocumentSerializers(data=request.data) d = DocumentSerializers.Query(
if d.is_valid(dataset_id=dataset_id): data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id})
d.save(dataset_id) d.is_valid(raise_exception=True)
return result.success("ok") return result.success(d.list())
class Operate(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取文档详情",
operation_id="获取文档详情",
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, document_id: str):
operate = DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id})
operate.is_valid(raise_exception=True)
return result.success(operate.one())
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="修改文档",
operation_id="修改文档",
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
request_body=DocumentSerializers.Operate.get_request_body_api(),
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api())
)
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def put(self, request: Request, dataset_id: str, document_id: str):
return result.success(
DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).edit(
request.data,
with_valid=True))
@action(methods=['DELETE'], detail=False)
@swagger_auto_schema(operation_summary="删除文档",
operation_id="删除文档",
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
responses=result.get_default_response())
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def delete(self, request: Request, dataset_id: str, document_id: str):
operate = DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id})
operate.is_valid(raise_exception=True)
return result.success(operate.delete())
class Split(APIView):
parser_classes = [MultiPartParser]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="分段文档",
operation_id="分段文档",
manual_parameters=DocumentSerializers.Split.get_request_params_api())
def post(self, request: Request):
ds = DocumentSerializers.Split(
data={'file': request.FILES.getlist('file'),
'patterns': request.data.getlist('patterns[]')})
ds.is_valid(raise_exception=True)
return result.success(ds.parse())
class Page(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取数据集分页列表",
operation_id="获取数据集分页列表",
manual_parameters=DocumentSerializers.Query.get_request_params_api(),
responses=result.get_page_api_response(DocumentSerializers.Query.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, current_page, page_size):
d = DocumentSerializers.Query(
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id})
d.is_valid(raise_exception=True)
return result.success(d.page(current_page, page_size))

View File

@ -0,0 +1,115 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file paragraph_serializers.py
@date2023/10/16 15:51
@desc:
"""
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.views import APIView
from rest_framework.views import Request
from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import Permission, Group, Operate
from common.response import result
from common.util.common import query_params_to_single_dict
from dataset.serializers.paragraph_serializers import ParagraphSerializers
class Paragraph(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="段落列表",
operation_id="段落列表",
manual_parameters=ParagraphSerializers.Query.get_request_params_api(),
responses=result.get_api_array_response(ParagraphSerializers.Query.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, document_id: str):
q = ParagraphSerializers.Query(
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id,
'document_id': document_id})
q.is_valid(raise_exception=True)
return result.success(q.list())
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建段落",
operation_id="创建段落",
manual_parameters=ParagraphSerializers.Create.get_request_params_api(),
request_body=ParagraphSerializers.Create.get_request_body_api(),
responses=result.get_api_response(ParagraphSerializers.Query.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def post(self, request: Request, dataset_id: str, document_id: str):
return result.success(
ParagraphSerializers.Create(data={'dataset_id': dataset_id, 'document_id': document_id}).save(request.data))
class Operate(APIView):
authentication_classes = [TokenAuth]
@action(methods=['UPDATE'], detail=False)
@swagger_auto_schema(operation_summary="修改段落数据",
operation_id="修改段落数据",
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
request_body=ParagraphSerializers.Operate.get_request_body_api(),
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def put(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
o = ParagraphSerializers.Operate(
data={"paragraph_id": paragraph_id, 'dataset_id': dataset_id, 'document_id': document_id})
o.is_valid(raise_exception=True)
return result.success(o.edit(request.data))
@action(methods=['UPDATE'], detail=False)
@swagger_auto_schema(operation_summary="获取段落详情",
operation_id="获取段落详情",
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
o = ParagraphSerializers.Operate(
data={"dataset_id": dataset_id, 'document_id': document_id, "paragraph_id": paragraph_id})
o.is_valid(raise_exception=True)
return result.success(o.one())
@action(methods=['DELETE'], detail=False)
@swagger_auto_schema(operation_summary="删除段落",
operation_id="删除段落",
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
responses=result.get_default_response())
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def delete(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
o = ParagraphSerializers.Operate(
data={"dataset_id": dataset_id, 'document_id': document_id, "paragraph_id": paragraph_id})
o.is_valid(raise_exception=True)
return result.success(o.delete())
class Page(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="分页获取段落列表",
operation_id="分页获取段落列表",
manual_parameters=result.get_page_request_params(
ParagraphSerializers.Query.get_request_params_api()),
responses=result.get_page_api_response(ParagraphSerializers.Query.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, document_id: str, current_page, page_size):
d = ParagraphSerializers.Query(
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id,
'document_id': document_id})
d.is_valid(raise_exception=True)
return result.success(d.page(current_page, page_size))

View File

@ -0,0 +1,65 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file problem.py
@date2023/10/23 13:54
@desc:
"""
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.views import APIView
from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import Permission, Group, Operate
from common.response import result
from dataset.serializers.problem_serializers import ProblemSerializers
class Problem(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="添加关联问题",
operation_id="添加段落关联问题",
manual_parameters=ProblemSerializers.Create.get_request_params_api(),
request_body=ProblemSerializers.Create.get_request_body_api(),
responses=result.get_api_response(ProblemSerializers.Operate.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def post(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
return result.success(ProblemSerializers.Create(
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).save(
request.data, with_valid=True))
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取段落问题列表",
operation_id="获取段落问题列表",
manual_parameters=ProblemSerializers.Query.get_request_params_api(),
responses=result.get_api_array_response(ProblemSerializers.Operate.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
return result.success(ProblemSerializers.Query(
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).list(
with_valid=True))
class Operate(APIView):
authentication_classes = [TokenAuth]
@action(methods=['DELETE'], detail=False)
@swagger_auto_schema(operation_summary="删除段落问题",
operation_id="删除段落问题",
manual_parameters=ProblemSerializers.Query.get_request_params_api(),
responses=result.get_default_response())
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def delete(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str):
o = ProblemSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id,
'problem_id': problem_id})
return result.success(o.delete(with_valid=True))

View File

@ -1,4 +1,4 @@
# Generated by Django 4.1.10 on 2023-10-09 06:33 # Generated by Django 4.1.10 on 2023-10-24 12:13
import common.field.vector_field import common.field.vector_field
from django.db import migrations, models from django.db import migrations, models
@ -20,8 +20,11 @@ class Migration(migrations.Migration):
('id', models.CharField(max_length=128, primary_key=True, serialize=False, verbose_name='主键id')), ('id', models.CharField(max_length=128, primary_key=True, serialize=False, verbose_name='主键id')),
('source_id', models.CharField(max_length=128, verbose_name='资源id')), ('source_id', models.CharField(max_length=128, verbose_name='资源id')),
('source_type', models.CharField(choices=[('0', '问题'), ('1', '段落')], default='0', max_length=1, verbose_name='资源类型')), ('source_type', models.CharField(choices=[('0', '问题'), ('1', '段落')], default='0', max_length=1, verbose_name='资源类型')),
('is_active', models.BooleanField(default=True, max_length=1, verbose_name='是否可用')),
('embedding', common.field.vector_field.VectorField(verbose_name='向量')), ('embedding', common.field.vector_field.VectorField(verbose_name='向量')),
('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='数据集关联')), ('dataset', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='文档关联')),
('document', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document', verbose_name='文档关联')),
('paragraph', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph', verbose_name='段落关联')),
], ],
options={ options={
'db_table': 'embedding', 'db_table': 'embedding',

View File

@ -9,7 +9,7 @@
from django.db import models from django.db import models
from common.field.vector_field import VectorField from common.field.vector_field import VectorField
from dataset.models.data_set import DataSet from dataset.models.data_set import Document, Paragraph, DataSet
class SourceType(models.TextChoices): class SourceType(models.TextChoices):
@ -26,7 +26,13 @@ class Embedding(models.Model):
source_type = models.CharField(verbose_name='资源类型', max_length=1, choices=SourceType.choices, source_type = models.CharField(verbose_name='资源类型', max_length=1, choices=SourceType.choices,
default=SourceType.PROBLEM) default=SourceType.PROBLEM)
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, verbose_name="数据集关联") is_active = models.BooleanField(verbose_name="是否可用", max_length=1, default=True)
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, verbose_name="文档关联", db_constraint=False)
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, verbose_name="文档关联", db_constraint=False)
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, verbose_name="段落关联", db_constraint=False)
embedding = VectorField(verbose_name="向量") embedding = VectorField(verbose_name="向量")

View File

@ -0,0 +1,117 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_vector.py
@date2023/10/18 19:16
@desc:
"""
from abc import ABC, abstractmethod
from typing import List, Dict
from langchain.embeddings import HuggingFaceEmbeddings
from common.config.embedding_config import EmbeddingModel
from embedding.models import SourceType
class BaseVectorStore(ABC):
vector_exists = False
@abstractmethod
def vector_is_create(self) -> bool:
"""
判断向量库是否创建
:return: 是否创建向量库
"""
pass
@abstractmethod
def vector_create(self):
"""
创建 向量库
:return:
"""
pass
def save_pre_handler(self):
"""
插入前置处理器 主要是判断向量库是否创建
:return: True
"""
if not BaseVectorStore.vector_exists:
if not self.vector_is_create():
self.vector_create()
BaseVectorStore.vector_exists = True
return True
def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding=None):
"""
插入向量数据
:param source_id: 资源id
:param dataset_id: 数据集id
:param text: 文本
:param source_type: 资源类型
:param document_id: 文档id
:param is_active: 是否禁用
:param embedding: 向量化处理器
:param paragraph_id 段落id
:return: bool
"""
if embedding is None:
embedding = EmbeddingModel.get_embedding_model()
self.save_pre_handler()
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
def batch_save(self, data_list: List[Dict], embedding=None):
"""
批量插入
:param data_list: 数据列表
:param embedding: 向量化处理器
:return: bool
"""
if embedding is None:
embedding = EmbeddingModel.get_embedding_model()
self.save_pre_handler()
self._batch_save(data_list, embedding)
return True
@abstractmethod
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding: HuggingFaceEmbeddings):
pass
@abstractmethod
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
pass
@abstractmethod
def search(self, query_text, dataset_id_list: list[str], is_active: bool, embedding: HuggingFaceEmbeddings):
pass
@abstractmethod
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
pass
@abstractmethod
def update_by_source_id(self, source_id: str, instance: Dict):
pass
@abstractmethod
def delete_by_dataset_id(self, dataset_id: str):
pass
@abstractmethod
def delete_by_document_id(self, document_id: str):
pass
@abstractmethod
def delete_by_source_id(self, source_id: str, source_type: str):
pass
@abstractmethod
def delete_by_paragraph_id(self, paragraph_id: str):
pass

View File

@ -0,0 +1,79 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file pg_vector.py
@date2023/10/19 15:28
@desc:
"""
import uuid
from typing import Dict, List
from django.db.models import QuerySet
from langchain.embeddings import HuggingFaceEmbeddings
from embedding.models import Embedding, SourceType
from embedding.vector.base_vector import BaseVectorStore
class PGVector(BaseVectorStore):
def vector_is_create(self) -> bool:
# 项目启动默认是创建好的 不需要再创建
return True
def vector_create(self):
return True
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding: HuggingFaceEmbeddings):
text_embedding = embedding.embed_query(text)
embedding = Embedding(id=uuid.uuid1(),
dataset_id=dataset_id,
document_id=document_id,
is_active=is_active,
paragraph_id=paragraph_id,
source_id=source_id,
embedding=text_embedding,
source_type=source_type,
)
embedding.save()
return True
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
texts = [row.get('text') for row in text_list]
embeddings = embedding.embed_documents(texts)
QuerySet(Embedding).bulk_create([Embedding(id=uuid.uuid1(),
document_id=text_list[index].get('document_id'),
paragraph_id=text_list[index].get('paragraph_id'),
dataset_id=text_list[index].get('dataset_id'),
is_active=text_list[index].get('is_active', True),
source_id=text_list[index].get('source_id'),
source_type=text_list[index].get('source_type'),
embedding=embeddings[index]) for index in
range(0, len(text_list))]) if len(text_list) > 0 else None
return True
def search(self, query_text, dataset_id_list: list[str], is_active: bool, embedding: HuggingFaceEmbeddings):
pass
def update_by_source_id(self, source_id: str, instance: Dict):
QuerySet(Embedding).filter(source_id=source_id).update(**instance)
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance)
def delete_by_dataset_id(self, dataset_id: str):
QuerySet(Embedding).filter(dataset_id=dataset_id).delete()
def delete_by_document_id(self, document_id: str):
QuerySet(Embedding).filter(document_id=document_id).delete()
return True
def delete_by_source_id(self, source_id: str, source_type: str):
QuerySet(Embedding).filter(source_id=source_id, source_type=source_type).delete()
return True
def delete_by_paragraph_id(self, paragraph_id: str):
QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete()

View File

@ -1,4 +1,4 @@
# Generated by Django 4.1.10 on 2023-10-09 06:33 # Generated by Django 4.1.10 on 2023-10-24 12:13
import django.contrib.postgres.fields import django.contrib.postgres.fields
from django.db import migrations, models from django.db import migrations, models

View File

@ -13,7 +13,7 @@ import os
import re import re
from importlib import import_module from importlib import import_module
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
import torch.backends
import yaml import yaml
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@ -88,7 +88,14 @@ class Config(dict):
"EMAIL_HOST": "", "EMAIL_HOST": "",
"EMAIL_PORT": 465, "EMAIL_PORT": 465,
"EMAIL_HOST_USER": "", "EMAIL_HOST_USER": "",
"EMAIL_HOST_PASSWORD": "" "EMAIL_HOST_PASSWORD": "",
# 向量模型
"EMBEDDING_MODEL_NAME": "shibing624/text2vec-base-chinese",
"EMBEDDING_DEVICE": "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu",
"EMBEDDING_MODEL_PATH": os.path.join(PROJECT_DIR, 'models'),
# 向量库配置
"VECTOR_STORE_NAME": 'pg_vector'
} }
def get_db_setting(self) -> dict: def get_db_setting(self) -> dict:
@ -120,6 +127,8 @@ class ConfigManager:
def __init__(self, root_path=None): def __init__(self, root_path=None):
self.root_path = root_path self.root_path = root_path
self.config = self.config_class() self.config = self.config_class()
for key in self.config_class.defaults:
self.config[key] = self.config_class.defaults[key]
def from_mapping(self, *mapping, **kwargs): def from_mapping(self, *mapping, **kwargs):
"""Updates the config like :meth:`update` ignoring items with non-upper """Updates the config like :meth:`update` ignoring items with non-upper

View File

@ -100,6 +100,11 @@ LOGGING = {
'level': LOG_LEVEL, 'level': LOG_LEVEL,
'propagate': False, 'propagate': False,
}, },
'sqlalchemy': {
'handlers': ['console', 'file', 'syslog'],
'level': LOG_LEVEL,
'propagate': False,
},
'django.db.backends': { 'django.db.backends': {
'handlers': ['console', 'file', 'syslog'], 'handlers': ['console', 'file', 'syslog'],
'propagate': False, 'propagate': False,

View File

@ -14,3 +14,12 @@ from django.core.wsgi import get_wsgi_application
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'smartdoc.settings') os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'smartdoc.settings')
application = get_wsgi_application() application = get_wsgi_application()
def post_handler():
from common.event.listener_manage import ListenerManagement
ListenerManagement().run()
ListenerManagement.init_embedding_model_signal.send()
post_handler()

View File

@ -4,3 +4,4 @@ from django.apps import AppConfig
class UsersConfig(AppConfig): class UsersConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField' default_auto_field = 'django.db.models.BigAutoField'
name = 'users' name = 'users'

View File

@ -1,4 +1,4 @@
# Generated by Django 4.1.10 on 2023-10-09 06:33 # Generated by Django 4.1.10 on 2023-10-24 12:13
from django.db import migrations, models from django.db import migrations, models
import uuid import uuid

View File

@ -22,7 +22,7 @@ from common.constants.permission_constants import PermissionConstants, CompareCo
from common.response import result from common.response import result
from smartdoc.settings import JWT_AUTH from smartdoc.settings import JWT_AUTH
from users.models.user import User as UserModel from users.models.user import User as UserModel
from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, UserSerializer, CheckCodeSerializer, \ from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, CheckCodeSerializer, \
RePasswordSerializer, \ RePasswordSerializer, \
SendEmailSerializer, UserProfile SendEmailSerializer, UserProfile

View File

@ -17,6 +17,12 @@ psycopg2-binary = "2.9.7"
jieba = "^0.42.1" jieba = "^0.42.1"
diskcache = "^5.6.3" diskcache = "^5.6.3"
pillow = "9.5.0" pillow = "9.5.0"
filetype = "^1.2.0"
chardet = "^5.2.0"
torch = "^2.1.0"
sentence-transformers = "^2.2.2"
blinker = "^1.6.3"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]