feat: 【知识库】docx支持图片上传 #69 (#267)

This commit is contained in:
shaohuzhang1 2024-04-26 18:03:02 +08:00 committed by GitHub
parent d34ebe5971
commit 1f916a5c3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 90 additions and 17 deletions

View File

@ -16,5 +16,5 @@ class BaseSplitHandle(ABC):
pass pass
@abstractmethod @abstractmethod
def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer): def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
pass pass

View File

@ -8,14 +8,17 @@
""" """
import io import io
import re import re
import traceback
import uuid
from typing import List from typing import List
from docx import Document from docx import Document, ImagePart
from docx.table import Table from docx.table import Table
from docx.text.paragraph import Paragraph from docx.text.paragraph import Paragraph
from common.handle.base_split_handle import BaseSplitHandle from common.handle.base_split_handle import BaseSplitHandle
from common.util.split_model import SplitModel from common.util.split_model import SplitModel
from dataset.models import Image
default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'), default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'),
re.compile('(?<=\\n)(?<!#)## (?!#).*|(?<=^)(?<!#)## (?!#).*'), re.compile('(?<=\\n)(?<!#)## (?!#).*|(?<=^)(?<!#)## (?!#).*'),
@ -25,28 +28,86 @@ default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'),
re.compile("(?<=\\n)(?<!#)###### (?!#).*|(?<=^)(?<!#)###### (?!#).*")] re.compile("(?<=\\n)(?<!#)###### (?!#).*|(?<=^)(?<!#)###### (?!#).*")]
def image_to_mode(image, doc: Document, images_list, get_image_id):
for img_id in image.xpath('.//a:blip/@r:embed'): # 获取图片id
part = doc.part.related_parts[img_id] # 根据图片id获取对应的图片
if isinstance(part, ImagePart):
image_uuid = get_image_id(img_id)
if len([i for i in images_list if i.id == image_uuid]) == 0:
image = Image(id=image_uuid, image=part.blob, image_name=part.filename)
images_list.append(image)
return f'![](/api/image/{image_uuid})'
def get_paragraph_element_txt(paragraph_element, doc: Document, images_list, get_image_id):
try:
images = paragraph_element.xpath(".//pic:pic")
if len(images) > 0:
return "".join(
[item for item in [image_to_mode(image, doc, images_list, get_image_id) for image in images] if
item is not None])
elif paragraph_element.text is not None:
return paragraph_element.text
return ""
except Exception as e:
print(e)
return ""
def get_paragraph_txt(paragraph: Paragraph, doc: Document, images_list, get_image_id):
try:
return "".join([get_paragraph_element_txt(e, doc, images_list, get_image_id) for e in paragraph._element])
except Exception as e:
return ""
def get_cell_text(cell, doc: Document, images_list, get_image_id):
try:
return "".join(
[get_paragraph_txt(paragraph, doc, images_list, get_image_id) for paragraph in cell.paragraphs]).replace(
"\n", '</br>')
except Exception as e:
return ""
def get_image_id_func():
image_map = {}
def get_image_id(image_id):
_v = image_map.get(image_id)
if _v is None:
image_map[image_id] = uuid.uuid1()
return image_map.get(image_id)
return _v
return get_image_id
class DocSplitHandle(BaseSplitHandle): class DocSplitHandle(BaseSplitHandle):
@staticmethod @staticmethod
def paragraph_to_md(paragraph): def paragraph_to_md(paragraph: Paragraph, doc: Document, images_list, get_image_id):
try: try:
psn = paragraph.style.name psn = paragraph.style.name
if psn.startswith('Heading'): if psn.startswith('Heading'):
return "".join(["#" for i in range(int(psn.replace("Heading ", '')))]) + " " + paragraph.text return "".join(["#" for i in range(int(psn.replace("Heading ", '')))]) + " " + paragraph.text
except Exception as e: except Exception as e:
return paragraph.text return paragraph.text
return paragraph.text return get_paragraph_txt(paragraph, doc, images_list, get_image_id)
@staticmethod @staticmethod
def table_to_md(table): def table_to_md(table, doc: Document, images_list, get_image_id):
rows = table.rows rows = table.rows
# 创建 Markdown 格式的表格 # 创建 Markdown 格式的表格
md_table = '| ' + ' | '.join([cell.text.replace("\n", '</br>') for cell in rows[0].cells]) + ' |\n' md_table = '| ' + ' | '.join(
[get_cell_text(cell, doc, images_list, get_image_id) for cell in rows[0].cells]) + ' |\n'
md_table += '| ' + ' | '.join(['---' for i in range(len(rows[0].cells))]) + ' |\n' md_table += '| ' + ' | '.join(['---' for i in range(len(rows[0].cells))]) + ' |\n'
for row in rows[1:]: for row in rows[1:]:
md_table += '| ' + ' | '.join([cell.text.replace("\n", '</br>') for cell in row.cells]) + ' |\n' md_table += '| ' + ' | '.join(
[get_cell_text(cell, doc, images_list, get_image_id) for cell in row.cells]) + ' |\n'
return md_table return md_table
def to_md(self, doc): def to_md(self, doc, images_list, get_image_id):
elements = [] elements = []
for element in doc.element.body: for element in doc.element.body:
if element.tag.endswith('tbl'): if element.tag.endswith('tbl'):
@ -57,21 +118,29 @@ class DocSplitHandle(BaseSplitHandle):
# 处理段落 # 处理段落
paragraph = Paragraph(element, doc) paragraph = Paragraph(element, doc)
elements.append(paragraph) elements.append(paragraph)
return "\n".join( return "\n".join(
[self.paragraph_to_md(element) if isinstance(element, Paragraph) else self.table_to_md(element) for element [self.paragraph_to_md(element, doc, images_list, get_image_id) if isinstance(element,
Paragraph) else self.table_to_md(
element,
doc,
images_list, get_image_id)
for element
in elements]) in elements])
def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer): def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
try: try:
image_list = []
buffer = get_buffer(file) buffer = get_buffer(file)
doc = Document(io.BytesIO(buffer)) doc = Document(io.BytesIO(buffer))
content = self.to_md(doc) content = self.to_md(doc, image_list, get_image_id_func())
if len(image_list) > 0:
save_image(image_list)
if pattern_list is not None and len(pattern_list) > 0: if pattern_list is not None and len(pattern_list) > 0:
split_model = SplitModel(pattern_list, with_filter, limit) split_model = SplitModel(pattern_list, with_filter, limit)
else: else:
split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit) split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit)
except BaseException as e: except BaseException as e:
traceback.print_exception(e)
return {'name': file.name, return {'name': file.name,
'content': []} 'content': []}
return {'name': file.name, return {'name': file.name,

View File

@ -30,7 +30,7 @@ def number_to_text(pdf_document, page_number):
class PdfSplitHandle(BaseSplitHandle): class PdfSplitHandle(BaseSplitHandle):
def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer): def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer,save_image):
try: try:
buffer = get_buffer(file) buffer = get_buffer(file)
pdf_document = fitz.open(file.name, buffer) pdf_document = fitz.open(file.name, buffer)

