feat: web数据集

This commit is contained in:
shaohuzhang1 2023-12-29 18:02:23 +08:00
parent 89a74dd862
commit 64c8cc6b39
13 changed files with 417 additions and 74 deletions

View File

@ -0,0 +1,20 @@
# Generated by Django 4.1.10 on 2023-12-28 15:16
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('dataset', '0002_dataset_meta_dataset_type_document_meta_and_more'),
('application', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='chatrecord',
name='dataset',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='dataset.dataset', verbose_name='数据集'),
),
]

View File

@ -18,6 +18,8 @@ from common.config.embedding_config import VectorStore, EmbeddingModel
from common.db.search import native_search, get_dynamics_model from common.db.search import native_search, get_dynamics_model
from common.event.common import poxy from common.event.common import poxy
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from common.util.fork import ForkManage
from common.util.lock import try_lock, un_lock
from dataset.models import Paragraph, Status, Document from dataset.models import Paragraph, Status, Document
from embedding.models import SourceType from embedding.models import SourceType
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
@ -26,6 +28,14 @@ max_kb_error = logging.getLogger("max_kb_error")
max_kb = logging.getLogger("max_kb") max_kb = logging.getLogger("max_kb")
class SyncWebDatasetArgs:
def __init__(self, lock_key: str, url: str, selector: str, handler):
self.lock_key = lock_key
self.url = url
self.selector = selector
self.handler = handler
class ListenerManagement: class ListenerManagement:
embedding_by_problem_signal = signal("embedding_by_problem") embedding_by_problem_signal = signal("embedding_by_problem")
embedding_by_paragraph_signal = signal("embedding_by_paragraph") embedding_by_paragraph_signal = signal("embedding_by_paragraph")
@ -38,6 +48,7 @@ class ListenerManagement:
enable_embedding_by_paragraph_signal = signal('enable_embedding_by_paragraph') enable_embedding_by_paragraph_signal = signal('enable_embedding_by_paragraph')
disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph') disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph')
init_embedding_model_signal = signal('init_embedding_model') init_embedding_model_signal = signal('init_embedding_model')
sync_web_dataset_signal = signal('sync_web_dataset')
@staticmethod @staticmethod
def embedding_by_problem(args): def embedding_by_problem(args):
@ -144,6 +155,18 @@ class ListenerManagement:
def enable_embedding_by_paragraph(paragraph_id): def enable_embedding_by_paragraph(paragraph_id):
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True}) VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True})
@staticmethod
@poxy
def sync_web_dataset(args: SyncWebDatasetArgs):
if try_lock('sync_web_dataset' + args.lock_key):
try:
ForkManage(args.url, args.selector.split(" ")).fork(2, set(),
args.handler)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
finally:
un_lock('sync_web_dataset' + args.lock_key)
@staticmethod @staticmethod
@poxy @poxy
def init_embedding_model(ags): def init_embedding_model(ags):
@ -175,3 +198,5 @@ class ListenerManagement:
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph) ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
# 初始化向量化模型 # 初始化向量化模型
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model) ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)
# 同步web站点知识库
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)

View File

