feat: add endpoint to download source files with updated parameter handling

This commit is contained in:
CaptainB 2025-06-03 20:01:11 +08:00
parent 642920e916
commit 56a9e69912
6 changed files with 125 additions and 13 deletions

View File

@ -393,6 +393,11 @@ class PermissionConstants(Enum):
role_list=[RoleConstants.ADMIN, RoleConstants.USER], role_list=[RoleConstants.ADMIN, RoleConstants.USER],
parent_group=[WorkspaceGroup.KNOWLEDGE, UserGroup.KNOWLEDGE] parent_group=[WorkspaceGroup.KNOWLEDGE, UserGroup.KNOWLEDGE]
) )
KNOWLEDGE_DOCUMENT_DOWNLOAD_RAW = Permission(
group=Group.KNOWLEDGE_DOCUMENT, operate=Operate.EXPORT,
role_list=[RoleConstants.ADMIN, RoleConstants.USER],
parent_group=[WorkspaceGroup.KNOWLEDGE, UserGroup.KNOWLEDGE]
)
KNOWLEDGE_DOCUMENT_GENERATE = Permission( KNOWLEDGE_DOCUMENT_GENERATE = Permission(
group=Group.KNOWLEDGE_DOCUMENT, operate=Operate.GENERATE, group=Group.KNOWLEDGE_DOCUMENT, operate=Operate.GENERATE,
role_list=[RoleConstants.ADMIN, RoleConstants.USER], role_list=[RoleConstants.ADMIN, RoleConstants.USER],

View File

@ -503,3 +503,35 @@ class DocumentMigrateAPI(APIMixin):
@staticmethod @staticmethod
def get_request(): def get_request():
return DocumentMigrateSerializer return DocumentMigrateSerializer
class DocumentDownloadSourceAPI(APIMixin):
@staticmethod
def get_parameters():
return [
OpenApiParameter(
name="workspace_id",
description="工作空间id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="knowledge_id",
description="知识库id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="document_id",
description="文档id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
]
@staticmethod
def get_response():
return DefaultResultSerializer

View File

@ -220,6 +220,10 @@ class FileSourceType(models.TextChoices):
KNOWLEDGE = "KNOWLEDGE" KNOWLEDGE = "KNOWLEDGE"
# 应用 跟随应用被删除而被删除 source_id 为应用id # 应用 跟随应用被删除而被删除 source_id 为应用id
APPLICATION = "APPLICATION" APPLICATION = "APPLICATION"
# 工具 跟随工具被删除而被删除 source_id 为应用id
TOOL = "TOOL"
# 文档
DOCUMENT = "DOCUMENT"
# 临时30分钟 数据30分钟后被清理 source_id 为TEMPORARY_30_MINUTE # 临时30分钟 数据30分钟后被清理 source_id 为TEMPORARY_30_MINUTE
TEMPORARY_30_MINUTE = "TEMPORARY_30_MINUTE" TEMPORARY_30_MINUTE = "TEMPORARY_30_MINUTE"
# 临时120分钟 数据120分钟后被清理 source_id为TEMPORARY_100_MINUTE # 临时120分钟 数据120分钟后被清理 source_id为TEMPORARY_100_MINUTE

View File

@ -12,7 +12,7 @@ import uuid_utils.compat as uuid
from celery_once import AlreadyQueued from celery_once import AlreadyQueued
from django.core import validators from django.core import validators
from django.db import transaction, models from django.db import transaction, models
from django.db.models import QuerySet, Model from django.db.models import QuerySet
from django.db.models.functions import Substr, Reverse from django.db.models.functions import Substr, Reverse
from django.http import HttpResponse from django.http import HttpResponse
from django.utils.translation import gettext_lazy as _, gettext, get_language, to_locale from django.utils.translation import gettext_lazy as _, gettext, get_language, to_locale
@ -43,7 +43,7 @@ from common.utils.common import post, get_file_content, bulk_create_in_batches,
from common.utils.fork import Fork from common.utils.fork import Fork
from common.utils.split_model import get_split_model, flat_map from common.utils.split_model import get_split_model, flat_map
from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \ from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \
TaskType, File TaskType, File, FileSourceType
from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer, \ from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer, \
get_embedding_model_id_by_knowledge_id, MetaSerializer, write_image, zip_dir get_embedding_model_id_by_knowledge_id, MetaSerializer, write_image, zip_dir
from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer, \ from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer, \
@ -54,6 +54,7 @@ from knowledge.task.embedding import embedding_by_document, delete_embedding_by_
from knowledge.task.generate import generate_related_by_document_id from knowledge.task.generate import generate_related_by_document_id
from knowledge.task.sync import sync_web_document from knowledge.task.sync import sync_web_document
from maxkb.const import PROJECT_DIR from maxkb.const import PROJECT_DIR
from models_provider.models import Model
default_split_handle = TextSplitHandle() default_split_handle = TextSplitHandle()
split_handles = [ split_handles = [
@ -87,6 +88,7 @@ class BatchCancelInstanceSerializer(serializers.Serializer):
class DocumentInstanceSerializer(serializers.Serializer): class DocumentInstanceSerializer(serializers.Serializer):
name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1) name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1)
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True) paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
source_file_id = serializers.UUIDField(required=True, label=_('source file id'))
class CancelInstanceSerializer(serializers.Serializer): class CancelInstanceSerializer(serializers.Serializer):
@ -545,6 +547,9 @@ class DocumentSerializers(serializers.Serializer):
response.write(zip_buffer.getvalue()) response.write(zip_buffer.getvalue())
return response return response
def download_source_file(self):
pass
def one(self, with_valid=False): def one(self, with_valid=False):
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
@ -626,8 +631,6 @@ class DocumentSerializers(serializers.Serializer):
embedding_model = QuerySet(Model).filter(id=embedding_model_id).first() embedding_model = QuerySet(Model).filter(id=embedding_model_id).first()
if embedding_model is None: if embedding_model is None:
raise AppApiException(500, _('Model does not exist')) raise AppApiException(500, _('Model does not exist'))
if embedding_model.permission_type == 'PRIVATE' and knowledge_user_id != embedding_model.user_id:
raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}")
document_id = self.data.get("document_id") document_id = self.data.get("document_id")
ListenerManagement.update_status( ListenerManagement.update_status(
QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.PENDING QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.PENDING
@ -859,6 +862,8 @@ class DocumentSerializers(serializers.Serializer):
for file in save_image_list: for file in save_image_list:
file_bytes = file.meta.pop('content') file_bytes = file.meta.pop('content')
file.meta['knowledge_id'] = self.data.get('knowledge_id') file.meta['knowledge_id'] = self.data.get('knowledge_id')
file.source_type = FileSourceType.KNOWLEDGE
file.source_id = self.data.get('knowledge_id')
file.save(file_bytes) file.save(file_bytes)
class Split(serializers.Serializer): class Split(serializers.Serializer):
@ -901,19 +906,39 @@ class DocumentSerializers(serializers.Serializer):
for file in save_image_list: for file in save_image_list:
file_bytes = file.meta.pop('content') file_bytes = file.meta.pop('content')
file.meta['knowledge_id'] = self.data.get('knowledge_id') file.meta['knowledge_id'] = self.data.get('knowledge_id')
file.source_type = FileSourceType.KNOWLEDGE
file.source_id = self.data.get('knowledge_id')
file.save(file_bytes) file.save(file_bytes)
def file_to_paragraph(self, file, pattern_list: List, with_filter: bool, limit: int): def file_to_paragraph(self, file, pattern_list: List, with_filter: bool, limit: int):
# 保存源文件
file_id = uuid.uuid7()
raw_file = File(
id=file_id,
file_name=file.name,
file_size=file.size,
source_type=FileSourceType.KNOWLEDGE,
source_id=self.data.get('knowledge_id'),
)
raw_file.save(file.read())
file.seek(0)
get_buffer = FileBufferHandle().get_buffer get_buffer = FileBufferHandle().get_buffer
for split_handle in split_handles: for split_handle in split_handles:
if split_handle.support(file, get_buffer): if split_handle.support(file, get_buffer):
result = split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, self.save_image) result = split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, self.save_image)
if isinstance(result, list): if isinstance(result, list):
for item in result:
item['source_file_id'] = file_id
return result return result
result['source_file_id'] = file_id
return [result] return [result]
result = default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, self.save_image) result = default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, self.save_image)
if isinstance(result, list): if isinstance(result, list):
for item in result:
item['source_file_id'] = file_id
return result return result
result['source_file_id'] = file_id
return [result] return [result]
class SplitPattern(serializers.Serializer): class SplitPattern(serializers.Serializer):
@ -937,14 +962,37 @@ class DocumentSerializers(serializers.Serializer):
] ]
class Batch(serializers.Serializer): class Batch(serializers.Serializer):
workspace_id = serializers.UUIDField(required=True, label=_('workspace id')) workspace_id = serializers.CharField(required=True, label=_('workspace id'))
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
@staticmethod @staticmethod
def post_embedding(document_list, knowledge_id): def link_file(source_file_id, document_id):
source_file = QuerySet(File).filter(id=source_file_id).first()
if source_file:
# 获取原始文件内容
file_content = source_file.get_bytes()
# 创建新文件对象,复制原始文件的重要属性
new_file = File(
id=uuid.uuid7(),
file_name=source_file.file_name,
file_size=source_file.file_size,
source_type=FileSourceType.DOCUMENT,
source_id=document_id, # 更新为当前知识库ID
meta=source_file.meta.copy() if source_file.meta else {}
)
# 保存文件内容和元数据
new_file.save(file_content)
@staticmethod
def post_embedding(document_list, knowledge_id, workspace_id):
for document_dict in document_list: for document_dict in document_list:
DocumentSerializers.Operate( DocumentSerializers.Operate(data={
data={'knowledge_id': knowledge_id, 'document_id': document_dict.get('id')}).refresh() 'knowledge_id': knowledge_id,
'document_id': document_dict.get('id'),
'workspace_id': workspace_id
}).refresh()
return document_list return document_list
@post(post_function=post_embedding) @post(post_function=post_embedding)
@ -953,15 +1001,21 @@ class DocumentSerializers(serializers.Serializer):
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
DocumentInstanceSerializer(many=True, data=instance_list).is_valid(raise_exception=True) DocumentInstanceSerializer(many=True, data=instance_list).is_valid(raise_exception=True)
workspace_id = self.data.get("workspace_id")
knowledge_id = self.data.get("knowledge_id") knowledge_id = self.data.get("knowledge_id")
document_model_list = [] document_model_list = []
paragraph_model_list = [] paragraph_model_list = []
problem_paragraph_object_list = [] problem_paragraph_object_list = []
# 插入文档 # 插入文档
for document in instance_list: for document in instance_list:
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(knowledge_id, document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(
document) knowledge_id,
document_model_list.append(document_paragraph_dict_model.get('document')) document
)
# 保存文档和文件的关系
document_instance = document_paragraph_dict_model.get('document')
self.link_file(document['source_file_id'], document_instance.id)
document_model_list.append(document_instance)
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'): for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
paragraph_model_list.append(paragraph) paragraph_model_list.append(paragraph)
for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'): for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
@ -992,7 +1046,7 @@ class DocumentSerializers(serializers.Serializer):
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql') os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')
), ),
with_search_one=False with_search_one=False
), knowledge_id ), knowledge_id, workspace_id
@staticmethod @staticmethod
def _batch_sync(document_id_list: List[str]): def _batch_sync(document_id_list: List[str]):

