qwen_agent/utils/data_merger.py
2025-11-27 21:50:03 +08:00

440 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Data merging functions for combining processed file results.
"""
import os
import pickle
import logging
from typing import Dict, List, Optional, Tuple
import json
# Configure logger
logger = logging.getLogger('app')
# Try to import numpy, but handle if missing
try:
import numpy as np
NUMPY_SUPPORT = True
except ImportError:
logger.warning("NumPy not available, some embedding features may be limited")
NUMPY_SUPPORT = False
def merge_documents_by_group(unique_id: str, group_name: str) -> Dict:
"""Merge all document.txt files in a group into a single document."""
processed_group_dir = os.path.join("projects", "data", unique_id, "processed", group_name)
dataset_group_dir = os.path.join("projects", "data", unique_id, "dataset", group_name)
os.makedirs(dataset_group_dir, exist_ok=True)
merged_document_path = os.path.join(dataset_group_dir, "document.txt")
result = {
"success": False,
"merged_document_path": merged_document_path,
"source_files": [],
"total_pages": 0,
"total_characters": 0,
"error": None
}
try:
# Find all document.txt files in the processed directory
document_files = []
if os.path.exists(processed_group_dir):
for item in os.listdir(processed_group_dir):
item_path = os.path.join(processed_group_dir, item)
if os.path.isdir(item_path):
document_path = os.path.join(item_path, "document.txt")
if os.path.exists(document_path) and os.path.getsize(document_path) > 0:
document_files.append((item, document_path))
if not document_files:
result["error"] = "No document files found to merge"
return result
# Merge all documents with page separators
merged_content = []
total_characters = 0
for filename_stem, document_path in sorted(document_files):
try:
with open(document_path, 'r', encoding='utf-8') as f:
content = f.read().strip()
if content:
merged_content.append(f"# Page {filename_stem}")
merged_content.append(content)
total_characters += len(content)
result["source_files"].append(filename_stem)
except Exception as e:
logger.error(f"Error reading document file {document_path}: {str(e)}")
continue
if merged_content:
# Write merged document
with open(merged_document_path, 'w', encoding='utf-8') as f:
f.write('\n\n'.join(merged_content))
result["total_pages"] = len(document_files)
result["total_characters"] = total_characters
result["success"] = True
else:
result["error"] = "No valid content found in document files"
except Exception as e:
result["error"] = f"Document merging failed: {str(e)}"
logger.error(f"Error merging documents for group {group_name}: {str(e)}")
return result
def merge_paginations_by_group(unique_id: str, group_name: str) -> Dict:
"""Merge all pagination.txt files in a group."""
processed_group_dir = os.path.join("projects", "data", unique_id, "processed", group_name)
dataset_group_dir = os.path.join("projects", "data", unique_id, "dataset", group_name)
os.makedirs(dataset_group_dir, exist_ok=True)
merged_pagination_path = os.path.join(dataset_group_dir, "pagination.txt")
result = {
"success": False,
"merged_pagination_path": merged_pagination_path,
"source_files": [],
"total_lines": 0,
"error": None
}
try:
# Find all pagination.txt files
pagination_files = []
if os.path.exists(processed_group_dir):
for item in os.listdir(processed_group_dir):
item_path = os.path.join(processed_group_dir, item)
if os.path.isdir(item_path):
pagination_path = os.path.join(item_path, "pagination.txt")
if os.path.exists(pagination_path) and os.path.getsize(pagination_path) > 0:
pagination_files.append((item, pagination_path))
if not pagination_files:
result["error"] = "No pagination files found to merge"
return result
# Merge all pagination files
merged_lines = []
for filename_stem, pagination_path in sorted(pagination_files):
try:
with open(pagination_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if line:
merged_lines.append(line)
result["source_files"].append(filename_stem)
except Exception as e:
logger.error(f"Error reading pagination file {pagination_path}: {str(e)}")
continue
if merged_lines:
# Write merged pagination
with open(merged_pagination_path, 'w', encoding='utf-8') as f:
for line in merged_lines:
f.write(f"{line}\n")
result["total_lines"] = len(merged_lines)
result["success"] = True
else:
result["error"] = "No valid pagination data found"
except Exception as e:
result["error"] = f"Pagination merging failed: {str(e)}"
logger.error(f"Error merging paginations for group {group_name}: {str(e)}")
return result
def merge_embeddings_by_group(unique_id: str, group_name: str) -> Dict:
"""Merge all embedding.pkl files in a group."""
processed_group_dir = os.path.join("projects", "data", unique_id, "processed", group_name)
dataset_group_dir = os.path.join("projects", "data", unique_id, "dataset", group_name)
os.makedirs(dataset_group_dir, exist_ok=True)
merged_embedding_path = os.path.join(dataset_group_dir, "embedding.pkl")
result = {
"success": False,
"merged_embedding_path": merged_embedding_path,
"source_files": [],
"total_chunks": 0,
"total_dimensions": 0,
"error": None
}
try:
# Find all embedding.pkl files
embedding_files = []
if os.path.exists(processed_group_dir):
for item in os.listdir(processed_group_dir):
item_path = os.path.join(processed_group_dir, item)
if os.path.isdir(item_path):
embedding_path = os.path.join(item_path, "embedding.pkl")
if os.path.exists(embedding_path) and os.path.getsize(embedding_path) > 0:
embedding_files.append((item, embedding_path))
if not embedding_files:
result["error"] = "No embedding files found to merge"
return result
# Load and merge all embedding data
all_chunks = []
all_embeddings = [] # 修复收集所有embeddings向量
total_chunks = 0
dimensions = 0
chunking_strategy = 'unknown'
chunking_params = {}
model_path = 'TaylorAI/gte-tiny'
for filename_stem, embedding_path in sorted(embedding_files):
try:
with open(embedding_path, 'rb') as f:
embedding_data = pickle.load(f)
if isinstance(embedding_data, dict) and 'chunks' in embedding_data:
chunks = embedding_data['chunks']
# 获取embeddings向量关键修复
if 'embeddings' in embedding_data:
embeddings = embedding_data['embeddings']
all_embeddings.append(embeddings)
# 从第一个文件获取模型信息
if 'model_path' in embedding_data:
model_path = embedding_data['model_path']
if 'chunking_strategy' in embedding_data:
chunking_strategy = embedding_data['chunking_strategy']
if 'chunking_params' in embedding_data:
chunking_params = embedding_data['chunking_params']
# Add source file metadata to each chunk
for chunk in chunks:
if isinstance(chunk, dict):
chunk['source_file'] = filename_stem
chunk['source_group'] = group_name
elif isinstance(chunk, str):
# 如果chunk是字符串保持原样
pass
all_chunks.extend(chunks)
total_chunks += len(chunks)
result["source_files"].append(filename_stem)
except Exception as e:
logger.error(f"Error loading embedding file {embedding_path}: {str(e)}")
continue
if all_chunks and all_embeddings:
# 合并所有embeddings向量
try:
# 尝试使用torch合并张量
import torch
if all(isinstance(emb, torch.Tensor) for emb in all_embeddings):
merged_embeddings = torch.cat(all_embeddings, dim=0)
dimensions = merged_embeddings.shape[1]
else:
# 如果不是tensor类型尝试转换为numpy
import numpy as np
if NUMPY_SUPPORT:
np_embeddings = []
for emb in all_embeddings:
if hasattr(emb, 'numpy'):
np_embeddings.append(emb.numpy())
elif isinstance(emb, np.ndarray):
np_embeddings.append(emb)
else:
# 如果无法转换,跳过这个文件
logger.warning(f"Warning: Cannot convert embedding to numpy from file {filename_stem}")
continue
if np_embeddings:
merged_embeddings = np.concatenate(np_embeddings, axis=0)
dimensions = merged_embeddings.shape[1]
else:
result["error"] = "No valid embedding tensors could be merged"
return result
else:
result["error"] = "NumPy not available for merging embeddings"
return result
except ImportError:
# 如果没有torch尝试使用numpy
if NUMPY_SUPPORT:
import numpy as np
np_embeddings = []
for emb in all_embeddings:
if hasattr(emb, 'numpy'):
np_embeddings.append(emb.numpy())
elif isinstance(emb, np.ndarray):
np_embeddings.append(emb)
else:
logger.warning(f"Warning: Cannot convert embedding to numpy from file {filename_stem}")
continue
if np_embeddings:
merged_embeddings = np.concatenate(np_embeddings, axis=0)
dimensions = merged_embeddings.shape[1]
else:
result["error"] = "No valid embedding tensors could be merged"
return result
else:
result["error"] = "Neither torch nor numpy available for merging embeddings"
return result
except Exception as e:
result["error"] = f"Failed to merge embedding tensors: {str(e)}"
logger.error(f"Error merging embedding tensors: {str(e)}")
return result
# Create merged embedding data structure
merged_embedding_data = {
'chunks': all_chunks,
'embeddings': merged_embeddings, # 关键修复添加embeddings键
'total_chunks': total_chunks,
'dimensions': dimensions,
'source_files': result["source_files"],
'group_name': group_name,
'merged_at': str(__import__('time').time()),
'chunking_strategy': chunking_strategy,
'chunking_params': chunking_params,
'model_path': model_path
}
# Save merged embeddings
with open(merged_embedding_path, 'wb') as f:
pickle.dump(merged_embedding_data, f)
result["total_chunks"] = total_chunks
result["total_dimensions"] = dimensions
result["success"] = True
else:
result["error"] = "No valid embedding data found"
except Exception as e:
result["error"] = f"Embedding merging failed: {str(e)}"
logger.error(f"Error merging embeddings for group {group_name}: {str(e)}")
return result
def merge_all_data_by_group(unique_id: str, group_name: str) -> Dict:
"""Merge documents, paginations, and embeddings for a group."""
merge_results = {
"group_name": group_name,
"unique_id": unique_id,
"success": True,
"document_merge": None,
"pagination_merge": None,
"embedding_merge": None,
"errors": []
}
# Merge documents
document_result = merge_documents_by_group(unique_id, group_name)
merge_results["document_merge"] = document_result
if not document_result["success"]:
merge_results["success"] = False
merge_results["errors"].append(f"Document merge failed: {document_result['error']}")
# Merge paginations
pagination_result = merge_paginations_by_group(unique_id, group_name)
merge_results["pagination_merge"] = pagination_result
if not pagination_result["success"]:
merge_results["success"] = False
merge_results["errors"].append(f"Pagination merge failed: {pagination_result['error']}")
# Merge embeddings
embedding_result = merge_embeddings_by_group(unique_id, group_name)
merge_results["embedding_merge"] = embedding_result
if not embedding_result["success"]:
merge_results["success"] = False
merge_results["errors"].append(f"Embedding merge failed: {embedding_result['error']}")
return merge_results
def get_group_merge_status(unique_id: str, group_name: str) -> Dict:
"""Get the status of merged data for a group."""
dataset_group_dir = os.path.join("projects", "data", unique_id, "dataset", group_name)
status = {
"group_name": group_name,
"unique_id": unique_id,
"dataset_dir_exists": os.path.exists(dataset_group_dir),
"document_exists": False,
"document_size": 0,
"pagination_exists": False,
"pagination_size": 0,
"embedding_exists": False,
"embedding_size": 0,
"merge_complete": False
}
if os.path.exists(dataset_group_dir):
document_path = os.path.join(dataset_group_dir, "document.txt")
pagination_path = os.path.join(dataset_group_dir, "pagination.txt")
embedding_path = os.path.join(dataset_group_dir, "embedding.pkl")
if os.path.exists(document_path):
status["document_exists"] = True
status["document_size"] = os.path.getsize(document_path)
if os.path.exists(pagination_path):
status["pagination_exists"] = True
status["pagination_size"] = os.path.getsize(pagination_path)
if os.path.exists(embedding_path):
status["embedding_exists"] = True
status["embedding_size"] = os.path.getsize(embedding_path)
# Check if all files exist and are not empty
if (status["document_exists"] and status["document_size"] > 0 and
status["pagination_exists"] and status["pagination_size"] > 0 and
status["embedding_exists"] and status["embedding_size"] > 0):
status["merge_complete"] = True
return status
def cleanup_dataset_group(unique_id: str, group_name: str) -> bool:
"""Clean up merged dataset files for a group."""
dataset_group_dir = os.path.join("projects", "data", unique_id, "dataset", group_name)
try:
if os.path.exists(dataset_group_dir):
import shutil
shutil.rmtree(dataset_group_dir)
logger.info(f"Cleaned up dataset group: {group_name}")
return True
else:
return True # Nothing to clean up
except Exception as e:
logger.error(f"Error cleaning up dataset group {group_name}: {str(e)}")
return False