@ -1,10 +1,19 @@
import copy
import logging
import re import re
import traceback
from functools import reduce from functools import reduce
from typing import List, Set from typing import List, Set
import requests import requests
import html2text as ht import html2text as ht
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from urllib.parse import urljoin from urllib.parse import urljoin, urlparse, ParseResult
class ChildLink:
def __init__(self, url, tag):
self.url = url
self.tag = copy.deepcopy(tag)
class ForkManage: class ForkManage:
@ -13,30 +22,34 @@ class ForkManage:
self.selector_list = selector_list self.selector_list = selector_list
def fork(self, level: int, exclude_link_url: Set[str], fork_handler): def fork(self, level: int, exclude_link_url: Set[str], fork_handler):
self.fork_child(self.base_url, self.selector_list, level, exclude_link_url, fork_handler) self.fork_child(ChildLink(self.base_url, None), self.selector_list, level, exclude_link_url, fork_handler)
@staticmethod @staticmethod
def fork_child(base_url: str, selector_list: List[str], level: int, exclude_link_url: Set[str], fork_handler): def fork_child(child_link: ChildLink, selector_list: List[str], level: int, exclude_link_url: Set[str],
fork_handler):
if level < 0: if level < 0:
return return
response = Fork(base_url, selector_list).fork() else:
fork_handler(base_url, response) child_url = child_link.url[:-1] if child_link.url.endswith('/') else child_link.url
exclude_link_url.add(child_url)
response = Fork(child_link.url, selector_list).fork()
fork_handler(child_link, response)
for child_link in response.child_link_list: for child_link in response.child_link_list:
if not exclude_link_url.__contains__(child_link): child_url = child_link.url[:-1] if child_link.url.endswith('/') else child_link.url
exclude_link_url.add(child_link) if not exclude_link_url.__contains__(child_url):
ForkManage.fork_child(child_link, selector_list, level - 1, exclude_link_url, fork_handler) ForkManage.fork_child(child_link, selector_list, level - 1, exclude_link_url, fork_handler)
class Fork: class Fork:
class Response: class Response:
def __init__(self, html_content: str, child_link_list: List[str], status, message: str): def __init__(self, content: str, child_link_list: List[ChildLink], status, message: str):
self.html_content = html_content self.content = content
self.child_link_list = child_link_list self.child_link_list = child_link_list
self.status = status self.status = status
self.message = message self.message = message
@staticmethod @staticmethod
def success(html_content: str, child_link_list: List[str]): def success(html_content: str, child_link_list: List[ChildLink]):
return Fork.Response(html_content, child_link_list, 200, '') return Fork.Response(html_content, child_link_list, 200, '')
@staticmethod @staticmethod
@ -45,13 +58,17 @@ class Fork:
def __init__(self, base_fork_url: str, selector_list: List[str]): def __init__(self, base_fork_url: str, selector_list: List[str]):
self.base_fork_url = urljoin(base_fork_url if base_fork_url.endswith("/") else base_fork_url + '/', '.') self.base_fork_url = urljoin(base_fork_url if base_fork_url.endswith("/") else base_fork_url + '/', '.')
self.base_fork_url = base_fork_url self.base_fork_url = self.base_fork_url[:-1]
self.selector_list = selector_list self.selector_list = selector_list
self.urlparse = urlparse(self.base_fork_url)
self.base_url = ParseResult(scheme=self.urlparse.scheme, netloc=self.urlparse.netloc, path='', params='',
query='',
fragment='').geturl()
def get_child_link_list(self, bf: BeautifulSoup): def get_child_link_list(self, bf: BeautifulSoup):
pattern = "^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*|" + self.base_fork_url pattern = "^((?!(http:|https:|tel:/|#|mailto:|javascript:))|" + self.base_fork_url + ").*"
link_list = bf.find_all(name='a', href=re.compile(pattern)) link_list = bf.find_all(name='a', href=re.compile(pattern))
result = [self.parse_href(link.get('href')) for link in link_list] result = [ChildLink(link.get('href'), link) for link in link_list]
return result return result
def get_content_html(self, bf: BeautifulSoup): def get_content_html(self, bf: BeautifulSoup):
@ -65,23 +82,34 @@ class Fork:
f = bf.find_all(**params) f = bf.find_all(**params)
return "\n".join([str(row) for row in f]) return "\n".join([str(row) for row in f])
def parse_href(self, href: str): @staticmethod
if href.startswith(self.base_fork_url[:-1] if self.base_fork_url.endswith('/') else self.base_fork_url): def reset_url(tag, field, base_fork_url):
return href field_value: str = tag[field]
if field_value.startswith("/"):
result = urlparse(base_fork_url)
result_url = ParseResult(scheme=result.scheme, netloc=result.netloc, path=field_value, params='', query='',
fragment='').geturl()
else: else:
return urljoin(self.base_fork_url + '/' + (href if href.endswith('/') else href + '/'), ".") result_url = urljoin(
base_fork_url + '/' + (field_value if field_value.endswith('/') else field_value + '/'),
".")
result_url = result_url[:-1] if result_url.endswith('/') else result_url
tag[field] = result_url
def reset_beautiful_soup(self, bf: BeautifulSoup): def reset_beautiful_soup(self, bf: BeautifulSoup):
href_list = bf.find_all(href=re.compile('^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*')) reset_config_list = [
for h in href_list: {
h['href'] = urljoin( 'field': 'href',
self.base_fork_url + '/' + (h['href'] if h['href'].endswith('/') else h['href'] + '/'), },
".")[:-1] {
src_list = bf.find_all(src=re.compile('^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*')) 'field': 'src',
for s in src_list: }
s['src'] = urljoin( ]
self.base_fork_url + '/' + (s['src'] if s['src'].endswith('/') else s['src'] + '/'), for reset_config in reset_config_list:
".")[:-1] field = reset_config.get('field')
tag_list = bf.find_all(**{field: re.compile('^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*')})
for tag in tag_list:
self.reset_url(tag, field, self.base_fork_url)
return bf return bf
@staticmethod @staticmethod
@ -92,11 +120,14 @@ class Fork:
def fork(self): def fork(self):
try: try:
logging.getLogger("max_kb").info(f'fork:{self.base_fork_url}')
response = requests.get(self.base_fork_url) response = requests.get(self.base_fork_url)
if response.status_code != 200: if response.status_code != 200:
raise Exception(response.status_code) logging.getLogger("max_kb").error(f"url: {self.base_fork_url} code:{response.status_code}")
return Fork.Response.error(f"url: {self.base_fork_url} code:{response.status_code}")
bf = self.get_beautiful_soup(response) bf = self.get_beautiful_soup(response)
except Exception as e: except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
return Fork.Response.error(str(e)) return Fork.Response.error(str(e))
bf = self.reset_beautiful_soup(bf) bf = self.reset_beautiful_soup(bf)
link_list = self.get_child_link_list(bf) link_list = self.get_child_link_list(bf)
@ -106,7 +137,6 @@ class Fork:
def handler(base_url, response: Fork.Response): def handler(base_url, response: Fork.Response):
print(base_url, response.status) print(base_url.url, base_url.tag.text if base_url.tag else None, response.content)
# ForkManage('https://bbs.fit2cloud.com/c/de/6', ['.md-content']).fork(3, set(), handler)
ForkManage('https://dataease.io/docs/v2/', ['.md-content']).fork(3, set(), handler)