View File

@ -38,6 +38,7 @@ urlpatterns = [
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/batch_cancel_task', views.DocumentView.BatchCancelTask.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/batch_cancel_task', views.DocumentView.BatchCancelTask.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/export', views.DocumentView.Export.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/export', views.DocumentView.Export.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/export_zip', views.DocumentView.ExportZip.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/export_zip', views.DocumentView.ExportZip.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/download_source_file', views.DocumentView.DownloadSourceFile.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph', views.ParagraphView.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph', views.ParagraphView.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/batch_delete', views.ParagraphView.BatchDelete.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/batch_delete', views.ParagraphView.BatchDelete.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/batch_generate_related', views.ParagraphView.BatchGenerateRelated.as_view()), path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/batch_generate_related', views.ParagraphView.BatchGenerateRelated.as_view()),

View File

@ -12,7 +12,7 @@ from knowledge.api.document import DocumentSplitAPI, DocumentBatchAPI, DocumentB
DocumentReadAPI, DocumentEditAPI, DocumentDeleteAPI, TableDocumentCreateAPI, QaDocumentCreateAPI, \ DocumentReadAPI, DocumentEditAPI, DocumentDeleteAPI, TableDocumentCreateAPI, QaDocumentCreateAPI, \
WebDocumentCreateAPI, CancelTaskAPI, BatchCancelTaskAPI, SyncWebAPI, RefreshAPI, BatchEditHitHandlingAPI, \ WebDocumentCreateAPI, CancelTaskAPI, BatchCancelTaskAPI, SyncWebAPI, RefreshAPI, BatchEditHitHandlingAPI, \
DocumentTreeReadAPI, DocumentSplitPatternAPI, BatchRefreshAPI, BatchGenerateRelatedAPI, TemplateExportAPI, \ DocumentTreeReadAPI, DocumentSplitPatternAPI, BatchRefreshAPI, BatchGenerateRelatedAPI, TemplateExportAPI, \
DocumentExportAPI, DocumentMigrateAPI DocumentExportAPI, DocumentMigrateAPI, DocumentDownloadSourceAPI
from knowledge.serializers.document import DocumentSerializers from knowledge.serializers.document import DocumentSerializers
@ -417,6 +417,22 @@ class DocumentView(APIView):
'workspace_id': workspace_id, 'document_id': document_id, 'knowledge_id': knowledge_id 'workspace_id': workspace_id, 'document_id': document_id, 'knowledge_id': knowledge_id
}).export_zip() }).export_zip()
class DownloadSourceFile(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
summary=_('Download source file'),
operation_id=_('Download source file'), # type: ignore
parameters=DocumentDownloadSourceAPI.get_parameters(),
responses=DocumentDownloadSourceAPI.get_response(),
tags=[_('Knowledge Base/Documentation')] # type: ignore
)
@has_permissions(PermissionConstants.KNOWLEDGE_DOCUMENT_DOWNLOAD_RAW.get_workspace_permission())
def get(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str):
return DocumentSerializers.Operate(data={
'workspace_id': workspace_id, 'document_id': document_id, 'knowledge_id': knowledge_id
}).download_source_file()
class Migrate(APIView): class Migrate(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]