add file process
This commit is contained in:
parent
e1c2df763e
commit
e21c3cb44e
584
fastapi_app.py
584
fastapi_app.py
@ -1,5 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
import hashlib
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
import uvicorn
|
||||
@ -41,19 +44,345 @@ def get_content_from_messages(messages: List[dict]) -> str:
|
||||
|
||||
from file_loaded_agent_manager import get_global_agent_manager, init_global_agent_manager
|
||||
from gbase_agent import update_agent_llm
|
||||
from zip_project_handler import zip_handler
|
||||
|
||||
|
||||
def get_zip_url_from_unique_id(unique_id: str) -> Optional[str]:
|
||||
"""从unique_map.json中读取zip_url"""
|
||||
async def download_file(url: str, destination_path: str) -> bool:
|
||||
"""Download file from URL to destination path"""
|
||||
try:
|
||||
with open('unique_map.json', 'r', encoding='utf-8') as f:
|
||||
unique_map = json.load(f)
|
||||
return unique_map.get(unique_id)
|
||||
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
async with aiofiles.open(destination_path, 'wb') as f:
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
await f.write(chunk)
|
||||
return True
|
||||
else:
|
||||
print(f"Failed to download file from {url}, status: {response.status}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Error reading unique_map.json: {e}")
|
||||
print(f"Error downloading file from {url}: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def get_file_hash(file_path: str) -> str:
|
||||
"""Generate MD5 hash for a file path/URL"""
|
||||
return hashlib.md5(file_path.encode('utf-8')).hexdigest()
|
||||
|
||||
def load_processed_files_log(unique_id: str) -> Dict[str, Dict]:
|
||||
"""Load processed files log for a project"""
|
||||
log_file = os.path.join("projects", unique_id, "processed_files.json")
|
||||
if os.path.exists(log_file):
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(f"Error loading processed files log: {str(e)}")
|
||||
return {}
|
||||
|
||||
def save_processed_files_log(unique_id: str, processed_log: Dict[str, Dict]):
|
||||
"""Save processed files log for a project"""
|
||||
log_file = os.path.join("projects", unique_id, "processed_files.json")
|
||||
try:
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
with open(log_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(processed_log, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"Error saving processed files log: {str(e)}")
|
||||
|
||||
def remove_file_or_directory(path: str):
|
||||
"""Remove file or directory if it exists"""
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
if os.path.isdir(path):
|
||||
import shutil
|
||||
shutil.rmtree(path)
|
||||
print(f"Removed directory: {path}")
|
||||
else:
|
||||
os.remove(path)
|
||||
print(f"Removed file: {path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error removing {path}: {str(e)}")
|
||||
return False
|
||||
|
||||
def remove_dataset_directory(unique_id: str, filename_without_ext: str):
|
||||
"""Remove the entire dataset directory for a specific file"""
|
||||
dataset_dir = os.path.join("projects", unique_id, "dataset", filename_without_ext)
|
||||
if remove_file_or_directory(dataset_dir):
|
||||
print(f"Removed dataset directory: {dataset_dir}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_document_preview(document_path: str, max_lines: int = 10) -> str:
|
||||
"""Get preview of document content (first max_lines lines)"""
|
||||
try:
|
||||
with open(document_path, 'r', encoding='utf-8') as f:
|
||||
lines = []
|
||||
for i, line in enumerate(f):
|
||||
if i >= max_lines:
|
||||
break
|
||||
lines.append(line.rstrip())
|
||||
return '\n'.join(lines)
|
||||
except Exception as e:
|
||||
print(f"Error reading document preview from {document_path}: {str(e)}")
|
||||
return f"Error reading document: {str(e)}"
|
||||
|
||||
def generate_dataset_structure(unique_id: str) -> str:
|
||||
"""Generate dataset directory structure as a string"""
|
||||
dataset_dir = os.path.join("projects", unique_id, "dataset")
|
||||
structure_lines = []
|
||||
|
||||
def build_tree(path: str, prefix: str = "", is_last: bool = True):
|
||||
try:
|
||||
items = sorted(os.listdir(path))
|
||||
items = [item for item in items if not item.startswith('.')] # Hide hidden files
|
||||
|
||||
for i, item in enumerate(items):
|
||||
item_path = os.path.join(path, item)
|
||||
is_dir = os.path.isdir(item_path)
|
||||
|
||||
# Determine tree symbols
|
||||
if i == len(items) - 1:
|
||||
current_prefix = "└── " if is_last else "├── "
|
||||
next_prefix = " " if is_last else "│ "
|
||||
else:
|
||||
current_prefix = "├── "
|
||||
next_prefix = "│ "
|
||||
|
||||
line = prefix + current_prefix + item
|
||||
if is_dir:
|
||||
line += "/"
|
||||
structure_lines.append(line)
|
||||
|
||||
# Recursively process subdirectories
|
||||
if is_dir:
|
||||
build_tree(item_path, prefix + next_prefix, i == len(items) - 1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error building tree for {path}: {str(e)}")
|
||||
|
||||
structure_lines.append("dataset/")
|
||||
if os.path.exists(dataset_dir):
|
||||
build_tree(dataset_dir)
|
||||
else:
|
||||
structure_lines.append(" (empty)")
|
||||
|
||||
return '\n'.join(structure_lines)
|
||||
|
||||
def generate_project_readme(unique_id: str) -> str:
|
||||
"""Generate README.md content for a project"""
|
||||
project_dir = os.path.join("projects", unique_id)
|
||||
dataset_dir = os.path.join(project_dir, "dataset")
|
||||
|
||||
readme_content = f"""# Project: {unique_id}
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
```
|
||||
{generate_dataset_structure(unique_id)}
|
||||
```
|
||||
|
||||
## Files Description
|
||||
|
||||
"""
|
||||
|
||||
if not os.path.exists(dataset_dir):
|
||||
readme_content += "No dataset files available.\n"
|
||||
else:
|
||||
# Get all document directories
|
||||
doc_dirs = []
|
||||
try:
|
||||
for item in sorted(os.listdir(dataset_dir)):
|
||||
item_path = os.path.join(dataset_dir, item)
|
||||
if os.path.isdir(item_path):
|
||||
doc_dirs.append(item)
|
||||
except Exception as e:
|
||||
print(f"Error listing dataset directories: {str(e)}")
|
||||
|
||||
if not doc_dirs:
|
||||
readme_content += "No document directories found.\n"
|
||||
else:
|
||||
for doc_dir in doc_dirs:
|
||||
doc_path = os.path.join(dataset_dir, doc_dir)
|
||||
document_file = os.path.join(doc_path, "document.txt")
|
||||
pagination_file = os.path.join(doc_path, "pagination.txt")
|
||||
embeddings_file = os.path.join(doc_path, "document_embeddings.pkl")
|
||||
|
||||
readme_content += f"### {doc_dir}\n\n"
|
||||
readme_content += f"**Files:**\n"
|
||||
readme_content += f"- `document.txt`"
|
||||
if os.path.exists(document_file):
|
||||
readme_content += " ✓"
|
||||
readme_content += "\n"
|
||||
|
||||
readme_content += f"- `pagination.txt`"
|
||||
if os.path.exists(pagination_file):
|
||||
readme_content += " ✓"
|
||||
readme_content += "\n"
|
||||
|
||||
readme_content += f"- `document_embeddings.pkl`"
|
||||
if os.path.exists(embeddings_file):
|
||||
readme_content += " ✓"
|
||||
readme_content += "\n\n"
|
||||
|
||||
# Add document preview
|
||||
if os.path.exists(document_file):
|
||||
readme_content += f"**Content Preview (first 10 lines):**\n\n```\n"
|
||||
preview = get_document_preview(document_file, 10)
|
||||
readme_content += preview
|
||||
readme_content += "\n```\n\n"
|
||||
else:
|
||||
readme_content += f"**Content Preview:** Not available\n\n"
|
||||
|
||||
readme_content += f"""---
|
||||
*Generated on {__import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*
|
||||
"""
|
||||
|
||||
return readme_content
|
||||
|
||||
def save_project_readme(unique_id: str):
|
||||
"""Generate and save README.md for a project"""
|
||||
try:
|
||||
readme_content = generate_project_readme(unique_id)
|
||||
readme_path = os.path.join("projects", unique_id, "README.md")
|
||||
|
||||
with open(readme_path, 'w', encoding='utf-8') as f:
|
||||
f.write(readme_content)
|
||||
|
||||
print(f"Generated README.md for project {unique_id}")
|
||||
return readme_path
|
||||
except Exception as e:
|
||||
print(f"Error generating README for project {unique_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
async def download_dataset_files(unique_id: str, files: List[str]) -> List[str]:
|
||||
"""Download or copy dataset files to projects/{unique_id}/files directory with processing state management"""
|
||||
if not files:
|
||||
return []
|
||||
|
||||
# Load existing processed files log
|
||||
processed_log = load_processed_files_log(unique_id)
|
||||
files_dir = os.path.join("projects", unique_id, "files")
|
||||
|
||||
# Convert files list to a set for easy comparison
|
||||
new_files_hashes = {get_file_hash(file_path): file_path for file_path in files}
|
||||
existing_files_hashes = set(processed_log.keys())
|
||||
|
||||
# Files to process (new or modified)
|
||||
files_to_process = []
|
||||
# Files to remove (no longer in the list)
|
||||
files_to_remove = existing_files_hashes - set(new_files_hashes.keys())
|
||||
|
||||
processed_files = []
|
||||
|
||||
# Remove files that are no longer in the list
|
||||
for file_hash in files_to_remove:
|
||||
file_info = processed_log[file_hash]
|
||||
|
||||
# Remove local file in files directory
|
||||
if 'local_path' in file_info:
|
||||
remove_file_or_directory(file_info['local_path'])
|
||||
|
||||
# Remove the entire dataset directory for this file
|
||||
if 'filename' in file_info:
|
||||
filename_without_ext = os.path.splitext(file_info['filename'])[0]
|
||||
remove_dataset_directory(unique_id, filename_without_ext)
|
||||
|
||||
# Also remove any specific dataset path if exists (fallback)
|
||||
if 'dataset_path' in file_info:
|
||||
remove_file_or_directory(file_info['dataset_path'])
|
||||
|
||||
# Remove from log
|
||||
del processed_log[file_hash]
|
||||
print(f"Removed file from processing: {file_info.get('original_path', 'unknown')}")
|
||||
|
||||
# Process new files
|
||||
for file_path in files:
|
||||
file_hash = get_file_hash(file_path)
|
||||
|
||||
# Check if file was already processed
|
||||
if file_hash in processed_log:
|
||||
file_info = processed_log[file_hash]
|
||||
if 'local_path' in file_info and os.path.exists(file_info['local_path']):
|
||||
processed_files.append(file_info['local_path'])
|
||||
print(f"Skipped already processed file: {file_path}")
|
||||
continue
|
||||
|
||||
# Extract filename from URL or path
|
||||
filename = file_path.split("/")[-1]
|
||||
if not filename:
|
||||
filename = f"file_{len(processed_files)}"
|
||||
|
||||
destination_path = os.path.join(files_dir, filename)
|
||||
|
||||
# Check if it's a URL (remote file) or local file
|
||||
success = False
|
||||
if file_path.startswith(('http://', 'https://')):
|
||||
# Download remote file
|
||||
success = await download_file(file_path, destination_path)
|
||||
else:
|
||||
# Copy local file
|
||||
try:
|
||||
import shutil
|
||||
os.makedirs(files_dir, exist_ok=True)
|
||||
shutil.copy2(file_path, destination_path)
|
||||
success = True
|
||||
print(f"Copied local file: {file_path} -> {destination_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to copy local file {file_path}: {str(e)}")
|
||||
|
||||
if success:
|
||||
processed_files.append(destination_path)
|
||||
# Update processed log
|
||||
processed_log[file_hash] = {
|
||||
'original_path': file_path,
|
||||
'local_path': destination_path,
|
||||
'filename': filename,
|
||||
'processed_at': str(__import__('datetime').datetime.now()),
|
||||
'file_type': 'remote' if file_path.startswith(('http://', 'https://')) else 'local'
|
||||
}
|
||||
print(f"Successfully processed file: {file_path}")
|
||||
else:
|
||||
print(f"Failed to process file: {file_path}")
|
||||
|
||||
# After downloading/copying files, organize them into dataset structure
|
||||
if processed_files:
|
||||
try:
|
||||
from organize_dataset_files import organize_single_project_files
|
||||
|
||||
# Update dataset paths in the log after organization
|
||||
old_processed_log = processed_log.copy()
|
||||
organize_single_project_files(unique_id, skip_processed=True)
|
||||
|
||||
# Try to update dataset paths in the log
|
||||
for file_hash, file_info in old_processed_log.items():
|
||||
if 'local_path' in file_info and os.path.exists(file_info['local_path']):
|
||||
# Construct expected dataset path based on known structure
|
||||
filename_without_ext = os.path.splitext(file_info['filename'])[0]
|
||||
dataset_path = os.path.join("projects", unique_id, "dataset", filename_without_ext, "document.txt")
|
||||
if os.path.exists(dataset_path):
|
||||
processed_log[file_hash]['dataset_path'] = dataset_path
|
||||
|
||||
print(f"Organized files for project {unique_id} into dataset structure (skipping already processed files)")
|
||||
except Exception as e:
|
||||
print(f"Failed to organize files for project {unique_id}: {str(e)}")
|
||||
|
||||
# Save the updated processed log
|
||||
save_processed_files_log(unique_id, processed_log)
|
||||
|
||||
# Generate README.md after processing files
|
||||
try:
|
||||
save_project_readme(unique_id)
|
||||
except Exception as e:
|
||||
print(f"Failed to generate README for project {unique_id}: {str(e)}")
|
||||
|
||||
return processed_files
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 全局助手管理器配置
|
||||
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "20"))
|
||||
|
||||
@ -80,11 +409,17 @@ class Message(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class DatasetRequest(BaseModel):
|
||||
system_prompt: Optional[str] = None
|
||||
mcp_settings: Optional[List[Dict]] = None
|
||||
files: Optional[List[str]] = None
|
||||
unique_id: Optional[str] = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: List[Message]
|
||||
model: str = "qwen3-next"
|
||||
model_server: str = ""
|
||||
zip_url: Optional[str] = None
|
||||
unique_id: Optional[str] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
@ -170,13 +505,102 @@ async def generate_stream_response(agent, messages, request) -> AsyncGenerator[s
|
||||
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
class FileProcessRequest(BaseModel):
|
||||
unique_id: str
|
||||
files: Optional[List[str]] = None
|
||||
system_prompt: Optional[str] = None
|
||||
mcp_settings: Optional[List[Dict]] = None
|
||||
|
||||
class Config:
|
||||
extra = 'allow'
|
||||
|
||||
|
||||
class FileProcessResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
unique_id: str
|
||||
processed_files: List[str]
|
||||
|
||||
|
||||
@app.post("/api/v1/files/process")
|
||||
async def process_files(request: FileProcessRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
Process dataset files for a given unique_id
|
||||
|
||||
Args:
|
||||
request: FileProcessRequest containing unique_id, files, system_prompt, and mcp_settings
|
||||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||||
|
||||
Returns:
|
||||
FileProcessResponse: Processing result with file list
|
||||
"""
|
||||
try:
|
||||
|
||||
unique_id = request.unique_id
|
||||
if not unique_id:
|
||||
raise HTTPException(status_code=400, detail="unique_id is required")
|
||||
|
||||
# 处理文件:只使用request.files
|
||||
processed_files = []
|
||||
if request.files:
|
||||
# 使用请求中的文件
|
||||
processed_files = await download_dataset_files(unique_id, request.files)
|
||||
print(f"Processed {len(processed_files)} dataset files for unique_id: {unique_id}")
|
||||
else:
|
||||
print(f"No files provided in request for unique_id: {unique_id}")
|
||||
|
||||
# 使用unique_id获取项目目录
|
||||
project_dir = os.path.join("projects", unique_id)
|
||||
if not os.path.exists(project_dir):
|
||||
raise HTTPException(status_code=400, detail=f"Project directory not found for unique_id: {unique_id}")
|
||||
|
||||
# 收集项目目录下所有的 document.txt 文件
|
||||
document_files = []
|
||||
for root, dirs, files in os.walk(project_dir):
|
||||
for file in files:
|
||||
if file == "document.txt":
|
||||
document_files.append(os.path.join(root, file))
|
||||
|
||||
# 合并所有处理的文件
|
||||
all_files = document_files + processed_files
|
||||
|
||||
if not all_files:
|
||||
print(f"警告: 项目目录 {project_dir} 中未找到任何 document.txt 文件")
|
||||
|
||||
# 保存system_prompt和mcp_settings到项目目录(如果提供)
|
||||
if request.system_prompt:
|
||||
system_prompt_file = os.path.join(project_dir, "system_prompt.md")
|
||||
with open(system_prompt_file, 'w', encoding='utf-8') as f:
|
||||
f.write(request.system_prompt)
|
||||
print(f"Saved system_prompt for unique_id: {unique_id}")
|
||||
|
||||
if request.mcp_settings:
|
||||
mcp_settings_file = os.path.join(project_dir, "mcp_settings.json")
|
||||
with open(mcp_settings_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(request.mcp_settings, f, ensure_ascii=False, indent=2)
|
||||
print(f"Saved mcp_settings for unique_id: {unique_id}")
|
||||
|
||||
return FileProcessResponse(
|
||||
success=True,
|
||||
message=f"Successfully processed {len(all_files)} files",
|
||||
unique_id=unique_id,
|
||||
processed_files=all_files
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"Error processing files: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/api/v1/chat/completions")
|
||||
async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
Chat completions API similar to OpenAI, supports both streaming and non-streaming
|
||||
|
||||
Args:
|
||||
request: ChatRequest containing messages, model, zip_url, etc.
|
||||
request: ChatRequest containing messages, model, dataset with unique_id, system_prompt, mcp_settings, and files
|
||||
authorization: Authorization header containing API key (Bearer <API_KEY>)
|
||||
|
||||
Returns:
|
||||
@ -192,39 +616,23 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
|
||||
else:
|
||||
api_key = authorization
|
||||
|
||||
# 从最外层获取zip_url和unique_id参数
|
||||
zip_url = request.zip_url
|
||||
# 获取unique_id
|
||||
unique_id = request.unique_id
|
||||
if not unique_id:
|
||||
raise HTTPException(status_code=400, detail="unique_id is required")
|
||||
|
||||
# 如果提供了unique_id,从unique_map.json中读取zip_url
|
||||
if unique_id:
|
||||
zip_url = get_zip_url_from_unique_id(unique_id)
|
||||
if not zip_url:
|
||||
raise HTTPException(status_code=400, detail=f"No zip_url found for unique_id: {unique_id}")
|
||||
|
||||
if not zip_url:
|
||||
raise HTTPException(status_code=400, detail="zip_url is required")
|
||||
|
||||
# 使用ZIP URL获取项目数据
|
||||
print(f"从ZIP URL加载项目: {zip_url}")
|
||||
project_dir = zip_handler.get_project_from_zip(zip_url, unique_id if unique_id else None)
|
||||
if not project_dir:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to load project from ZIP URL: {zip_url}")
|
||||
|
||||
# 收集项目目录下所有的 document.txt 文件
|
||||
document_files = zip_handler.collect_document_files(project_dir)
|
||||
|
||||
if not document_files:
|
||||
print(f"警告: 项目目录 {project_dir} 中未找到任何 document.txt 文件")
|
||||
# 使用unique_id获取项目目录
|
||||
project_dir = os.path.join("projects", unique_id)
|
||||
if not os.path.exists(project_dir):
|
||||
raise HTTPException(status_code=400, detail=f"Project directory not found for unique_id: {unique_id}")
|
||||
|
||||
# 收集额外参数作为 generate_cfg
|
||||
exclude_fields = {'messages', 'model', 'model_server', 'zip_url', 'unique_id', 'stream'}
|
||||
exclude_fields = {'messages', 'model', 'model_server', 'unique_id', 'stream'}
|
||||
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
|
||||
|
||||
# 从全局管理器获取或创建文件预加载的助手实例
|
||||
# 从全局管理器获取或创建助手实例(配置读取逻辑已在agent_manager内部处理)
|
||||
agent = await agent_manager.get_or_create_agent(
|
||||
zip_url=zip_url,
|
||||
files=document_files,
|
||||
unique_id=unique_id,
|
||||
project_dir=project_dir,
|
||||
model_name=request.model,
|
||||
api_key=api_key,
|
||||
@ -322,17 +730,13 @@ async def system_status():
|
||||
|
||||
@app.post("/system/cleanup-cache")
|
||||
async def cleanup_cache():
|
||||
"""清理ZIP文件缓存和助手缓存"""
|
||||
"""清理助手缓存"""
|
||||
try:
|
||||
# 清理ZIP文件缓存
|
||||
zip_handler.cleanup_cache()
|
||||
|
||||
# 清理助手实例缓存
|
||||
cleared_count = agent_manager.clear_cache()
|
||||
|
||||
return {
|
||||
"message": "缓存清理成功",
|
||||
"cleared_zip_files": True,
|
||||
"cleared_agent_instances": cleared_count
|
||||
}
|
||||
except Exception as e:
|
||||
@ -356,11 +760,9 @@ async def cleanup_agent_cache():
|
||||
async def get_cached_projects():
|
||||
"""获取所有缓存的项目信息"""
|
||||
try:
|
||||
cached_urls = agent_manager.list_cached_zip_urls()
|
||||
cache_stats = agent_manager.get_cache_stats()
|
||||
|
||||
return {
|
||||
"cached_projects": cached_urls,
|
||||
"cache_stats": cache_stats
|
||||
}
|
||||
except Exception as e:
|
||||
@ -368,18 +770,108 @@ async def get_cached_projects():
|
||||
|
||||
|
||||
@app.post("/system/remove-project-cache")
|
||||
async def remove_project_cache(zip_url: str):
|
||||
async def remove_project_cache(unique_id: str):
|
||||
"""移除特定项目的缓存"""
|
||||
try:
|
||||
success = agent_manager.remove_cache_by_url(zip_url)
|
||||
success = agent_manager.remove_cache_by_unique_id(unique_id)
|
||||
if success:
|
||||
return {"message": f"项目缓存移除成功: {zip_url}"}
|
||||
return {"message": f"项目缓存移除成功: {unique_id}"}
|
||||
else:
|
||||
return {"message": f"未找到项目缓存: {zip_url}", "removed": False}
|
||||
return {"message": f"未找到项目缓存: {unique_id}", "removed": False}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/api/v1/files/{unique_id}/status")
|
||||
async def get_files_processing_status(unique_id: str):
|
||||
"""获取项目的文件处理状态"""
|
||||
try:
|
||||
# Load processed files log
|
||||
processed_log = load_processed_files_log(unique_id)
|
||||
|
||||
# Get project directory info
|
||||
project_dir = os.path.join("projects", unique_id)
|
||||
project_exists = os.path.exists(project_dir)
|
||||
|
||||
# Collect document.txt files
|
||||
document_files = []
|
||||
if project_exists:
|
||||
for root, dirs, files in os.walk(project_dir):
|
||||
for file in files:
|
||||
if file == "document.txt":
|
||||
document_files.append(os.path.join(root, file))
|
||||
|
||||
return {
|
||||
"unique_id": unique_id,
|
||||
"project_exists": project_exists,
|
||||
"processed_files_count": len(processed_log),
|
||||
"processed_files": processed_log,
|
||||
"document_files_count": len(document_files),
|
||||
"document_files": document_files,
|
||||
"log_file_exists": os.path.exists(os.path.join("projects", unique_id, "processed_files.json"))
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"获取文件处理状态失败: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/api/v1/files/{unique_id}/reset")
|
||||
async def reset_files_processing(unique_id: str):
|
||||
"""重置项目的文件处理状态,删除处理日志和所有文件"""
|
||||
try:
|
||||
project_dir = os.path.join("projects", unique_id)
|
||||
log_file = os.path.join("projects", unique_id, "processed_files.json")
|
||||
|
||||
# Load processed log to know what files to remove
|
||||
processed_log = load_processed_files_log(unique_id)
|
||||
|
||||
removed_files = []
|
||||
# Remove all processed files and their dataset directories
|
||||
for file_hash, file_info in processed_log.items():
|
||||
# Remove local file in files directory
|
||||
if 'local_path' in file_info:
|
||||
if remove_file_or_directory(file_info['local_path']):
|
||||
removed_files.append(file_info['local_path'])
|
||||
|
||||
# Remove the entire dataset directory for this file
|
||||
if 'filename' in file_info:
|
||||
filename_without_ext = os.path.splitext(file_info['filename'])[0]
|
||||
dataset_dir = os.path.join("projects", unique_id, "dataset", filename_without_ext)
|
||||
if remove_file_or_directory(dataset_dir):
|
||||
removed_files.append(dataset_dir)
|
||||
|
||||
# Also remove any specific dataset path if exists (fallback)
|
||||
if 'dataset_path' in file_info:
|
||||
if remove_file_or_directory(file_info['dataset_path']):
|
||||
removed_files.append(file_info['dataset_path'])
|
||||
|
||||
# Remove the log file
|
||||
if remove_file_or_directory(log_file):
|
||||
removed_files.append(log_file)
|
||||
|
||||
# Remove the entire files directory
|
||||
files_dir = os.path.join(project_dir, "files")
|
||||
if remove_file_or_directory(files_dir):
|
||||
removed_files.append(files_dir)
|
||||
|
||||
# Also remove the entire dataset directory (clean up any remaining files)
|
||||
dataset_dir = os.path.join(project_dir, "dataset")
|
||||
if remove_file_or_directory(dataset_dir):
|
||||
removed_files.append(dataset_dir)
|
||||
|
||||
# Remove README.md if exists
|
||||
readme_file = os.path.join(project_dir, "README.md")
|
||||
if remove_file_or_directory(readme_file):
|
||||
removed_files.append(readme_file)
|
||||
|
||||
return {
|
||||
"message": f"文件处理状态重置成功: {unique_id}",
|
||||
"removed_files_count": len(removed_files),
|
||||
"removed_files": removed_files
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}")
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""文件预加载助手管理器 - 管理基于ZIP URL的助手实例缓存"""
|
||||
"""文件预加载助手管理器 - 管理基于unique_id的助手实例缓存"""
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
@ -27,20 +27,19 @@ from gbase_agent import init_agent_service_with_files, update_agent_llm
|
||||
class FileLoadedAgentManager:
|
||||
"""文件预加载助手管理器
|
||||
|
||||
基于 ZIP URL 缓存助手实例,避免重复创建和文件解析
|
||||
基于 unique_id 缓存助手实例,避免重复创建和文件解析
|
||||
"""
|
||||
|
||||
def __init__(self, max_cached_agents: int = 20):
|
||||
self.agents: Dict[str, Assistant] = {} # {zip_url_hash: assistant_instance}
|
||||
self.zip_urls: Dict[str, str] = {} # {zip_url_hash: original_zip_url}
|
||||
self.agents: Dict[str, Assistant] = {} # {unique_id: assistant_instance}
|
||||
self.unique_ids: Dict[str, str] = {} # {cache_key: unique_id}
|
||||
self.access_times: Dict[str, float] = {} # LRU 访问时间管理
|
||||
self.creation_times: Dict[str, float] = {} # 创建时间记录
|
||||
self.file_counts: Dict[str, int] = {} # 缓存的文件数量
|
||||
self.max_cached_agents = max_cached_agents
|
||||
|
||||
def _get_zip_url_hash(self, zip_url: str) -> str:
|
||||
"""获取 ZIP URL 的哈希值作为缓存键"""
|
||||
return hashlib.md5(zip_url.encode('utf-8')).hexdigest()[:16]
|
||||
def _get_cache_key(self, unique_id: str) -> str:
|
||||
"""获取 unique_id 的哈希值作为缓存键"""
|
||||
return hashlib.md5(unique_id.encode('utf-8')).hexdigest()[:16]
|
||||
|
||||
def _update_access_time(self, cache_key: str):
|
||||
"""更新访问时间(LRU 管理)"""
|
||||
@ -60,10 +59,9 @@ class FileLoadedAgentManager:
|
||||
for cache_key in keys_to_remove:
|
||||
try:
|
||||
del self.agents[cache_key]
|
||||
del self.zip_urls[cache_key]
|
||||
del self.unique_ids[cache_key]
|
||||
del self.access_times[cache_key]
|
||||
del self.creation_times[cache_key]
|
||||
del self.file_counts[cache_key]
|
||||
removed_count += 1
|
||||
logger.info(f"清理过期的助手实例缓存: {cache_key}")
|
||||
except KeyError:
|
||||
@ -73,8 +71,7 @@ class FileLoadedAgentManager:
|
||||
logger.info(f"已清理 {removed_count} 个过期的助手实例缓存")
|
||||
|
||||
async def get_or_create_agent(self,
|
||||
zip_url: str,
|
||||
files: List[str],
|
||||
unique_id: str,
|
||||
project_dir: str,
|
||||
model_name: str = "qwen3-next",
|
||||
api_key: Optional[str] = None,
|
||||
@ -83,7 +80,7 @@ class FileLoadedAgentManager:
|
||||
"""获取或创建文件预加载的助手实例
|
||||
|
||||
Args:
|
||||
zip_url: ZIP 文件的 URL
|
||||
unique_id: 项目的唯一标识符
|
||||
files: 需要预加载的文件路径列表
|
||||
project_dir: 项目目录路径,用于读取system_prompt.md和mcp_settings.json
|
||||
model_name: 模型名称
|
||||
@ -97,12 +94,16 @@ class FileLoadedAgentManager:
|
||||
import os
|
||||
import json
|
||||
|
||||
# 从项目目录读取system_prompt.md和mcp_settings.json
|
||||
# 读取system_prompt:优先从项目目录读取,然后降级到全局配置
|
||||
# 降级到全局配置
|
||||
system_prompt_template = ""
|
||||
system_prompt_path = os.path.join(project_dir, "system_prompt.md")
|
||||
if os.path.exists(system_prompt_path):
|
||||
with open(system_prompt_path, "r", encoding="utf-8") as f:
|
||||
system_prompt_template = f.read().strip()
|
||||
# 尝试从项目目录读取
|
||||
system_prompt_file = os.path.join(project_dir, "system_prompt.md")
|
||||
if not os.path.exists(system_prompt_file):
|
||||
system_prompt_file = "./system_prompt.md"
|
||||
|
||||
with open(system_prompt_file, "r", encoding="utf-8") as f:
|
||||
system_prompt_template = f.read().strip()
|
||||
|
||||
readme = ""
|
||||
readme_path = os.path.join(project_dir, "README.md")
|
||||
@ -110,55 +111,69 @@ class FileLoadedAgentManager:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
readme = f.read().strip()
|
||||
dataset_dir = os.path.join(project_dir, "dataset")
|
||||
|
||||
system_prompt = system_prompt_template.replace("{dataset_dir}", str(dataset_dir)).replace("{readme}", str(readme))
|
||||
|
||||
final_system_prompt = system_prompt_template.replace("{dataset_dir}", str(dataset_dir)).replace("{readme}", str(readme))
|
||||
logger.info(f"Loaded global system_prompt for unique_id: {unique_id}")
|
||||
if not final_system_prompt:
|
||||
logger.info(f"No system_prompt found for unique_id: {unique_id}")
|
||||
|
||||
mcp_settings = {}
|
||||
mcp_settings_path = os.path.join(project_dir, "mcp_settings.json")
|
||||
if os.path.exists(mcp_settings_path):
|
||||
with open(mcp_settings_path, "r", encoding="utf-8") as f:
|
||||
mcp_settings = json.load(f)
|
||||
# 读取mcp_settings:优先从项目目录读取,然后降级到全局配置
|
||||
final_mcp_settings = None
|
||||
|
||||
cache_key = self._get_zip_url_hash(zip_url)
|
||||
# 尝试从项目目录读取
|
||||
mcp_settings_file = os.path.join(project_dir, "mcp_settings.json")
|
||||
if os.path.exists(mcp_settings_file):
|
||||
with open(mcp_settings_file, 'r', encoding='utf-8') as f:
|
||||
final_mcp_settings = json.load(f)
|
||||
logger.info(f"Loaded mcp_settings from project directory for unique_id: {unique_id}")
|
||||
else:
|
||||
# 降级到全局配置
|
||||
mcp_settings_path = "./mcp/mcp_settings.json"
|
||||
if os.path.exists(mcp_settings_path):
|
||||
with open(mcp_settings_path, "r", encoding="utf-8") as f:
|
||||
final_mcp_settings = json.load(f)
|
||||
logger.info(f"Loaded global mcp_settings for unique_id: {unique_id}")
|
||||
else:
|
||||
final_mcp_settings = []
|
||||
logger.info(f"No mcp_settings found for unique_id: {unique_id}")
|
||||
|
||||
if final_mcp_settings is None:
|
||||
final_mcp_settings = []
|
||||
|
||||
cache_key = self._get_cache_key(unique_id)
|
||||
|
||||
# 检查是否已存在该助手实例
|
||||
if cache_key in self.agents:
|
||||
self._update_access_time(cache_key)
|
||||
agent = self.agents[cache_key]
|
||||
|
||||
# 动态更新 LLM 配置(如果参数有变化)
|
||||
update_agent_llm(agent, model_name, api_key, model_server, generate_cfg)
|
||||
# 动态更新 LLM 配置和系统设置(如果参数有变化)
|
||||
update_agent_llm(agent, model_name, api_key, model_server, generate_cfg, final_system_prompt, final_mcp_settings)
|
||||
|
||||
# 如果从项目目录读取到了system_prompt,更新agent的系统消息
|
||||
if system_prompt:
|
||||
agent.system_message = system_prompt
|
||||
|
||||
logger.info(f"复用现有的助手实例缓存: {cache_key} (文件数: {len(files)})")
|
||||
logger.info(f"复用现有的助手实例缓存: {cache_key} (unique_id: {unique_id}")
|
||||
return agent
|
||||
|
||||
# 清理过期实例
|
||||
self._cleanup_old_agents()
|
||||
|
||||
# 创建新的助手实例,预加载文件
|
||||
logger.info(f"创建新的助手实例缓存: {cache_key}, 预加载文件数: {len(files)}")
|
||||
logger.info(f"创建新的助手实例缓存: {cache_key}, unique_id: {unique_id}")
|
||||
current_time = time.time()
|
||||
|
||||
agent = init_agent_service_with_files(
|
||||
files=files,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
model_server=model_server,
|
||||
generate_cfg=generate_cfg,
|
||||
system_prompt=system_prompt,
|
||||
mcp=mcp_settings
|
||||
system_prompt=final_system_prompt,
|
||||
mcp=final_mcp_settings
|
||||
)
|
||||
|
||||
# 缓存实例
|
||||
self.agents[cache_key] = agent
|
||||
self.zip_urls[cache_key] = zip_url
|
||||
self.unique_ids[cache_key] = unique_id
|
||||
self.access_times[cache_key] = current_time
|
||||
self.creation_times[cache_key] = current_time
|
||||
self.file_counts[cache_key] = len(files)
|
||||
|
||||
logger.info(f"助手实例缓存创建完成: {cache_key}")
|
||||
return agent
|
||||
@ -174,8 +189,7 @@ class FileLoadedAgentManager:
|
||||
|
||||
for cache_key, agent in self.agents.items():
|
||||
stats["agents"][cache_key] = {
|
||||
"zip_url": self.zip_urls.get(cache_key, "unknown"),
|
||||
"file_count": self.file_counts.get(cache_key, 0),
|
||||
"unique_id": self.unique_ids.get(cache_key, "unknown"),
|
||||
"created_at": self.creation_times.get(cache_key, 0),
|
||||
"last_accessed": self.access_times.get(cache_key, 0),
|
||||
"age_seconds": int(current_time - self.creation_times.get(cache_key, current_time)),
|
||||
@ -193,40 +207,34 @@ class FileLoadedAgentManager:
|
||||
cache_count = len(self.agents)
|
||||
|
||||
self.agents.clear()
|
||||
self.zip_urls.clear()
|
||||
self.unique_ids.clear()
|
||||
self.access_times.clear()
|
||||
self.creation_times.clear()
|
||||
self.file_counts.clear()
|
||||
|
||||
logger.info(f"已清空所有助手实例缓存,共清理 {cache_count} 个实例")
|
||||
return cache_count
|
||||
|
||||
def remove_cache_by_url(self, zip_url: str) -> bool:
|
||||
"""根据 ZIP URL 移除特定的缓存
|
||||
def remove_cache_by_unique_id(self, unique_id: str) -> bool:
|
||||
"""根据 unique_id 移除特定的缓存
|
||||
|
||||
Args:
|
||||
zip_url: ZIP 文件 URL
|
||||
unique_id: 项目的唯一标识符
|
||||
|
||||
Returns:
|
||||
bool: 是否成功移除
|
||||
"""
|
||||
cache_key = self._get_zip_url_hash(zip_url)
|
||||
cache_key = self._get_cache_key(unique_id)
|
||||
|
||||
if cache_key in self.agents:
|
||||
del self.agents[cache_key]
|
||||
del self.zip_urls[cache_key]
|
||||
del self.unique_ids[cache_key]
|
||||
del self.access_times[cache_key]
|
||||
del self.creation_times[cache_key]
|
||||
del self.file_counts[cache_key]
|
||||
|
||||
logger.info(f"已移除特定 ZIP URL 的助手实例缓存: {zip_url}")
|
||||
logger.info(f"已移除特定 unique_id 的助手实例缓存: {unique_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def list_cached_zip_urls(self) -> List[str]:
|
||||
"""列出所有缓存的 ZIP URL"""
|
||||
return list(self.zip_urls.values())
|
||||
|
||||
|
||||
# 全局文件预加载助手管理器实例
|
||||
|
||||
@ -103,7 +103,7 @@ def init_agent_service_universal():
|
||||
return init_agent_service_with_files(files=None)
|
||||
|
||||
|
||||
def init_agent_service_with_files(files: Optional[List[str]] = None, rag_cfg: Optional[Dict] = None,
|
||||
def init_agent_service_with_files(rag_cfg: Optional[Dict] = None,
|
||||
model_name: str = "qwen3-next", api_key: Optional[str] = None,
|
||||
model_server: Optional[str] = None, generate_cfg: Optional[Dict] = None,
|
||||
system_prompt: Optional[str] = None, mcp: Optional[List[Dict]] = None):
|
||||
@ -160,8 +160,8 @@ def init_agent_service_with_files(files: Optional[List[str]] = None, rag_cfg: Op
|
||||
return bot
|
||||
|
||||
|
||||
def update_agent_llm(agent, model_name: str, api_key: str = None, model_server: str = None,generate_cfg: Dict = None):
|
||||
"""动态更新助手实例的LLM,支持从接口传入参数"""
|
||||
def update_agent_llm(agent, model_name: str, api_key: str = None, model_server: str = None, generate_cfg: Dict = None, system_prompt: str = None, mcp_settings: List[Dict] = None):
|
||||
"""动态更新助手实例的LLM和配置,支持从接口传入参数"""
|
||||
|
||||
# 获取基础配置
|
||||
llm_config = {
|
||||
@ -181,6 +181,14 @@ def update_agent_llm(agent, model_name: str, api_key: str = None, model_server:
|
||||
# 动态设置LLM
|
||||
agent.llm = llm_instance
|
||||
|
||||
# 更新系统消息(如果提供)
|
||||
if system_prompt:
|
||||
agent.system_message = system_prompt
|
||||
|
||||
# 更新MCP设置(如果提供)
|
||||
if mcp_settings:
|
||||
agent.function_list = mcp_settings
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
|
||||
@ -5,10 +5,10 @@
|
||||
"command": "mcp-ripgrep",
|
||||
"args": []
|
||||
},
|
||||
"json-reader": {
|
||||
"semantic_search": {
|
||||
"command": "python",
|
||||
"args": [
|
||||
"./mcp/json_reader_server.py"
|
||||
"./mcp/semantic_search_server.py"
|
||||
]
|
||||
},
|
||||
"multi-keyword-search": {
|
||||
|
||||
@ -125,13 +125,20 @@ def multi_keyword_search(keywords: List[str], file_paths: List[str],
|
||||
try:
|
||||
# 解析相对路径
|
||||
if not os.path.isabs(file_path):
|
||||
# 移除 projects/ 前缀(如果存在)
|
||||
clean_path = file_path
|
||||
if clean_path.startswith('projects/'):
|
||||
clean_path = clean_path[9:] # 移除 'projects/' 前缀
|
||||
elif clean_path.startswith('./projects/'):
|
||||
clean_path = clean_path[11:] # 移除 './projects/' 前缀
|
||||
|
||||
# 尝试在项目目录中查找文件
|
||||
full_path = os.path.join(project_data_dir, file_path.lstrip('./'))
|
||||
full_path = os.path.join(project_data_dir, clean_path.lstrip('./'))
|
||||
if os.path.exists(full_path):
|
||||
valid_paths.append(full_path)
|
||||
else:
|
||||
# 如果直接路径不存在,尝试递归查找
|
||||
found = find_file_in_project(file_path, project_data_dir)
|
||||
found = find_file_in_project(clean_path, project_data_dir)
|
||||
if found:
|
||||
valid_paths.append(found)
|
||||
else:
|
||||
|
||||
@ -87,11 +87,18 @@ def semantic_search(query: str, embeddings_file: str, top_k: int = 20) -> Dict[s
|
||||
try:
|
||||
# 解析相对路径
|
||||
if not os.path.isabs(embeddings_file):
|
||||
# 移除 projects/ 前缀(如果存在)
|
||||
clean_path = embeddings_file
|
||||
if clean_path.startswith('projects/'):
|
||||
clean_path = clean_path[9:] # 移除 'projects/' 前缀
|
||||
elif clean_path.startswith('./projects/'):
|
||||
clean_path = clean_path[11:] # 移除 './projects/' 前缀
|
||||
|
||||
# 尝试在项目目录中查找文件
|
||||
full_path = os.path.join(project_data_dir, embeddings_file.lstrip('./'))
|
||||
full_path = os.path.join(project_data_dir, clean_path.lstrip('./'))
|
||||
if not os.path.exists(full_path):
|
||||
# 如果直接路径不存在,尝试递归查找
|
||||
found = find_file_in_project(embeddings_file, project_data_dir)
|
||||
found = find_file_in_project(clean_path, project_data_dir)
|
||||
if found:
|
||||
embeddings_file = found
|
||||
else:
|
||||
|
||||
182
organize_dataset_files.py
Normal file
182
organize_dataset_files.py
Normal file
@ -0,0 +1,182 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
def is_file_already_processed(target_file: Path, pagination_file: Path, embeddings_file: Path) -> bool:
|
||||
"""Check if a file has already been processed (document.txt, pagination.txt, and embeddings exist)"""
|
||||
if not target_file.exists():
|
||||
return False
|
||||
|
||||
# Check if pagination and embeddings files exist and are not empty
|
||||
if pagination_file.exists() and embeddings_file.exists():
|
||||
# Check file sizes to ensure they're not empty
|
||||
if pagination_file.stat().st_size > 0 and embeddings_file.stat().st_size > 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def organize_single_project_files(unique_id: str, skip_processed=True):
|
||||
"""Organize files for a single project from projects/{unique_id}/files to projects/{unique_id}/dataset/{file_name}/document.txt"""
|
||||
|
||||
project_dir = Path("projects") / unique_id
|
||||
|
||||
if not project_dir.exists():
|
||||
print(f"Project directory not found: {project_dir}")
|
||||
return
|
||||
|
||||
print(f"Organizing files for project: {unique_id} (skip_processed={skip_processed})")
|
||||
|
||||
files_dir = project_dir / "files"
|
||||
dataset_dir = project_dir / "dataset"
|
||||
|
||||
# Check if files directory exists and has files
|
||||
if not files_dir.exists():
|
||||
print(f" No files directory found, skipping...")
|
||||
return
|
||||
|
||||
files = list(files_dir.glob("*"))
|
||||
if not files:
|
||||
print(f" Files directory is empty, skipping...")
|
||||
return
|
||||
|
||||
# Create dataset directory if it doesn't exist
|
||||
dataset_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Copy each file to its own directory
|
||||
for file_path in files:
|
||||
if file_path.is_file():
|
||||
# Get filename without extension as directory name
|
||||
file_name_without_ext = file_path.stem
|
||||
target_dir = dataset_dir / file_name_without_ext
|
||||
target_file = target_dir / "document.txt"
|
||||
pagination_file = target_dir / "pagination.txt"
|
||||
embeddings_file = target_dir / "document_embeddings.pkl"
|
||||
|
||||
# Check if file is already processed
|
||||
if skip_processed and is_file_already_processed(target_file, pagination_file, embeddings_file):
|
||||
print(f" Skipping already processed file: {file_path.name}")
|
||||
continue
|
||||
|
||||
print(f" Copying {file_path.name} -> {target_file.relative_to(project_dir)}")
|
||||
|
||||
# Create target directory
|
||||
target_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Copy and rename file
|
||||
shutil.copy2(str(file_path), str(target_file))
|
||||
|
||||
print(f" Files remain in original location (copied to dataset structure)")
|
||||
|
||||
# Process each document.txt file: split pages and generate embeddings
|
||||
if not skip_processed:
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), 'embedding'))
|
||||
|
||||
from embedding import split_document_by_pages, embed_document
|
||||
|
||||
for file_path in files:
|
||||
if file_path.is_file():
|
||||
file_name_without_ext = file_path.stem
|
||||
target_dir = dataset_dir / file_name_without_ext
|
||||
document_file = target_dir / "document.txt"
|
||||
pagination_file = target_dir / "pagination.txt"
|
||||
embeddings_file = target_dir / "document_embeddings.pkl"
|
||||
|
||||
# Skip if already processed
|
||||
if is_file_already_processed(document_file, pagination_file, embeddings_file):
|
||||
print(f" Skipping document processing for already processed file: {file_path.name}")
|
||||
continue
|
||||
|
||||
# Split document by pages
|
||||
print(f" Splitting pages for {document_file.name}")
|
||||
try:
|
||||
pages = split_document_by_pages(str(document_file), str(pagination_file))
|
||||
print(f" Generated {len(pages)} pages")
|
||||
except Exception as e:
|
||||
print(f" Failed to split pages: {e}")
|
||||
continue
|
||||
|
||||
# Generate embeddings
|
||||
print(f" Generating embeddings for {document_file.name}")
|
||||
try:
|
||||
# Set local model path for faster processing
|
||||
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
if not os.path.exists(local_model_path):
|
||||
local_model_path = None # Fallback to HuggingFace model
|
||||
|
||||
embedding_data = embed_document(
|
||||
str(document_file),
|
||||
str(embeddings_file),
|
||||
chunking_strategy='smart',
|
||||
model_path=local_model_path,
|
||||
max_chunk_size=800,
|
||||
overlap=100
|
||||
)
|
||||
|
||||
if embedding_data:
|
||||
print(f" Generated embeddings for {len(embedding_data['chunks'])} chunks")
|
||||
else:
|
||||
print(f" Failed to generate embeddings")
|
||||
except Exception as e:
|
||||
print(f" Failed to generate embeddings: {e}")
|
||||
|
||||
print(f" Document processing completed for project {unique_id}")
|
||||
else:
|
||||
print(f" Skipping document processing (skip_processed=True)")
|
||||
|
||||
|
||||
def organize_dataset_files():
|
||||
"""Move files from projects/{unique_id}/files to projects/{unique_id}/dataset/{file_name}/document.txt"""
|
||||
|
||||
projects_dir = Path("projects")
|
||||
|
||||
if not projects_dir.exists():
|
||||
print("Projects directory not found")
|
||||
return
|
||||
|
||||
# Get all project directories (exclude cache and other non-project dirs)
|
||||
project_dirs = [d for d in projects_dir.iterdir()
|
||||
if d.is_dir() and d.name != "_cache" and not d.name.startswith(".")]
|
||||
|
||||
for project_dir in project_dirs:
|
||||
print(f"\nProcessing project: {project_dir.name}")
|
||||
|
||||
files_dir = project_dir / "files"
|
||||
dataset_dir = project_dir / "dataset"
|
||||
|
||||
# Check if files directory exists and has files
|
||||
if not files_dir.exists():
|
||||
print(f" No files directory found, skipping...")
|
||||
continue
|
||||
|
||||
files = list(files_dir.glob("*"))
|
||||
if not files:
|
||||
print(f" Files directory is empty, skipping...")
|
||||
continue
|
||||
|
||||
# Create dataset directory if it doesn't exist
|
||||
dataset_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Move each file to its own directory
|
||||
for file_path in files:
|
||||
if file_path.is_file():
|
||||
# Get filename without extension as directory name
|
||||
file_name_without_ext = file_path.stem
|
||||
target_dir = dataset_dir / file_name_without_ext
|
||||
target_file = target_dir / "document.txt"
|
||||
|
||||
print(f" Copying {file_path.name} -> {target_file.relative_to(project_dir)}")
|
||||
|
||||
# Create target directory
|
||||
target_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Copy and rename file
|
||||
shutil.copy2(str(file_path), str(target_file))
|
||||
|
||||
print(f" Files remain in original location (copied to dataset structure)")
|
||||
|
||||
print("\nFile organization complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
organize_dataset_files()
|
||||
14
poetry.lock
generated
14
poetry.lock
generated
@ -1,5 +1,17 @@
|
||||
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiofiles"
|
||||
version = "25.1.0"
|
||||
description = "File support for asyncio."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "aiofiles-25.1.0-py3-none-any.whl", hash = "sha256:abe311e527c862958650f9438e859c1fa7568a141b22abcd015e120e86a85695"},
|
||||
{file = "aiofiles-25.1.0.tar.gz", hash = "sha256:a8d728f0a29de45dc521f18f07297428d56992a742f0cd2701ba86e44d23d5b2"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
version = "2.6.1"
|
||||
@ -3949,4 +3961,4 @@ propcache = ">=0.2.1"
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "3.12.0"
|
||||
content-hash = "e3237e20c799a13a0854f786a34a7c7cb2d5020902281c92ebff2e497492edd1"
|
||||
content-hash = "06c3b78c8107692eb5944b144ae4df02862fa5e4e8a198f6ccfa07c6743a49cf"
|
||||
|
||||
@ -18,6 +18,8 @@ dependencies = [
|
||||
"transformers",
|
||||
"sentence-transformers",
|
||||
"numpy<2",
|
||||
"aiohttp",
|
||||
"aiofiles",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1,5 +0,0 @@
|
||||
{
|
||||
"b743ccc3-13be-43ea-8ec9-4ce9c86103b3": [
|
||||
"public/all_hp_product_spec_book2506.txt"
|
||||
]
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user