370 lines
12 KiB
Python
370 lines
12 KiB
Python
"""
|
||
Knowledge Base API 路由
|
||
通过 RAGFlow SDK 提供知识库管理功能
|
||
"""
|
||
import logging
|
||
from typing import Optional
|
||
from fastapi import APIRouter, HTTPException, Header, UploadFile, File, Query, Depends
|
||
from pydantic import BaseModel, Field
|
||
|
||
from agent.db_pool_manager import get_db_pool_manager
|
||
from utils.fastapi_utils import extract_api_key_from_auth
|
||
from repositories.ragflow_repository import RAGFlowRepository
|
||
from services.knowledge_base_service import KnowledgeBaseService
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
router = APIRouter()
|
||
|
||
# ============== 依赖注入 ==============
|
||
async def get_kb_service() -> KnowledgeBaseService:
|
||
"""获取知识库服务实例"""
|
||
return KnowledgeBaseService(RAGFlowRepository())
|
||
|
||
|
||
async def verify_user(authorization: Optional[str] = Header(None)) -> tuple:
|
||
"""
|
||
验证用户权限(检查 agent_user_tokens 表)
|
||
|
||
Returns:
|
||
tuple[str, str]: (user_id, username)
|
||
"""
|
||
from routes.bot_manager import verify_user_auth
|
||
|
||
valid, user_id, username = await verify_user_auth(authorization)
|
||
if not valid:
|
||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||
return user_id, username
|
||
|
||
|
||
# ============== 数据库表初始化 ==============
|
||
|
||
async def init_knowledge_base_tables():
|
||
"""
|
||
初始化知识库相关的数据库表
|
||
"""
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 检查 user_datasets 表是否已存在
|
||
await cursor.execute("""
|
||
SELECT EXISTS (
|
||
SELECT FROM information_schema.tables
|
||
WHERE table_name = 'user_datasets'
|
||
)
|
||
""")
|
||
table_exists = (await cursor.fetchone())[0]
|
||
|
||
if not table_exists:
|
||
logger.info("Creating user_datasets table")
|
||
|
||
await cursor.execute("""
|
||
CREATE TABLE user_datasets (
|
||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||
user_id UUID NOT NULL REFERENCES agent_user(id) ON DELETE CASCADE,
|
||
dataset_id VARCHAR(255) NOT NULL,
|
||
dataset_name VARCHAR(255),
|
||
owner BOOLEAN DEFAULT TRUE,
|
||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||
UNIQUE(user_id, dataset_id)
|
||
)
|
||
""")
|
||
|
||
await cursor.execute("CREATE INDEX idx_user_datasets_user_id ON user_datasets(user_id)")
|
||
await cursor.execute("CREATE INDEX idx_user_datasets_dataset_id ON user_datasets(dataset_id)")
|
||
|
||
logger.info("user_datasets table created successfully")
|
||
|
||
await conn.commit()
|
||
logger.info("Knowledge base tables initialized successfully")
|
||
|
||
|
||
# ============== Pydantic Models ==============
|
||
|
||
class DatasetCreate(BaseModel):
|
||
"""创建数据集请求"""
|
||
name: str = Field(..., min_length=1, max_length=128, description="数据集名称")
|
||
description: Optional[str] = Field(None, max_length=500, description="描述信息")
|
||
chunk_method: str = Field(
|
||
default="naive",
|
||
description="分块方法: naive, manual, qa, table, paper, book, laws, presentation, picture, one, email, knowledge-graph"
|
||
)
|
||
|
||
|
||
class DatasetUpdate(BaseModel):
|
||
"""更新数据集请求(部分更新)"""
|
||
name: Optional[str] = Field(None, min_length=1, max_length=128)
|
||
description: Optional[str] = Field(None, max_length=500)
|
||
chunk_method: Optional[str] = None
|
||
|
||
|
||
class DatasetListResponse(BaseModel):
|
||
"""数据集列表响应(分页)"""
|
||
items: list
|
||
total: int
|
||
page: int
|
||
page_size: int
|
||
|
||
|
||
class FileListResponse(BaseModel):
|
||
"""文件列表响应(分页)"""
|
||
items: list
|
||
total: int
|
||
page: int
|
||
page_size: int
|
||
|
||
|
||
class ChunkListResponse(BaseModel):
|
||
"""切片列表响应(分页)"""
|
||
items: list
|
||
total: int
|
||
page: int
|
||
page_size: int
|
||
|
||
|
||
# ============== 数据集端点 ==============
|
||
|
||
@router.get("/datasets", response_model=DatasetListResponse)
|
||
async def list_datasets(
|
||
page: int = Query(1, ge=1, description="页码"),
|
||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||
user_info: tuple = Depends(verify_user),
|
||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||
):
|
||
"""获取当前用户的数据集列表(支持分页和搜索)"""
|
||
user_id, username = user_info
|
||
return await kb_service.list_datasets(
|
||
user_id=user_id,
|
||
page=page,
|
||
page_size=page_size,
|
||
search=search
|
||
)
|
||
|
||
|
||
@router.post("/datasets", status_code=201)
|
||
async def create_dataset(
|
||
data: DatasetCreate,
|
||
user_info: tuple = Depends(verify_user),
|
||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||
):
|
||
"""创建数据集并关联到当前用户"""
|
||
try:
|
||
user_id, username = user_info
|
||
dataset = await kb_service.create_dataset(
|
||
user_id=user_id,
|
||
name=data.name,
|
||
description=data.description,
|
||
chunk_method=data.chunk_method
|
||
)
|
||
return dataset
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
except Exception as e:
|
||
logger.error(f"Failed to create dataset: {e}")
|
||
raise HTTPException(status_code=500, detail=f"创建数据集失败: {str(e)}")
|
||
|
||
|
||
@router.get("/datasets/{dataset_id}")
|
||
async def get_dataset(
|
||
dataset_id: str,
|
||
user_info: tuple = Depends(verify_user),
|
||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||
):
|
||
"""获取数据集详情(仅限自己的数据集)"""
|
||
user_id, username = user_info
|
||
dataset = await kb_service.get_dataset(dataset_id, user_id=user_id)
|
||
if not dataset:
|
||
raise HTTPException(status_code=404, detail="数据集不存在")
|
||
return dataset
|
||
|
||
|
||
@router.patch("/datasets/{dataset_id}")
|
||
async def update_dataset(
|
||
dataset_id: str,
|
||
data: DatasetUpdate,
|
||
user_info: tuple = Depends(verify_user),
|
||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||
):
|
||
"""更新数据集(部分更新)"""
|
||
try:
|
||
user_id, username = user_info
|
||
# 只传递非 None 的字段
|
||
updates = data.model_dump(exclude_unset=True)
|
||
if not updates:
|
||
raise HTTPException(status_code=400, detail="没有提供要更新的字段")
|
||
|
||
dataset = await kb_service.update_dataset(dataset_id, updates, user_id=user_id)
|
||
if not dataset:
|
||
raise HTTPException(status_code=404, detail="数据集不存在")
|
||
return dataset
|
||
except Exception as e:
|
||
logger.error(f"Failed to update dataset: {e}")
|
||
raise HTTPException(status_code=500, detail=f"更新数据集失败: {str(e)}")
|
||
|
||
|
||
@router.delete("/datasets/{dataset_id}")
|
||
async def delete_dataset(
|
||
dataset_id: str,
|
||
user_info: tuple = Depends(verify_user),
|
||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||
):
|
||
"""删除数据集"""
|
||
user_id, username = user_info
|
||
success = await kb_service.delete_dataset(dataset_id, user_id=user_id)
|
||
if not success:
|
||
raise HTTPException(status_code=404, detail="数据集不存在")
|
||
return {"success": True, "message": "数据集已删除"}
|
||
|
||
|
||
# ============== 文件端点 ==============
|
||
|
||
@router.get("/datasets/{dataset_id}/files", response_model=FileListResponse)
|
||
async def list_dataset_files(
|
||
dataset_id: str,
|
||
page: int = Query(1, ge=1),
|
||
page_size: int = Query(20, ge=1, le=100),
|
||
user_info: tuple = Depends(verify_user),
|
||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||
):
|
||
"""获取数据集内文件列表(分页,仅限自己的数据集)"""
|
||
user_id, username = user_info
|
||
return await kb_service.list_files(dataset_id, user_id=user_id, page=page, page_size=page_size)
|
||
|
||
|
||
@router.post("/datasets/{dataset_id}/files")
|
||
async def upload_file(
|
||
dataset_id: str,
|
||
file: UploadFile = File(...),
|
||
user_info: tuple = Depends(verify_user),
|
||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||
):
|
||
"""
|
||
上传文件到数据集(流式处理)
|
||
|
||
支持的文件类型: PDF, DOCX, TXT, MD, CSV
|
||
最大文件大小: 100MB
|
||
"""
|
||
try:
|
||
user_id, username = user_info
|
||
result = await kb_service.upload_file(dataset_id, user_id=user_id, file=file)
|
||
return result
|
||
except ValueError as e:
|
||
if "File validation failed" in str(e) or "not belong to you" in str(e):
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
logger.error(f"Failed to upload file: {e}")
|
||
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to upload file: {e}")
|
||
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
|
||
|
||
|
||
@router.delete("/datasets/{dataset_id}/files/{document_id}")
|
||
async def delete_file(
|
||
dataset_id: str,
|
||
document_id: str,
|
||
user_info: tuple = Depends(verify_user),
|
||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||
):
|
||
"""删除文件"""
|
||
user_id, username = user_info
|
||
success = await kb_service.delete_file(dataset_id, document_id, user_id=user_id)
|
||
if not success:
|
||
raise HTTPException(status_code=404, detail="文件不存在")
|
||
return {"success": True}
|
||
|
||
|
||
# ============== 切片端点 ==============
|
||
|
||
@router.get("/datasets/{dataset_id}/chunks", response_model=ChunkListResponse)
|
||
async def list_chunks(
|
||
dataset_id: str,
|
||
document_id: Optional[str] = Query(None, description="文档 ID(可选)"),
|
||
page: int = Query(1, ge=1),
|
||
page_size: int = Query(50, ge=1, le=200),
|
||
user_info: tuple = Depends(verify_user),
|
||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||
):
|
||
"""获取数据集内切片列表(分页,仅限自己的数据集)"""
|
||
user_id, username = user_info
|
||
return await kb_service.list_chunks(
|
||
user_id=user_id,
|
||
dataset_id=dataset_id,
|
||
document_id=document_id,
|
||
page=page,
|
||
page_size=page_size
|
||
)
|
||
|
||
|
||
@router.delete("/datasets/{dataset_id}/chunks/{chunk_id}")
|
||
async def delete_chunk(
|
||
dataset_id: str,
|
||
chunk_id: str,
|
||
document_id: str = Query(..., description="文档 ID"),
|
||
user_info: tuple = Depends(verify_user),
|
||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||
):
|
||
"""删除切片"""
|
||
user_id, username = user_info
|
||
success = await kb_service.delete_chunk(dataset_id, document_id, chunk_id, user_id=user_id)
|
||
if not success:
|
||
raise HTTPException(status_code=404, detail="切片不存在")
|
||
return {"success": True}
|
||
|
||
|
||
# ============== 文档解析端点 ==============
|
||
|
||
@router.post("/datasets/{dataset_id}/documents/{document_id}/parse")
|
||
async def parse_document(
|
||
dataset_id: str,
|
||
document_id: str,
|
||
user_info: tuple = Depends(verify_user),
|
||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||
):
|
||
"""开始解析文档"""
|
||
try:
|
||
user_id, username = user_info
|
||
result = await kb_service.parse_document(dataset_id, document_id, user_id=user_id)
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Failed to parse document: {e}")
|
||
raise HTTPException(status_code=500, detail=f"启动解析失败: {str(e)}")
|
||
|
||
|
||
@router.post("/datasets/{dataset_id}/documents/{document_id}/cancel-parse")
|
||
async def cancel_parse_document(
|
||
dataset_id: str,
|
||
document_id: str,
|
||
user_info: tuple = Depends(verify_user),
|
||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||
):
|
||
"""取消解析文档"""
|
||
try:
|
||
user_id, username = user_info
|
||
result = await kb_service.cancel_parse_document(dataset_id, document_id, user_id=user_id)
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Failed to cancel parse: {e}")
|
||
raise HTTPException(status_code=500, detail=f"取消解析失败: {str(e)}")
|
||
|
||
|
||
# ============== Bot 数据集关联端点 ==============
|
||
|
||
@router.get("/bots/{bot_id}/datasets")
|
||
async def get_bot_datasets(
|
||
bot_id: str,
|
||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||
):
|
||
"""
|
||
获取 bot 关联的数据集 ID 列表
|
||
|
||
用于 MCP 服务器通过 bot_id 获取对应的数据集 IDs
|
||
"""
|
||
try:
|
||
dataset_ids = await kb_service.get_dataset_ids_by_bot(bot_id)
|
||
return {"dataset_ids": dataset_ids}
|
||
except Exception as e:
|
||
logger.error(f"Failed to get datasets for bot {bot_id}: {e}")
|
||
raise HTTPException(status_code=500, detail=f"获取数据集失败: {str(e)}")
|