560 lines
17 KiB
Python
560 lines
17 KiB
Python
"""
|
||
RAGFlow Repository - 数据访问层
|
||
封装 RAGFlow SDK 调用,提供统一的数据访问接口
|
||
"""
|
||
import logging
|
||
import asyncio
|
||
from typing import Optional, List, Dict, Any
|
||
from pathlib import Path
|
||
|
||
try:
|
||
from ragflow_sdk import RAGFlow
|
||
except ImportError:
|
||
RAGFlow = None
|
||
logging.warning("ragflow-sdk not installed")
|
||
|
||
from utils.settings import (
|
||
RAGFLOW_API_URL,
|
||
RAGFLOW_API_KEY,
|
||
RAGFLOW_CONNECTION_TIMEOUT
|
||
)
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
|
||
class RAGFlowRepository:
|
||
"""
|
||
RAGFlow 数据仓储类
|
||
|
||
封装 RAGFlow SDK 的所有调用,提供:
|
||
- 统一的错误处理
|
||
- 连接管理
|
||
- 数据转换
|
||
"""
|
||
|
||
def __init__(self, api_key: str = None, base_url: str = None):
|
||
"""
|
||
初始化 RAGFlow 客户端
|
||
|
||
Args:
|
||
api_key: RAGFlow API Key,默认从配置读取
|
||
base_url: RAGFlow 服务地址,默认从配置读取
|
||
"""
|
||
self.api_key = api_key or RAGFLOW_API_KEY
|
||
self.base_url = base_url or RAGFLOW_API_URL
|
||
self._client: Optional[Any] = None
|
||
self._lock = asyncio.Lock()
|
||
|
||
async def _get_client(self):
|
||
"""
|
||
获取 RAGFlow 客户端实例(懒加载)
|
||
|
||
Returns:
|
||
RAGFlow 客户端
|
||
"""
|
||
if RAGFlow is None:
|
||
raise RuntimeError("ragflow-sdk is not installed. Run: poetry install")
|
||
|
||
if self._client is None:
|
||
async with self._lock:
|
||
# 双重检查
|
||
if self._client is None:
|
||
try:
|
||
self._client = RAGFlow(
|
||
api_key=self.api_key,
|
||
base_url=self.base_url
|
||
)
|
||
logger.info(f"RAGFlow client initialized: {self.base_url}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize RAGFlow client: {e}")
|
||
raise
|
||
|
||
return self._client
|
||
|
||
async def create_dataset(
|
||
self,
|
||
name: str,
|
||
description: str = None,
|
||
chunk_method: str = "naive",
|
||
permission: str = "me"
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
创建数据集
|
||
|
||
Args:
|
||
name: 数据集名称
|
||
description: 描述信息
|
||
chunk_method: 分块方法 (naive, manual, qa, table, paper, book, etc.)
|
||
permission: 权限 (me 或 team)
|
||
|
||
Returns:
|
||
创建的数据集信息
|
||
"""
|
||
client = await self._get_client()
|
||
|
||
try:
|
||
dataset = client.create_dataset(
|
||
name=name,
|
||
avatar=None,
|
||
description=description,
|
||
chunk_method=chunk_method,
|
||
permission=permission
|
||
)
|
||
|
||
return {
|
||
"dataset_id": getattr(dataset, 'id', None),
|
||
"name": getattr(dataset, 'name', name),
|
||
"description": getattr(dataset, 'description', description),
|
||
"chunk_method": getattr(dataset, 'chunk_method', chunk_method),
|
||
"permission": getattr(dataset, 'permission', permission),
|
||
"created_at": getattr(dataset, 'created_at', None),
|
||
"updated_at": getattr(dataset, 'updated_at', None),
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Failed to create dataset: {e}")
|
||
raise
|
||
|
||
async def list_datasets(
|
||
self,
|
||
page: int = 1,
|
||
page_size: int = 30,
|
||
search: str = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
获取数据集列表
|
||
|
||
Args:
|
||
page: 页码
|
||
page_size: 每页数量
|
||
search: 搜索关键词
|
||
|
||
Returns:
|
||
数据集列表和分页信息
|
||
"""
|
||
client = await self._get_client()
|
||
|
||
try:
|
||
# RAGFlow SDK 的 list_datasets 方法
|
||
datasets = client.list_datasets(
|
||
page=page,
|
||
page_size=page_size
|
||
)
|
||
|
||
items = []
|
||
for dataset in datasets:
|
||
dataset_info = {
|
||
"dataset_id": getattr(dataset, 'id', None),
|
||
"name": getattr(dataset, 'name', None),
|
||
"description": getattr(dataset, 'description', None),
|
||
"chunk_method": getattr(dataset, 'chunk_method', None),
|
||
"avatar": getattr(dataset, 'avatar', None),
|
||
"permission": getattr(dataset, 'permission', None),
|
||
"created_at": getattr(dataset, 'created_at', None),
|
||
"updated_at": getattr(dataset, 'updated_at', None),
|
||
"metadata": getattr(dataset, 'metadata', {}),
|
||
}
|
||
|
||
# 搜索过滤
|
||
if search:
|
||
search_lower = search.lower()
|
||
if (search_lower not in (dataset_info.get('name') or '').lower() and
|
||
search_lower not in (dataset_info.get('description') or '').lower()):
|
||
continue
|
||
|
||
items.append(dataset_info)
|
||
|
||
return {
|
||
"items": items,
|
||
"total": len(items), # RAGFlow 可能不返回总数,使用实际返回数量
|
||
"page": page,
|
||
"page_size": page_size
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Failed to list datasets: {e}")
|
||
raise
|
||
|
||
async def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取数据集详情
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
|
||
Returns:
|
||
数据集详情,不存在返回 None
|
||
"""
|
||
client = await self._get_client()
|
||
|
||
try:
|
||
datasets = client.list_datasets(id=dataset_id)
|
||
if datasets and len(datasets) > 0:
|
||
dataset = datasets[0]
|
||
return {
|
||
"dataset_id": getattr(dataset, 'id', dataset_id),
|
||
"name": getattr(dataset, 'name', None),
|
||
"description": getattr(dataset, 'description', None),
|
||
"chunk_method": getattr(dataset, 'chunk_method', None),
|
||
"permission": getattr(dataset, 'permission', None),
|
||
"created_at": getattr(dataset, 'created_at', None),
|
||
"updated_at": getattr(dataset, 'updated_at', None),
|
||
}
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Failed to get dataset {dataset_id}: {e}")
|
||
raise
|
||
|
||
async def update_dataset(
|
||
self,
|
||
dataset_id: str,
|
||
**updates
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
更新数据集
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
**updates: 要更新的字段
|
||
|
||
Returns:
|
||
更新后的数据集信息
|
||
"""
|
||
client = await self._get_client()
|
||
|
||
try:
|
||
datasets = client.list_datasets(id=dataset_id)
|
||
if datasets and len(datasets) > 0:
|
||
dataset = datasets[0]
|
||
# 调用 update 方法
|
||
dataset.update(updates)
|
||
|
||
return {
|
||
"dataset_id": getattr(dataset, 'id', dataset_id),
|
||
"name": getattr(dataset, 'name', None),
|
||
"description": getattr(dataset, 'description', None),
|
||
"chunk_method": getattr(dataset, 'chunk_method', None),
|
||
"updated_at": getattr(dataset, 'updated_at', None),
|
||
}
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Failed to update dataset {dataset_id}: {e}")
|
||
raise
|
||
|
||
async def delete_datasets(self, dataset_ids: List[str] = None) -> bool:
|
||
"""
|
||
删除数据集
|
||
|
||
Args:
|
||
dataset_ids: 要删除的数据集 ID 列表
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
client = await self._get_client()
|
||
|
||
try:
|
||
if dataset_ids:
|
||
client.delete_datasets(ids=dataset_ids)
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Failed to delete datasets: {e}")
|
||
raise
|
||
|
||
async def upload_document(
|
||
self,
|
||
dataset_id: str,
|
||
file_name: str,
|
||
file_content: bytes,
|
||
display_name: str = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
上传文档到数据集
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
file_name: 文件名
|
||
file_content: 文件内容
|
||
display_name: 显示名称
|
||
|
||
Returns:
|
||
上传的文档信息
|
||
"""
|
||
client = await self._get_client()
|
||
|
||
try:
|
||
# 获取数据集
|
||
datasets = client.list_datasets(id=dataset_id)
|
||
if not datasets or len(datasets) == 0:
|
||
raise ValueError(f"Dataset {dataset_id} not found")
|
||
|
||
dataset = datasets[0]
|
||
|
||
# 上传文档
|
||
display_name = display_name or file_name
|
||
dataset.upload_documents([{
|
||
"display_name": display_name,
|
||
"blob": file_content
|
||
}])
|
||
|
||
# 查找刚上传的文档
|
||
documents = dataset.list_documents()
|
||
for doc in documents:
|
||
if getattr(doc, 'name', None) == display_name:
|
||
return {
|
||
"document_id": getattr(doc, 'id', None),
|
||
"name": display_name,
|
||
"dataset_id": dataset_id,
|
||
"size": len(file_content),
|
||
"status": "running",
|
||
"chunk_count": getattr(doc, 'chunk_count', 0),
|
||
"token_count": getattr(doc, 'token_count', 0),
|
||
"created_at": getattr(doc, 'created_at', None),
|
||
}
|
||
|
||
return {
|
||
"document_id": None,
|
||
"name": display_name,
|
||
"dataset_id": dataset_id,
|
||
"size": len(file_content),
|
||
"status": "uploaded",
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Failed to upload document to {dataset_id}: {e}")
|
||
raise
|
||
|
||
async def list_documents(
|
||
self,
|
||
dataset_id: str,
|
||
page: int = 1,
|
||
page_size: int = 20
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
获取数据集中的文档列表
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
page: 页码
|
||
page_size: 每页数量
|
||
|
||
Returns:
|
||
文档列表和分页信息
|
||
"""
|
||
client = await self._get_client()
|
||
|
||
try:
|
||
datasets = client.list_datasets(id=dataset_id)
|
||
if not datasets or len(datasets) == 0:
|
||
return {"items": [], "total": 0, "page": page, "page_size": page_size}
|
||
|
||
dataset = datasets[0]
|
||
documents = dataset.list_documents(
|
||
page=page,
|
||
page_size=page_size
|
||
)
|
||
|
||
items = []
|
||
for doc in documents:
|
||
items.append({
|
||
"document_id": getattr(doc, 'id', None),
|
||
"name": getattr(doc, 'name', None),
|
||
"dataset_id": dataset_id,
|
||
"size": getattr(doc, 'size', 0),
|
||
"status": getattr(doc, 'run', 'unknown'),
|
||
"progress": getattr(doc, 'progress', 0),
|
||
"chunk_count": getattr(doc, 'chunk_count', 0),
|
||
"token_count": getattr(doc, 'token_count', 0),
|
||
"created_at": getattr(doc, 'created_at', None),
|
||
"updated_at": getattr(doc, 'updated_at', None),
|
||
})
|
||
|
||
return {
|
||
"items": items,
|
||
"total": len(items),
|
||
"page": page,
|
||
"page_size": page_size
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Failed to list documents for {dataset_id}: {e}")
|
||
raise
|
||
|
||
async def delete_document(
|
||
self,
|
||
dataset_id: str,
|
||
document_id: str
|
||
) -> bool:
|
||
"""
|
||
删除文档
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
document_id: 文档 ID
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
client = await self._get_client()
|
||
|
||
try:
|
||
datasets = client.list_datasets(id=dataset_id)
|
||
if datasets and len(datasets) > 0:
|
||
dataset = datasets[0]
|
||
dataset.delete_documents(ids=[document_id])
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Failed to delete document {document_id}: {e}")
|
||
raise
|
||
|
||
async def list_chunks(
|
||
self,
|
||
dataset_id: str,
|
||
document_id: str = None,
|
||
page: int = 1,
|
||
page_size: int = 50
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
获取切片列表
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
document_id: 文档 ID(可选)
|
||
page: 页码
|
||
page_size: 每页数量
|
||
|
||
Returns:
|
||
切片列表和分页信息
|
||
"""
|
||
client = await self._get_client()
|
||
|
||
try:
|
||
datasets = client.list_datasets(id=dataset_id)
|
||
if not datasets or len(datasets) == 0:
|
||
return {"items": [], "total": 0, "page": page, "page_size": page_size}
|
||
|
||
dataset = datasets[0]
|
||
|
||
# 如果指定了文档 ID,先获取文档
|
||
if document_id:
|
||
documents = dataset.list_documents(id=document_id)
|
||
if documents and len(documents) > 0:
|
||
doc = documents[0]
|
||
chunks = doc.list_chunks(page=page, page_size=page_size)
|
||
else:
|
||
chunks = []
|
||
else:
|
||
# 获取所有文档的所有切片
|
||
chunks = []
|
||
for doc in dataset.list_documents():
|
||
chunks.extend(doc.list_chunks(page=page, page_size=page_size))
|
||
|
||
items = []
|
||
for chunk in chunks:
|
||
items.append({
|
||
"chunk_id": getattr(chunk, 'id', None),
|
||
"content": getattr(chunk, 'content', ''),
|
||
"document_id": getattr(chunk, 'document_id', None),
|
||
"dataset_id": dataset_id,
|
||
"position": getattr(chunk, 'position', 0),
|
||
"important_keywords": getattr(chunk, 'important_keywords', []),
|
||
"available": getattr(chunk, 'available', True),
|
||
"created_at": getattr(chunk, 'create_time', None),
|
||
})
|
||
|
||
return {
|
||
"items": items,
|
||
"total": len(items),
|
||
"page": page,
|
||
"page_size": page_size
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Failed to list chunks for {dataset_id}: {e}")
|
||
raise
|
||
|
||
async def delete_chunk(
|
||
self,
|
||
dataset_id: str,
|
||
document_id: str,
|
||
chunk_id: str
|
||
) -> bool:
|
||
"""
|
||
删除切片
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
document_id: 文档 ID
|
||
chunk_id: 切片 ID
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
client = await self._get_client()
|
||
|
||
try:
|
||
datasets = client.list_datasets(id=dataset_id)
|
||
if datasets and len(datasets) > 0:
|
||
dataset = datasets[0]
|
||
documents = dataset.list_documents(id=document_id)
|
||
if documents and len(documents) > 0:
|
||
doc = documents[0]
|
||
doc.delete_chunks(chunk_ids=[chunk_id])
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Failed to delete chunk {chunk_id}: {e}")
|
||
raise
|
||
|
||
async def parse_document(
|
||
self,
|
||
dataset_id: str,
|
||
document_id: str
|
||
) -> bool:
|
||
"""
|
||
开始解析文档
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
document_id: 文档 ID
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
client = await self._get_client()
|
||
|
||
try:
|
||
datasets = client.list_datasets(id=dataset_id)
|
||
if not datasets or len(datasets) == 0:
|
||
raise ValueError(f"Dataset {dataset_id} not found")
|
||
|
||
dataset = datasets[0]
|
||
dataset.async_parse_documents([document_id])
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Failed to parse document {document_id}: {e}")
|
||
raise
|
||
|
||
async def cancel_parse_document(
|
||
self,
|
||
dataset_id: str,
|
||
document_id: str
|
||
) -> bool:
|
||
"""
|
||
取消解析文档
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
document_id: 文档 ID
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
client = await self._get_client()
|
||
|
||
try:
|
||
datasets = client.list_datasets(id=dataset_id)
|
||
if not datasets or len(datasets) == 0:
|
||
raise ValueError(f"Dataset {dataset_id} not found")
|
||
|
||
dataset = datasets[0]
|
||
dataset.async_cancel_parse_documents([document_id])
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Failed to cancel parse document {document_id}: {e}")
|
||
raise
|