#!/usr/bin/env python3 """ Data merging functions for combining processed file results. """ import os import pickle import logging from typing import Dict, List, Optional, Tuple import json # Configure logger logger = logging.getLogger('app') # Try to import numpy, but handle if missing try: import numpy as np NUMPY_SUPPORT = True except ImportError: logger.warning("NumPy not available, some embedding features may be limited") NUMPY_SUPPORT = False def merge_documents_by_group(unique_id: str, group_name: str) -> Dict: """Merge all document.txt files in a group into a single document.""" processed_group_dir = os.path.join("projects", "data", unique_id, "processed", group_name) dataset_group_dir = os.path.join("projects", "data", unique_id, "dataset", group_name) os.makedirs(dataset_group_dir, exist_ok=True) merged_document_path = os.path.join(dataset_group_dir, "document.txt") result = { "success": False, "merged_document_path": merged_document_path, "source_files": [], "total_pages": 0, "total_characters": 0, "error": None } try: # Find all document.txt files in the processed directory document_files = [] if os.path.exists(processed_group_dir): for item in os.listdir(processed_group_dir): item_path = os.path.join(processed_group_dir, item) if os.path.isdir(item_path): document_path = os.path.join(item_path, "document.txt") if os.path.exists(document_path) and os.path.getsize(document_path) > 0: document_files.append((item, document_path)) if not document_files: result["error"] = "No document files found to merge" return result # Merge all documents with page separators merged_content = [] total_characters = 0 for filename_stem, document_path in sorted(document_files): try: with open(document_path, 'r', encoding='utf-8') as f: content = f.read().strip() if content: merged_content.append(f"# Page {filename_stem}") merged_content.append(content) total_characters += len(content) result["source_files"].append(filename_stem) except Exception as e: logger.error(f"Error reading document file {document_path}: {str(e)}") continue if merged_content: # Write merged document with open(merged_document_path, 'w', encoding='utf-8') as f: f.write('\n\n'.join(merged_content)) result["total_pages"] = len(document_files) result["total_characters"] = total_characters result["success"] = True else: result["error"] = "No valid content found in document files" except Exception as e: result["error"] = f"Document merging failed: {str(e)}" logger.error(f"Error merging documents for group {group_name}: {str(e)}") return result def merge_paginations_by_group(unique_id: str, group_name: str) -> Dict: """Merge all pagination.txt files in a group.""" processed_group_dir = os.path.join("projects", "data", unique_id, "processed", group_name) dataset_group_dir = os.path.join("projects", "data", unique_id, "dataset", group_name) os.makedirs(dataset_group_dir, exist_ok=True) merged_pagination_path = os.path.join(dataset_group_dir, "pagination.txt") result = { "success": False, "merged_pagination_path": merged_pagination_path, "source_files": [], "total_lines": 0, "error": None } try: # Find all pagination.txt files pagination_files = [] if os.path.exists(processed_group_dir): for item in os.listdir(processed_group_dir): item_path = os.path.join(processed_group_dir, item) if os.path.isdir(item_path): pagination_path = os.path.join(item_path, "pagination.txt") if os.path.exists(pagination_path) and os.path.getsize(pagination_path) > 0: pagination_files.append((item, pagination_path)) if not pagination_files: result["error"] = "No pagination files found to merge" return result # Merge all pagination files merged_lines = [] for filename_stem, pagination_path in sorted(pagination_files): try: with open(pagination_path, 'r', encoding='utf-8') as f: lines = f.readlines() for line in lines: line = line.strip() if line: merged_lines.append(line) result["source_files"].append(filename_stem) except Exception as e: logger.error(f"Error reading pagination file {pagination_path}: {str(e)}") continue if merged_lines: # Write merged pagination with open(merged_pagination_path, 'w', encoding='utf-8') as f: for line in merged_lines: f.write(f"{line}\n") result["total_lines"] = len(merged_lines) result["success"] = True else: result["error"] = "No valid pagination data found" except Exception as e: result["error"] = f"Pagination merging failed: {str(e)}" logger.error(f"Error merging paginations for group {group_name}: {str(e)}") return result def merge_embeddings_by_group(unique_id: str, group_name: str) -> Dict: """Merge all embedding.pkl files in a group.""" processed_group_dir = os.path.join("projects", "data", unique_id, "processed", group_name) dataset_group_dir = os.path.join("projects", "data", unique_id, "dataset", group_name) os.makedirs(dataset_group_dir, exist_ok=True) merged_embedding_path = os.path.join(dataset_group_dir, "embedding.pkl") result = { "success": False, "merged_embedding_path": merged_embedding_path, "source_files": [], "total_chunks": 0, "total_dimensions": 0, "error": None } try: # Find all embedding.pkl files embedding_files = [] if os.path.exists(processed_group_dir): for item in os.listdir(processed_group_dir): item_path = os.path.join(processed_group_dir, item) if os.path.isdir(item_path): embedding_path = os.path.join(item_path, "embedding.pkl") if os.path.exists(embedding_path) and os.path.getsize(embedding_path) > 0: embedding_files.append((item, embedding_path)) if not embedding_files: result["error"] = "No embedding files found to merge" return result # Load and merge all embedding data all_chunks = [] all_embeddings = [] # 修复:收集所有embeddings向量 total_chunks = 0 dimensions = 0 chunking_strategy = 'unknown' chunking_params = {} model_path = 'TaylorAI/gte-tiny' for filename_stem, embedding_path in sorted(embedding_files): try: with open(embedding_path, 'rb') as f: embedding_data = pickle.load(f) if isinstance(embedding_data, dict) and 'chunks' in embedding_data: chunks = embedding_data['chunks'] # 获取embeddings向量(关键修复) if 'embeddings' in embedding_data: embeddings = embedding_data['embeddings'] all_embeddings.append(embeddings) # 从第一个文件获取模型信息 if 'model_path' in embedding_data: model_path = embedding_data['model_path'] if 'chunking_strategy' in embedding_data: chunking_strategy = embedding_data['chunking_strategy'] if 'chunking_params' in embedding_data: chunking_params = embedding_data['chunking_params'] # Add source file metadata to each chunk for chunk in chunks: if isinstance(chunk, dict): chunk['source_file'] = filename_stem chunk['source_group'] = group_name elif isinstance(chunk, str): # 如果chunk是字符串,保持原样 pass all_chunks.extend(chunks) total_chunks += len(chunks) result["source_files"].append(filename_stem) except Exception as e: logger.error(f"Error loading embedding file {embedding_path}: {str(e)}") continue if all_chunks and all_embeddings: # 合并所有embeddings向量 try: # 尝试使用torch合并张量 import torch if all(isinstance(emb, torch.Tensor) for emb in all_embeddings): merged_embeddings = torch.cat(all_embeddings, dim=0) dimensions = merged_embeddings.shape[1] else: # 如果不是tensor类型,尝试转换为numpy import numpy as np if NUMPY_SUPPORT: np_embeddings = [] for emb in all_embeddings: if hasattr(emb, 'numpy'): np_embeddings.append(emb.numpy()) elif isinstance(emb, np.ndarray): np_embeddings.append(emb) else: # 如果无法转换,跳过这个文件 logger.warning(f"Warning: Cannot convert embedding to numpy from file {filename_stem}") continue if np_embeddings: merged_embeddings = np.concatenate(np_embeddings, axis=0) dimensions = merged_embeddings.shape[1] else: result["error"] = "No valid embedding tensors could be merged" return result else: result["error"] = "NumPy not available for merging embeddings" return result except ImportError: # 如果没有torch,尝试使用numpy if NUMPY_SUPPORT: import numpy as np np_embeddings = [] for emb in all_embeddings: if hasattr(emb, 'numpy'): np_embeddings.append(emb.numpy()) elif isinstance(emb, np.ndarray): np_embeddings.append(emb) else: logger.warning(f"Warning: Cannot convert embedding to numpy from file {filename_stem}") continue if np_embeddings: merged_embeddings = np.concatenate(np_embeddings, axis=0) dimensions = merged_embeddings.shape[1] else: result["error"] = "No valid embedding tensors could be merged" return result else: result["error"] = "Neither torch nor numpy available for merging embeddings" return result except Exception as e: result["error"] = f"Failed to merge embedding tensors: {str(e)}" logger.error(f"Error merging embedding tensors: {str(e)}") return result # Create merged embedding data structure merged_embedding_data = { 'chunks': all_chunks, 'embeddings': merged_embeddings, # 关键修复:添加embeddings键 'total_chunks': total_chunks, 'dimensions': dimensions, 'source_files': result["source_files"], 'group_name': group_name, 'merged_at': str(__import__('time').time()), 'chunking_strategy': chunking_strategy, 'chunking_params': chunking_params, 'model_path': model_path } # Save merged embeddings with open(merged_embedding_path, 'wb') as f: pickle.dump(merged_embedding_data, f) result["total_chunks"] = total_chunks result["total_dimensions"] = dimensions result["success"] = True else: result["error"] = "No valid embedding data found" except Exception as e: result["error"] = f"Embedding merging failed: {str(e)}" logger.error(f"Error merging embeddings for group {group_name}: {str(e)}") return result def merge_all_data_by_group(unique_id: str, group_name: str) -> Dict: """Merge documents, paginations, and embeddings for a group.""" merge_results = { "group_name": group_name, "unique_id": unique_id, "success": True, "document_merge": None, "pagination_merge": None, "embedding_merge": None, "errors": [] } # Merge documents document_result = merge_documents_by_group(unique_id, group_name) merge_results["document_merge"] = document_result if not document_result["success"]: merge_results["success"] = False merge_results["errors"].append(f"Document merge failed: {document_result['error']}") # Merge paginations pagination_result = merge_paginations_by_group(unique_id, group_name) merge_results["pagination_merge"] = pagination_result if not pagination_result["success"]: merge_results["success"] = False merge_results["errors"].append(f"Pagination merge failed: {pagination_result['error']}") # Merge embeddings embedding_result = merge_embeddings_by_group(unique_id, group_name) merge_results["embedding_merge"] = embedding_result if not embedding_result["success"]: merge_results["success"] = False merge_results["errors"].append(f"Embedding merge failed: {embedding_result['error']}") return merge_results def get_group_merge_status(unique_id: str, group_name: str) -> Dict: """Get the status of merged data for a group.""" dataset_group_dir = os.path.join("projects", "data", unique_id, "dataset", group_name) status = { "group_name": group_name, "unique_id": unique_id, "dataset_dir_exists": os.path.exists(dataset_group_dir), "document_exists": False, "document_size": 0, "pagination_exists": False, "pagination_size": 0, "embedding_exists": False, "embedding_size": 0, "merge_complete": False } if os.path.exists(dataset_group_dir): document_path = os.path.join(dataset_group_dir, "document.txt") pagination_path = os.path.join(dataset_group_dir, "pagination.txt") embedding_path = os.path.join(dataset_group_dir, "embedding.pkl") if os.path.exists(document_path): status["document_exists"] = True status["document_size"] = os.path.getsize(document_path) if os.path.exists(pagination_path): status["pagination_exists"] = True status["pagination_size"] = os.path.getsize(pagination_path) if os.path.exists(embedding_path): status["embedding_exists"] = True status["embedding_size"] = os.path.getsize(embedding_path) # Check if all files exist and are not empty if (status["document_exists"] and status["document_size"] > 0 and status["pagination_exists"] and status["pagination_size"] > 0 and status["embedding_exists"] and status["embedding_size"] > 0): status["merge_complete"] = True return status def cleanup_dataset_group(unique_id: str, group_name: str) -> bool: """Clean up merged dataset files for a group.""" dataset_group_dir = os.path.join("projects", "data", unique_id, "dataset", group_name) try: if os.path.exists(dataset_group_dir): import shutil shutil.rmtree(dataset_group_dir) logger.info(f"Cleaned up dataset group: {group_name}") return True else: return True # Nothing to clean up except Exception as e: logger.error(f"Error cleaning up dataset group {group_name}: {str(e)}") return False