628 lines
20 KiB
Python
628 lines
20 KiB
Python
"""
|
||
Knowledge Base Service - 业务逻辑层
|
||
提供知识库管理的业务逻辑,协调数据访问和业务规则
|
||
"""
|
||
import logging
|
||
import mimetypes
|
||
import os
|
||
from typing import Optional, List, Dict, Any
|
||
from pathlib import Path
|
||
|
||
from agent.db_pool_manager import get_db_pool_manager
|
||
from repositories.ragflow_repository import RAGFlowRepository
|
||
from utils.settings import (
|
||
RAGFLOW_MAX_UPLOAD_SIZE,
|
||
RAGFLOW_ALLOWED_EXTENSIONS,
|
||
)
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
|
||
class FileValidationError(Exception):
|
||
"""文件验证错误"""
|
||
pass
|
||
|
||
|
||
class KnowledgeBaseService:
|
||
"""
|
||
知识库服务类
|
||
|
||
提供知识库管理的业务逻辑:
|
||
- 数据集 CRUD
|
||
- 文件上传和管理
|
||
- 文件验证
|
||
"""
|
||
|
||
def __init__(self, repository: RAGFlowRepository):
|
||
"""
|
||
初始化服务
|
||
|
||
Args:
|
||
repository: RAGFlow 数据仓储实例
|
||
"""
|
||
self.repository = repository
|
||
|
||
def _validate_file(self, filename: str, content: bytes) -> None:
|
||
"""
|
||
验证文件
|
||
|
||
Args:
|
||
filename: 文件名
|
||
content: 文件内容
|
||
|
||
Raises:
|
||
FileValidationError: 验证失败时抛出
|
||
"""
|
||
# 检查文件名
|
||
if not filename or filename == "unknown":
|
||
raise FileValidationError("无效的文件名")
|
||
|
||
# 检查路径遍历
|
||
if '..' in filename or '/' in filename or '\\' in filename:
|
||
raise FileValidationError("文件名包含非法字符")
|
||
|
||
# 检查文件扩展名(去掉点号进行比较)
|
||
ext = Path(filename).suffix.lower().lstrip('.')
|
||
if ext not in RAGFLOW_ALLOWED_EXTENSIONS:
|
||
allowed = ', '.join(RAGFLOW_ALLOWED_EXTENSIONS)
|
||
raise FileValidationError(f"不支持的文件类型: {ext}。支持的类型: {allowed}")
|
||
|
||
# 检查文件大小
|
||
file_size = len(content)
|
||
if file_size > RAGFLOW_MAX_UPLOAD_SIZE:
|
||
size_mb = file_size / (1024 * 1024)
|
||
max_mb = RAGFLOW_MAX_UPLOAD_SIZE / (1024 * 1024)
|
||
raise FileValidationError(f"文件过大: {size_mb:.1f}MB (最大 {max_mb}MB)")
|
||
|
||
# 验证 MIME 类型(使用 mimetypes 标准库)
|
||
detected_mime, _ = mimetypes.guess_type(filename)
|
||
logger.info(f"File {filename} detected as {detected_mime}")
|
||
|
||
# ============== 数据集管理 ==============
|
||
|
||
async def _check_dataset_access(self, dataset_id: str, user_id: str) -> bool:
|
||
"""
|
||
检查用户是否有权访问该数据集
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
user_id: 用户 ID
|
||
|
||
Returns:
|
||
是否有权限
|
||
"""
|
||
pool = get_db_pool_manager().pool
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
SELECT id FROM user_datasets
|
||
WHERE user_id = %s AND dataset_id = %s
|
||
""", (user_id, dataset_id))
|
||
return await cursor.fetchone() is not None
|
||
|
||
async def list_datasets(
|
||
self,
|
||
user_id: str,
|
||
page: int = 1,
|
||
page_size: int = 20,
|
||
search: str = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
获取用户的数据集列表(从本地数据库过滤)
|
||
|
||
Args:
|
||
user_id: 用户 ID
|
||
page: 页码
|
||
page_size: 每页数量
|
||
search: 搜索关键词
|
||
|
||
Returns:
|
||
数据集列表和分页信息
|
||
"""
|
||
logger.info(f"Listing datasets for user {user_id}: page={page}, page_size={page_size}, search={search}")
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
# 从本地数据库获取用户的数据集 ID 列表
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 构建查询条件
|
||
where_conditions = ["user_id = %s"]
|
||
params = [user_id]
|
||
|
||
if search:
|
||
where_conditions.append("dataset_name ILIKE %s")
|
||
params.append(f"%{search}%")
|
||
|
||
where_clause = " AND ".join(where_conditions)
|
||
|
||
# 获取总数
|
||
await cursor.execute(f"""
|
||
SELECT COUNT(*) FROM user_datasets
|
||
WHERE {where_clause}
|
||
""", params)
|
||
total = (await cursor.fetchone())[0]
|
||
|
||
# 获取分页数据
|
||
offset = (page - 1) * page_size
|
||
await cursor.execute(f"""
|
||
SELECT dataset_id, dataset_name, created_at
|
||
FROM user_datasets
|
||
WHERE {where_clause}
|
||
ORDER BY created_at DESC
|
||
LIMIT %s OFFSET %s
|
||
""", params + [page_size, offset])
|
||
|
||
user_datasets = await cursor.fetchall()
|
||
|
||
if not user_datasets:
|
||
return {
|
||
"items": [],
|
||
"total": 0,
|
||
"page": page,
|
||
"page_size": page_size
|
||
}
|
||
|
||
# 获取数据集 ID 列表,从 RAGFlow 获取详情
|
||
dataset_ids = [row[0] for row in user_datasets]
|
||
dataset_names = {row[0]: row[1] for row in user_datasets}
|
||
|
||
# 从 RAGFlow 获取完整的数据集信息
|
||
ragflow_result = await self.repository.list_datasets(
|
||
page=1,
|
||
page_size=1000 # 获取所有数据集,然后在本地过滤
|
||
)
|
||
|
||
# 过滤出属于该用户的数据集
|
||
user_dataset_ids_set = set(dataset_ids)
|
||
items = []
|
||
for item in ragflow_result["items"]:
|
||
if item.get("dataset_id") in user_dataset_ids_set:
|
||
items.append(item)
|
||
|
||
return {
|
||
"items": items,
|
||
"total": total,
|
||
"page": page,
|
||
"page_size": page_size
|
||
}
|
||
|
||
async def create_dataset(
|
||
self,
|
||
user_id: str,
|
||
name: str,
|
||
description: str = None,
|
||
chunk_method: str = "naive"
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
创建数据集并关联到用户
|
||
|
||
Args:
|
||
user_id: 用户 ID
|
||
name: 数据集名称
|
||
description: 描述信息
|
||
chunk_method: 分块方法
|
||
|
||
Returns:
|
||
创建的数据集信息
|
||
"""
|
||
logger.info(f"Creating dataset for user {user_id}: name={name}, chunk_method={chunk_method}")
|
||
|
||
# 验证分块方法
|
||
valid_methods = [
|
||
"naive", "manual", "qa", "table", "paper",
|
||
"book", "laws", "presentation", "picture", "one", "email", "knowledge-graph"
|
||
]
|
||
if chunk_method not in valid_methods:
|
||
raise ValueError(f"无效的分块方法: {chunk_method}。支持的方法: {', '.join(valid_methods)}")
|
||
|
||
# 先在 RAGFlow 创建数据集
|
||
result = await self.repository.create_dataset(
|
||
name=name,
|
||
description=description,
|
||
chunk_method=chunk_method,
|
||
permission="me"
|
||
)
|
||
|
||
# 记录到本地数据库
|
||
dataset_id = result.get("dataset_id")
|
||
if dataset_id:
|
||
pool = get_db_pool_manager().pool
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
INSERT INTO user_datasets (user_id, dataset_id, dataset_name, owner)
|
||
VALUES (%s, %s, %s, TRUE)
|
||
""", (user_id, dataset_id, name))
|
||
await conn.commit()
|
||
logger.info(f"Dataset {dataset_id} associated with user {user_id}")
|
||
|
||
return result
|
||
|
||
async def get_dataset(self, dataset_id: str, user_id: str = None) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取数据集详情
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
user_id: 用户 ID(可选,用于权限验证)
|
||
|
||
Returns:
|
||
数据集详情,不存在或无权限返回 None
|
||
"""
|
||
logger.info(f"Getting dataset: {dataset_id} for user: {user_id}")
|
||
|
||
# 如果提供了 user_id,先检查权限
|
||
if user_id:
|
||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||
if not has_access:
|
||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||
return None
|
||
|
||
return await self.repository.get_dataset(dataset_id)
|
||
|
||
async def update_dataset(
|
||
self,
|
||
dataset_id: str,
|
||
updates: Dict[str, Any],
|
||
user_id: str = None
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
更新数据集
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
updates: 要更新的字段
|
||
user_id: 用户 ID(可选,用于权限验证)
|
||
|
||
Returns:
|
||
更新后的数据集信息
|
||
"""
|
||
logger.info(f"Updating dataset {dataset_id}: {updates}")
|
||
|
||
# 如果提供了 user_id,先检查权限
|
||
if user_id:
|
||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||
if not has_access:
|
||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||
return None
|
||
|
||
result = await self.repository.update_dataset(dataset_id, **updates)
|
||
|
||
# 如果更新了名称,同步更新本地数据库
|
||
if result and user_id and 'name' in updates:
|
||
pool = get_db_pool_manager().pool
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
UPDATE user_datasets
|
||
SET dataset_name = %s
|
||
WHERE user_id = %s AND dataset_id = %s
|
||
""", (updates['name'], user_id, dataset_id))
|
||
await conn.commit()
|
||
|
||
return result
|
||
|
||
async def delete_dataset(self, dataset_id: str, user_id: str = None) -> bool:
|
||
"""
|
||
删除数据集
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
user_id: 用户 ID(可选,用于权限验证)
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
logger.info(f"Deleting dataset: {dataset_id}")
|
||
|
||
# 如果提供了 user_id,先检查权限
|
||
if user_id:
|
||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||
if not has_access:
|
||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||
return False
|
||
|
||
# 从 RAGFlow 删除
|
||
result = await self.repository.delete_datasets([dataset_id])
|
||
|
||
# 从本地数据库删除关联记录
|
||
if result and user_id:
|
||
pool = get_db_pool_manager().pool
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
DELETE FROM user_datasets
|
||
WHERE user_id = %s AND dataset_id = %s
|
||
""", (user_id, dataset_id))
|
||
await conn.commit()
|
||
logger.info(f"Dataset {dataset_id} unlinked from user {user_id}")
|
||
|
||
return result
|
||
|
||
# ============== 文件管理 ==============
|
||
|
||
async def list_files(
|
||
self,
|
||
dataset_id: str,
|
||
user_id: str = None,
|
||
page: int = 1,
|
||
page_size: int = 20
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
获取数据集中的文件列表
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
user_id: 用户 ID(可选,用于权限验证)
|
||
page: 页码
|
||
page_size: 每页数量
|
||
|
||
Returns:
|
||
文件列表和分页信息
|
||
"""
|
||
logger.info(f"Listing files for dataset {dataset_id}: page={page}, page_size={page_size}")
|
||
|
||
# 如果提供了 user_id,先检查权限
|
||
if user_id:
|
||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||
if not has_access:
|
||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||
raise ValueError("Dataset not found or does not belong to you")
|
||
|
||
return await self.repository.list_documents(
|
||
dataset_id=dataset_id,
|
||
page=page,
|
||
page_size=page_size
|
||
)
|
||
|
||
async def upload_file(
|
||
self,
|
||
dataset_id: str,
|
||
user_id: str = None,
|
||
file=None,
|
||
chunk_size: int = 1024 * 1024 # 1MB chunks
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
上传文件到数据集(流式处理)
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
user_id: 用户 ID(可选,用于权限验证)
|
||
file: FastAPI UploadFile 对象
|
||
chunk_size: 分块大小
|
||
|
||
Returns:
|
||
上传的文档信息
|
||
"""
|
||
filename = file.filename or "unknown"
|
||
|
||
logger.info(f"Uploading file {filename} to dataset {dataset_id}")
|
||
|
||
# 如果提供了 user_id,先检查权限
|
||
if user_id:
|
||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||
if not has_access:
|
||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||
raise ValueError("Dataset not found or does not belong to you")
|
||
|
||
# 流式读取文件内容
|
||
content = await file.read()
|
||
|
||
# 验证文件
|
||
try:
|
||
self._validate_file(filename, content)
|
||
except FileValidationError as e:
|
||
logger.warning(f"File validation failed: {e}")
|
||
raise
|
||
|
||
# 上传到 RAGFlow
|
||
result = await self.repository.upload_document(
|
||
dataset_id=dataset_id,
|
||
file_name=filename,
|
||
file_content=content,
|
||
display_name=filename
|
||
)
|
||
|
||
logger.info(f"File {filename} uploaded successfully")
|
||
return result
|
||
|
||
async def delete_file(self, dataset_id: str, document_id: str, user_id: str = None) -> bool:
|
||
"""
|
||
删除文件
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
document_id: 文档 ID
|
||
user_id: 用户 ID(可选,用于权限验证)
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
logger.info(f"Deleting file {document_id} from dataset {dataset_id}")
|
||
|
||
# 如果提供了 user_id,先检查权限
|
||
if user_id:
|
||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||
if not has_access:
|
||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||
return False
|
||
|
||
return await self.repository.delete_document(dataset_id, document_id)
|
||
|
||
# ============== 切片管理 ==============
|
||
|
||
async def list_chunks(
|
||
self,
|
||
dataset_id: str,
|
||
user_id: str = None,
|
||
document_id: str = None,
|
||
page: int = 1,
|
||
page_size: int = 50
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
获取切片列表
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
user_id: 用户 ID(可选,用于权限验证)
|
||
document_id: 文档 ID(可选)
|
||
page: 页码
|
||
page_size: 每页数量
|
||
|
||
Returns:
|
||
切片列表和分页信息
|
||
"""
|
||
logger.info(f"Listing chunks for dataset {dataset_id}, document {document_id}")
|
||
|
||
# 如果提供了 user_id,先检查权限
|
||
if user_id:
|
||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||
if not has_access:
|
||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||
raise ValueError("Dataset not found or does not belong to you")
|
||
|
||
return await self.repository.list_chunks(
|
||
dataset_id=dataset_id,
|
||
document_id=document_id,
|
||
page=page,
|
||
page_size=page_size
|
||
)
|
||
|
||
async def delete_chunk(
|
||
self,
|
||
dataset_id: str,
|
||
document_id: str,
|
||
chunk_id: str,
|
||
user_id: str = None
|
||
) -> bool:
|
||
"""
|
||
删除切片
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
document_id: 文档 ID
|
||
chunk_id: 切片 ID
|
||
user_id: 用户 ID(可选,用于权限验证)
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
logger.info(f"Deleting chunk {chunk_id} from document {document_id}")
|
||
|
||
# 如果提供了 user_id,先检查权限
|
||
if user_id:
|
||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||
if not has_access:
|
||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||
return False
|
||
|
||
return await self.repository.delete_chunk(dataset_id, document_id, chunk_id)
|
||
|
||
async def parse_document(
|
||
self,
|
||
dataset_id: str,
|
||
document_id: str,
|
||
user_id: str = None
|
||
) -> dict:
|
||
"""
|
||
开始解析文档
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
document_id: 文档 ID
|
||
user_id: 用户 ID(可选,用于权限验证)
|
||
|
||
Returns:
|
||
操作结果
|
||
"""
|
||
logger.info(f"Parsing document {document_id} in dataset {dataset_id}")
|
||
|
||
# 如果提供了 user_id,先检查权限
|
||
if user_id:
|
||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||
if not has_access:
|
||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||
raise ValueError("Dataset not found or does not belong to you")
|
||
|
||
success = await self.repository.parse_document(dataset_id, document_id)
|
||
return {"success": success, "message": "解析任务已启动"}
|
||
|
||
async def cancel_parse_document(
|
||
self,
|
||
dataset_id: str,
|
||
document_id: str,
|
||
user_id: str = None
|
||
) -> dict:
|
||
"""
|
||
取消解析文档
|
||
|
||
Args:
|
||
dataset_id: 数据集 ID
|
||
document_id: 文档 ID
|
||
user_id: 用户 ID(可选,用于权限验证)
|
||
|
||
Returns:
|
||
操作结果
|
||
"""
|
||
logger.info(f"Cancelling parse for document {document_id} in dataset {dataset_id}")
|
||
|
||
# 如果提供了 user_id,先检查权限
|
||
if user_id:
|
||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||
if not has_access:
|
||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||
raise ValueError("Dataset not found or does not belong to you")
|
||
|
||
success = await self.repository.cancel_parse_document(dataset_id, document_id)
|
||
return {"success": success, "message": "解析任务已取消"}
|
||
|
||
# ============== Bot 数据集关联管理 ==============
|
||
|
||
async def get_dataset_ids_by_bot(self, bot_id: str) -> list[str]:
|
||
"""
|
||
根据 bot_id 获取关联的数据集 ID 列表
|
||
|
||
Args:
|
||
bot_id: Bot ID (agent_bots 表中的 bot_id 字段)
|
||
|
||
Returns:
|
||
数据集 ID 列表
|
||
"""
|
||
logger.info(f"Getting dataset_ids for bot_id: {bot_id}")
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 查询 bot 的 settings 字段中的 dataset_ids
|
||
await cursor.execute("""
|
||
SELECT settings
|
||
FROM agent_bots
|
||
WHERE bot_id = %s
|
||
""", (bot_id,))
|
||
|
||
row = await cursor.fetchone()
|
||
if not row:
|
||
logger.warning(f"Bot not found: {bot_id}")
|
||
return []
|
||
|
||
settings = row[0]
|
||
|
||
# dataset_ids 在 settings 中存储为逗号分隔的字符串
|
||
dataset_ids_str = settings.get('dataset_ids') if settings else None
|
||
|
||
if not dataset_ids_str:
|
||
return []
|
||
|
||
# 如果是字符串,按逗号分割
|
||
if isinstance(dataset_ids_str, str):
|
||
dataset_ids = [ds_id.strip() for ds_id in dataset_ids_str.split(',') if ds_id.strip()]
|
||
elif isinstance(dataset_ids_str, list):
|
||
dataset_ids = dataset_ids_str
|
||
else:
|
||
dataset_ids = []
|
||
|
||
logger.info(f"Found {len(dataset_ids)} datasets for bot {bot_id}")
|
||
return dataset_ids
|