fix: 修复qa知识库导入失败错误 (#536)

This commit is contained in:
shaohuzhang1 2024-05-24 17:59:02 +08:00 committed by GitHub
parent d5b0937015
commit e9a05b1255
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 110 additions and 65 deletions

View File

@ -9,6 +9,37 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
def get_row_value(row, title_row_index_dict, field):
index = title_row_index_dict.get(field)
if index is None:
return None
if (len(row) - 1) >= index:
return row[index]
return None
def get_title_row_index_dict(title_row_list):
title_row_index_dict = {}
if len(title_row_list) == 1:
title_row_index_dict['content'] = 0
elif len(title_row_list) == 1:
title_row_index_dict['title'] = 0
title_row_index_dict['content'] = 1
else:
title_row_index_dict['title'] = 0
title_row_index_dict['content'] = 1
title_row_index_dict['problem_list'] = 2
for index in range(len(title_row_list)):
title_row = title_row_list[index]
if title_row.startswith('分段标题'):
title_row_index_dict['title'] = index
if title_row.startswith('分段内容'):
title_row_index_dict['content'] = index
if title_row.startswith('问题'):
title_row_index_dict['problem_list'] = index
return title_row_index_dict
class BaseParseQAHandle(ABC): class BaseParseQAHandle(ABC):
@abstractmethod @abstractmethod
def support(self, file, get_buffer): def support(self, file, get_buffer):

View File

@ -11,7 +11,7 @@ import io
from charset_normalizer import detect from charset_normalizer import detect
from common.handle.base_parse_qa_handle import BaseParseQAHandle from common.handle.base_parse_qa_handle import BaseParseQAHandle, get_title_row_index_dict, get_row_value
def read_csv_standard(file_path): def read_csv_standard(file_path):
@ -32,25 +32,28 @@ class CsvParseQAHandle(BaseParseQAHandle):
def handle(self, file, get_buffer): def handle(self, file, get_buffer):
buffer = get_buffer(file) buffer = get_buffer(file)
reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding']))
try: try:
title_row_list = reader.__next__() reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding']))
try:
title_row_list = reader.__next__()
except Exception as e:
return [{'name': file.name, 'paragraphs': []}]
if len(title_row_list) == 0:
return [{'name': file.name, 'paragraphs': []}]
title_row_index_dict = get_title_row_index_dict(title_row_list)
paragraph_list = []
for row in reader:
content = get_row_value(row, title_row_index_dict, 'content')
if content is None:
continue
problem = get_row_value(row, title_row_index_dict, 'problem_list')
problem = str(problem) if problem is not None else ''
problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0]
title = get_row_value(row, title_row_index_dict, 'title')
title = str(title) if title is not None else ''
paragraph_list.append({'title': title[0:255],
'content': content[0:4096],
'problem_list': problem_list})
return [{'name': file.name, 'paragraphs': paragraph_list}]
except Exception as e: except Exception as e:
return [] return [{'name': file.name, 'paragraphs': []}]
title_row_index_dict = {'title': 0, 'content': 1, 'problem_list': 2}
for index in range(len(title_row_list)):
title_row = title_row_list[index]
if title_row.startswith('分段标题'):
title_row_index_dict['title'] = index
if title_row.startswith('分段内容'):
title_row_index_dict['content'] = index
if title_row.startswith('问题'):
title_row_index_dict['problem_list'] = index
paragraph_list = []
for row in reader:
problem = row[title_row_index_dict.get('problem_list')]
problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0]
paragraph_list.append({'title': row[title_row_index_dict.get('title')][0:255],
'content': row[title_row_index_dict.get('content')][0:4096],
'problem_list': problem_list})
return [{'name': file.name, 'paragraphs': paragraph_list}]

View File