View File

@ -277,11 +277,11 @@ def filter_special_char(content: str):
class SplitModel: class SplitModel:
def __init__(self, content_level_pattern, with_filter=True, limit=1024): def __init__(self, content_level_pattern, with_filter=True, limit=4096):
self.content_level_pattern = content_level_pattern self.content_level_pattern = content_level_pattern
self.with_filter = with_filter self.with_filter = with_filter
if limit is None or limit > 1024: if limit is None or limit > 4096:
limit = 1024 limit = 4096
if limit < 50: if limit < 50:
limit = 50 limit = 50
self.limit = limit self.limit = limit
@ -337,13 +337,12 @@ class SplitModel:
default_split_pattern = { default_split_pattern = {
'md': [re.compile("^# .*"), re.compile('(?<!#)## (?!#).*'), re.compile("(?<!#)### (?!#).*"), 'md': [re.compile("^# .*"), re.compile('(?<!#)## (?!#).*'), re.compile("(?<!#)### (?!#).*"),
re.compile("(?<!#)####(?!#).*"), re.compile("(?<!#)#####(?!#).*"), re.compile("(?<!#)####(?!#).*"), re.compile("(?<!#)#####(?!#).*"),
re.compile("(?<!#)######(?!#).*"), re.compile("(?<!#)######(?!#).*")],
re.compile("(?<! )- .*")],
'default': [re.compile("(?<!\n)\n\n.+")] 'default': [re.compile("(?<!\n)\n\n.+")]
} }
def get_split_model(filename: str, with_filter: bool, limit: int): def get_split_model(filename: str, with_filter: bool = False, limit: int = 4096):
""" """
根据文件名称获取分段模型 根据文件名称获取分段模型
:param limit: 每段大小 :param limit: 每段大小

View File

