""" Knowledge Base API 路由 通过 RAGFlow SDK 提供知识库管理功能 """ import logging from typing import Optional from fastapi import APIRouter, HTTPException, Header, UploadFile, File, Query, Depends from pydantic import BaseModel, Field from agent.db_pool_manager import get_db_pool_manager from utils.fastapi_utils import extract_api_key_from_auth from repositories.ragflow_repository import RAGFlowRepository from services.knowledge_base_service import KnowledgeBaseService logger = logging.getLogger('app') router = APIRouter() # ============== 依赖注入 ============== async def get_kb_service() -> KnowledgeBaseService: """获取知识库服务实例""" return KnowledgeBaseService(RAGFlowRepository()) async def verify_user(authorization: Optional[str] = Header(None)) -> tuple: """ 验证用户权限(检查 agent_user_tokens 表) Returns: tuple[str, str]: (user_id, username) """ from routes.bot_manager import verify_user_auth valid, user_id, username = await verify_user_auth(authorization) if not valid: raise HTTPException(status_code=401, detail="Unauthorized") return user_id, username # ============== 数据库表初始化 ============== async def init_knowledge_base_tables(): """ 初始化知识库相关的数据库表 """ pool = get_db_pool_manager().pool async with pool.connection() as conn: async with conn.cursor() as cursor: # 检查 user_datasets 表是否已存在 await cursor.execute(""" SELECT EXISTS ( SELECT FROM information_schema.tables WHERE table_name = 'user_datasets' ) """) table_exists = (await cursor.fetchone())[0] if not table_exists: logger.info("Creating user_datasets table") await cursor.execute(""" CREATE TABLE user_datasets ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), user_id UUID NOT NULL REFERENCES agent_user(id) ON DELETE CASCADE, dataset_id VARCHAR(255) NOT NULL, dataset_name VARCHAR(255), owner BOOLEAN DEFAULT TRUE, created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), UNIQUE(user_id, dataset_id) ) """) await cursor.execute("CREATE INDEX idx_user_datasets_user_id ON user_datasets(user_id)") await cursor.execute("CREATE INDEX idx_user_datasets_dataset_id ON user_datasets(dataset_id)") logger.info("user_datasets table created successfully") await conn.commit() logger.info("Knowledge base tables initialized successfully") # ============== Pydantic Models ============== class DatasetCreate(BaseModel): """创建数据集请求""" name: str = Field(..., min_length=1, max_length=128, description="数据集名称") description: Optional[str] = Field(None, max_length=500, description="描述信息") chunk_method: str = Field( default="naive", description="分块方法: naive, manual, qa, table, paper, book, laws, presentation, picture, one, email, knowledge-graph" ) class DatasetUpdate(BaseModel): """更新数据集请求(部分更新)""" name: Optional[str] = Field(None, min_length=1, max_length=128) description: Optional[str] = Field(None, max_length=500) chunk_method: Optional[str] = None class DatasetListResponse(BaseModel): """数据集列表响应(分页)""" items: list total: int page: int page_size: int class FileListResponse(BaseModel): """文件列表响应(分页)""" items: list total: int page: int page_size: int class ChunkListResponse(BaseModel): """切片列表响应(分页)""" items: list total: int page: int page_size: int # ============== 数据集端点 ============== @router.get("/datasets", response_model=DatasetListResponse) async def list_datasets( page: int = Query(1, ge=1, description="页码"), page_size: int = Query(20, ge=1, le=100, description="每页数量"), search: Optional[str] = Query(None, description="搜索关键词"), user_info: tuple = Depends(verify_user), kb_service: KnowledgeBaseService = Depends(get_kb_service) ): """获取当前用户的数据集列表(支持分页和搜索)""" user_id, username = user_info return await kb_service.list_datasets( user_id=user_id, page=page, page_size=page_size, search=search ) @router.post("/datasets", status_code=201) async def create_dataset( data: DatasetCreate, user_info: tuple = Depends(verify_user), kb_service: KnowledgeBaseService = Depends(get_kb_service) ): """创建数据集并关联到当前用户""" try: user_id, username = user_info dataset = await kb_service.create_dataset( user_id=user_id, name=data.name, description=data.description, chunk_method=data.chunk_method ) return dataset except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(f"Failed to create dataset: {e}") raise HTTPException(status_code=500, detail=f"创建数据集失败: {str(e)}") @router.get("/datasets/{dataset_id}") async def get_dataset( dataset_id: str, user_info: tuple = Depends(verify_user), kb_service: KnowledgeBaseService = Depends(get_kb_service) ): """获取数据集详情(仅限自己的数据集)""" user_id, username = user_info dataset = await kb_service.get_dataset(dataset_id, user_id=user_id) if not dataset: raise HTTPException(status_code=404, detail="数据集不存在") return dataset @router.patch("/datasets/{dataset_id}") async def update_dataset( dataset_id: str, data: DatasetUpdate, user_info: tuple = Depends(verify_user), kb_service: KnowledgeBaseService = Depends(get_kb_service) ): """更新数据集(部分更新)""" try: user_id, username = user_info # 只传递非 None 的字段 updates = data.model_dump(exclude_unset=True) if not updates: raise HTTPException(status_code=400, detail="没有提供要更新的字段") dataset = await kb_service.update_dataset(dataset_id, updates, user_id=user_id) if not dataset: raise HTTPException(status_code=404, detail="数据集不存在") return dataset except Exception as e: logger.error(f"Failed to update dataset: {e}") raise HTTPException(status_code=500, detail=f"更新数据集失败: {str(e)}") @router.delete("/datasets/{dataset_id}") async def delete_dataset( dataset_id: str, user_info: tuple = Depends(verify_user), kb_service: KnowledgeBaseService = Depends(get_kb_service) ): """删除数据集""" user_id, username = user_info success = await kb_service.delete_dataset(dataset_id, user_id=user_id) if not success: raise HTTPException(status_code=404, detail="数据集不存在") return {"success": True, "message": "数据集已删除"} # ============== 文件端点 ============== @router.get("/datasets/{dataset_id}/files", response_model=FileListResponse) async def list_dataset_files( dataset_id: str, page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), user_info: tuple = Depends(verify_user), kb_service: KnowledgeBaseService = Depends(get_kb_service) ): """获取数据集内文件列表(分页,仅限自己的数据集)""" user_id, username = user_info return await kb_service.list_files(dataset_id, user_id=user_id, page=page, page_size=page_size) @router.post("/datasets/{dataset_id}/files") async def upload_file( dataset_id: str, file: UploadFile = File(...), user_info: tuple = Depends(verify_user), kb_service: KnowledgeBaseService = Depends(get_kb_service) ): """ 上传文件到数据集(流式处理) 支持的文件类型: PDF, DOCX, TXT, MD, CSV 最大文件大小: 100MB """ try: user_id, username = user_info result = await kb_service.upload_file(dataset_id, user_id=user_id, file=file) return result except ValueError as e: if "File validation failed" in str(e) or "not belong to you" in str(e): raise HTTPException(status_code=400, detail=str(e)) logger.error(f"Failed to upload file: {e}") raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}") except Exception as e: logger.error(f"Failed to upload file: {e}") raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}") @router.delete("/datasets/{dataset_id}/files/{document_id}") async def delete_file( dataset_id: str, document_id: str, user_info: tuple = Depends(verify_user), kb_service: KnowledgeBaseService = Depends(get_kb_service) ): """删除文件""" user_id, username = user_info success = await kb_service.delete_file(dataset_id, document_id, user_id=user_id) if not success: raise HTTPException(status_code=404, detail="文件不存在") return {"success": True} # ============== 切片端点 ============== @router.get("/datasets/{dataset_id}/chunks", response_model=ChunkListResponse) async def list_chunks( dataset_id: str, document_id: Optional[str] = Query(None, description="文档 ID(可选)"), page: int = Query(1, ge=1), page_size: int = Query(50, ge=1, le=200), user_info: tuple = Depends(verify_user), kb_service: KnowledgeBaseService = Depends(get_kb_service) ): """获取数据集内切片列表(分页,仅限自己的数据集)""" user_id, username = user_info return await kb_service.list_chunks( user_id=user_id, dataset_id=dataset_id, document_id=document_id, page=page, page_size=page_size ) @router.delete("/datasets/{dataset_id}/chunks/{chunk_id}") async def delete_chunk( dataset_id: str, chunk_id: str, document_id: str = Query(..., description="文档 ID"), user_info: tuple = Depends(verify_user), kb_service: KnowledgeBaseService = Depends(get_kb_service) ): """删除切片""" user_id, username = user_info success = await kb_service.delete_chunk(dataset_id, document_id, chunk_id, user_id=user_id) if not success: raise HTTPException(status_code=404, detail="切片不存在") return {"success": True} # ============== 文档解析端点 ============== @router.post("/datasets/{dataset_id}/documents/{document_id}/parse") async def parse_document( dataset_id: str, document_id: str, user_info: tuple = Depends(verify_user), kb_service: KnowledgeBaseService = Depends(get_kb_service) ): """开始解析文档""" try: user_id, username = user_info result = await kb_service.parse_document(dataset_id, document_id, user_id=user_id) return result except Exception as e: logger.error(f"Failed to parse document: {e}") raise HTTPException(status_code=500, detail=f"启动解析失败: {str(e)}") @router.post("/datasets/{dataset_id}/documents/{document_id}/cancel-parse") async def cancel_parse_document( dataset_id: str, document_id: str, user_info: tuple = Depends(verify_user), kb_service: KnowledgeBaseService = Depends(get_kb_service) ): """取消解析文档""" try: user_id, username = user_info result = await kb_service.cancel_parse_document(dataset_id, document_id, user_id=user_id) return result except Exception as e: logger.error(f"Failed to cancel parse: {e}") raise HTTPException(status_code=500, detail=f"取消解析失败: {str(e)}") # ============== Bot 数据集关联端点 ============== @router.get("/bots/{bot_id}/datasets") async def get_bot_datasets( bot_id: str, kb_service: KnowledgeBaseService = Depends(get_kb_service) ): """ 获取 bot 关联的数据集 ID 列表 用于 MCP 服务器通过 bot_id 获取对应的数据集 IDs """ try: dataset_ids = await kb_service.get_dataset_ids_by_bot(bot_id) return {"dataset_ids": dataset_ids} except Exception as e: logger.error(f"Failed to get datasets for bot {bot_id}: {e}") raise HTTPException(status_code=500, detail=f"获取数据集失败: {str(e)}")