436 lines
17 KiB
Python
436 lines
17 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Data merging functions for combining processed file results.
|
||
"""
|
||
|
||
import os
|
||
import pickle
|
||
from typing import Dict, List, Optional, Tuple
|
||
import json
|
||
|
||
# Try to import numpy, but handle if missing
|
||
try:
|
||
import numpy as np
|
||
NUMPY_SUPPORT = True
|
||
except ImportError:
|
||
print("NumPy not available, some embedding features may be limited")
|
||
NUMPY_SUPPORT = False
|
||
|
||
|
||
def merge_documents_by_group(unique_id: str, group_name: str) -> Dict:
|
||
"""Merge all document.txt files in a group into a single document."""
|
||
|
||
processed_group_dir = os.path.join("projects", "data", unique_id, "processed", group_name)
|
||
dataset_group_dir = os.path.join("projects", "data", unique_id, "dataset", group_name)
|
||
os.makedirs(dataset_group_dir, exist_ok=True)
|
||
|
||
merged_document_path = os.path.join(dataset_group_dir, "document.txt")
|
||
|
||
result = {
|
||
"success": False,
|
||
"merged_document_path": merged_document_path,
|
||
"source_files": [],
|
||
"total_pages": 0,
|
||
"total_characters": 0,
|
||
"error": None
|
||
}
|
||
|
||
try:
|
||
# Find all document.txt files in the processed directory
|
||
document_files = []
|
||
if os.path.exists(processed_group_dir):
|
||
for item in os.listdir(processed_group_dir):
|
||
item_path = os.path.join(processed_group_dir, item)
|
||
if os.path.isdir(item_path):
|
||
document_path = os.path.join(item_path, "document.txt")
|
||
if os.path.exists(document_path) and os.path.getsize(document_path) > 0:
|
||
document_files.append((item, document_path))
|
||
|
||
if not document_files:
|
||
result["error"] = "No document files found to merge"
|
||
return result
|
||
|
||
# Merge all documents with page separators
|
||
merged_content = []
|
||
total_characters = 0
|
||
|
||
for filename_stem, document_path in sorted(document_files):
|
||
try:
|
||
with open(document_path, 'r', encoding='utf-8') as f:
|
||
content = f.read().strip()
|
||
|
||
if content:
|
||
merged_content.append(f"# Page {filename_stem}")
|
||
merged_content.append(content)
|
||
total_characters += len(content)
|
||
result["source_files"].append(filename_stem)
|
||
|
||
except Exception as e:
|
||
print(f"Error reading document file {document_path}: {str(e)}")
|
||
continue
|
||
|
||
if merged_content:
|
||
# Write merged document
|
||
with open(merged_document_path, 'w', encoding='utf-8') as f:
|
||
f.write('\n\n'.join(merged_content))
|
||
|
||
result["total_pages"] = len(document_files)
|
||
result["total_characters"] = total_characters
|
||
result["success"] = True
|
||
|
||
else:
|
||
result["error"] = "No valid content found in document files"
|
||
|
||
except Exception as e:
|
||
result["error"] = f"Document merging failed: {str(e)}"
|
||
print(f"Error merging documents for group {group_name}: {str(e)}")
|
||
|
||
return result
|
||
|
||
|
||
def merge_paginations_by_group(unique_id: str, group_name: str) -> Dict:
|
||
"""Merge all pagination.txt files in a group."""
|
||
|
||
processed_group_dir = os.path.join("projects", "data", unique_id, "processed", group_name)
|
||
dataset_group_dir = os.path.join("projects", "data", unique_id, "dataset", group_name)
|
||
os.makedirs(dataset_group_dir, exist_ok=True)
|
||
|
||
merged_pagination_path = os.path.join(dataset_group_dir, "pagination.txt")
|
||
|
||
result = {
|
||
"success": False,
|
||
"merged_pagination_path": merged_pagination_path,
|
||
"source_files": [],
|
||
"total_lines": 0,
|
||
"error": None
|
||
}
|
||
|
||
try:
|
||
# Find all pagination.txt files
|
||
pagination_files = []
|
||
if os.path.exists(processed_group_dir):
|
||
for item in os.listdir(processed_group_dir):
|
||
item_path = os.path.join(processed_group_dir, item)
|
||
if os.path.isdir(item_path):
|
||
pagination_path = os.path.join(item_path, "pagination.txt")
|
||
if os.path.exists(pagination_path) and os.path.getsize(pagination_path) > 0:
|
||
pagination_files.append((item, pagination_path))
|
||
|
||
if not pagination_files:
|
||
result["error"] = "No pagination files found to merge"
|
||
return result
|
||
|
||
# Merge all pagination files
|
||
merged_lines = []
|
||
|
||
for filename_stem, pagination_path in sorted(pagination_files):
|
||
try:
|
||
with open(pagination_path, 'r', encoding='utf-8') as f:
|
||
lines = f.readlines()
|
||
|
||
for line in lines:
|
||
line = line.strip()
|
||
if line:
|
||
merged_lines.append(line)
|
||
|
||
result["source_files"].append(filename_stem)
|
||
|
||
except Exception as e:
|
||
print(f"Error reading pagination file {pagination_path}: {str(e)}")
|
||
continue
|
||
|
||
if merged_lines:
|
||
# Write merged pagination
|
||
with open(merged_pagination_path, 'w', encoding='utf-8') as f:
|
||
for line in merged_lines:
|
||
f.write(f"{line}\n")
|
||
|
||
result["total_lines"] = len(merged_lines)
|
||
result["success"] = True
|
||
|
||
else:
|
||
result["error"] = "No valid pagination data found"
|
||
|
||
except Exception as e:
|
||
result["error"] = f"Pagination merging failed: {str(e)}"
|
||
print(f"Error merging paginations for group {group_name}: {str(e)}")
|
||
|
||
return result
|
||
|
||
|
||
def merge_embeddings_by_group(unique_id: str, group_name: str) -> Dict:
|
||
"""Merge all embedding.pkl files in a group."""
|
||
|
||
processed_group_dir = os.path.join("projects", "data", unique_id, "processed", group_name)
|
||
dataset_group_dir = os.path.join("projects", "data", unique_id, "dataset", group_name)
|
||
os.makedirs(dataset_group_dir, exist_ok=True)
|
||
|
||
merged_embedding_path = os.path.join(dataset_group_dir, "embedding.pkl")
|
||
|
||
result = {
|
||
"success": False,
|
||
"merged_embedding_path": merged_embedding_path,
|
||
"source_files": [],
|
||
"total_chunks": 0,
|
||
"total_dimensions": 0,
|
||
"error": None
|
||
}
|
||
|
||
try:
|
||
# Find all embedding.pkl files
|
||
embedding_files = []
|
||
if os.path.exists(processed_group_dir):
|
||
for item in os.listdir(processed_group_dir):
|
||
item_path = os.path.join(processed_group_dir, item)
|
||
if os.path.isdir(item_path):
|
||
embedding_path = os.path.join(item_path, "embedding.pkl")
|
||
if os.path.exists(embedding_path) and os.path.getsize(embedding_path) > 0:
|
||
embedding_files.append((item, embedding_path))
|
||
|
||
if not embedding_files:
|
||
result["error"] = "No embedding files found to merge"
|
||
return result
|
||
|
||
# Load and merge all embedding data
|
||
all_chunks = []
|
||
all_embeddings = [] # 修复:收集所有embeddings向量
|
||
total_chunks = 0
|
||
dimensions = 0
|
||
chunking_strategy = 'unknown'
|
||
chunking_params = {}
|
||
model_path = 'TaylorAI/gte-tiny'
|
||
|
||
for filename_stem, embedding_path in sorted(embedding_files):
|
||
try:
|
||
with open(embedding_path, 'rb') as f:
|
||
embedding_data = pickle.load(f)
|
||
|
||
if isinstance(embedding_data, dict) and 'chunks' in embedding_data:
|
||
chunks = embedding_data['chunks']
|
||
|
||
# 获取embeddings向量(关键修复)
|
||
if 'embeddings' in embedding_data:
|
||
embeddings = embedding_data['embeddings']
|
||
all_embeddings.append(embeddings)
|
||
|
||
# 从第一个文件获取模型信息
|
||
if 'model_path' in embedding_data:
|
||
model_path = embedding_data['model_path']
|
||
if 'chunking_strategy' in embedding_data:
|
||
chunking_strategy = embedding_data['chunking_strategy']
|
||
if 'chunking_params' in embedding_data:
|
||
chunking_params = embedding_data['chunking_params']
|
||
|
||
# Add source file metadata to each chunk
|
||
for chunk in chunks:
|
||
if isinstance(chunk, dict):
|
||
chunk['source_file'] = filename_stem
|
||
chunk['source_group'] = group_name
|
||
elif isinstance(chunk, str):
|
||
# 如果chunk是字符串,保持原样
|
||
pass
|
||
|
||
all_chunks.extend(chunks)
|
||
total_chunks += len(chunks)
|
||
|
||
result["source_files"].append(filename_stem)
|
||
|
||
except Exception as e:
|
||
print(f"Error loading embedding file {embedding_path}: {str(e)}")
|
||
continue
|
||
|
||
if all_chunks and all_embeddings:
|
||
# 合并所有embeddings向量
|
||
try:
|
||
# 尝试使用torch合并张量
|
||
import torch
|
||
if all(isinstance(emb, torch.Tensor) for emb in all_embeddings):
|
||
merged_embeddings = torch.cat(all_embeddings, dim=0)
|
||
dimensions = merged_embeddings.shape[1]
|
||
else:
|
||
# 如果不是tensor类型,尝试转换为numpy
|
||
import numpy as np
|
||
if NUMPY_SUPPORT:
|
||
np_embeddings = []
|
||
for emb in all_embeddings:
|
||
if hasattr(emb, 'numpy'):
|
||
np_embeddings.append(emb.numpy())
|
||
elif isinstance(emb, np.ndarray):
|
||
np_embeddings.append(emb)
|
||
else:
|
||
# 如果无法转换,跳过这个文件
|
||
print(f"Warning: Cannot convert embedding to numpy from file {filename_stem}")
|
||
continue
|
||
|
||
if np_embeddings:
|
||
merged_embeddings = np.concatenate(np_embeddings, axis=0)
|
||
dimensions = merged_embeddings.shape[1]
|
||
else:
|
||
result["error"] = "No valid embedding tensors could be merged"
|
||
return result
|
||
else:
|
||
result["error"] = "NumPy not available for merging embeddings"
|
||
return result
|
||
|
||
except ImportError:
|
||
# 如果没有torch,尝试使用numpy
|
||
if NUMPY_SUPPORT:
|
||
import numpy as np
|
||
np_embeddings = []
|
||
for emb in all_embeddings:
|
||
if hasattr(emb, 'numpy'):
|
||
np_embeddings.append(emb.numpy())
|
||
elif isinstance(emb, np.ndarray):
|
||
np_embeddings.append(emb)
|
||
else:
|
||
print(f"Warning: Cannot convert embedding to numpy from file {filename_stem}")
|
||
continue
|
||
|
||
if np_embeddings:
|
||
merged_embeddings = np.concatenate(np_embeddings, axis=0)
|
||
dimensions = merged_embeddings.shape[1]
|
||
else:
|
||
result["error"] = "No valid embedding tensors could be merged"
|
||
return result
|
||
else:
|
||
result["error"] = "Neither torch nor numpy available for merging embeddings"
|
||
return result
|
||
except Exception as e:
|
||
result["error"] = f"Failed to merge embedding tensors: {str(e)}"
|
||
print(f"Error merging embedding tensors: {str(e)}")
|
||
return result
|
||
|
||
# Create merged embedding data structure
|
||
merged_embedding_data = {
|
||
'chunks': all_chunks,
|
||
'embeddings': merged_embeddings, # 关键修复:添加embeddings键
|
||
'total_chunks': total_chunks,
|
||
'dimensions': dimensions,
|
||
'source_files': result["source_files"],
|
||
'group_name': group_name,
|
||
'merged_at': str(__import__('time').time()),
|
||
'chunking_strategy': chunking_strategy,
|
||
'chunking_params': chunking_params,
|
||
'model_path': model_path
|
||
}
|
||
|
||
# Save merged embeddings
|
||
with open(merged_embedding_path, 'wb') as f:
|
||
pickle.dump(merged_embedding_data, f)
|
||
|
||
result["total_chunks"] = total_chunks
|
||
result["total_dimensions"] = dimensions
|
||
result["success"] = True
|
||
|
||
else:
|
||
result["error"] = "No valid embedding data found"
|
||
|
||
except Exception as e:
|
||
result["error"] = f"Embedding merging failed: {str(e)}"
|
||
print(f"Error merging embeddings for group {group_name}: {str(e)}")
|
||
|
||
return result
|
||
|
||
|
||
def merge_all_data_by_group(unique_id: str, group_name: str) -> Dict:
|
||
"""Merge documents, paginations, and embeddings for a group."""
|
||
|
||
merge_results = {
|
||
"group_name": group_name,
|
||
"unique_id": unique_id,
|
||
"success": True,
|
||
"document_merge": None,
|
||
"pagination_merge": None,
|
||
"embedding_merge": None,
|
||
"errors": []
|
||
}
|
||
|
||
# Merge documents
|
||
document_result = merge_documents_by_group(unique_id, group_name)
|
||
merge_results["document_merge"] = document_result
|
||
|
||
if not document_result["success"]:
|
||
merge_results["success"] = False
|
||
merge_results["errors"].append(f"Document merge failed: {document_result['error']}")
|
||
|
||
# Merge paginations
|
||
pagination_result = merge_paginations_by_group(unique_id, group_name)
|
||
merge_results["pagination_merge"] = pagination_result
|
||
|
||
if not pagination_result["success"]:
|
||
merge_results["success"] = False
|
||
merge_results["errors"].append(f"Pagination merge failed: {pagination_result['error']}")
|
||
|
||
# Merge embeddings
|
||
embedding_result = merge_embeddings_by_group(unique_id, group_name)
|
||
merge_results["embedding_merge"] = embedding_result
|
||
|
||
if not embedding_result["success"]:
|
||
merge_results["success"] = False
|
||
merge_results["errors"].append(f"Embedding merge failed: {embedding_result['error']}")
|
||
|
||
return merge_results
|
||
|
||
|
||
def get_group_merge_status(unique_id: str, group_name: str) -> Dict:
|
||
"""Get the status of merged data for a group."""
|
||
|
||
dataset_group_dir = os.path.join("projects", "data", unique_id, "dataset", group_name)
|
||
|
||
status = {
|
||
"group_name": group_name,
|
||
"unique_id": unique_id,
|
||
"dataset_dir_exists": os.path.exists(dataset_group_dir),
|
||
"document_exists": False,
|
||
"document_size": 0,
|
||
"pagination_exists": False,
|
||
"pagination_size": 0,
|
||
"embedding_exists": False,
|
||
"embedding_size": 0,
|
||
"merge_complete": False
|
||
}
|
||
|
||
if os.path.exists(dataset_group_dir):
|
||
document_path = os.path.join(dataset_group_dir, "document.txt")
|
||
pagination_path = os.path.join(dataset_group_dir, "pagination.txt")
|
||
embedding_path = os.path.join(dataset_group_dir, "embedding.pkl")
|
||
|
||
if os.path.exists(document_path):
|
||
status["document_exists"] = True
|
||
status["document_size"] = os.path.getsize(document_path)
|
||
|
||
if os.path.exists(pagination_path):
|
||
status["pagination_exists"] = True
|
||
status["pagination_size"] = os.path.getsize(pagination_path)
|
||
|
||
if os.path.exists(embedding_path):
|
||
status["embedding_exists"] = True
|
||
status["embedding_size"] = os.path.getsize(embedding_path)
|
||
|
||
# Check if all files exist and are not empty
|
||
if (status["document_exists"] and status["document_size"] > 0 and
|
||
status["pagination_exists"] and status["pagination_size"] > 0 and
|
||
status["embedding_exists"] and status["embedding_size"] > 0):
|
||
status["merge_complete"] = True
|
||
|
||
return status
|
||
|
||
|
||
def cleanup_dataset_group(unique_id: str, group_name: str) -> bool:
|
||
"""Clean up merged dataset files for a group."""
|
||
|
||
dataset_group_dir = os.path.join("projects", "data", unique_id, "dataset", group_name)
|
||
|
||
try:
|
||
if os.path.exists(dataset_group_dir):
|
||
import shutil
|
||
shutil.rmtree(dataset_group_dir)
|
||
print(f"Cleaned up dataset group: {group_name}")
|
||
return True
|
||
else:
|
||
return True # Nothing to clean up
|
||
|
||
except Exception as e:
|
||
print(f"Error cleaning up dataset group {group_name}: {str(e)}")
|
||
return False
|