@ -0,0 +1,38 @@
# Generated by Django 4.1.10 on 2023-12-28 15:16
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('dataset', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='dataset',
name='meta',
field=models.JSONField(default=dict, verbose_name='元数据'),
),
migrations.AddField(
model_name='dataset',
name='type',
field=models.CharField(choices=[('0', '通用类型'), ('1', 'web站点类型')], default='0', max_length=1, verbose_name='类型'),
),
migrations.AddField(
model_name='document',
name='meta',
field=models.JSONField(default=dict, verbose_name='元数据'),
),
migrations.AddField(
model_name='document',
name='type',
field=models.CharField(choices=[('0', '通用类型'), ('1', 'web站点类型')], default='0', max_length=1, verbose_name='类型'),
),
migrations.AlterField(
model_name='dataset',
name='name',
field=models.CharField(max_length=150, verbose_name='数据集名称'),
),
]

View File

@ -0,0 +1,18 @@
# Generated by Django 4.1.10 on 2023-12-29 17:49
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('dataset', '0002_dataset_meta_dataset_type_document_meta_and_more'),
]
operations = [
migrations.AlterField(
model_name='paragraph',
name='content',
field=models.CharField(max_length=4096, verbose_name='段落内容'),
),
]

View File

@ -21,6 +21,12 @@ class Status(models.TextChoices):
error = 2, '导入失败' error = 2, '导入失败'
class Type(models.TextChoices):
base = 0, '通用类型'
web = 1, 'web站点类型'
class DataSet(AppModelMixin): class DataSet(AppModelMixin):
""" """
数据集表 数据集表
@ -29,6 +35,10 @@ class DataSet(AppModelMixin):
name = models.CharField(max_length=150, verbose_name="数据集名称") name = models.CharField(max_length=150, verbose_name="数据集名称")
desc = models.CharField(max_length=256, verbose_name="数据库描述") desc = models.CharField(max_length=256, verbose_name="数据库描述")
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户") user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
default=Type.base)
meta = models.JSONField(verbose_name="元数据", default=dict)
class Meta: class Meta:
db_table = "dataset" db_table = "dataset"
@ -46,6 +56,11 @@ class Document(AppModelMixin):
default=Status.embedding) default=Status.embedding)
is_active = models.BooleanField(default=True) is_active = models.BooleanField(default=True)
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
default=Type.base)
meta = models.JSONField(verbose_name="元数据", default=dict)
class Meta: class Meta:
db_table = "document" db_table = "document"
@ -57,7 +72,7 @@ class Paragraph(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False) document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False)
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING) dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
content = models.CharField(max_length=1024, verbose_name="段落内容") content = models.CharField(max_length=4096, verbose_name="段落内容")
title = models.CharField(max_length=256, verbose_name="标题", default="") title = models.CharField(max_length=256, verbose_name="标题", default="")
hit_num = models.IntegerField(verbose_name="命中数量", default=0) hit_num = models.IntegerField(verbose_name="命中数量", default=0)
star_num = models.IntegerField(verbose_name="点赞数", default=0) star_num = models.IntegerField(verbose_name="点赞数", default=0)

View File