@ -9,7 +9,7 @@
import xlrd import xlrd
from common.handle.base_parse_qa_handle import BaseParseQAHandle from common.handle.base_parse_qa_handle import BaseParseQAHandle, get_title_row_index_dict, get_row_value
def handle_sheet(file_name, sheet): def handle_sheet(file_name, sheet):
@ -17,22 +17,22 @@ def handle_sheet(file_name, sheet):
try: try:
title_row_list = next(rows) title_row_list = next(rows)
except Exception as e: except Exception as e:
return None return {'name': file_name, 'paragraphs': []}
title_row_index_dict = {'title': 0, 'content': 1, 'problem_list': 2} if len(title_row_list) == 0:
for index in range(len(title_row_list)): return {'name': file_name, 'paragraphs': []}
title_row = str(title_row_list[index]) title_row_index_dict = get_title_row_index_dict(title_row_list)
if title_row.startswith('分段标题'):
title_row_index_dict['title'] = index
if title_row.startswith('分段内容'):
title_row_index_dict['content'] = index
if title_row.startswith('问题'):
title_row_index_dict['problem_list'] = index
paragraph_list = [] paragraph_list = []
for row in rows: for row in rows:
problem = str(row[title_row_index_dict.get('problem_list')]) content = get_row_value(row, title_row_index_dict, 'content')
if content is None:
continue
problem = get_row_value(row, title_row_index_dict, 'problem_list')
problem = str(problem) if problem is not None else ''
problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0]
paragraph_list.append({'title': str(row[title_row_index_dict.get('title')])[0:255], title = get_row_value(row, title_row_index_dict, 'title')
'content': str(row[title_row_index_dict.get('content')])[0:4096], title = str(title) if title is not None else ''
paragraph_list.append({'title': title[0:255],
'content': content[0:4096],
'problem_list': problem_list}) 'problem_list': problem_list})
return {'name': file_name, 'paragraphs': paragraph_list} return {'name': file_name, 'paragraphs': paragraph_list}
@ -40,16 +40,21 @@ def handle_sheet(file_name, sheet):
class XlsParseQAHandle(BaseParseQAHandle): class XlsParseQAHandle(BaseParseQAHandle):
def support(self, file, get_buffer): def support(self, file, get_buffer):
file_name: str = file.name.lower() file_name: str = file.name.lower()
if file_name.endswith(".xls"): buffer = get_buffer(file)
if file_name.endswith(".xls") and xlrd.inspect_format(content=buffer):
return True return True
return False return False
def handle(self, file, get_buffer): def handle(self, file, get_buffer):
buffer = get_buffer(file) buffer = get_buffer(file)
workbook = xlrd.open_workbook(file_contents=buffer) try:
worksheets = workbook.sheets() workbook = xlrd.open_workbook(file_contents=buffer)
worksheets_size = len(worksheets) worksheets = workbook.sheets()
return [row for row in worksheets_size = len(worksheets)
[handle_sheet(file.name, sheet) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet( return [row for row in
sheet.name, sheet) for sheet [handle_sheet(file.name,
in worksheets] if row is not None] sheet) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet(
sheet.name, sheet) for sheet
in worksheets] if row is not None]
except Exception as e:
return [{'name': file.name, 'paragraphs': []}]

View File

@ -10,30 +10,32 @@ import io
import openpyxl import openpyxl
from common.handle.base_parse_qa_handle import BaseParseQAHandle from common.handle.base_parse_qa_handle import BaseParseQAHandle, get_title_row_index_dict, get_row_value
def handle_sheet(file_name, sheet): def handle_sheet(file_name, sheet):
rows = sheet.rows rows = sheet.rows
try: try:
title_row_list = next(rows) title_row_list = next(rows)
title_row_list = [row.value for row in title_row_list]
except Exception as e: except Exception as e:
return None return {'name': file_name, 'paragraphs': []}
title_row_index_dict = {} if len(title_row_list) == 0:
for index in range(len(title_row_list)): return {'name': file_name, 'paragraphs': []}
title_row = str(title_row_list[index].value) title_row_index_dict = get_title_row_index_dict(title_row_list)
if title_row.startswith('分段标题'):
title_row_index_dict['title'] = index
if title_row.startswith('分段内容'):
title_row_index_dict['content'] = index
if title_row.startswith('问题'):
title_row_index_dict['problem_list'] = index
paragraph_list = [] paragraph_list = []
for row in rows: for row in rows:
problem = str(row[title_row_index_dict.get('problem_list')].value) content = get_row_value(row, title_row_index_dict, 'content')
if content is None:
continue
problem = get_row_value(row, title_row_index_dict, 'problem_list')
problem = str(problem.value) if problem is not None else ''
problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0]
paragraph_list.append({'title': str(row[title_row_index_dict.get('title')].value)[0:255], title = get_row_value(row, title_row_index_dict, 'title')
'content': str(row[title_row_index_dict.get('content')].value)[0:4096], title = str(title.value) if title is not None else ''
content = content.value
paragraph_list.append({'title': title[0:255],
'content': content[0:4096],
'problem_list': problem_list}) 'problem_list': problem_list})
return {'name': file_name, 'paragraphs': paragraph_list} return {'name': file_name, 'paragraphs': paragraph_list}
@ -47,10 +49,14 @@ class XlsxParseQAHandle(BaseParseQAHandle):
def handle(self, file, get_buffer): def handle(self, file, get_buffer):
buffer = get_buffer(file) buffer = get_buffer(file)
workbook = openpyxl.load_workbook(io.BytesIO(buffer)) try:
worksheets = workbook.worksheets workbook = openpyxl.load_workbook(io.BytesIO(buffer))
worksheets_size = len(worksheets) worksheets = workbook.worksheets
return [row for row in worksheets_size = len(worksheets)
[handle_sheet(file.name, sheet) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet( return [row for row in
sheet.title, sheet) for sheet [handle_sheet(file.name,
in worksheets] if row is not None] sheet) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet(
sheet.title, sheet) for sheet
in worksheets] if row is not None]
except Exception as e:
return [{'name': file.name, 'paragraphs': []}]

View File

@ -523,8 +523,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
@staticmethod @staticmethod
def parse_qa_file(file): def parse_qa_file(file):
get_buffer = FileBufferHandle().get_buffer
for parse_qa_handle in parse_qa_handle_list: for parse_qa_handle in parse_qa_handle_list:
get_buffer = FileBufferHandle().get_buffer
if parse_qa_handle.support(file, get_buffer): if parse_qa_handle.support(file, get_buffer):
return parse_qa_handle.handle(file, get_buffer) return parse_qa_handle.handle(file, get_buffer)
raise AppApiException(500, '不支持的文件格式') raise AppApiException(500, '不支持的文件格式')