""" Knowledge Base Service - 业务逻辑层 提供知识库管理的业务逻辑,协调数据访问和业务规则 """ import logging import mimetypes import os from typing import Optional, List, Dict, Any from pathlib import Path from agent.db_pool_manager import get_db_pool_manager from repositories.ragflow_repository import RAGFlowRepository from utils.settings import ( RAGFLOW_MAX_UPLOAD_SIZE, RAGFLOW_ALLOWED_EXTENSIONS, ) logger = logging.getLogger('app') class FileValidationError(Exception): """文件验证错误""" pass class KnowledgeBaseService: """ 知识库服务类 提供知识库管理的业务逻辑: - 数据集 CRUD - 文件上传和管理 - 文件验证 """ def __init__(self, repository: RAGFlowRepository): """ 初始化服务 Args: repository: RAGFlow 数据仓储实例 """ self.repository = repository def _validate_file(self, filename: str, content: bytes) -> None: """ 验证文件 Args: filename: 文件名 content: 文件内容 Raises: FileValidationError: 验证失败时抛出 """ # 检查文件名 if not filename or filename == "unknown": raise FileValidationError("无效的文件名") # 检查路径遍历 if '..' in filename or '/' in filename or '\\' in filename: raise FileValidationError("文件名包含非法字符") # 检查文件扩展名(去掉点号进行比较) ext = Path(filename).suffix.lower().lstrip('.') if ext not in RAGFLOW_ALLOWED_EXTENSIONS: allowed = ', '.join(RAGFLOW_ALLOWED_EXTENSIONS) raise FileValidationError(f"不支持的文件类型: {ext}。支持的类型: {allowed}") # 检查文件大小 file_size = len(content) if file_size > RAGFLOW_MAX_UPLOAD_SIZE: size_mb = file_size / (1024 * 1024) max_mb = RAGFLOW_MAX_UPLOAD_SIZE / (1024 * 1024) raise FileValidationError(f"文件过大: {size_mb:.1f}MB (最大 {max_mb}MB)") # 验证 MIME 类型(使用 mimetypes 标准库) detected_mime, _ = mimetypes.guess_type(filename) logger.info(f"File {filename} detected as {detected_mime}") # ============== 数据集管理 ============== async def _check_dataset_access(self, dataset_id: str, user_id: str) -> bool: """ 检查用户是否有权访问该数据集 Args: dataset_id: 数据集 ID user_id: 用户 ID Returns: 是否有权限 """ pool = get_db_pool_manager().pool async with pool.connection() as conn: async with conn.cursor() as cursor: await cursor.execute(""" SELECT id FROM user_datasets WHERE user_id = %s AND dataset_id = %s """, (user_id, dataset_id)) return await cursor.fetchone() is not None async def list_datasets( self, user_id: str, page: int = 1, page_size: int = 20, search: str = None ) -> Dict[str, Any]: """ 获取用户的数据集列表(从本地数据库过滤) Args: user_id: 用户 ID page: 页码 page_size: 每页数量 search: 搜索关键词 Returns: 数据集列表和分页信息 """ logger.info(f"Listing datasets for user {user_id}: page={page}, page_size={page_size}, search={search}") pool = get_db_pool_manager().pool # 从本地数据库获取用户的数据集 ID 列表 async with pool.connection() as conn: async with conn.cursor() as cursor: # 构建查询条件 where_conditions = ["user_id = %s"] params = [user_id] if search: where_conditions.append("dataset_name ILIKE %s") params.append(f"%{search}%") where_clause = " AND ".join(where_conditions) # 获取总数 await cursor.execute(f""" SELECT COUNT(*) FROM user_datasets WHERE {where_clause} """, params) total = (await cursor.fetchone())[0] # 获取分页数据 offset = (page - 1) * page_size await cursor.execute(f""" SELECT dataset_id, dataset_name, created_at FROM user_datasets WHERE {where_clause} ORDER BY created_at DESC LIMIT %s OFFSET %s """, params + [page_size, offset]) user_datasets = await cursor.fetchall() if not user_datasets: return { "items": [], "total": 0, "page": page, "page_size": page_size } # 获取数据集 ID 列表,从 RAGFlow 获取详情 dataset_ids = [row[0] for row in user_datasets] dataset_names = {row[0]: row[1] for row in user_datasets} # 从 RAGFlow 获取完整的数据集信息 ragflow_result = await self.repository.list_datasets( page=1, page_size=1000 # 获取所有数据集,然后在本地过滤 ) # 过滤出属于该用户的数据集 user_dataset_ids_set = set(dataset_ids) items = [] for item in ragflow_result["items"]: if item.get("dataset_id") in user_dataset_ids_set: items.append(item) return { "items": items, "total": total, "page": page, "page_size": page_size } async def create_dataset( self, user_id: str, name: str, description: str = None, chunk_method: str = "naive" ) -> Dict[str, Any]: """ 创建数据集并关联到用户 Args: user_id: 用户 ID name: 数据集名称 description: 描述信息 chunk_method: 分块方法 Returns: 创建的数据集信息 """ logger.info(f"Creating dataset for user {user_id}: name={name}, chunk_method={chunk_method}") # 验证分块方法 valid_methods = [ "naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "email", "knowledge-graph" ] if chunk_method not in valid_methods: raise ValueError(f"无效的分块方法: {chunk_method}。支持的方法: {', '.join(valid_methods)}") # 先在 RAGFlow 创建数据集 result = await self.repository.create_dataset( name=name, description=description, chunk_method=chunk_method, permission="me" ) # 记录到本地数据库 dataset_id = result.get("dataset_id") if dataset_id: pool = get_db_pool_manager().pool async with pool.connection() as conn: async with conn.cursor() as cursor: await cursor.execute(""" INSERT INTO user_datasets (user_id, dataset_id, dataset_name, owner) VALUES (%s, %s, %s, TRUE) """, (user_id, dataset_id, name)) await conn.commit() logger.info(f"Dataset {dataset_id} associated with user {user_id}") return result async def get_dataset(self, dataset_id: str, user_id: str = None) -> Optional[Dict[str, Any]]: """ 获取数据集详情 Args: dataset_id: 数据集 ID user_id: 用户 ID(可选,用于权限验证) Returns: 数据集详情,不存在或无权限返回 None """ logger.info(f"Getting dataset: {dataset_id} for user: {user_id}") # 如果提供了 user_id,先检查权限 if user_id: has_access = await self._check_dataset_access(dataset_id, user_id) if not has_access: logger.warning(f"User {user_id} has no access to dataset {dataset_id}") return None return await self.repository.get_dataset(dataset_id) async def update_dataset( self, dataset_id: str, updates: Dict[str, Any], user_id: str = None ) -> Optional[Dict[str, Any]]: """ 更新数据集 Args: dataset_id: 数据集 ID updates: 要更新的字段 user_id: 用户 ID(可选,用于权限验证) Returns: 更新后的数据集信息 """ logger.info(f"Updating dataset {dataset_id}: {updates}") # 如果提供了 user_id,先检查权限 if user_id: has_access = await self._check_dataset_access(dataset_id, user_id) if not has_access: logger.warning(f"User {user_id} has no access to dataset {dataset_id}") return None result = await self.repository.update_dataset(dataset_id, **updates) # 如果更新了名称,同步更新本地数据库 if result and user_id and 'name' in updates: pool = get_db_pool_manager().pool async with pool.connection() as conn: async with conn.cursor() as cursor: await cursor.execute(""" UPDATE user_datasets SET dataset_name = %s WHERE user_id = %s AND dataset_id = %s """, (updates['name'], user_id, dataset_id)) await conn.commit() return result async def delete_dataset(self, dataset_id: str, user_id: str = None) -> bool: """ 删除数据集 Args: dataset_id: 数据集 ID user_id: 用户 ID(可选,用于权限验证) Returns: 是否成功 """ logger.info(f"Deleting dataset: {dataset_id}") # 如果提供了 user_id,先检查权限 if user_id: has_access = await self._check_dataset_access(dataset_id, user_id) if not has_access: logger.warning(f"User {user_id} has no access to dataset {dataset_id}") return False # 从 RAGFlow 删除 result = await self.repository.delete_datasets([dataset_id]) # 从本地数据库删除关联记录 if result and user_id: pool = get_db_pool_manager().pool async with pool.connection() as conn: async with conn.cursor() as cursor: await cursor.execute(""" DELETE FROM user_datasets WHERE user_id = %s AND dataset_id = %s """, (user_id, dataset_id)) await conn.commit() logger.info(f"Dataset {dataset_id} unlinked from user {user_id}") return result # ============== 文件管理 ============== async def list_files( self, dataset_id: str, user_id: str = None, page: int = 1, page_size: int = 20 ) -> Dict[str, Any]: """ 获取数据集中的文件列表 Args: dataset_id: 数据集 ID user_id: 用户 ID(可选,用于权限验证) page: 页码 page_size: 每页数量 Returns: 文件列表和分页信息 """ logger.info(f"Listing files for dataset {dataset_id}: page={page}, page_size={page_size}") # 如果提供了 user_id,先检查权限 if user_id: has_access = await self._check_dataset_access(dataset_id, user_id) if not has_access: logger.warning(f"User {user_id} has no access to dataset {dataset_id}") raise ValueError("Dataset not found or does not belong to you") return await self.repository.list_documents( dataset_id=dataset_id, page=page, page_size=page_size ) async def upload_file( self, dataset_id: str, user_id: str = None, file=None, chunk_size: int = 1024 * 1024 # 1MB chunks ) -> Dict[str, Any]: """ 上传文件到数据集(流式处理) Args: dataset_id: 数据集 ID user_id: 用户 ID(可选,用于权限验证) file: FastAPI UploadFile 对象 chunk_size: 分块大小 Returns: 上传的文档信息 """ filename = file.filename or "unknown" logger.info(f"Uploading file {filename} to dataset {dataset_id}") # 如果提供了 user_id,先检查权限 if user_id: has_access = await self._check_dataset_access(dataset_id, user_id) if not has_access: logger.warning(f"User {user_id} has no access to dataset {dataset_id}") raise ValueError("Dataset not found or does not belong to you") # 流式读取文件内容 content = await file.read() # 验证文件 try: self._validate_file(filename, content) except FileValidationError as e: logger.warning(f"File validation failed: {e}") raise # 上传到 RAGFlow result = await self.repository.upload_document( dataset_id=dataset_id, file_name=filename, file_content=content, display_name=filename ) logger.info(f"File {filename} uploaded successfully") return result async def delete_file(self, dataset_id: str, document_id: str, user_id: str = None) -> bool: """ 删除文件 Args: dataset_id: 数据集 ID document_id: 文档 ID user_id: 用户 ID(可选,用于权限验证) Returns: 是否成功 """ logger.info(f"Deleting file {document_id} from dataset {dataset_id}") # 如果提供了 user_id,先检查权限 if user_id: has_access = await self._check_dataset_access(dataset_id, user_id) if not has_access: logger.warning(f"User {user_id} has no access to dataset {dataset_id}") return False return await self.repository.delete_document(dataset_id, document_id) # ============== 切片管理 ============== async def list_chunks( self, dataset_id: str, user_id: str = None, document_id: str = None, page: int = 1, page_size: int = 50 ) -> Dict[str, Any]: """ 获取切片列表 Args: dataset_id: 数据集 ID user_id: 用户 ID(可选,用于权限验证) document_id: 文档 ID(可选) page: 页码 page_size: 每页数量 Returns: 切片列表和分页信息 """ logger.info(f"Listing chunks for dataset {dataset_id}, document {document_id}") # 如果提供了 user_id,先检查权限 if user_id: has_access = await self._check_dataset_access(dataset_id, user_id) if not has_access: logger.warning(f"User {user_id} has no access to dataset {dataset_id}") raise ValueError("Dataset not found or does not belong to you") return await self.repository.list_chunks( dataset_id=dataset_id, document_id=document_id, page=page, page_size=page_size ) async def delete_chunk( self, dataset_id: str, document_id: str, chunk_id: str, user_id: str = None ) -> bool: """ 删除切片 Args: dataset_id: 数据集 ID document_id: 文档 ID chunk_id: 切片 ID user_id: 用户 ID(可选,用于权限验证) Returns: 是否成功 """ logger.info(f"Deleting chunk {chunk_id} from document {document_id}") # 如果提供了 user_id,先检查权限 if user_id: has_access = await self._check_dataset_access(dataset_id, user_id) if not has_access: logger.warning(f"User {user_id} has no access to dataset {dataset_id}") return False return await self.repository.delete_chunk(dataset_id, document_id, chunk_id) async def parse_document( self, dataset_id: str, document_id: str, user_id: str = None ) -> dict: """ 开始解析文档 Args: dataset_id: 数据集 ID document_id: 文档 ID user_id: 用户 ID(可选,用于权限验证) Returns: 操作结果 """ logger.info(f"Parsing document {document_id} in dataset {dataset_id}") # 如果提供了 user_id,先检查权限 if user_id: has_access = await self._check_dataset_access(dataset_id, user_id) if not has_access: logger.warning(f"User {user_id} has no access to dataset {dataset_id}") raise ValueError("Dataset not found or does not belong to you") success = await self.repository.parse_document(dataset_id, document_id) return {"success": success, "message": "解析任务已启动"} async def cancel_parse_document( self, dataset_id: str, document_id: str, user_id: str = None ) -> dict: """ 取消解析文档 Args: dataset_id: 数据集 ID document_id: 文档 ID user_id: 用户 ID(可选,用于权限验证) Returns: 操作结果 """ logger.info(f"Cancelling parse for document {document_id} in dataset {dataset_id}") # 如果提供了 user_id,先检查权限 if user_id: has_access = await self._check_dataset_access(dataset_id, user_id) if not has_access: logger.warning(f"User {user_id} has no access to dataset {dataset_id}") raise ValueError("Dataset not found or does not belong to you") success = await self.repository.cancel_parse_document(dataset_id, document_id) return {"success": success, "message": "解析任务已取消"} # ============== Bot 数据集关联管理 ============== async def get_dataset_ids_by_bot(self, bot_id: str) -> list[str]: """ 根据 bot_id 获取关联的数据集 ID 列表 Args: bot_id: Bot ID (agent_bots 表中的 bot_id 字段) Returns: 数据集 ID 列表 """ logger.info(f"Getting dataset_ids for bot_id: {bot_id}") pool = get_db_pool_manager().pool async with pool.connection() as conn: async with conn.cursor() as cursor: # 查询 bot 的 settings 字段中的 dataset_ids await cursor.execute(""" SELECT settings FROM agent_bots WHERE bot_id = %s """, (bot_id,)) row = await cursor.fetchone() if not row: logger.warning(f"Bot not found: {bot_id}") return [] settings = row[0] # dataset_ids 在 settings 中存储为逗号分隔的字符串 dataset_ids_str = settings.get('dataset_ids') if settings else None if not dataset_ids_str: return [] # 如果是字符串,按逗号分割 if isinstance(dataset_ids_str, str): dataset_ids = [ds_id.strip() for ds_id in dataset_ids_str.split(',') if ds_id.strip()] elif isinstance(dataset_ids_str, list): dataset_ids = dataset_ids_str else: dataset_ids = [] logger.info(f"Found {len(dataset_ids)} datasets for bot {bot_id}") return dataset_ids