qwen_agent/routes/knowledge_base.py
2026-02-10 18:59:10 +08:00

370 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Knowledge Base API 路由
通过 RAGFlow SDK 提供知识库管理功能
"""
import logging
from typing import Optional
from fastapi import APIRouter, HTTPException, Header, UploadFile, File, Query, Depends
from pydantic import BaseModel, Field
from agent.db_pool_manager import get_db_pool_manager
from utils.fastapi_utils import extract_api_key_from_auth
from repositories.ragflow_repository import RAGFlowRepository
from services.knowledge_base_service import KnowledgeBaseService
logger = logging.getLogger('app')
router = APIRouter()
# ============== 依赖注入 ==============
async def get_kb_service() -> KnowledgeBaseService:
"""获取知识库服务实例"""
return KnowledgeBaseService(RAGFlowRepository())
async def verify_user(authorization: Optional[str] = Header(None)) -> tuple:
"""
验证用户权限(检查 agent_user_tokens 表)
Returns:
tuple[str, str]: (user_id, username)
"""
from routes.bot_manager import verify_user_auth
valid, user_id, username = await verify_user_auth(authorization)
if not valid:
raise HTTPException(status_code=401, detail="Unauthorized")
return user_id, username
# ============== 数据库表初始化 ==============
async def init_knowledge_base_tables():
"""
初始化知识库相关的数据库表
"""
pool = get_db_pool_manager().pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
# 检查 user_datasets 表是否已存在
await cursor.execute("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'user_datasets'
)
""")
table_exists = (await cursor.fetchone())[0]
if not table_exists:
logger.info("Creating user_datasets table")
await cursor.execute("""
CREATE TABLE user_datasets (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES agent_user(id) ON DELETE CASCADE,
dataset_id VARCHAR(255) NOT NULL,
dataset_name VARCHAR(255),
owner BOOLEAN DEFAULT TRUE,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
UNIQUE(user_id, dataset_id)
)
""")
await cursor.execute("CREATE INDEX idx_user_datasets_user_id ON user_datasets(user_id)")
await cursor.execute("CREATE INDEX idx_user_datasets_dataset_id ON user_datasets(dataset_id)")
logger.info("user_datasets table created successfully")
await conn.commit()
logger.info("Knowledge base tables initialized successfully")
# ============== Pydantic Models ==============
class DatasetCreate(BaseModel):
"""创建数据集请求"""
name: str = Field(..., min_length=1, max_length=128, description="数据集名称")
description: Optional[str] = Field(None, max_length=500, description="描述信息")
chunk_method: str = Field(
default="naive",
description="分块方法: naive, manual, qa, table, paper, book, laws, presentation, picture, one, email, knowledge-graph"
)
class DatasetUpdate(BaseModel):
"""更新数据集请求(部分更新)"""
name: Optional[str] = Field(None, min_length=1, max_length=128)
description: Optional[str] = Field(None, max_length=500)
chunk_method: Optional[str] = None
class DatasetListResponse(BaseModel):
"""数据集列表响应(分页)"""
items: list
total: int
page: int
page_size: int
class FileListResponse(BaseModel):
"""文件列表响应(分页)"""
items: list
total: int
page: int
page_size: int
class ChunkListResponse(BaseModel):
"""切片列表响应(分页)"""
items: list
total: int
page: int
page_size: int
# ============== 数据集端点 ==============
@router.get("/datasets", response_model=DatasetListResponse)
async def list_datasets(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""获取当前用户的数据集列表(支持分页和搜索)"""
user_id, username = user_info
return await kb_service.list_datasets(
user_id=user_id,
page=page,
page_size=page_size,
search=search
)
@router.post("/datasets", status_code=201)
async def create_dataset(
data: DatasetCreate,
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""创建数据集并关联到当前用户"""
try:
user_id, username = user_info
dataset = await kb_service.create_dataset(
user_id=user_id,
name=data.name,
description=data.description,
chunk_method=data.chunk_method
)
return dataset
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Failed to create dataset: {e}")
raise HTTPException(status_code=500, detail=f"创建数据集失败: {str(e)}")
@router.get("/datasets/{dataset_id}")
async def get_dataset(
dataset_id: str,
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""获取数据集详情(仅限自己的数据集)"""
user_id, username = user_info
dataset = await kb_service.get_dataset(dataset_id, user_id=user_id)
if not dataset:
raise HTTPException(status_code=404, detail="数据集不存在")
return dataset
@router.patch("/datasets/{dataset_id}")
async def update_dataset(
dataset_id: str,
data: DatasetUpdate,
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""更新数据集(部分更新)"""
try:
user_id, username = user_info
# 只传递非 None 的字段
updates = data.model_dump(exclude_unset=True)
if not updates:
raise HTTPException(status_code=400, detail="没有提供要更新的字段")
dataset = await kb_service.update_dataset(dataset_id, updates, user_id=user_id)
if not dataset:
raise HTTPException(status_code=404, detail="数据集不存在")
return dataset
except Exception as e:
logger.error(f"Failed to update dataset: {e}")
raise HTTPException(status_code=500, detail=f"更新数据集失败: {str(e)}")
@router.delete("/datasets/{dataset_id}")
async def delete_dataset(
dataset_id: str,
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""删除数据集"""
user_id, username = user_info
success = await kb_service.delete_dataset(dataset_id, user_id=user_id)
if not success:
raise HTTPException(status_code=404, detail="数据集不存在")
return {"success": True, "message": "数据集已删除"}
# ============== 文件端点 ==============
@router.get("/datasets/{dataset_id}/files", response_model=FileListResponse)
async def list_dataset_files(
dataset_id: str,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""获取数据集内文件列表(分页,仅限自己的数据集)"""
user_id, username = user_info
return await kb_service.list_files(dataset_id, user_id=user_id, page=page, page_size=page_size)
@router.post("/datasets/{dataset_id}/files")
async def upload_file(
dataset_id: str,
file: UploadFile = File(...),
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""
上传文件到数据集(流式处理)
支持的文件类型: PDF, DOCX, TXT, MD, CSV
最大文件大小: 100MB
"""
try:
user_id, username = user_info
result = await kb_service.upload_file(dataset_id, user_id=user_id, file=file)
return result
except ValueError as e:
if "File validation failed" in str(e) or "not belong to you" in str(e):
raise HTTPException(status_code=400, detail=str(e))
logger.error(f"Failed to upload file: {e}")
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
except Exception as e:
logger.error(f"Failed to upload file: {e}")
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
@router.delete("/datasets/{dataset_id}/files/{document_id}")
async def delete_file(
dataset_id: str,
document_id: str,
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""删除文件"""
user_id, username = user_info
success = await kb_service.delete_file(dataset_id, document_id, user_id=user_id)
if not success:
raise HTTPException(status_code=404, detail="文件不存在")
return {"success": True}
# ============== 切片端点 ==============
@router.get("/datasets/{dataset_id}/chunks", response_model=ChunkListResponse)
async def list_chunks(
dataset_id: str,
document_id: Optional[str] = Query(None, description="文档 ID可选"),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""获取数据集内切片列表(分页,仅限自己的数据集)"""
user_id, username = user_info
return await kb_service.list_chunks(
user_id=user_id,
dataset_id=dataset_id,
document_id=document_id,
page=page,
page_size=page_size
)
@router.delete("/datasets/{dataset_id}/chunks/{chunk_id}")
async def delete_chunk(
dataset_id: str,
chunk_id: str,
document_id: str = Query(..., description="文档 ID"),
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""删除切片"""
user_id, username = user_info
success = await kb_service.delete_chunk(dataset_id, document_id, chunk_id, user_id=user_id)
if not success:
raise HTTPException(status_code=404, detail="切片不存在")
return {"success": True}
# ============== 文档解析端点 ==============
@router.post("/datasets/{dataset_id}/documents/{document_id}/parse")
async def parse_document(
dataset_id: str,
document_id: str,
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""开始解析文档"""
try:
user_id, username = user_info
result = await kb_service.parse_document(dataset_id, document_id, user_id=user_id)
return result
except Exception as e:
logger.error(f"Failed to parse document: {e}")
raise HTTPException(status_code=500, detail=f"启动解析失败: {str(e)}")
@router.post("/datasets/{dataset_id}/documents/{document_id}/cancel-parse")
async def cancel_parse_document(
dataset_id: str,
document_id: str,
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""取消解析文档"""
try:
user_id, username = user_info
result = await kb_service.cancel_parse_document(dataset_id, document_id, user_id=user_id)
return result
except Exception as e:
logger.error(f"Failed to cancel parse: {e}")
raise HTTPException(status_code=500, detail=f"取消解析失败: {str(e)}")
# ============== Bot 数据集关联端点 ==============
@router.get("/bots/{bot_id}/datasets")
async def get_bot_datasets(
bot_id: str,
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""
获取 bot 关联的数据集 ID 列表
用于 MCP 服务器通过 bot_id 获取对应的数据集 IDs
"""
try:
dataset_ids = await kb_service.get_dataset_ids_by_bot(bot_id)
return {"dataset_ids": dataset_ids}
except Exception as e:
logger.error(f"Failed to get datasets for bot {bot_id}: {e}")
raise HTTPException(status_code=500, detail=f"获取数据集失败: {str(e)}")