@ -6,11 +6,13 @@
@date2023/9/21 16:14 @date2023/9/21 16:14
@desc: @desc:
""" """
import logging
import os.path import os.path
import traceback
import uuid import uuid
from functools import reduce from functools import reduce
from itertools import groupby
from typing import Dict from typing import Dict
from urllib.parse import urlparse
from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields import ArrayField
from django.core import validators from django.core import validators
@ -23,17 +25,18 @@ from application.models import ApplicationDatasetMapping
from common.config.embedding_config import VectorStore, EmbeddingModel from common.config.embedding_config import VectorStore, EmbeddingModel
from common.db.search import get_dynamics_model, native_page_search, native_search from common.db.search import get_dynamics_model, native_page_search, native_search
from common.db.sql_execute import select_list from common.db.sql_execute import select_list
from common.event.listener_manage import ListenerManagement from common.event.listener_manage import ListenerManagement, SyncWebDatasetArgs
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin from common.mixins.api_mixin import ApiMixin
from common.util.common import post from common.util.common import post
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from dataset.models.data_set import DataSet, Document, Paragraph, Problem from common.util.fork import ChildLink, Fork, ForkManage
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
from dataset.serializers.common_serializers import list_paragraph from dataset.serializers.common_serializers import list_paragraph
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from setting.models import AuthOperate from setting.models import AuthOperate
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
from users.models import User
""" """
# __exact 精确等于 like aaa # __exact 精确等于 like aaa
@ -187,30 +190,105 @@ class DataSetSerializers(serializers.ModelSerializer):
return DataSetSerializers.Operate.get_response_body_api() return DataSetSerializers.Operate.get_response_body_api()
class Create(ApiMixin, serializers.Serializer): class Create(ApiMixin, serializers.Serializer):
""" user_id = serializers.UUIDField(required=True)
创建序列化对象
"""
name = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=20,
message="知识库名称在1-20个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-20个字符之间")
])
desc = serializers.CharField(required=True, class CreateBaseSerializers(ApiMixin, serializers.Serializer):
validators=[ """
validators.MaxLengthValidator(limit_value=256, 创建通用数据集序列化对象
message="知识库名称在1-256个字符之间"), """
validators.MinLengthValidator(limit_value=1, name = serializers.CharField(required=True,
message="知识库名称在1-256个字符之间") validators=[
]) validators.MaxLengthValidator(limit_value=20,
message="知识库名称在1-20个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-20个字符之间")
])
documents = DocumentInstanceSerializer(required=False, many=True) desc = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=256,
message="知识库名称在1-256个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-256个字符之间")
])
def is_valid(self, *, raise_exception=False): documents = DocumentInstanceSerializer(required=False, many=True)
super().is_valid(raise_exception=True)
return True def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
return True
class CreateWebSerializers(serializers.Serializer):
"""
创建web站点序列化对象
"""
name = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=20,
message="知识库名称在1-20个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-20个字符之间")
])
desc = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=256,
message="知识库名称在1-256个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-256个字符之间")
])
url = serializers.CharField(required=True)
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
return True
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count',
'update_time', 'create_time', 'document_list'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称",
description="名称", default="测试知识库"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述",
description="描述", default="测试知识库描述"),
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id",
description="所属用户id", default="user_xxxx"),
'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数",
description="字符数", default=10),
'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量",
description="文档数量", default=1),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
),
'document_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档列表",
description="文档列表",
items=DocumentSerializers.Operate.get_response_body_api())
}
)
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['name', 'desc', 'url'],
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
'url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url", description="web站点url"),
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器")
}
)
@staticmethod @staticmethod
def post_embedding_dataset(document_list, dataset_id): def post_embedding_dataset(document_list, dataset_id):
@ -220,16 +298,21 @@ class DataSetSerializers(serializers.ModelSerializer):
@post(post_function=post_embedding_dataset) @post(post_function=post_embedding_dataset)
@transaction.atomic @transaction.atomic
def save(self, user: User): def save(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
self.CreateBaseSerializers(data=instance).is_valid()
dataset_id = uuid.uuid1() dataset_id = uuid.uuid1()
user_id = self.data.get('user_id')
dataset = DataSet( dataset = DataSet(
**{'id': dataset_id, 'name': self.data.get("name"), 'desc': self.data.get('desc'), 'user': user}) **{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id})
document_model_list = [] document_model_list = []
paragraph_model_list = [] paragraph_model_list = []
problem_model_list = [] problem_model_list = []
# 插入文档 # 插入文档
for document in self.data.get('documents') if 'documents' in self.data else []: for document in instance.get('documents') if 'documents' in instance else []:
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id, document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
document) document)
document_model_list.append(document_paragraph_dict_model.get('document')) document_model_list.append(document_paragraph_dict_model.get('document'))
@ -252,6 +335,47 @@ class DataSetSerializers(serializers.ModelSerializer):
'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list( 'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(
with_valid=True)}, dataset_id with_valid=True)}, dataset_id
@staticmethod
def get_last_url_path(url):
parsed_url = urlparse(url)
if parsed_url.path is None or len(parsed_url.path) == 0:
return url
else:
return parsed_url.path.split("/")[-1]
@staticmethod
def get_save_handler(dataset_id, selector):
def handler(child_link: ChildLink, response: Fork.Response):
if response.status == 200:
try:
document_name = child_link.tag.text if child_link.tag is not None and len(
child_link.tag.text.strip()) > 0 else child_link.url
paragraphs = get_split_model('web.md').parse(response.content)
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
{'name': document_name, 'paragraphs': paragraphs,
'meta': {'source_url': child_link.url, 'selector': selector},
'type': Type.web}, with_valid=True)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
return handler
def save_web(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
self.CreateWebSerializers(data=instance).is_valid(raise_exception=True)
user_id = self.data.get('user_id')
dataset_id = uuid.uuid1()
dataset = DataSet(
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
'type': Type.web, 'meta': {'source_url': instance.get('url'), 'selector': instance.get('selector')}})
dataset.save()
ListenerManagement.sync_web_dataset_signal.send(
SyncWebDatasetArgs(str(dataset_id), instance.get('url'), instance.get('selector'),
self.get_save_handler(dataset_id, instance.get('selector'))))
return {**DataSetSerializers(dataset).data,
'document_list': []}
@staticmethod @staticmethod
def get_response_body_api(): def get_response_body_api():
return openapi.Schema( return openapi.Schema(
@ -298,12 +422,43 @@ class DataSetSerializers(serializers.ModelSerializer):
} }
) )
class Edit(serializers.Serializer): class MetaSerializer(serializers.Serializer):
class WebMeta(serializers.Serializer):
source_url = serializers.CharField(required=True)
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
source_url = self.data.get('source_url')
response = Fork(source_url, []).fork()
if response.status == 500:
raise AppApiException(500, response.message)
class BaseMeta(serializers.Serializer):
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
class Edit(serializers.Serializer):
name = serializers.CharField(required=False) name = serializers.CharField(required=False)
desc = serializers.CharField(required=False) desc = serializers.CharField(required=False)
meta = serializers.DictField(required=False)
application_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) application_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
@staticmethod
def get_dataset_meta_valid_map():
dataset_meta_valid_map = {
Type.base: DataSetSerializers.MetaSerializer.BaseMeta,
Type.web: DataSetSerializers.MetaSerializer.WebMeta
}
return dataset_meta_valid_map
def is_valid(self, *, dataset: DataSet = None):
super().is_valid(raise_exception=True)
if 'meta' in self.data and self.data.get('meta') is not None:
dataset_meta_valid_map = self.get_dataset_meta_valid_map()
valid_class = dataset_meta_valid_map.get(dataset.type)
valid_class(data=self.data.get('meta')).is_valid(raise_exception=True)
class HitTest(ApiMixin, serializers.Serializer): class HitTest(ApiMixin, serializers.Serializer):
id = serializers.CharField(required=True) id = serializers.CharField(required=True)
user_id = serializers.UUIDField(required=False) user_id = serializers.UUIDField(required=False)
@ -392,12 +547,14 @@ class DataSetSerializers(serializers.ModelSerializer):
:return: :return:
""" """
self.is_valid() self.is_valid()
DataSetSerializers.Edit(data=dataset).is_valid(raise_exception=True)
_dataset = QuerySet(DataSet).get(id=self.data.get("id")) _dataset = QuerySet(DataSet).get(id=self.data.get("id"))
DataSetSerializers.Edit(data=dataset).is_valid(dataset=_dataset)
if "name" in dataset: if "name" in dataset:
_dataset.name = dataset.get("name") _dataset.name = dataset.get("name")
if 'desc' in dataset: if 'desc' in dataset:
_dataset.desc = dataset.get("desc") _dataset.desc = dataset.get("desc")
if 'meta' in dataset:
_dataset.meta = dataset.get('meta')
if 'application_id_list' in dataset and dataset.get('application_id_list') is not None: if 'application_id_list' in dataset and dataset.get('application_id_list') is not None:
application_id_list = dataset.get('application_id_list') application_id_list = dataset.get('application_id_list')
# 当前用户可修改关联的知识库列表 # 当前用户可修改关联的知识库列表
@ -429,6 +586,8 @@ class DataSetSerializers(serializers.ModelSerializer):
properties={ properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"), 'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
'meta': openapi.Schema(type=openapi.TYPE_OBJECT, title="知识库元数据",
description="知识库元数据->web:{source_url:xxx,selector:'xxx'},base:{}"),
'application_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="应用id列表", 'application_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="应用id列表",
description="应用id列表", description="应用id列表",
items=openapi.Schema(type=openapi.TYPE_STRING)) items=openapi.Schema(type=openapi.TYPE_STRING))