View File

@ -34,7 +34,7 @@ class TextSplitHandle(BaseSplitHandle):
return True return True
return False return False
def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer): def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
buffer = get_buffer(file) buffer = get_buffer(file)
if pattern_list is not None and len(pattern_list) > 0: if pattern_list is not None and len(pattern_list) > 0:
split_model = SplitModel(pattern_list, with_filter, limit) split_model = SplitModel(pattern_list, with_filter, limit)

View File

@ -33,7 +33,7 @@ from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from common.util.fork import Fork from common.util.fork import Fork
from common.util.split_model import get_split_model from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
@ -627,9 +627,13 @@ default_split_handle = TextSplitHandle()
split_handles = [DocSplitHandle(), PdfSplitHandle(), default_split_handle] split_handles = [DocSplitHandle(), PdfSplitHandle(), default_split_handle]
def save_image(image_list):
QuerySet(Image).bulk_create(image_list)
def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int): def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int):
get_buffer = FileBufferHandle().get_buffer get_buffer = FileBufferHandle().get_buffer
for split_handle in split_handles: for split_handle in split_handles:
if split_handle.support(file, get_buffer): if split_handle.support(file, get_buffer):
return split_handle.handle(file, pattern_list, with_filter, limit, get_buffer) return split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image)
return default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer) return default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image)