qwen_agent/repositories/ragflow_repository.py
2026-02-10 18:59:10 +08:00

560 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAGFlow Repository - 数据访问层
封装 RAGFlow SDK 调用,提供统一的数据访问接口
"""
import logging
import asyncio
from typing import Optional, List, Dict, Any
from pathlib import Path
try:
from ragflow_sdk import RAGFlow
except ImportError:
RAGFlow = None
logging.warning("ragflow-sdk not installed")
from utils.settings import (
RAGFLOW_API_URL,
RAGFLOW_API_KEY,
RAGFLOW_CONNECTION_TIMEOUT
)
logger = logging.getLogger('app')
class RAGFlowRepository:
"""
RAGFlow 数据仓储类
封装 RAGFlow SDK 的所有调用,提供:
- 统一的错误处理
- 连接管理
- 数据转换
"""
def __init__(self, api_key: str = None, base_url: str = None):
"""
初始化 RAGFlow 客户端
Args:
api_key: RAGFlow API Key默认从配置读取
base_url: RAGFlow 服务地址,默认从配置读取
"""
self.api_key = api_key or RAGFLOW_API_KEY
self.base_url = base_url or RAGFLOW_API_URL
self._client: Optional[Any] = None
self._lock = asyncio.Lock()
async def _get_client(self):
"""
获取 RAGFlow 客户端实例(懒加载)
Returns:
RAGFlow 客户端
"""
if RAGFlow is None:
raise RuntimeError("ragflow-sdk is not installed. Run: poetry install")
if self._client is None:
async with self._lock:
# 双重检查
if self._client is None:
try:
self._client = RAGFlow(
api_key=self.api_key,
base_url=self.base_url
)
logger.info(f"RAGFlow client initialized: {self.base_url}")
except Exception as e:
logger.error(f"Failed to initialize RAGFlow client: {e}")
raise
return self._client
async def create_dataset(
self,
name: str,
description: str = None,
chunk_method: str = "naive",
permission: str = "me"
) -> Dict[str, Any]:
"""
创建数据集
Args:
name: 数据集名称
description: 描述信息
chunk_method: 分块方法 (naive, manual, qa, table, paper, book, etc.)
permission: 权限 (me 或 team)
Returns:
创建的数据集信息
"""
client = await self._get_client()
try:
dataset = client.create_dataset(
name=name,
avatar=None,
description=description,
chunk_method=chunk_method,
permission=permission
)
return {
"dataset_id": getattr(dataset, 'id', None),
"name": getattr(dataset, 'name', name),
"description": getattr(dataset, 'description', description),
"chunk_method": getattr(dataset, 'chunk_method', chunk_method),
"permission": getattr(dataset, 'permission', permission),
"created_at": getattr(dataset, 'created_at', None),
"updated_at": getattr(dataset, 'updated_at', None),
}
except Exception as e:
logger.error(f"Failed to create dataset: {e}")
raise
async def list_datasets(
self,
page: int = 1,
page_size: int = 30,
search: str = None
) -> Dict[str, Any]:
"""
获取数据集列表
Args:
page: 页码
page_size: 每页数量
search: 搜索关键词
Returns:
数据集列表和分页信息
"""
client = await self._get_client()
try:
# RAGFlow SDK 的 list_datasets 方法
datasets = client.list_datasets(
page=page,
page_size=page_size
)
items = []
for dataset in datasets:
dataset_info = {
"dataset_id": getattr(dataset, 'id', None),
"name": getattr(dataset, 'name', None),
"description": getattr(dataset, 'description', None),
"chunk_method": getattr(dataset, 'chunk_method', None),
"avatar": getattr(dataset, 'avatar', None),
"permission": getattr(dataset, 'permission', None),
"created_at": getattr(dataset, 'created_at', None),
"updated_at": getattr(dataset, 'updated_at', None),
"metadata": getattr(dataset, 'metadata', {}),
}
# 搜索过滤
if search:
search_lower = search.lower()
if (search_lower not in (dataset_info.get('name') or '').lower() and
search_lower not in (dataset_info.get('description') or '').lower()):
continue
items.append(dataset_info)
return {
"items": items,
"total": len(items), # RAGFlow 可能不返回总数,使用实际返回数量
"page": page,
"page_size": page_size
}
except Exception as e:
logger.error(f"Failed to list datasets: {e}")
raise
async def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]:
"""
获取数据集详情
Args:
dataset_id: 数据集 ID
Returns:
数据集详情,不存在返回 None
"""
client = await self._get_client()
try:
datasets = client.list_datasets(id=dataset_id)
if datasets and len(datasets) > 0:
dataset = datasets[0]
return {
"dataset_id": getattr(dataset, 'id', dataset_id),
"name": getattr(dataset, 'name', None),
"description": getattr(dataset, 'description', None),
"chunk_method": getattr(dataset, 'chunk_method', None),
"permission": getattr(dataset, 'permission', None),
"created_at": getattr(dataset, 'created_at', None),
"updated_at": getattr(dataset, 'updated_at', None),
}
return None
except Exception as e:
logger.error(f"Failed to get dataset {dataset_id}: {e}")
raise
async def update_dataset(
self,
dataset_id: str,
**updates
) -> Optional[Dict[str, Any]]:
"""
更新数据集
Args:
dataset_id: 数据集 ID
**updates: 要更新的字段
Returns:
更新后的数据集信息
"""
client = await self._get_client()
try:
datasets = client.list_datasets(id=dataset_id)
if datasets and len(datasets) > 0:
dataset = datasets[0]
# 调用 update 方法
dataset.update(updates)
return {
"dataset_id": getattr(dataset, 'id', dataset_id),
"name": getattr(dataset, 'name', None),
"description": getattr(dataset, 'description', None),
"chunk_method": getattr(dataset, 'chunk_method', None),
"updated_at": getattr(dataset, 'updated_at', None),
}
return None
except Exception as e:
logger.error(f"Failed to update dataset {dataset_id}: {e}")
raise
async def delete_datasets(self, dataset_ids: List[str] = None) -> bool:
"""
删除数据集
Args:
dataset_ids: 要删除的数据集 ID 列表
Returns:
是否成功
"""
client = await self._get_client()
try:
if dataset_ids:
client.delete_datasets(ids=dataset_ids)
return True
except Exception as e:
logger.error(f"Failed to delete datasets: {e}")
raise
async def upload_document(
self,
dataset_id: str,
file_name: str,
file_content: bytes,
display_name: str = None
) -> Dict[str, Any]:
"""
上传文档到数据集
Args:
dataset_id: 数据集 ID
file_name: 文件名
file_content: 文件内容
display_name: 显示名称
Returns:
上传的文档信息
"""
client = await self._get_client()
try:
# 获取数据集
datasets = client.list_datasets(id=dataset_id)
if not datasets or len(datasets) == 0:
raise ValueError(f"Dataset {dataset_id} not found")
dataset = datasets[0]
# 上传文档
display_name = display_name or file_name
dataset.upload_documents([{
"display_name": display_name,
"blob": file_content
}])
# 查找刚上传的文档
documents = dataset.list_documents()
for doc in documents:
if getattr(doc, 'name', None) == display_name:
return {
"document_id": getattr(doc, 'id', None),
"name": display_name,
"dataset_id": dataset_id,
"size": len(file_content),
"status": "running",
"chunk_count": getattr(doc, 'chunk_count', 0),
"token_count": getattr(doc, 'token_count', 0),
"created_at": getattr(doc, 'created_at', None),
}
return {
"document_id": None,
"name": display_name,
"dataset_id": dataset_id,
"size": len(file_content),
"status": "uploaded",
}
except Exception as e:
logger.error(f"Failed to upload document to {dataset_id}: {e}")
raise
async def list_documents(
self,
dataset_id: str,
page: int = 1,
page_size: int = 20
) -> Dict[str, Any]:
"""
获取数据集中的文档列表
Args:
dataset_id: 数据集 ID
page: 页码
page_size: 每页数量
Returns:
文档列表和分页信息
"""
client = await self._get_client()
try:
datasets = client.list_datasets(id=dataset_id)
if not datasets or len(datasets) == 0:
return {"items": [], "total": 0, "page": page, "page_size": page_size}
dataset = datasets[0]
documents = dataset.list_documents(
page=page,
page_size=page_size
)
items = []
for doc in documents:
items.append({
"document_id": getattr(doc, 'id', None),
"name": getattr(doc, 'name', None),
"dataset_id": dataset_id,
"size": getattr(doc, 'size', 0),
"status": getattr(doc, 'run', 'unknown'),
"progress": getattr(doc, 'progress', 0),
"chunk_count": getattr(doc, 'chunk_count', 0),
"token_count": getattr(doc, 'token_count', 0),
"created_at": getattr(doc, 'created_at', None),
"updated_at": getattr(doc, 'updated_at', None),
})
return {
"items": items,
"total": len(items),
"page": page,
"page_size": page_size
}
except Exception as e:
logger.error(f"Failed to list documents for {dataset_id}: {e}")
raise
async def delete_document(
self,
dataset_id: str,
document_id: str
) -> bool:
"""
删除文档
Args:
dataset_id: 数据集 ID
document_id: 文档 ID
Returns:
是否成功
"""
client = await self._get_client()
try:
datasets = client.list_datasets(id=dataset_id)
if datasets and len(datasets) > 0:
dataset = datasets[0]
dataset.delete_documents(ids=[document_id])
return True
except Exception as e:
logger.error(f"Failed to delete document {document_id}: {e}")
raise
async def list_chunks(
self,
dataset_id: str,
document_id: str = None,
page: int = 1,
page_size: int = 50
) -> Dict[str, Any]:
"""
获取切片列表
Args:
dataset_id: 数据集 ID
document_id: 文档 ID可选
page: 页码
page_size: 每页数量
Returns:
切片列表和分页信息
"""
client = await self._get_client()
try:
datasets = client.list_datasets(id=dataset_id)
if not datasets or len(datasets) == 0:
return {"items": [], "total": 0, "page": page, "page_size": page_size}
dataset = datasets[0]
# 如果指定了文档 ID先获取文档
if document_id:
documents = dataset.list_documents(id=document_id)
if documents and len(documents) > 0:
doc = documents[0]
chunks = doc.list_chunks(page=page, page_size=page_size)
else:
chunks = []
else:
# 获取所有文档的所有切片
chunks = []
for doc in dataset.list_documents():
chunks.extend(doc.list_chunks(page=page, page_size=page_size))
items = []
for chunk in chunks:
items.append({
"chunk_id": getattr(chunk, 'id', None),
"content": getattr(chunk, 'content', ''),
"document_id": getattr(chunk, 'document_id', None),
"dataset_id": dataset_id,
"position": getattr(chunk, 'position', 0),
"important_keywords": getattr(chunk, 'important_keywords', []),
"available": getattr(chunk, 'available', True),
"created_at": getattr(chunk, 'create_time', None),
})
return {
"items": items,
"total": len(items),
"page": page,
"page_size": page_size
}
except Exception as e:
logger.error(f"Failed to list chunks for {dataset_id}: {e}")
raise
async def delete_chunk(
self,
dataset_id: str,
document_id: str,
chunk_id: str
) -> bool:
"""
删除切片
Args:
dataset_id: 数据集 ID
document_id: 文档 ID
chunk_id: 切片 ID
Returns:
是否成功
"""
client = await self._get_client()
try:
datasets = client.list_datasets(id=dataset_id)
if datasets and len(datasets) > 0:
dataset = datasets[0]
documents = dataset.list_documents(id=document_id)
if documents and len(documents) > 0:
doc = documents[0]
doc.delete_chunks(chunk_ids=[chunk_id])
return True
except Exception as e:
logger.error(f"Failed to delete chunk {chunk_id}: {e}")
raise
async def parse_document(
self,
dataset_id: str,
document_id: str
) -> bool:
"""
开始解析文档
Args:
dataset_id: 数据集 ID
document_id: 文档 ID
Returns:
是否成功
"""
client = await self._get_client()
try:
datasets = client.list_datasets(id=dataset_id)
if not datasets or len(datasets) == 0:
raise ValueError(f"Dataset {dataset_id} not found")
dataset = datasets[0]
dataset.async_parse_documents([document_id])
return True
except Exception as e:
logger.error(f"Failed to parse document {document_id}: {e}")
raise
async def cancel_parse_document(
self,
dataset_id: str,
document_id: str
) -> bool:
"""
取消解析文档
Args:
dataset_id: 数据集 ID
document_id: 文档 ID
Returns:
是否成功
"""
client = await self._get_client()
try:
datasets = client.list_datasets(id=dataset_id)
if not datasets or len(datasets) == 0:
raise ValueError(f"Dataset {dataset_id} not found")
dataset = datasets[0]
dataset.async_cancel_parse_documents([document_id])
return True
except Exception as e:
logger.error(f"Failed to cancel parse document {document_id}: {e}")
raise