View File

@ -24,7 +24,7 @@ from common.mixins.api_mixin import ApiMixin
from common.util.common import post from common.util.common import post
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from common.util.split_model import SplitModel, get_split_model from common.util.split_model import SplitModel, get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
@ -243,7 +243,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
'name': instance.get('name'), 'name': instance.get('name'),
'char_length': reduce(lambda x, y: x + y, 'char_length': reduce(lambda x, y: x + y,
[len(p.get('content')) for p in instance.get('paragraphs', [])], [len(p.get('content')) for p in instance.get('paragraphs', [])],
0)}) 0),
'meta': instance.get('meta') if instance.get('meta') is not None else {},
'type': instance.get('type') if instance.get('type') is not None else Type.base})
paragraph_model_dict_list = [ParagraphSerializers.Create( paragraph_model_dict_list = [ParagraphSerializers.Create(
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).get_paragraph_problem_model( data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).get_paragraph_problem_model(

View File

@ -37,7 +37,7 @@ class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer):
段落实例对象 段落实例对象
""" """
content = serializers.CharField(required=True, validators=[ content = serializers.CharField(required=True, validators=[
validators.MaxLengthValidator(limit_value=1024, validators.MaxLengthValidator(limit_value=4096,
message="段落在1-1024个字符之间"), message="段落在1-1024个字符之间"),
validators.MinLengthValidator(limit_value=1, validators.MinLengthValidator(limit_value=1,
message="段落在1-1024个字符之间"), message="段落在1-1024个字符之间"),

View File

@ -5,6 +5,7 @@ from . import views
app_name = "dataset" app_name = "dataset"
urlpatterns = [ urlpatterns = [
path('dataset', views.Dataset.as_view(), name="dataset"), path('dataset', views.Dataset.as_view(), name="dataset"),
path('dataset/web', views.Dataset.CreateWebDataset.as_view()),
path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"), path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"),
path('dataset/<str:dataset_id>/application', views.Dataset.Application.as_view()), path('dataset/<str:dataset_id>/application', views.Dataset.Application.as_view()),
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"), path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),

View File

@ -23,6 +23,21 @@ from dataset.serializers.dataset_serializers import DataSetSerializers
class Dataset(APIView): class Dataset(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]
class CreateWebDataset(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建web站点知识库",
operation_id="创建web站点知识库",
request_body=DataSetSerializers.Create.CreateWebSerializers.get_request_body_api(),
responses=get_api_response(
DataSetSerializers.Create.CreateWebSerializers.get_response_body_api()),
tags=["知识库"]
)
@has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND)
def post(self, request: Request):
return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save_web(request.data))
class Application(APIView): class Application(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]
@ -58,9 +73,7 @@ class Dataset(APIView):
) )
@has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND) @has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND)
def post(self, request: Request): def post(self, request: Request):
s = DataSetSerializers.Create(data=request.data) return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save(request.data))
s.is_valid(raise_exception=True)
return result.success(s.save(request.user))
class HitTest(APIView): class HitTest(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]

View File

@ -0,0 +1,23 @@
# Generated by Django 4.1.10 on 2023-12-28 15:16
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('setting', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='teammemberpermission',
name='auth_target_type',
field=models.CharField(choices=[('DATASET', '数据集'), ('APPLICATION', '应用')], default='DATASET', max_length=128, verbose_name='授权目标'),
),
migrations.AlterField(
model_name='teammemberpermission',
name='target',
field=models.UUIDField(verbose_name='数据集/应用id'),
),
]