qwen_agent/services/knowledge_base_service.py
2026-02-10 18:59:10 +08:00

628 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Knowledge Base Service - 业务逻辑层
提供知识库管理的业务逻辑,协调数据访问和业务规则
"""
import logging
import mimetypes
import os
from typing import Optional, List, Dict, Any
from pathlib import Path
from agent.db_pool_manager import get_db_pool_manager
from repositories.ragflow_repository import RAGFlowRepository
from utils.settings import (
RAGFLOW_MAX_UPLOAD_SIZE,
RAGFLOW_ALLOWED_EXTENSIONS,
)
logger = logging.getLogger('app')
class FileValidationError(Exception):
"""文件验证错误"""
pass
class KnowledgeBaseService:
"""
知识库服务类
提供知识库管理的业务逻辑:
- 数据集 CRUD
- 文件上传和管理
- 文件验证
"""
def __init__(self, repository: RAGFlowRepository):
"""
初始化服务
Args:
repository: RAGFlow 数据仓储实例
"""
self.repository = repository
def _validate_file(self, filename: str, content: bytes) -> None:
"""
验证文件
Args:
filename: 文件名
content: 文件内容
Raises:
FileValidationError: 验证失败时抛出
"""
# 检查文件名
if not filename or filename == "unknown":
raise FileValidationError("无效的文件名")
# 检查路径遍历
if '..' in filename or '/' in filename or '\\' in filename:
raise FileValidationError("文件名包含非法字符")
# 检查文件扩展名(去掉点号进行比较)
ext = Path(filename).suffix.lower().lstrip('.')
if ext not in RAGFLOW_ALLOWED_EXTENSIONS:
allowed = ', '.join(RAGFLOW_ALLOWED_EXTENSIONS)
raise FileValidationError(f"不支持的文件类型: {ext}。支持的类型: {allowed}")
# 检查文件大小
file_size = len(content)
if file_size > RAGFLOW_MAX_UPLOAD_SIZE:
size_mb = file_size / (1024 * 1024)
max_mb = RAGFLOW_MAX_UPLOAD_SIZE / (1024 * 1024)
raise FileValidationError(f"文件过大: {size_mb:.1f}MB (最大 {max_mb}MB)")
# 验证 MIME 类型(使用 mimetypes 标准库)
detected_mime, _ = mimetypes.guess_type(filename)
logger.info(f"File {filename} detected as {detected_mime}")
# ============== 数据集管理 ==============
async def _check_dataset_access(self, dataset_id: str, user_id: str) -> bool:
"""
检查用户是否有权访问该数据集
Args:
dataset_id: 数据集 ID
user_id: 用户 ID
Returns:
是否有权限
"""
pool = get_db_pool_manager().pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute("""
SELECT id FROM user_datasets
WHERE user_id = %s AND dataset_id = %s
""", (user_id, dataset_id))
return await cursor.fetchone() is not None
async def list_datasets(
self,
user_id: str,
page: int = 1,
page_size: int = 20,
search: str = None
) -> Dict[str, Any]:
"""
获取用户的数据集列表(从本地数据库过滤)
Args:
user_id: 用户 ID
page: 页码
page_size: 每页数量
search: 搜索关键词
Returns:
数据集列表和分页信息
"""
logger.info(f"Listing datasets for user {user_id}: page={page}, page_size={page_size}, search={search}")
pool = get_db_pool_manager().pool
# 从本地数据库获取用户的数据集 ID 列表
async with pool.connection() as conn:
async with conn.cursor() as cursor:
# 构建查询条件
where_conditions = ["user_id = %s"]
params = [user_id]
if search:
where_conditions.append("dataset_name ILIKE %s")
params.append(f"%{search}%")
where_clause = " AND ".join(where_conditions)
# 获取总数
await cursor.execute(f"""
SELECT COUNT(*) FROM user_datasets
WHERE {where_clause}
""", params)
total = (await cursor.fetchone())[0]
# 获取分页数据
offset = (page - 1) * page_size
await cursor.execute(f"""
SELECT dataset_id, dataset_name, created_at
FROM user_datasets
WHERE {where_clause}
ORDER BY created_at DESC
LIMIT %s OFFSET %s
""", params + [page_size, offset])
user_datasets = await cursor.fetchall()
if not user_datasets:
return {
"items": [],
"total": 0,
"page": page,
"page_size": page_size
}
# 获取数据集 ID 列表,从 RAGFlow 获取详情
dataset_ids = [row[0] for row in user_datasets]
dataset_names = {row[0]: row[1] for row in user_datasets}
# 从 RAGFlow 获取完整的数据集信息
ragflow_result = await self.repository.list_datasets(
page=1,
page_size=1000 # 获取所有数据集,然后在本地过滤
)
# 过滤出属于该用户的数据集
user_dataset_ids_set = set(dataset_ids)
items = []
for item in ragflow_result["items"]:
if item.get("dataset_id") in user_dataset_ids_set:
items.append(item)
return {
"items": items,
"total": total,
"page": page,
"page_size": page_size
}
async def create_dataset(
self,
user_id: str,
name: str,
description: str = None,
chunk_method: str = "naive"
) -> Dict[str, Any]:
"""
创建数据集并关联到用户
Args:
user_id: 用户 ID
name: 数据集名称
description: 描述信息
chunk_method: 分块方法
Returns:
创建的数据集信息
"""
logger.info(f"Creating dataset for user {user_id}: name={name}, chunk_method={chunk_method}")
# 验证分块方法
valid_methods = [
"naive", "manual", "qa", "table", "paper",
"book", "laws", "presentation", "picture", "one", "email", "knowledge-graph"
]
if chunk_method not in valid_methods:
raise ValueError(f"无效的分块方法: {chunk_method}。支持的方法: {', '.join(valid_methods)}")
# 先在 RAGFlow 创建数据集
result = await self.repository.create_dataset(
name=name,
description=description,
chunk_method=chunk_method,
permission="me"
)
# 记录到本地数据库
dataset_id = result.get("dataset_id")
if dataset_id:
pool = get_db_pool_manager().pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute("""
INSERT INTO user_datasets (user_id, dataset_id, dataset_name, owner)
VALUES (%s, %s, %s, TRUE)
""", (user_id, dataset_id, name))
await conn.commit()
logger.info(f"Dataset {dataset_id} associated with user {user_id}")
return result
async def get_dataset(self, dataset_id: str, user_id: str = None) -> Optional[Dict[str, Any]]:
"""
获取数据集详情
Args:
dataset_id: 数据集 ID
user_id: 用户 ID可选用于权限验证
Returns:
数据集详情,不存在或无权限返回 None
"""
logger.info(f"Getting dataset: {dataset_id} for user: {user_id}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
return None
return await self.repository.get_dataset(dataset_id)
async def update_dataset(
self,
dataset_id: str,
updates: Dict[str, Any],
user_id: str = None
) -> Optional[Dict[str, Any]]:
"""
更新数据集
Args:
dataset_id: 数据集 ID
updates: 要更新的字段
user_id: 用户 ID可选用于权限验证
Returns:
更新后的数据集信息
"""
logger.info(f"Updating dataset {dataset_id}: {updates}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
return None
result = await self.repository.update_dataset(dataset_id, **updates)
# 如果更新了名称,同步更新本地数据库
if result and user_id and 'name' in updates:
pool = get_db_pool_manager().pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute("""
UPDATE user_datasets
SET dataset_name = %s
WHERE user_id = %s AND dataset_id = %s
""", (updates['name'], user_id, dataset_id))
await conn.commit()
return result
async def delete_dataset(self, dataset_id: str, user_id: str = None) -> bool:
"""
删除数据集
Args:
dataset_id: 数据集 ID
user_id: 用户 ID可选用于权限验证
Returns:
是否成功
"""
logger.info(f"Deleting dataset: {dataset_id}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
return False
# 从 RAGFlow 删除
result = await self.repository.delete_datasets([dataset_id])
# 从本地数据库删除关联记录
if result and user_id:
pool = get_db_pool_manager().pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute("""
DELETE FROM user_datasets
WHERE user_id = %s AND dataset_id = %s
""", (user_id, dataset_id))
await conn.commit()
logger.info(f"Dataset {dataset_id} unlinked from user {user_id}")
return result
# ============== 文件管理 ==============
async def list_files(
self,
dataset_id: str,
user_id: str = None,
page: int = 1,
page_size: int = 20
) -> Dict[str, Any]:
"""
获取数据集中的文件列表
Args:
dataset_id: 数据集 ID
user_id: 用户 ID可选用于权限验证
page: 页码
page_size: 每页数量
Returns:
文件列表和分页信息
"""
logger.info(f"Listing files for dataset {dataset_id}: page={page}, page_size={page_size}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
raise ValueError("Dataset not found or does not belong to you")
return await self.repository.list_documents(
dataset_id=dataset_id,
page=page,
page_size=page_size
)
async def upload_file(
self,
dataset_id: str,
user_id: str = None,
file=None,
chunk_size: int = 1024 * 1024 # 1MB chunks
) -> Dict[str, Any]:
"""
上传文件到数据集(流式处理)
Args:
dataset_id: 数据集 ID
user_id: 用户 ID可选用于权限验证
file: FastAPI UploadFile 对象
chunk_size: 分块大小
Returns:
上传的文档信息
"""
filename = file.filename or "unknown"
logger.info(f"Uploading file {filename} to dataset {dataset_id}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
raise ValueError("Dataset not found or does not belong to you")
# 流式读取文件内容
content = await file.read()
# 验证文件
try:
self._validate_file(filename, content)
except FileValidationError as e:
logger.warning(f"File validation failed: {e}")
raise
# 上传到 RAGFlow
result = await self.repository.upload_document(
dataset_id=dataset_id,
file_name=filename,
file_content=content,
display_name=filename
)
logger.info(f"File {filename} uploaded successfully")
return result
async def delete_file(self, dataset_id: str, document_id: str, user_id: str = None) -> bool:
"""
删除文件
Args:
dataset_id: 数据集 ID
document_id: 文档 ID
user_id: 用户 ID可选用于权限验证
Returns:
是否成功
"""
logger.info(f"Deleting file {document_id} from dataset {dataset_id}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
return False
return await self.repository.delete_document(dataset_id, document_id)
# ============== 切片管理 ==============
async def list_chunks(
self,
dataset_id: str,
user_id: str = None,
document_id: str = None,
page: int = 1,
page_size: int = 50
) -> Dict[str, Any]:
"""
获取切片列表
Args:
dataset_id: 数据集 ID
user_id: 用户 ID可选用于权限验证
document_id: 文档 ID可选
page: 页码
page_size: 每页数量
Returns:
切片列表和分页信息
"""
logger.info(f"Listing chunks for dataset {dataset_id}, document {document_id}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
raise ValueError("Dataset not found or does not belong to you")
return await self.repository.list_chunks(
dataset_id=dataset_id,
document_id=document_id,
page=page,
page_size=page_size
)
async def delete_chunk(
self,
dataset_id: str,
document_id: str,
chunk_id: str,
user_id: str = None
) -> bool:
"""
删除切片
Args:
dataset_id: 数据集 ID
document_id: 文档 ID
chunk_id: 切片 ID
user_id: 用户 ID可选用于权限验证
Returns:
是否成功
"""
logger.info(f"Deleting chunk {chunk_id} from document {document_id}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
return False
return await self.repository.delete_chunk(dataset_id, document_id, chunk_id)
async def parse_document(
self,
dataset_id: str,
document_id: str,
user_id: str = None
) -> dict:
"""
开始解析文档
Args:
dataset_id: 数据集 ID
document_id: 文档 ID
user_id: 用户 ID可选用于权限验证
Returns:
操作结果
"""
logger.info(f"Parsing document {document_id} in dataset {dataset_id}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
raise ValueError("Dataset not found or does not belong to you")
success = await self.repository.parse_document(dataset_id, document_id)
return {"success": success, "message": "解析任务已启动"}
async def cancel_parse_document(
self,
dataset_id: str,
document_id: str,
user_id: str = None
) -> dict:
"""
取消解析文档
Args:
dataset_id: 数据集 ID
document_id: 文档 ID
user_id: 用户 ID可选用于权限验证
Returns:
操作结果
"""
logger.info(f"Cancelling parse for document {document_id} in dataset {dataset_id}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
raise ValueError("Dataset not found or does not belong to you")
success = await self.repository.cancel_parse_document(dataset_id, document_id)
return {"success": success, "message": "解析任务已取消"}
# ============== Bot 数据集关联管理 ==============
async def get_dataset_ids_by_bot(self, bot_id: str) -> list[str]:
"""
根据 bot_id 获取关联的数据集 ID 列表
Args:
bot_id: Bot ID (agent_bots 表中的 bot_id 字段)
Returns:
数据集 ID 列表
"""
logger.info(f"Getting dataset_ids for bot_id: {bot_id}")
pool = get_db_pool_manager().pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
# 查询 bot 的 settings 字段中的 dataset_ids
await cursor.execute("""
SELECT settings
FROM agent_bots
WHERE bot_id = %s
""", (bot_id,))
row = await cursor.fetchone()
if not row:
logger.warning(f"Bot not found: {bot_id}")
return []
settings = row[0]
# dataset_ids 在 settings 中存储为逗号分隔的字符串
dataset_ids_str = settings.get('dataset_ids') if settings else None
if not dataset_ids_str:
return []
# 如果是字符串,按逗号分割
if isinstance(dataset_ids_str, str):
dataset_ids = [ds_id.strip() for ds_id in dataset_ids_str.split(',') if ds_id.strip()]
elif isinstance(dataset_ids_str, list):
dataset_ids = dataset_ids_str
else:
dataset_ids = []
logger.info(f"Found {len(dataset_ids)} datasets for bot {bot_id}")
return dataset_ids