新增folder功能,删除unique_id
This commit is contained in:
parent
d00601af23
commit
09690a101e
@ -224,7 +224,10 @@ async def process_files_async_endpoint(request: QueueTaskRequest, authorization:
|
||||
|
||||
# 估算处理时间(基于文件数量)
|
||||
estimated_time = 0
|
||||
if request.files:
|
||||
if request.upload_folder:
|
||||
# 对于upload_folder,无法预先估算文件数量,使用默认时间
|
||||
estimated_time = 120 # 默认2分钟
|
||||
elif request.files:
|
||||
total_files = sum(len(file_list) for file_list in request.files.values())
|
||||
estimated_time = max(30, total_files * 10) # 每个文件预估10秒,最少30秒
|
||||
|
||||
@ -246,17 +249,25 @@ async def process_files_async_endpoint(request: QueueTaskRequest, authorization:
|
||||
|
||||
# 提交异步任务
|
||||
task = process_files_async(
|
||||
unique_id=dataset_id,
|
||||
dataset_id=dataset_id,
|
||||
files=request.files,
|
||||
system_prompt=request.system_prompt,
|
||||
mcp_settings=request.mcp_settings,
|
||||
upload_folder=request.upload_folder,
|
||||
task_id=task_id
|
||||
)
|
||||
|
||||
# 构建更详细的消息
|
||||
message = f"文件处理任务已提交到队列,项目ID: {dataset_id}"
|
||||
if request.upload_folder:
|
||||
group_count = len(request.upload_folder)
|
||||
message += f",将从 {group_count} 个上传文件夹自动扫描文件"
|
||||
elif request.files:
|
||||
total_files = sum(len(file_list) for file_list in request.files.values())
|
||||
message += f",包含 {total_files} 个文件"
|
||||
|
||||
return QueueTaskResponse(
|
||||
success=True,
|
||||
message=f"文件处理任务已提交到队列,项目ID: {dataset_id}",
|
||||
unique_id=dataset_id,
|
||||
message=message,
|
||||
dataset_id=dataset_id,
|
||||
task_id=task_id, # 使用我们自己的task_id
|
||||
task_status="pending",
|
||||
estimated_processing_time=estimated_time
|
||||
@ -319,7 +330,7 @@ async def process_files_incremental_endpoint(request: IncrementalTaskRequest, au
|
||||
return QueueTaskResponse(
|
||||
success=True,
|
||||
message=f"增量文件处理任务已提交到队列 - 添加 {total_add_files} 个文件,删除 {total_remove_files} 个文件,项目ID: {dataset_id}",
|
||||
unique_id=dataset_id,
|
||||
dataset_id=dataset_id,
|
||||
task_id=task_id,
|
||||
task_status="pending",
|
||||
estimated_processing_time=estimated_time
|
||||
@ -618,7 +629,7 @@ async def get_project_tasks(dataset_id: str):
|
||||
async def cleanup_project_async_endpoint(dataset_id: str, remove_all: bool = False):
|
||||
"""异步清理项目文件"""
|
||||
try:
|
||||
task = cleanup_project_async(unique_id=dataset_id, remove_all=remove_all)
|
||||
task = cleanup_project_async(dataset_id=dataset_id, remove_all=remove_all)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@ -976,19 +987,33 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
|
||||
|
||||
|
||||
@app.post("/api/v1/upload")
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
async def upload_file(file: UploadFile = File(...), folder: Optional[str] = None):
|
||||
"""
|
||||
文件上传API接口,上传文件到 ./projects/uploads 目录
|
||||
文件上传API接口,上传文件到 ./projects/uploads/ 目录下
|
||||
|
||||
可以指定自定义文件夹名,如果不指定则使用日期文件夹
|
||||
|
||||
Args:
|
||||
file: 上传的文件
|
||||
folder: 可选的自定义文件夹名
|
||||
|
||||
Returns:
|
||||
dict: 包含文件路径和文件名的响应
|
||||
dict: 包含文件路径和文件夹信息的响应
|
||||
"""
|
||||
try:
|
||||
# 确保上传目录存在
|
||||
upload_dir = os.path.join("projects", "uploads")
|
||||
# 确定上传文件夹
|
||||
if folder:
|
||||
# 使用指定的自定义文件夹
|
||||
target_folder = folder
|
||||
# 安全性检查:防止路径遍历攻击
|
||||
target_folder = os.path.basename(target_folder)
|
||||
else:
|
||||
# 获取当前日期并格式化为年月日
|
||||
current_date = datetime.now()
|
||||
target_folder = current_date.strftime("%Y%m%d")
|
||||
|
||||
# 创建上传目录
|
||||
upload_dir = os.path.join("projects", "uploads", target_folder)
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
|
||||
# 生成唯一文件名
|
||||
@ -1003,9 +1028,8 @@ async def upload_file(file: UploadFile = File(...)):
|
||||
return {
|
||||
"success": True,
|
||||
"message": "文件上传成功",
|
||||
"filename": unique_filename,
|
||||
"original_filename": file.filename,
|
||||
"file_path": file_path
|
||||
"file_path": file_path,
|
||||
"folder": target_folder
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -6,6 +6,8 @@ Queue tasks for file processing integration.
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
import shutil
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from task_queue.config import huey
|
||||
@ -15,45 +17,134 @@ from utils import download_dataset_files, save_processed_files_log, load_process
|
||||
from utils.dataset_manager import remove_dataset_directory_by_key
|
||||
|
||||
|
||||
def scan_upload_folder(upload_dir: str) -> List[str]:
|
||||
"""
|
||||
扫描上传文件夹中的所有支持格式的文件
|
||||
|
||||
Args:
|
||||
upload_dir: 上传文件夹路径
|
||||
|
||||
Returns:
|
||||
List[str]: 支持的文件路径列表
|
||||
"""
|
||||
supported_extensions = {
|
||||
# 文本文件
|
||||
'.txt', '.md', '.rtf',
|
||||
# 文档文件
|
||||
'.doc', '.docx', '.pdf', '.odt',
|
||||
# 表格文件
|
||||
'.xls', '.xlsx', '.csv', '.ods',
|
||||
# 演示文件
|
||||
'.ppt', '.pptx', '.odp',
|
||||
# 电子书
|
||||
'.epub', '.mobi',
|
||||
# 网页文件
|
||||
'.html', '.htm',
|
||||
# 配置文件
|
||||
'.json', '.xml', '.yaml', '.yml',
|
||||
# 代码文件
|
||||
'.py', '.js', '.java', '.cpp', '.c', '.go', '.rs',
|
||||
# 压缩文件
|
||||
'.zip', '.rar', '.7z', '.tar', '.gz'
|
||||
}
|
||||
|
||||
scanned_files = []
|
||||
|
||||
if not os.path.exists(upload_dir):
|
||||
return scanned_files
|
||||
|
||||
for root, dirs, files in os.walk(upload_dir):
|
||||
for file in files:
|
||||
# 跳过隐藏文件和系统文件
|
||||
if file.startswith('.') or file.startswith('~'):
|
||||
continue
|
||||
|
||||
file_path = os.path.join(root, file)
|
||||
file_extension = os.path.splitext(file)[1].lower()
|
||||
|
||||
# 检查文件扩展名是否支持
|
||||
if file_extension in supported_extensions:
|
||||
scanned_files.append(file_path)
|
||||
else:
|
||||
# 对于没有扩展名的文件,也尝试处理(可能是文本文件)
|
||||
if not file_extension:
|
||||
try:
|
||||
# 尝试读取文件头部来判断是否为文本文件
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
f.read(1024) # 读取前1KB
|
||||
scanned_files.append(file_path)
|
||||
except (UnicodeDecodeError, PermissionError):
|
||||
# 不是文本文件或无法读取,跳过
|
||||
pass
|
||||
|
||||
return scanned_files
|
||||
|
||||
|
||||
@huey.task()
|
||||
def process_files_async(
|
||||
unique_id: str,
|
||||
dataset_id: str,
|
||||
files: Optional[Dict[str, List[str]]] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
mcp_settings: Optional[List[Dict]] = None,
|
||||
upload_folder: Optional[Dict[str, str]] = None,
|
||||
task_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
异步处理文件任务 - 与现有files/process API兼容
|
||||
|
||||
Args:
|
||||
unique_id: 项目唯一ID
|
||||
dataset_id: 项目唯一ID
|
||||
files: 按key分组的文件路径字典
|
||||
system_prompt: 系统提示词
|
||||
mcp_settings: MCP设置
|
||||
upload_folder: 上传文件夹字典,按组名组织文件夹,例如 {'group1': 'my_project1', 'group2': 'my_project2'}
|
||||
task_id: 任务ID(用于状态跟踪)
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
try:
|
||||
print(f"开始异步处理文件任务,项目ID: {unique_id}")
|
||||
print(f"开始异步处理文件任务,项目ID: {dataset_id}")
|
||||
|
||||
# 如果有task_id,设置初始状态
|
||||
if task_id:
|
||||
task_status_store.set_status(
|
||||
task_id=task_id,
|
||||
unique_id=unique_id,
|
||||
unique_id=dataset_id,
|
||||
status="running"
|
||||
)
|
||||
|
||||
# 确保项目目录存在
|
||||
project_dir = os.path.join("projects", "data", unique_id)
|
||||
project_dir = os.path.join("projects", "data", dataset_id)
|
||||
if not os.path.exists(project_dir):
|
||||
os.makedirs(project_dir, exist_ok=True)
|
||||
|
||||
# 处理文件:使用按key分组格式
|
||||
processed_files_by_key = {}
|
||||
|
||||
# 如果提供了upload_folder,扫描这些文件夹中的文件
|
||||
if upload_folder and not files:
|
||||
scanned_files_by_group = {}
|
||||
total_scanned_files = 0
|
||||
|
||||
for group_name, folder_name in upload_folder.items():
|
||||
# 安全性检查:防止路径遍历攻击
|
||||
safe_folder_name = os.path.basename(folder_name)
|
||||
upload_dir = os.path.join("projects", "uploads", safe_folder_name)
|
||||
|
||||
if os.path.exists(upload_dir):
|
||||
scanned_files = scan_upload_folder(upload_dir)
|
||||
if scanned_files:
|
||||
scanned_files_by_group[group_name] = scanned_files
|
||||
total_scanned_files += len(scanned_files)
|
||||
print(f"从上传文件夹 '{safe_folder_name}' (组: {group_name}) 扫描到 {len(scanned_files)} 个文件")
|
||||
else:
|
||||
print(f"上传文件夹 '{safe_folder_name}' (组: {group_name}) 中没有找到支持的文件")
|
||||
else:
|
||||
print(f"上传文件夹不存在: {upload_dir} (组: {group_name})")
|
||||
|
||||
if scanned_files_by_group:
|
||||
files = scanned_files_by_group
|
||||
print(f"总共从 {len(scanned_files_by_group)} 个组扫描到 {total_scanned_files} 个文件")
|
||||
else:
|
||||
print(f"所有上传文件夹中都没有找到支持的文件")
|
||||
|
||||
if files:
|
||||
# 使用请求中的文件(按key分组)
|
||||
# 由于这是异步任务,需要同步调用
|
||||
@ -64,11 +155,11 @@ def process_files_async(
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
processed_files_by_key = loop.run_until_complete(download_dataset_files(unique_id, files))
|
||||
processed_files_by_key = loop.run_until_complete(download_dataset_files(dataset_id, files))
|
||||
total_files = sum(len(files_list) for files_list in processed_files_by_key.values())
|
||||
print(f"异步处理了 {total_files} 个数据集文件,涉及 {len(processed_files_by_key)} 个key,项目ID: {unique_id}")
|
||||
print(f"异步处理了 {total_files} 个数据集文件,涉及 {len(processed_files_by_key)} 个key,项目ID: {dataset_id}")
|
||||
else:
|
||||
print(f"请求中未提供文件,项目ID: {unique_id}")
|
||||
print(f"请求中未提供文件,项目ID: {dataset_id}")
|
||||
|
||||
# 收集项目目录下所有的 document.txt 文件
|
||||
document_files = []
|
||||
@ -77,33 +168,20 @@ def process_files_async(
|
||||
if file == "document.txt":
|
||||
document_files.append(os.path.join(root, file))
|
||||
|
||||
# 保存system_prompt和mcp_settings到项目目录(如果提供)
|
||||
if system_prompt:
|
||||
system_prompt_file = os.path.join(project_dir, "system_prompt.md")
|
||||
with open(system_prompt_file, 'w', encoding='utf-8') as f:
|
||||
f.write(system_prompt)
|
||||
print(f"已保存system_prompt,项目ID: {unique_id}")
|
||||
|
||||
if mcp_settings:
|
||||
mcp_settings_file = os.path.join(project_dir, "mcp_settings.json")
|
||||
with open(mcp_settings_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(mcp_settings, f, ensure_ascii=False, indent=2)
|
||||
print(f"已保存mcp_settings,项目ID: {unique_id}")
|
||||
|
||||
# 生成项目README.md文件
|
||||
try:
|
||||
from utils.project_manager import save_project_readme
|
||||
save_project_readme(unique_id)
|
||||
print(f"已生成README.md文件,项目ID: {unique_id}")
|
||||
save_project_readme(dataset_id)
|
||||
print(f"已生成README.md文件,项目ID: {dataset_id}")
|
||||
except Exception as e:
|
||||
print(f"生成README.md失败,项目ID: {unique_id}, 错误: {str(e)}")
|
||||
print(f"生成README.md失败,项目ID: {dataset_id}, 错误: {str(e)}")
|
||||
# 不影响主要处理流程,继续执行
|
||||
|
||||
# 构建结果文件列表
|
||||
result_files = []
|
||||
for key in processed_files_by_key.keys():
|
||||
# 添加对应的dataset document.txt路径
|
||||
document_path = os.path.join("projects", "data", unique_id, "dataset", key, "document.txt")
|
||||
document_path = os.path.join("projects", "data", dataset_id, "dataset", key, "document.txt")
|
||||
if os.path.exists(document_path):
|
||||
result_files.append(document_path)
|
||||
|
||||
@ -116,7 +194,7 @@ def process_files_async(
|
||||
result = {
|
||||
"status": "success",
|
||||
"message": f"成功异步处理了 {len(result_files)} 个文档文件,涉及 {len(processed_files_by_key)} 个key",
|
||||
"unique_id": unique_id,
|
||||
"dataset_id": dataset_id,
|
||||
"processed_files": result_files,
|
||||
"processed_files_by_key": processed_files_by_key,
|
||||
"document_files": document_files,
|
||||
@ -132,7 +210,7 @@ def process_files_async(
|
||||
result=result
|
||||
)
|
||||
|
||||
print(f"异步文件处理任务完成: {unique_id}")
|
||||
print(f"异步文件处理任务完成: {dataset_id}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
@ -150,7 +228,7 @@ def process_files_async(
|
||||
return {
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
"unique_id": unique_id,
|
||||
"dataset_id": dataset_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
@ -220,8 +298,23 @@ def process_files_incremental_async(
|
||||
# 删除特定文件
|
||||
for file_path in file_list:
|
||||
print(f"删除特定文件: {key}/{file_path}")
|
||||
|
||||
# 实际删除文件
|
||||
filename = os.path.basename(file_path)
|
||||
|
||||
# 删除原始文件
|
||||
source_file = os.path.join("projects", "data", dataset_id, "files", key, filename)
|
||||
if os.path.exists(source_file):
|
||||
os.remove(source_file)
|
||||
removed_files.append(f"file:{key}/{filename}")
|
||||
|
||||
# 删除处理后的文件目录
|
||||
processed_dir = os.path.join("projects", "data", dataset_id, "processed", key, filename)
|
||||
if os.path.exists(processed_dir):
|
||||
shutil.rmtree(processed_dir)
|
||||
removed_files.append(f"processed:{key}/{filename}")
|
||||
|
||||
# 计算文件hash以在日志中查找
|
||||
import hashlib
|
||||
file_hash = hashlib.md5(file_path.encode('utf-8')).hexdigest()
|
||||
|
||||
# 从处理日志中移除
|
||||
@ -241,7 +334,7 @@ def process_files_incremental_async(
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
processed_files_by_key = loop.run_until_complete(download_dataset_files(dataset_id, files_to_add))
|
||||
processed_files_by_key = loop.run_until_complete(download_dataset_files(dataset_id, files_to_add, incremental_mode=True))
|
||||
total_added_files = sum(len(files_list) for files_list in processed_files_by_key.values())
|
||||
print(f"异步处理了 {total_added_files} 个数据集文件,涉及 {len(processed_files_by_key)} 个key,项目ID: {dataset_id}")
|
||||
|
||||
@ -347,23 +440,23 @@ def process_files_incremental_async(
|
||||
|
||||
@huey.task()
|
||||
def cleanup_project_async(
|
||||
unique_id: str,
|
||||
dataset_id: str,
|
||||
remove_all: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
异步清理项目文件
|
||||
|
||||
Args:
|
||||
unique_id: 项目唯一ID
|
||||
dataset_id: 项目唯一ID
|
||||
remove_all: 是否删除整个项目目录
|
||||
|
||||
Returns:
|
||||
清理结果字典
|
||||
"""
|
||||
try:
|
||||
print(f"开始异步清理项目,项目ID: {unique_id}")
|
||||
print(f"开始异步清理项目,项目ID: {dataset_id}")
|
||||
|
||||
project_dir = os.path.join("projects", "data", unique_id)
|
||||
project_dir = os.path.join("projects", "data", dataset_id)
|
||||
removed_items = []
|
||||
|
||||
if remove_all and os.path.exists(project_dir):
|
||||
@ -373,7 +466,7 @@ def cleanup_project_async(
|
||||
result = {
|
||||
"status": "success",
|
||||
"message": f"已删除整个项目目录: {project_dir}",
|
||||
"unique_id": unique_id,
|
||||
"dataset_id": dataset_id,
|
||||
"removed_items": removed_items,
|
||||
"action": "remove_all"
|
||||
}
|
||||
@ -386,13 +479,13 @@ def cleanup_project_async(
|
||||
|
||||
result = {
|
||||
"status": "success",
|
||||
"message": f"已清理项目处理日志,项目ID: {unique_id}",
|
||||
"unique_id": unique_id,
|
||||
"message": f"已清理项目处理日志,项目ID: {dataset_id}",
|
||||
"dataset_id": dataset_id,
|
||||
"removed_items": removed_items,
|
||||
"action": "cleanup_logs"
|
||||
}
|
||||
|
||||
print(f"异步清理任务完成: {unique_id}")
|
||||
print(f"异步清理任务完成: {dataset_id}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
@ -401,6 +494,6 @@ def cleanup_project_async(
|
||||
return {
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
"unique_id": unique_id,
|
||||
"dataset_id": dataset_id,
|
||||
"error": str(e)
|
||||
}
|
||||
@ -215,15 +215,31 @@ def create_error_response(message: str, error_type: str = "error", **kwargs) ->
|
||||
|
||||
class QueueTaskRequest(BaseModel):
|
||||
"""队列任务请求模型"""
|
||||
unique_id: str
|
||||
dataset_id: str
|
||||
files: Optional[Dict[str, List[str]]] = Field(default=None, description="Files organized by key groups. Each key maps to a list of file paths (supports zip files)")
|
||||
system_prompt: Optional[str] = None
|
||||
mcp_settings: Optional[List[Dict]] = None
|
||||
upload_folder: Optional[Dict[str, str]] = Field(default=None, description="Upload folders organized by group names. Each key maps to a folder name. Example: {'group1': 'my_project1', 'group2': 'my_project2'}")
|
||||
priority: Optional[int] = Field(default=0, description="Task priority (higher number = higher priority)")
|
||||
delay: Optional[int] = Field(default=0, description="Delay execution by N seconds")
|
||||
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
@field_validator('upload_folder', mode='before')
|
||||
@classmethod
|
||||
def validate_upload_folder(cls, v):
|
||||
"""Validate upload_folder dict format"""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, dict):
|
||||
# 验证字典格式
|
||||
for key, value in v.items():
|
||||
if not isinstance(key, str):
|
||||
raise ValueError(f"Key in upload_folder dict must be string, got {type(key)}")
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"Value in upload_folder dict must be string (folder name), got {type(value)} for key '{key}'")
|
||||
return v
|
||||
else:
|
||||
raise ValueError(f"upload_folder must be a dict with group names as keys and folder names as values, got {type(v)}")
|
||||
|
||||
@field_validator('files', mode='before')
|
||||
@classmethod
|
||||
def validate_files(cls, v):
|
||||
@ -300,7 +316,7 @@ class QueueTaskResponse(BaseModel):
|
||||
"""队列任务响应模型"""
|
||||
success: bool
|
||||
message: str
|
||||
unique_id: str
|
||||
dataset_id: str
|
||||
task_id: Optional[str] = None
|
||||
task_status: Optional[str] = None
|
||||
estimated_processing_time: Optional[int] = None # seconds
|
||||
|
||||
@ -193,8 +193,12 @@ def merge_embeddings_by_group(unique_id: str, group_name: str) -> Dict:
|
||||
|
||||
# Load and merge all embedding data
|
||||
all_chunks = []
|
||||
all_embeddings = [] # 修复:收集所有embeddings向量
|
||||
total_chunks = 0
|
||||
dimensions = 0
|
||||
chunking_strategy = 'unknown'
|
||||
chunking_params = {}
|
||||
model_path = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
|
||||
|
||||
for filename_stem, embedding_path in sorted(embedding_files):
|
||||
try:
|
||||
@ -204,35 +208,110 @@ def merge_embeddings_by_group(unique_id: str, group_name: str) -> Dict:
|
||||
if isinstance(embedding_data, dict) and 'chunks' in embedding_data:
|
||||
chunks = embedding_data['chunks']
|
||||
|
||||
# 获取embeddings向量(关键修复)
|
||||
if 'embeddings' in embedding_data:
|
||||
embeddings = embedding_data['embeddings']
|
||||
all_embeddings.append(embeddings)
|
||||
|
||||
# 从第一个文件获取模型信息
|
||||
if 'model_path' in embedding_data:
|
||||
model_path = embedding_data['model_path']
|
||||
if 'chunking_strategy' in embedding_data:
|
||||
chunking_strategy = embedding_data['chunking_strategy']
|
||||
if 'chunking_params' in embedding_data:
|
||||
chunking_params = embedding_data['chunking_params']
|
||||
|
||||
# Add source file metadata to each chunk
|
||||
for chunk in chunks:
|
||||
if isinstance(chunk, dict):
|
||||
chunk['source_file'] = filename_stem
|
||||
chunk['source_group'] = group_name
|
||||
elif isinstance(chunk, str):
|
||||
# 如果chunk是字符串,保持原样
|
||||
pass
|
||||
|
||||
all_chunks.extend(chunks)
|
||||
total_chunks += len(chunks)
|
||||
|
||||
# Get dimensions from first chunk if available
|
||||
if dimensions == 0 and chunks and isinstance(chunks[0], dict):
|
||||
if 'embedding' in chunks[0] and hasattr(chunks[0]['embedding'], 'shape'):
|
||||
dimensions = chunks[0]['embedding'].shape[0]
|
||||
|
||||
result["source_files"].append(filename_stem)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading embedding file {embedding_path}: {str(e)}")
|
||||
continue
|
||||
|
||||
if all_chunks:
|
||||
if all_chunks and all_embeddings:
|
||||
# 合并所有embeddings向量
|
||||
try:
|
||||
# 尝试使用torch合并张量
|
||||
import torch
|
||||
if all(isinstance(emb, torch.Tensor) for emb in all_embeddings):
|
||||
merged_embeddings = torch.cat(all_embeddings, dim=0)
|
||||
dimensions = merged_embeddings.shape[1]
|
||||
else:
|
||||
# 如果不是tensor类型,尝试转换为numpy
|
||||
import numpy as np
|
||||
if NUMPY_SUPPORT:
|
||||
np_embeddings = []
|
||||
for emb in all_embeddings:
|
||||
if hasattr(emb, 'numpy'):
|
||||
np_embeddings.append(emb.numpy())
|
||||
elif isinstance(emb, np.ndarray):
|
||||
np_embeddings.append(emb)
|
||||
else:
|
||||
# 如果无法转换,跳过这个文件
|
||||
print(f"Warning: Cannot convert embedding to numpy from file {filename_stem}")
|
||||
continue
|
||||
|
||||
if np_embeddings:
|
||||
merged_embeddings = np.concatenate(np_embeddings, axis=0)
|
||||
dimensions = merged_embeddings.shape[1]
|
||||
else:
|
||||
result["error"] = "No valid embedding tensors could be merged"
|
||||
return result
|
||||
else:
|
||||
result["error"] = "NumPy not available for merging embeddings"
|
||||
return result
|
||||
|
||||
except ImportError:
|
||||
# 如果没有torch,尝试使用numpy
|
||||
if NUMPY_SUPPORT:
|
||||
import numpy as np
|
||||
np_embeddings = []
|
||||
for emb in all_embeddings:
|
||||
if hasattr(emb, 'numpy'):
|
||||
np_embeddings.append(emb.numpy())
|
||||
elif isinstance(emb, np.ndarray):
|
||||
np_embeddings.append(emb)
|
||||
else:
|
||||
print(f"Warning: Cannot convert embedding to numpy from file {filename_stem}")
|
||||
continue
|
||||
|
||||
if np_embeddings:
|
||||
merged_embeddings = np.concatenate(np_embeddings, axis=0)
|
||||
dimensions = merged_embeddings.shape[1]
|
||||
else:
|
||||
result["error"] = "No valid embedding tensors could be merged"
|
||||
return result
|
||||
else:
|
||||
result["error"] = "Neither torch nor numpy available for merging embeddings"
|
||||
return result
|
||||
except Exception as e:
|
||||
result["error"] = f"Failed to merge embedding tensors: {str(e)}"
|
||||
print(f"Error merging embedding tensors: {str(e)}")
|
||||
return result
|
||||
|
||||
# Create merged embedding data structure
|
||||
merged_embedding_data = {
|
||||
'chunks': all_chunks,
|
||||
'embeddings': merged_embeddings, # 关键修复:添加embeddings键
|
||||
'total_chunks': total_chunks,
|
||||
'dimensions': dimensions,
|
||||
'source_files': result["source_files"],
|
||||
'group_name': group_name,
|
||||
'merged_at': str(os.path.getmtime(merged_embedding_path) if os.path.exists(merged_embedding_path) else 0)
|
||||
'merged_at': str(__import__('time').time()),
|
||||
'chunking_strategy': chunking_strategy,
|
||||
'chunking_params': chunking_params,
|
||||
'model_path': model_path
|
||||
}
|
||||
|
||||
# Save merged embeddings
|
||||
|
||||
@ -21,34 +21,42 @@ from utils.data_merger import (
|
||||
)
|
||||
|
||||
|
||||
async def download_dataset_files(unique_id: str, files: Dict[str, List[str]]) -> Dict[str, List[str]]:
|
||||
async def download_dataset_files(unique_id: str, files: Dict[str, List[str]], incremental_mode: bool = False) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Process dataset files with new architecture:
|
||||
1. Sync files to group directories
|
||||
2. Process each file individually
|
||||
3. Merge results by group
|
||||
4. Clean up orphaned files
|
||||
4. Clean up orphaned files (only in non-incremental mode)
|
||||
|
||||
Args:
|
||||
unique_id: Project ID
|
||||
files: Dictionary of files to process, grouped by key
|
||||
incremental_mode: If True, preserve existing files and only process new ones
|
||||
"""
|
||||
if not files:
|
||||
return {}
|
||||
|
||||
print(f"Starting new file processing for project: {unique_id}")
|
||||
print(f"Starting {'incremental' if incremental_mode else 'full'} file processing for project: {unique_id}")
|
||||
|
||||
# Ensure project directories exist
|
||||
ensure_directories(unique_id)
|
||||
|
||||
# Step 1: Sync files to group directories
|
||||
print("Step 1: Syncing files to group directories...")
|
||||
synced_files, failed_files = sync_files_to_group(unique_id, files)
|
||||
synced_files, failed_files = sync_files_to_group(unique_id, files, incremental_mode)
|
||||
|
||||
# Step 2: Detect changes and cleanup orphaned files
|
||||
# Step 2: Detect changes and cleanup orphaned files (only in non-incremental mode)
|
||||
from utils.file_manager import detect_file_changes
|
||||
changes = detect_file_changes(unique_id, files)
|
||||
changes = detect_file_changes(unique_id, files, incremental_mode)
|
||||
|
||||
if any(changes["removed"].values()):
|
||||
# Only cleanup orphaned files in non-incremental mode or when files are explicitly removed
|
||||
if not incremental_mode and any(changes["removed"].values()):
|
||||
print("Step 2: Cleaning up orphaned files...")
|
||||
removed_files = cleanup_orphaned_files(unique_id, changes)
|
||||
print(f"Removed orphaned files: {removed_files}")
|
||||
elif incremental_mode:
|
||||
print("Step 2: Skipping cleanup in incremental mode to preserve existing files")
|
||||
|
||||
# Step 3: Process individual files
|
||||
print("Step 3: Processing individual files...")
|
||||
|
||||
@ -30,8 +30,16 @@ def get_existing_files(unique_id: str) -> Dict[str, Set[str]]:
|
||||
return existing_files
|
||||
|
||||
|
||||
def detect_file_changes(unique_id: str, new_files: Dict[str, List[str]]) -> Dict:
|
||||
"""Detect file changes: added, removed, and existing files."""
|
||||
def detect_file_changes(unique_id: str, new_files: Dict[str, List[str]], incremental_mode: bool = False) -> Dict:
|
||||
"""
|
||||
Detect file changes: added, removed, and existing files.
|
||||
|
||||
Args:
|
||||
unique_id: Project ID
|
||||
new_files: Dictionary of files to process, grouped by key
|
||||
incremental_mode: If True, only detect removed files when files_to_remove is explicitly provided
|
||||
This prevents accidental deletion of existing files during incremental additions
|
||||
"""
|
||||
existing_files = get_existing_files(unique_id)
|
||||
|
||||
changes = {
|
||||
@ -63,6 +71,8 @@ def detect_file_changes(unique_id: str, new_files: Dict[str, List[str]]) -> Dict
|
||||
changes["removed"][group] = set()
|
||||
|
||||
# Detect removed files (files that exist but not in new request)
|
||||
# Skip this step in incremental mode to preserve existing files
|
||||
if not incremental_mode:
|
||||
for group, existing_filenames in existing_files.items():
|
||||
if group in new_files_sets:
|
||||
# Group exists in new request, check for individual file removals
|
||||
@ -76,10 +86,15 @@ def detect_file_changes(unique_id: str, new_files: Dict[str, List[str]]) -> Dict
|
||||
return changes
|
||||
|
||||
|
||||
def sync_files_to_group(unique_id: str, files: Dict[str, List[str]]) -> Tuple[Dict, Dict]:
|
||||
def sync_files_to_group(unique_id: str, files: Dict[str, List[str]], incremental_mode: bool = False) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Sync files to group directories and return sync results.
|
||||
|
||||
Args:
|
||||
unique_id: Project ID
|
||||
files: Dictionary of files to sync, grouped by key
|
||||
incremental_mode: If True, preserve existing files and only add new ones
|
||||
|
||||
Returns:
|
||||
Tuple of (synced_files, failed_files)
|
||||
"""
|
||||
@ -90,7 +105,7 @@ def sync_files_to_group(unique_id: str, files: Dict[str, List[str]]) -> Tuple[Di
|
||||
os.makedirs(files_dir, exist_ok=True)
|
||||
|
||||
# Detect changes first
|
||||
changes = detect_file_changes(unique_id, files)
|
||||
changes = detect_file_changes(unique_id, files, incremental_mode)
|
||||
|
||||
synced_files = {}
|
||||
failed_files = {}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user