diff --git a/utils/dataset_manager.py b/utils/dataset_manager.py index 1aa0b87..dd15fa7 100644 --- a/utils/dataset_manager.py +++ b/utils/dataset_manager.py @@ -16,6 +16,9 @@ from utils.file_utils import ( load_processed_files_log, save_processed_files_log, remove_file_or_directory ) +from utils.excel_csv_processor import ( + is_excel_file, is_csv_file, process_excel_file, process_csv_file +) async def download_dataset_files(unique_id: str, files: Dict[str, List[str]]) -> Dict[str, List[str]]: @@ -43,13 +46,13 @@ async def download_dataset_files(unique_id: str, files: Dict[str, List[str]]) -> with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(extract_dir) - # Find all extracted txt and md files + # Find all extracted txt, md, xlsx, xls, and csv files for root, dirs, files in os.walk(extract_dir): for file in files: - if file.lower().endswith(('.txt', '.md')): + if file.lower().endswith(('.txt', '.md', '.xlsx', '.xls', '.csv')): extracted_files.append(os.path.join(root, file)) - print(f"Extracted {len(extracted_files)} txt/md files from {zip_path}") + print(f"Extracted {len(extracted_files)} txt/md/xlsx/csv files from {zip_path}") return extracted_files except Exception as e: @@ -86,6 +89,7 @@ async def download_dataset_files(unique_id: str, files: Dict[str, List[str]]) -> # Read and combine all files for this key combined_content = [] + pagination_lines = [] # Collect pagination lines from all files all_processed_files = [] for file_path in file_list: @@ -151,14 +155,41 @@ async def download_dataset_files(unique_id: str, files: Dict[str, List[str]]) -> # Process all files (extracted from zip or single file) for process_file_path in files_to_process: try: - with open(process_file_path, 'r', encoding='utf-8') as f: - content = f.read().strip() + base_filename = os.path.basename(process_file_path) - if content: - # Add file content with page separator - base_filename = os.path.basename(process_file_path) - combined_content.append(f"# Page {base_filename}") - combined_content.append(content) + # Check if it's an Excel file + if is_excel_file(process_file_path): + print(f"Processing Excel file: {base_filename}") + document_content, excel_pagination_lines = process_excel_file(process_file_path) + + if document_content: + combined_content.append(f"# Page {base_filename}") + combined_content.append(document_content) + + # Collect pagination lines from Excel files + pagination_lines.extend(excel_pagination_lines) + + # Check if it's a CSV file + elif is_csv_file(process_file_path): + print(f"Processing CSV file: {base_filename}") + document_content, csv_pagination_lines = process_csv_file(process_file_path) + + if document_content: + combined_content.append(f"# Page {base_filename}") + combined_content.append(document_content) + + # Collect pagination lines from CSV files + pagination_lines.extend(csv_pagination_lines) + + # Handle text files (original logic) + else: + with open(process_file_path, 'r', encoding='utf-8') as f: + content = f.read().strip() + + if content: + # Add file content with page separator + combined_content.append(f"# Page {base_filename}") + combined_content.append(content) except Exception as e: print(f"Failed to read file content from {process_file_path}: {str(e)}") @@ -186,12 +217,24 @@ async def download_dataset_files(unique_id: str, files: Dict[str, List[str]]) -> try: import sys sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'embedding')) - from embedding import split_document_by_pages, embed_document + from embedding import embed_document - # Generate pagination - print(f" Generating pagination for {key}") - pages = split_document_by_pages(str(document_file), str(pagination_file)) - print(f" Generated {len(pages)} pages") + # Generate pagination file from collected pagination lines + # For Excel/CSV files, use the pagination format we collected + # For text files, fall back to the original pagination generation + if pagination_lines: + print(f" Writing pagination data from Excel/CSV files for {key}") + with open(pagination_file, 'w', encoding='utf-8') as f: + for line in pagination_lines: + if line.strip(): + f.write(f"{line}\n") + print(f" Generated {len(pagination_lines)} pagination lines") + else: + # For text-only files, use the original pagination generation + from embedding import split_document_by_pages + print(f" Generating pagination from text files for {key}") + pages = split_document_by_pages(str(document_file), str(pagination_file)) + print(f" Generated {len(pages)} pages") # Generate embeddings print(f" Generating embeddings for {key}") diff --git a/utils/excel_csv_processor.py b/utils/excel_csv_processor.py new file mode 100644 index 0000000..d04bee8 --- /dev/null +++ b/utils/excel_csv_processor.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 +""" +Excel and CSV file processor for converting data to document.txt and pagination.txt formats. +""" + +import os +import pandas as pd +from typing import List, Dict, Any, Tuple + + +def read_excel_sheets(file_path: str) -> Dict[str, List[Dict[str, Any]]]: + """ + 读取 Excel 文件的所有 sheet + + Args: + file_path: Excel 文件路径 + + Returns: + Dict: 键为 sheet 名称,值为该 sheet 的数据列表 + """ + try: + # 读取所有 sheet + excel_file = pd.ExcelFile(file_path) + sheets_data = {} + + for sheet_name in excel_file.sheet_names: + try: + # 读取每个 sheet 的数据 + df = pd.read_excel(file_path, sheet_name=sheet_name) + + # 转换为字典列表,跳过 NaN 值 + sheet_data = [] + for _, row in df.iterrows(): + # 将 NaN 值转换为空字符串 + row_dict = {} + for col in df.columns: + value = row[col] + if pd.isna(value): + value = "" + elif isinstance(value, (int, float)): + value = str(value) + else: + value = str(value).strip() + row_dict[str(col)] = value + + # 只添加非空行 + if any(v.strip() for v in row_dict.values()): + sheet_data.append(row_dict) + + sheets_data[sheet_name] = sheet_data + print(f"读取 Excel sheet '{sheet_name}': {len(sheet_data)} 行数据") + + except Exception as e: + print(f"读取 Excel sheet '{sheet_name}' 失败: {str(e)}") + continue + + return sheets_data + + except Exception as e: + print(f"读取 Excel 文件失败: {str(e)}") + return {} + + +def read_csv_file(file_path: str, encoding: str = 'utf-8') -> List[Dict[str, Any]]: + """ + 读取 CSV 文件 + + Args: + file_path: CSV 文件路径 + encoding: 文件编码 + + Returns: + List: CSV 数据列表 + """ + try: + # 尝试不同编码 + encodings_to_try = [encoding, 'utf-8', 'gbk', 'gb2312', 'utf-8-sig'] + + for enc in encodings_to_try: + try: + df = pd.read_csv(file_path, encoding=enc) + break + except UnicodeDecodeError: + continue + else: + # 如果所有编码都失败,使用默认编码并忽略错误 + df = pd.read_csv(file_path, encoding='utf-8', errors='ignore') + + # 转换为字典列表,跳过 NaN 值 + csv_data = [] + for _, row in df.iterrows(): + # 将 NaN 值转换为空字符串 + row_dict = {} + for col in df.columns: + value = row[col] + if pd.isna(value): + value = "" + elif isinstance(value, (int, float)): + value = str(value) + else: + value = str(value).strip() + row_dict[str(col)] = value + + # 只添加非空行 + if any(v.strip() for v in row_dict.values()): + csv_data.append(row_dict) + + print(f"读取 CSV 文件: {len(csv_data)} 行数据") + return csv_data + + except Exception as e: + print(f"读取 CSV 文件失败: {str(e)}") + return [] + + +def convert_to_markdown_format(data_list: List[Dict[str, Any]], sheet_name: str = "") -> str: + """ + 将数据转换为指定的 markdown 格式 + + Args: + data_list: 数据列表 + sheet_name: sheet 名称(可选) + + Returns: + str: markdown 格式的文本 + """ + if not data_list: + return "" + + markdown_content = [] + + # 添加 sheet 标题 + if sheet_name: + markdown_content.append(f"# Sheet: {sheet_name}") + markdown_content.append("") + + for i, row_data in enumerate(data_list, 1): + # 为每行数据生成 markdown 格式 + row_markdown = [] + + for key, value in row_data.items(): + if value and value.strip(): # 只包含非空值 + row_markdown.append(f"{key}: {value.strip()}") + + if row_markdown: + markdown_content.extend(row_markdown) + + # 在行之间添加分隔符,除了最后一行 + if i < len(data_list): + markdown_content.append("---") + markdown_content.append("") + + return "\n".join(markdown_content) + + +def convert_to_pagination_format(data_list: List[Dict[str, Any]]) -> List[str]: + """ + 将数据转换为 key:value;key:value 格式 + + Args: + data_list: 数据列表 + + Returns: + List: pagination 格式的文本列表 + """ + if not data_list: + return [] + + pagination_lines = [] + + for row_data in data_list: + # 为每行数据生成 pagination 格式 + row_pairs = [] + + for key, value in row_data.items(): + if value and value.strip(): # 只包含非空值 + # 去除值中的分号和换行符,避免格式问题 + clean_value = str(value).replace(';', ',').replace('\n', ' ').strip() + if clean_value: + row_pairs.append(f"{key}:{clean_value}") + + if row_pairs: + pagination_line = ";".join(row_pairs) + pagination_lines.append(pagination_line) + + return pagination_lines + + +def process_excel_file(file_path: str) -> Tuple[str, List[str]]: + """ + 处理 Excel 文件,生成 document.txt 和 pagination.txt 内容 + + Args: + file_path: Excel 文件路径 + + Returns: + Tuple: (document_content, pagination_lines) + """ + sheets_data = read_excel_sheets(file_path) + + document_content_parts = [] + pagination_lines = [] + + # 处理每个 sheet + for sheet_name, sheet_data in sheets_data.items(): + if sheet_data: + # 生成 markdown 格式的文档内容 + markdown_content = convert_to_markdown_format(sheet_data, sheet_name) + if markdown_content: + document_content_parts.append(markdown_content) + + # 生成 pagination 格式内容 + sheet_pagination_lines = convert_to_pagination_format(sheet_data) + pagination_lines.extend(sheet_pagination_lines) + + # 合并所有 sheet 的文档内容 + document_content = "\n\n".join(document_content_parts) + + return document_content, pagination_lines + + +def process_csv_file(file_path: str) -> Tuple[str, List[str]]: + """ + 处理 CSV 文件,生成 document.txt 和 pagination.txt 内容 + + Args: + file_path: CSV 文件路径 + + Returns: + Tuple: (document_content, pagination_lines) + """ + csv_data = read_csv_file(file_path) + + if not csv_data: + return "", [] + + # 生成 markdown 格式的文档内容 + document_content = convert_to_markdown_format(csv_data) + + # 生成 pagination 格式内容 + pagination_lines = convert_to_pagination_format(csv_data) + + return document_content, pagination_lines + + +def is_excel_file(file_path: str) -> bool: + """检查文件是否为 Excel 文件""" + return file_path.lower().endswith(('.xlsx', '.xls')) + + +def is_csv_file(file_path: str) -> bool: + """检查文件是否为 CSV 文件""" + return file_path.lower().endswith('.csv') \ No newline at end of file diff --git a/utils/file_utils.py b/utils/file_utils.py index 5eccc4a..d7e47a5 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -59,13 +59,13 @@ def extract_zip_file(zip_path: str, extract_dir: str) -> List[str]: with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(extract_dir) - # Find all extracted txt and md files + # Find all extracted txt, md, xlsx, xls, and csv files for root, dirs, files in os.walk(extract_dir): for file in files: - if file.lower().endswith(('.txt', '.md')): + if file.lower().endswith(('.txt', '.md', '.xlsx', '.xls', '.csv')): extracted_files.append(os.path.join(root, file)) - print(f"Extracted {len(extracted_files)} txt/md files from {zip_path}") + print(f"Extracted {len(extracted_files)} txt/md/xlsx/csv files from {zip_path}") return extracted_files except Exception as e: