add rag flow
This commit is contained in:
parent
742eaf0a1c
commit
bd39a53507
@ -120,6 +120,21 @@ CREATE INDEX IF NOT EXISTS idx_bot_shares_bot_id ON bot_shares(bot_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_bot_shares_user_id ON bot_shares(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_bot_shares_shared_by ON bot_shares(shared_by);
|
||||
|
||||
-- 9. 创建 user_datasets 表(用户与 RAGFlow 数据集的关联表)
|
||||
CREATE TABLE IF NOT EXISTS user_datasets (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES agent_user(id) ON DELETE CASCADE,
|
||||
dataset_id VARCHAR(255) NOT NULL, -- RAGFlow 返回的 dataset_id
|
||||
dataset_name VARCHAR(255), -- 冗余存储数据集名称,方便查询
|
||||
owner BOOLEAN DEFAULT TRUE, -- 是否为所有者(预留分享功能)
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||||
UNIQUE(user_id, dataset_id)
|
||||
);
|
||||
|
||||
-- user_datasets 索引
|
||||
CREATE INDEX IF NOT EXISTS idx_user_datasets_user_id ON user_datasets(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_datasets_dataset_id ON user_datasets(dataset_id);
|
||||
|
||||
-- ===========================
|
||||
-- 默认 Admin 账号
|
||||
-- 用户名: admin
|
||||
|
||||
@ -71,7 +71,7 @@ from utils.log_util.logger import init_with_fastapi
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
# Import route modules
|
||||
from routes import chat, files, projects, system, skill_manager, database, bot_manager
|
||||
from routes import chat, files, projects, system, skill_manager, database, bot_manager, knowledge_base
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@ -125,7 +125,14 @@ async def lifespan(app: FastAPI):
|
||||
except Exception as e:
|
||||
logger.warning(f"Bot Manager table initialization failed (non-fatal): {e}")
|
||||
|
||||
# 6. 启动 checkpoint 清理调度器
|
||||
# 6. 初始化 Knowledge Base 表
|
||||
try:
|
||||
await knowledge_base.init_knowledge_base_tables()
|
||||
logger.info("Knowledge Base tables initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Knowledge Base table initialization failed (non-fatal): {e}")
|
||||
|
||||
# 7. 启动 checkpoint 清理调度器
|
||||
if CHECKPOINT_CLEANUP_ENABLED:
|
||||
# 启动时立即执行一次清理
|
||||
try:
|
||||
@ -187,6 +194,9 @@ app.include_router(bot_manager.router)
|
||||
# 注册文件管理API路由
|
||||
app.include_router(file_manager_router)
|
||||
|
||||
# 注册知识库API路由
|
||||
app.include_router(knowledge_base.router, prefix="/api/v1/knowledge-base", tags=["knowledge-base"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动 FastAPI 应用
|
||||
|
||||
@ -2,12 +2,12 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"rag_retrieve": {
|
||||
"transport": "stdio",
|
||||
"command": "python",
|
||||
"args": [
|
||||
"./mcp/rag_retrieve_server.py",
|
||||
"{bot_id}"
|
||||
]
|
||||
"transport": "http",
|
||||
"url": "http://100.77.70.35:9382/mcp/",
|
||||
"headers": {
|
||||
"api_key": "ragflow-MRqxnDnYZ1yp5kklDMIlKH4f1qezvXIngSMGPhu1AG8",
|
||||
"X-Dataset-Ids": "{dataset_ids}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
874
mcp/rag_flow_server.py
Normal file
874
mcp/rag_flow_server.py
Normal file
@ -0,0 +1,874 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
import httpx
|
||||
import mcp.types as types
|
||||
from mcp.server.lowlevel import Server
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.routing import Mount, Route
|
||||
from strenum import StrEnum
|
||||
|
||||
|
||||
class LaunchMode(StrEnum):
|
||||
SELF_HOST = "self-host"
|
||||
HOST = "host"
|
||||
|
||||
|
||||
class Transport(StrEnum):
|
||||
SSE = "sse"
|
||||
STEAMABLE_HTTP = "streamable-http"
|
||||
|
||||
|
||||
BASE_URL = "http://127.0.0.1:9380"
|
||||
HOST = "127.0.0.1"
|
||||
PORT = "9382"
|
||||
HOST_API_KEY = ""
|
||||
MODE = ""
|
||||
TRANSPORT_SSE_ENABLED = True
|
||||
TRANSPORT_STREAMABLE_HTTP_ENABLED = True
|
||||
JSON_RESPONSE = True
|
||||
|
||||
|
||||
class RAGFlowConnector:
|
||||
_MAX_DATASET_CACHE = 32
|
||||
_CACHE_TTL = 300
|
||||
|
||||
_dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts)
|
||||
_document_metadata_cache: OrderedDict[str, tuple[list[tuple[str, dict]], float | int]] = OrderedDict() # "dataset_id" -> ([(document_id, doc_metadata)], expiry_ts)
|
||||
|
||||
def __init__(self, base_url: str, version="v1"):
|
||||
self.base_url = base_url
|
||||
self.version = version
|
||||
self.api_url = f"{self.base_url}/api/{self.version}"
|
||||
self._async_client = None
|
||||
|
||||
async def _get_client(self):
|
||||
if self._async_client is None:
|
||||
self._async_client = httpx.AsyncClient(timeout=httpx.Timeout(60.0))
|
||||
return self._async_client
|
||||
|
||||
async def close(self):
|
||||
if self._async_client is not None:
|
||||
await self._async_client.aclose()
|
||||
self._async_client = None
|
||||
|
||||
async def _post(self, path, json=None, stream=False, files=None, api_key: str = ""):
|
||||
if not api_key:
|
||||
return None
|
||||
client = await self._get_client()
|
||||
res = await client.post(url=self.api_url + path, json=json, headers={"Authorization": f"Bearer {api_key}"})
|
||||
return res
|
||||
|
||||
async def _get(self, path, params=None, api_key: str = ""):
|
||||
if not api_key:
|
||||
return None
|
||||
client = await self._get_client()
|
||||
res = await client.get(url=self.api_url + path, params=params, headers={"Authorization": f"Bearer {api_key}"})
|
||||
return res
|
||||
|
||||
def _is_cache_valid(self, ts):
|
||||
return time.time() < ts
|
||||
|
||||
def _get_expiry_timestamp(self):
|
||||
offset = random.randint(-30, 30)
|
||||
return time.time() + self._CACHE_TTL + offset
|
||||
|
||||
def _get_cached_dataset_metadata(self, dataset_id):
|
||||
entry = self._dataset_metadata_cache.get(dataset_id)
|
||||
if entry:
|
||||
data, ts = entry
|
||||
if self._is_cache_valid(ts):
|
||||
self._dataset_metadata_cache.move_to_end(dataset_id)
|
||||
return data
|
||||
return None
|
||||
|
||||
def _set_cached_dataset_metadata(self, dataset_id, metadata):
|
||||
self._dataset_metadata_cache[dataset_id] = (metadata, self._get_expiry_timestamp())
|
||||
self._dataset_metadata_cache.move_to_end(dataset_id)
|
||||
if len(self._dataset_metadata_cache) > self._MAX_DATASET_CACHE:
|
||||
self._dataset_metadata_cache.popitem(last=False)
|
||||
|
||||
def _get_cached_document_metadata_by_dataset(self, dataset_id):
|
||||
entry = self._document_metadata_cache.get(dataset_id)
|
||||
if entry:
|
||||
data_list, ts = entry
|
||||
if self._is_cache_valid(ts):
|
||||
self._document_metadata_cache.move_to_end(dataset_id)
|
||||
return {doc_id: doc_meta for doc_id, doc_meta in data_list}
|
||||
return None
|
||||
|
||||
def _set_cached_document_metadata_by_dataset(self, dataset_id, doc_id_meta_list):
|
||||
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
|
||||
self._document_metadata_cache.move_to_end(dataset_id)
|
||||
|
||||
async def list_datasets(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
page: int = 1,
|
||||
page_size: int = 1000,
|
||||
orderby: str = "create_time",
|
||||
desc: bool = True,
|
||||
id: str | None = None,
|
||||
name: str | None = None,
|
||||
):
|
||||
res = await self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}, api_key=api_key)
|
||||
if not res or res.status_code != 200:
|
||||
raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])
|
||||
|
||||
res = res.json()
|
||||
if res.get("code") == 0:
|
||||
result_list = []
|
||||
for data in res["data"]:
|
||||
d = {"description": data["description"], "id": data["id"]}
|
||||
result_list.append(json.dumps(d, ensure_ascii=False))
|
||||
return "\n".join(result_list)
|
||||
return ""
|
||||
|
||||
async def retrieval(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
dataset_ids,
|
||||
document_ids=None,
|
||||
question="",
|
||||
page=1,
|
||||
page_size=30,
|
||||
similarity_threshold=0.2,
|
||||
vector_similarity_weight=0.3,
|
||||
top_k=1024,
|
||||
rerank_id: str | None = None,
|
||||
keyword: bool = False,
|
||||
force_refresh: bool = False,
|
||||
):
|
||||
if document_ids is None:
|
||||
document_ids = []
|
||||
|
||||
# If no dataset_ids provided or empty list, get all available dataset IDs
|
||||
if not dataset_ids:
|
||||
dataset_list_str = await self.list_datasets(api_key=api_key)
|
||||
dataset_ids = []
|
||||
|
||||
# Parse the dataset list to extract IDs
|
||||
if dataset_list_str:
|
||||
for line in dataset_list_str.strip().split("\n"):
|
||||
if line.strip():
|
||||
try:
|
||||
dataset_info = json.loads(line.strip())
|
||||
dataset_ids.append(dataset_info["id"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# Skip malformed lines
|
||||
continue
|
||||
|
||||
data_json = {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"similarity_threshold": similarity_threshold,
|
||||
"vector_similarity_weight": vector_similarity_weight,
|
||||
"top_k": top_k,
|
||||
"rerank_id": rerank_id,
|
||||
"keyword": keyword,
|
||||
"question": question,
|
||||
"dataset_ids": dataset_ids,
|
||||
"document_ids": document_ids,
|
||||
}
|
||||
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
|
||||
res = await self._post("/retrieval", json=data_json, api_key=api_key)
|
||||
if not res or res.status_code != 200:
|
||||
raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])
|
||||
|
||||
res = res.json()
|
||||
if res.get("code") == 0:
|
||||
data = res["data"]
|
||||
chunks = []
|
||||
|
||||
# Cache document metadata and dataset information
|
||||
document_cache, dataset_cache = await self._get_document_metadata_cache(dataset_ids, api_key=api_key, force_refresh=force_refresh)
|
||||
|
||||
# Process chunks with enhanced field mapping including per-chunk metadata
|
||||
for chunk_data in data.get("chunks", []):
|
||||
enhanced_chunk = self._map_chunk_fields(chunk_data, dataset_cache, document_cache)
|
||||
chunks.append(enhanced_chunk)
|
||||
|
||||
# Build structured response (no longer need response-level document_metadata)
|
||||
response = {
|
||||
"chunks": chunks,
|
||||
"pagination": {
|
||||
"page": data.get("page", page),
|
||||
"page_size": data.get("page_size", page_size),
|
||||
"total_chunks": data.get("total", len(chunks)),
|
||||
"total_pages": (data.get("total", len(chunks)) + page_size - 1) // page_size,
|
||||
},
|
||||
"query_info": {
|
||||
"question": question,
|
||||
"similarity_threshold": similarity_threshold,
|
||||
"vector_weight": vector_similarity_weight,
|
||||
"keyword_search": keyword,
|
||||
"dataset_count": len(dataset_ids),
|
||||
},
|
||||
}
|
||||
|
||||
return [types.TextContent(type="text", text=json.dumps(response, ensure_ascii=False))]
|
||||
|
||||
raise Exception([types.TextContent(type="text", text=res.get("message"))])
|
||||
|
||||
async def _get_document_metadata_cache(self, dataset_ids, *, api_key: str, force_refresh=False):
|
||||
"""Cache document metadata for all documents in the specified datasets"""
|
||||
document_cache = {}
|
||||
dataset_cache = {}
|
||||
|
||||
try:
|
||||
for dataset_id in dataset_ids:
|
||||
dataset_meta = None if force_refresh else self._get_cached_dataset_metadata(dataset_id)
|
||||
if not dataset_meta:
|
||||
# First get dataset info for name
|
||||
dataset_res = await self._get("/datasets", {"id": dataset_id, "page_size": 1}, api_key=api_key)
|
||||
if dataset_res and dataset_res.status_code == 200:
|
||||
dataset_data = dataset_res.json()
|
||||
if dataset_data.get("code") == 0 and dataset_data.get("data"):
|
||||
dataset_info = dataset_data["data"][0]
|
||||
dataset_meta = {"name": dataset_info.get("name", "Unknown"), "description": dataset_info.get("description", "")}
|
||||
self._set_cached_dataset_metadata(dataset_id, dataset_meta)
|
||||
if dataset_meta:
|
||||
dataset_cache[dataset_id] = dataset_meta
|
||||
|
||||
docs = None if force_refresh else self._get_cached_document_metadata_by_dataset(dataset_id)
|
||||
if docs is None:
|
||||
page = 1
|
||||
page_size = 30
|
||||
doc_id_meta_list = []
|
||||
docs = {}
|
||||
while page:
|
||||
docs_res = await self._get(f"/datasets/{dataset_id}/documents?page={page}", api_key=api_key)
|
||||
if not docs_res:
|
||||
break
|
||||
docs_data = docs_res.json()
|
||||
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
|
||||
for doc in docs_data["data"]["docs"]:
|
||||
doc_id = doc.get("id")
|
||||
if not doc_id:
|
||||
continue
|
||||
doc_meta = {
|
||||
"document_id": doc_id,
|
||||
"name": doc.get("name", ""),
|
||||
"location": doc.get("location", ""),
|
||||
"type": doc.get("type", ""),
|
||||
"size": doc.get("size"),
|
||||
"chunk_count": doc.get("chunk_count"),
|
||||
"create_date": doc.get("create_date", ""),
|
||||
"update_date": doc.get("update_date", ""),
|
||||
"token_count": doc.get("token_count"),
|
||||
"thumbnail": doc.get("thumbnail", ""),
|
||||
"dataset_id": doc.get("dataset_id", dataset_id),
|
||||
"meta_fields": doc.get("meta_fields", {}),
|
||||
}
|
||||
doc_id_meta_list.append((doc_id, doc_meta))
|
||||
docs[doc_id] = doc_meta
|
||||
|
||||
page += 1
|
||||
if docs_data.get("data", {}).get("total", 0) - page * page_size <= 0:
|
||||
page = None
|
||||
|
||||
self._set_cached_document_metadata_by_dataset(dataset_id, doc_id_meta_list)
|
||||
if docs:
|
||||
document_cache.update(docs)
|
||||
|
||||
except Exception as e:
|
||||
# Gracefully handle metadata cache failures
|
||||
logging.error(f"Problem building the document metadata cache: {str(e)}")
|
||||
pass
|
||||
|
||||
return document_cache, dataset_cache
|
||||
|
||||
def _map_chunk_fields(self, chunk_data, dataset_cache, document_cache):
|
||||
"""Preserve all original API fields and add per-chunk document metadata"""
|
||||
# Start with ALL raw data from API (preserve everything like original version)
|
||||
mapped = dict(chunk_data)
|
||||
|
||||
# Add dataset name enhancement
|
||||
dataset_id = chunk_data.get("dataset_id") or chunk_data.get("kb_id")
|
||||
if dataset_id and dataset_id in dataset_cache:
|
||||
mapped["dataset_name"] = dataset_cache[dataset_id]["name"]
|
||||
else:
|
||||
mapped["dataset_name"] = "Unknown"
|
||||
|
||||
# Add document name convenience field
|
||||
mapped["document_name"] = chunk_data.get("document_keyword", "")
|
||||
|
||||
# Add per-chunk document metadata
|
||||
document_id = chunk_data.get("document_id")
|
||||
if document_id and document_id in document_cache:
|
||||
mapped["document_metadata"] = document_cache[document_id]
|
||||
|
||||
return mapped
|
||||
|
||||
|
||||
class RAGFlowCtx:
|
||||
def __init__(self, connector: RAGFlowConnector):
|
||||
self.conn = connector
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_lifespan(server: Server) -> AsyncIterator[dict]:
|
||||
ctx = RAGFlowCtx(RAGFlowConnector(base_url=BASE_URL))
|
||||
|
||||
logging.info("Legacy SSE application started with StreamableHTTP session manager!")
|
||||
try:
|
||||
yield {"ragflow_ctx": ctx}
|
||||
finally:
|
||||
await ctx.conn.close()
|
||||
logging.info("Legacy SSE application shutting down...")
|
||||
|
||||
|
||||
app = Server("ragflow-mcp-server", lifespan=sse_lifespan)
|
||||
AUTH_TOKEN_STATE_KEY = "ragflow_auth_token"
|
||||
|
||||
|
||||
def _to_text(value: Any) -> str:
|
||||
if isinstance(value, bytes):
|
||||
return value.decode(errors="ignore")
|
||||
return str(value)
|
||||
|
||||
|
||||
def _extract_token_from_headers(headers: Any) -> str | None:
|
||||
if not headers or not hasattr(headers, "get"):
|
||||
return None
|
||||
|
||||
auth_keys = ("authorization", "Authorization", b"authorization", b"Authorization")
|
||||
for key in auth_keys:
|
||||
auth = headers.get(key)
|
||||
if not auth:
|
||||
continue
|
||||
auth_text = _to_text(auth).strip()
|
||||
if auth_text.lower().startswith("bearer "):
|
||||
token = auth_text[7:].strip()
|
||||
if token:
|
||||
return token
|
||||
|
||||
api_key_keys = ("api_key", "x-api-key", "Api-Key", "X-API-Key", b"api_key", b"x-api-key", b"Api-Key", b"X-API-Key")
|
||||
for key in api_key_keys:
|
||||
token = headers.get(key)
|
||||
if token:
|
||||
token_text = _to_text(token).strip()
|
||||
if token_text:
|
||||
return token_text
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_dataset_ids_from_headers(headers: Any) -> list[str]:
|
||||
"""Extract dataset_ids from request headers.
|
||||
|
||||
Supports multiple header formats:
|
||||
- Comma-separated string: "dataset1,dataset2,dataset3"
|
||||
- JSON array string: '["dataset1","dataset2","dataset3"]'
|
||||
- Repeated headers (x-dataset-id)
|
||||
|
||||
Returns:
|
||||
List of dataset IDs. Empty list if none found or invalid format.
|
||||
"""
|
||||
if not headers or not hasattr(headers, "get"):
|
||||
return []
|
||||
|
||||
# Try various header key variations
|
||||
header_keys = (
|
||||
"x-dataset-ids", "X-Dataset-Ids", "X-DATASET-IDS",
|
||||
"dataset_ids", "Dataset-Ids", "DATASET_IDS",
|
||||
"x-datasets", "X-Datasets", "X-DATASETS",
|
||||
b"x-dataset-ids", b"X-Dataset-Ids", b"X-DATASET-IDS",
|
||||
b"dataset_ids", b"Dataset-Ids", b"DATASET_IDS",
|
||||
b"x-datasets", b"X-Datasets", b"X-DATASETS",
|
||||
)
|
||||
|
||||
for key in header_keys:
|
||||
value = headers.get(key)
|
||||
if not value:
|
||||
continue
|
||||
|
||||
value_text = _to_text(value).strip()
|
||||
if not value_text:
|
||||
continue
|
||||
|
||||
# Try parsing as JSON array first
|
||||
if value_text.startswith("["):
|
||||
try:
|
||||
dataset_ids = json.loads(value_text)
|
||||
if isinstance(dataset_ids, list):
|
||||
return [str(ds_id).strip() for ds_id in dataset_ids if str(ds_id).strip()]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try parsing as comma-separated string
|
||||
dataset_ids = [ds_id.strip() for ds_id in value_text.split(",") if ds_id.strip()]
|
||||
if dataset_ids:
|
||||
return dataset_ids
|
||||
|
||||
# Try repeated header format (x-dataset-id)
|
||||
single_header_keys = (
|
||||
"x-dataset-id", "X-Dataset-Id", "X-DATASET-ID",
|
||||
"dataset_id", "Dataset-Id", "DATASET_ID",
|
||||
b"x-dataset-id", b"X-Dataset-Id", b"X-DATASET-ID",
|
||||
b"dataset_id", b"Dataset-Id", b"DATASET_ID",
|
||||
)
|
||||
|
||||
dataset_ids = []
|
||||
for key in single_header_keys:
|
||||
value = headers.get(key)
|
||||
if value:
|
||||
value_text = _to_text(value).strip()
|
||||
if value_text:
|
||||
dataset_ids.append(value_text)
|
||||
|
||||
return dataset_ids
|
||||
|
||||
|
||||
def _extract_token_from_request(request: Any) -> str | None:
|
||||
if request is None:
|
||||
return None
|
||||
|
||||
state = getattr(request, "state", None)
|
||||
if state is not None:
|
||||
token = getattr(state, AUTH_TOKEN_STATE_KEY, None)
|
||||
if token:
|
||||
return token
|
||||
|
||||
token = _extract_token_from_headers(getattr(request, "headers", None))
|
||||
if token and state is not None:
|
||||
setattr(state, AUTH_TOKEN_STATE_KEY, token)
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def _extract_dataset_ids_from_request(request: Any) -> list[str]:
|
||||
"""Extract dataset_ids from a request object.
|
||||
|
||||
First checks state for cached dataset_ids, then extracts from headers.
|
||||
|
||||
Returns:
|
||||
List of dataset IDs. Empty list if none found.
|
||||
"""
|
||||
if request is None:
|
||||
return []
|
||||
|
||||
state = getattr(request, "state", None)
|
||||
if state is not None:
|
||||
dataset_ids = getattr(state, "ragflow_dataset_ids", None)
|
||||
if dataset_ids:
|
||||
return dataset_ids
|
||||
|
||||
headers = getattr(request, "headers", None)
|
||||
dataset_ids = _extract_dataset_ids_from_headers(headers)
|
||||
|
||||
if dataset_ids and state is not None:
|
||||
setattr(state, "ragflow_dataset_ids", dataset_ids)
|
||||
|
||||
return dataset_ids
|
||||
|
||||
|
||||
def with_api_key(required: bool = True):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
ctx = app.request_context
|
||||
ragflow_ctx = ctx.lifespan_context.get("ragflow_ctx")
|
||||
if not ragflow_ctx:
|
||||
raise ValueError("Get RAGFlow Context failed")
|
||||
|
||||
connector = ragflow_ctx.conn
|
||||
api_key = HOST_API_KEY
|
||||
request = getattr(ctx, "request", None)
|
||||
|
||||
if MODE == LaunchMode.HOST:
|
||||
api_key = _extract_token_from_request(request) or ""
|
||||
if required and not api_key:
|
||||
raise ValueError("RAGFlow API key or Bearer token is required.")
|
||||
|
||||
return await func(*args, connector=connector, api_key=api_key, request=request, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@app.list_tools()
|
||||
@with_api_key(required=True)
|
||||
async def list_tools(*, connector: RAGFlowConnector, api_key: str) -> list[types.Tool]:
|
||||
dataset_description = await connector.list_datasets(api_key=api_key)
|
||||
|
||||
return [
|
||||
types.Tool(
|
||||
name="ragflow_retrieval",
|
||||
description="Retrieve relevant chunks from the RAGFlow retrieve interface based on the question. You can optionally specify dataset_ids to search only specific datasets, or omit dataset_ids entirely to search across ALL available datasets. You can also optionally specify document_ids to search within specific documents. When dataset_ids is not provided or is empty, the system will automatically search across all available datasets. Below is the list of all available datasets, including their descriptions and IDs:"
|
||||
+ dataset_description,
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dataset_ids": {"type": "array", "items": {"type": "string"}, "description": "Optional array of dataset IDs to search. If not provided or empty, all datasets will be searched."},
|
||||
"document_ids": {"type": "array", "items": {"type": "string"}, "description": "Optional array of document IDs to search within."},
|
||||
"question": {"type": "string", "description": "The question or query to search for."},
|
||||
"page": {
|
||||
"type": "integer",
|
||||
"description": "Page number for pagination",
|
||||
"default": 1,
|
||||
"minimum": 1,
|
||||
},
|
||||
"page_size": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return per page (default: 10, max recommended: 50 to avoid token limits)",
|
||||
"default": 10,
|
||||
"minimum": 1,
|
||||
"maximum": 100,
|
||||
},
|
||||
"similarity_threshold": {
|
||||
"type": "number",
|
||||
"description": "Minimum similarity threshold for results",
|
||||
"default": 0.2,
|
||||
"minimum": 0.0,
|
||||
"maximum": 1.0,
|
||||
},
|
||||
"vector_similarity_weight": {
|
||||
"type": "number",
|
||||
"description": "Weight for vector similarity vs term similarity",
|
||||
"default": 0.3,
|
||||
"minimum": 0.0,
|
||||
"maximum": 1.0,
|
||||
},
|
||||
"keyword": {
|
||||
"type": "boolean",
|
||||
"description": "Enable keyword-based search",
|
||||
"default": False,
|
||||
},
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"description": "Maximum results to consider before ranking",
|
||||
"default": 1024,
|
||||
"minimum": 1,
|
||||
"maximum": 1024,
|
||||
},
|
||||
"rerank_id": {
|
||||
"type": "string",
|
||||
"description": "Optional reranking model identifier",
|
||||
},
|
||||
"force_refresh": {
|
||||
"type": "boolean",
|
||||
"description": "Set to true only if fresh dataset and document metadata is explicitly required. Otherwise, cached metadata is used (default: false).",
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["question"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@app.call_tool()
|
||||
@with_api_key(required=True)
|
||||
async def call_tool(
|
||||
name: str,
|
||||
arguments: dict,
|
||||
*,
|
||||
connector: RAGFlowConnector,
|
||||
api_key: str,
|
||||
request: Any = None,
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
if name == "ragflow_retrieval":
|
||||
document_ids = arguments.get("document_ids", [])
|
||||
dataset_ids = arguments.get("dataset_ids", [])
|
||||
question = arguments.get("question", "")
|
||||
page = arguments.get("page", 1)
|
||||
page_size = arguments.get("page_size", 10)
|
||||
similarity_threshold = arguments.get("similarity_threshold", 0.2)
|
||||
vector_similarity_weight = arguments.get("vector_similarity_weight", 0.3)
|
||||
keyword = arguments.get("keyword", False)
|
||||
top_k = arguments.get("top_k", 1024)
|
||||
rerank_id = arguments.get("rerank_id")
|
||||
force_refresh = arguments.get("force_refresh", False)
|
||||
|
||||
# If no dataset_ids provided or empty list, try to extract from request headers
|
||||
if not dataset_ids:
|
||||
dataset_ids = _extract_dataset_ids_from_request(request)
|
||||
|
||||
# If still no dataset_ids, get all available dataset IDs
|
||||
if not dataset_ids:
|
||||
dataset_list_str = await connector.list_datasets(api_key=api_key)
|
||||
dataset_ids = []
|
||||
|
||||
# Parse the dataset list to extract IDs
|
||||
if dataset_list_str:
|
||||
for line in dataset_list_str.strip().split("\n"):
|
||||
if line.strip():
|
||||
try:
|
||||
dataset_info = json.loads(line.strip())
|
||||
dataset_ids.append(dataset_info["id"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# Skip malformed lines
|
||||
continue
|
||||
|
||||
return await connector.retrieval(
|
||||
api_key=api_key,
|
||||
dataset_ids=dataset_ids,
|
||||
document_ids=document_ids,
|
||||
question=question,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
similarity_threshold=similarity_threshold,
|
||||
vector_similarity_weight=vector_similarity_weight,
|
||||
keyword=keyword,
|
||||
top_k=top_k,
|
||||
rerank_id=rerank_id,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
raise ValueError(f"Tool not found: {name}")
|
||||
|
||||
|
||||
def create_starlette_app():
|
||||
routes = []
|
||||
middleware = None
|
||||
if MODE == LaunchMode.HOST:
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
class AuthMiddleware:
|
||||
def __init__(self, app: ASGIApp):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
path = scope["path"]
|
||||
if path.startswith("/messages/") or path.startswith("/sse") or path.startswith("/mcp"):
|
||||
headers = dict(scope["headers"])
|
||||
token = _extract_token_from_headers(headers)
|
||||
|
||||
if not token:
|
||||
response = JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
scope.setdefault("state", {})[AUTH_TOKEN_STATE_KEY] = token
|
||||
|
||||
# Extract and cache dataset_ids from headers
|
||||
dataset_ids = _extract_dataset_ids_from_headers(headers)
|
||||
if dataset_ids:
|
||||
scope.setdefault("state", {})["ragflow_dataset_ids"] = dataset_ids
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
middleware = [Middleware(AuthMiddleware)]
|
||||
|
||||
# Add SSE routes if enabled
|
||||
if TRANSPORT_SSE_ENABLED:
|
||||
from mcp.server.sse import SseServerTransport
|
||||
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
async def handle_sse(request):
|
||||
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
||||
await app.run(streams[0], streams[1], app.create_initialization_options(experimental_capabilities={"headers": dict(request.headers)}))
|
||||
return Response()
|
||||
|
||||
routes.extend(
|
||||
[
|
||||
Route("/sse", endpoint=handle_sse, methods=["GET"]),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
]
|
||||
)
|
||||
|
||||
# Add streamable HTTP route if enabled
|
||||
streamablehttp_lifespan = None
|
||||
if TRANSPORT_STREAMABLE_HTTP_ENABLED:
|
||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
session_manager = StreamableHTTPSessionManager(
|
||||
app=app,
|
||||
event_store=None,
|
||||
json_response=JSON_RESPONSE,
|
||||
stateless=True,
|
||||
)
|
||||
|
||||
class StreamableHTTPEntry:
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await session_manager.handle_request(scope, receive, send)
|
||||
|
||||
streamable_http_entry = StreamableHTTPEntry()
|
||||
|
||||
@asynccontextmanager
|
||||
async def streamablehttp_lifespan(app: Starlette) -> AsyncIterator[None]:
|
||||
async with session_manager.run():
|
||||
logging.info("StreamableHTTP application started with StreamableHTTP session manager!")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logging.info("StreamableHTTP application shutting down...")
|
||||
|
||||
routes.extend(
|
||||
[
|
||||
Route("/mcp", endpoint=streamable_http_entry, methods=["GET", "POST", "DELETE"]),
|
||||
Mount("/mcp", app=streamable_http_entry),
|
||||
]
|
||||
)
|
||||
|
||||
return Starlette(
|
||||
debug=True,
|
||||
routes=routes,
|
||||
middleware=middleware,
|
||||
lifespan=streamablehttp_lifespan,
|
||||
)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--base-url", type=str, default="http://127.0.0.1:9380", help="API base URL for RAGFlow backend")
|
||||
@click.option("--host", type=str, default="127.0.0.1", help="Host to bind the RAGFlow MCP server")
|
||||
@click.option("--port", type=int, default=9382, help="Port to bind the RAGFlow MCP server")
|
||||
@click.option(
|
||||
"--mode",
|
||||
type=click.Choice(["self-host", "host"]),
|
||||
default="self-host",
|
||||
help=("Launch mode:\n self-host: run MCP for a single tenant (requires --api-key)\n host: multi-tenant mode, users must provide Authorization headers"),
|
||||
)
|
||||
@click.option("--api-key", type=str, default="", help="API key to use when in self-host mode")
|
||||
@click.option(
|
||||
"--transport-sse-enabled/--no-transport-sse-enabled",
|
||||
default=True,
|
||||
help="Enable or disable legacy SSE transport mode (default: enabled)",
|
||||
)
|
||||
@click.option(
|
||||
"--transport-streamable-http-enabled/--no-transport-streamable-http-enabled",
|
||||
default=True,
|
||||
help="Enable or disable streamable-http transport mode (default: enabled)",
|
||||
)
|
||||
@click.option(
|
||||
"--json-response/--no-json-response",
|
||||
default=True,
|
||||
help="Enable or disable JSON response mode for streamable-http (default: enabled)",
|
||||
)
|
||||
def main(base_url, host, port, mode, api_key, transport_sse_enabled, transport_streamable_http_enabled, json_response):
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
def parse_bool_flag(key: str, default: bool) -> bool:
|
||||
val = os.environ.get(key, str(default))
|
||||
return str(val).strip().lower() in ("1", "true", "yes", "on")
|
||||
|
||||
global BASE_URL, HOST, PORT, MODE, HOST_API_KEY, TRANSPORT_SSE_ENABLED, TRANSPORT_STREAMABLE_HTTP_ENABLED, JSON_RESPONSE
|
||||
BASE_URL = os.environ.get("RAGFLOW_MCP_BASE_URL", base_url)
|
||||
HOST = os.environ.get("RAGFLOW_MCP_HOST", host)
|
||||
PORT = os.environ.get("RAGFLOW_MCP_PORT", str(port))
|
||||
MODE = os.environ.get("RAGFLOW_MCP_LAUNCH_MODE", mode)
|
||||
HOST_API_KEY = os.environ.get("RAGFLOW_MCP_HOST_API_KEY", api_key)
|
||||
TRANSPORT_SSE_ENABLED = parse_bool_flag("RAGFLOW_MCP_TRANSPORT_SSE_ENABLED", transport_sse_enabled)
|
||||
TRANSPORT_STREAMABLE_HTTP_ENABLED = parse_bool_flag("RAGFLOW_MCP_TRANSPORT_STREAMABLE_ENABLED", transport_streamable_http_enabled)
|
||||
JSON_RESPONSE = parse_bool_flag("RAGFLOW_MCP_JSON_RESPONSE", json_response)
|
||||
|
||||
if MODE == LaunchMode.SELF_HOST and not HOST_API_KEY:
|
||||
raise click.UsageError("--api-key is required when --mode is 'self-host'")
|
||||
|
||||
if not TRANSPORT_STREAMABLE_HTTP_ENABLED and JSON_RESPONSE:
|
||||
JSON_RESPONSE = False
|
||||
|
||||
print(
|
||||
r"""
|
||||
__ __ ____ ____ ____ _____ ______ _______ ____
|
||||
| \/ |/ ___| _ \ / ___|| ____| _ \ \ / / ____| _ \
|
||||
| |\/| | | | |_) | \___ \| _| | |_) \ \ / /| _| | |_) |
|
||||
| | | | |___| __/ ___) | |___| _ < \ V / | |___| _ <
|
||||
|_| |_|\____|_| |____/|_____|_| \_\ \_/ |_____|_| \_\
|
||||
""",
|
||||
flush=True,
|
||||
)
|
||||
print(f"MCP launch mode: {MODE}", flush=True)
|
||||
print(f"MCP host: {HOST}", flush=True)
|
||||
print(f"MCP port: {PORT}", flush=True)
|
||||
print(f"MCP base_url: {BASE_URL}", flush=True)
|
||||
|
||||
if not any([TRANSPORT_SSE_ENABLED, TRANSPORT_STREAMABLE_HTTP_ENABLED]):
|
||||
print("At least one transport should be enabled, enable streamable-http automatically", flush=True)
|
||||
TRANSPORT_STREAMABLE_HTTP_ENABLED = True
|
||||
|
||||
if TRANSPORT_SSE_ENABLED:
|
||||
print("SSE transport enabled: yes", flush=True)
|
||||
print("SSE endpoint available at /sse", flush=True)
|
||||
else:
|
||||
print("SSE transport enabled: no", flush=True)
|
||||
|
||||
if TRANSPORT_STREAMABLE_HTTP_ENABLED:
|
||||
print("Streamable HTTP transport enabled: yes", flush=True)
|
||||
print("Streamable HTTP endpoint available at /mcp", flush=True)
|
||||
if JSON_RESPONSE:
|
||||
print("Streamable HTTP mode: JSON response enabled", flush=True)
|
||||
else:
|
||||
print("Streamable HTTP mode: SSE over HTTP enabled", flush=True)
|
||||
else:
|
||||
print("Streamable HTTP transport enabled: no", flush=True)
|
||||
if JSON_RESPONSE:
|
||||
print("Warning: --json-response ignored because streamable transport is disabled.", flush=True)
|
||||
|
||||
uvicorn.run(
|
||||
create_starlette_app(),
|
||||
host=HOST,
|
||||
port=int(PORT),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Launch examples:
|
||||
|
||||
1. Self-host mode with both SSE and Streamable HTTP (in JSON response mode) enabled (default):
|
||||
uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
|
||||
--base-url=http://127.0.0.1:9380 \
|
||||
--mode=self-host --api-key=ragflow-xxxxx
|
||||
|
||||
2. Host mode (multi-tenant, clients must provide Authorization headers):
|
||||
uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
|
||||
--base-url=http://127.0.0.1:9380 \
|
||||
--mode=host
|
||||
|
||||
3. Disable legacy SSE (only streamable HTTP will be active):
|
||||
uv run mcp/server/server.py --no-transport-sse-enabled \
|
||||
--mode=self-host --api-key=ragflow-xxxxx
|
||||
|
||||
4. Disable streamable HTTP (only legacy SSE will be active):
|
||||
uv run mcp/server/server.py --no-transport-streamable-http-enabled \
|
||||
--mode=self-host --api-key=ragflow-xxxxx
|
||||
|
||||
5. Use streamable HTTP with SSE-style events (disable JSON response):
|
||||
uv run mcp/server/server.py --transport-streamable-http-enabled --no-json-response \
|
||||
--mode=self-host --api-key=ragflow-xxxxx
|
||||
|
||||
6. Disable both transports (for testing):
|
||||
uv run mcp/server/server.py --no-transport-sse-enabled --no-transport-streamable-http-enabled \
|
||||
--mode=self-host --api-key=ragflow-xxxxx
|
||||
"""
|
||||
main()
|
||||
869
plans/knowledge-base-module.md
Normal file
869
plans/knowledge-base-module.md
Normal file
@ -0,0 +1,869 @@
|
||||
# 知识库模块功能实现计划
|
||||
|
||||
> **Enhanced on:** 2025-02-10
|
||||
> **Sections enhanced:** 10
|
||||
> **Research agents used:** FastAPI best practices, Vue 3 composables, UI/UX patterns, RAGFlow SDK, File upload security, Architecture strategy, Code simplicity, Security sentinel, Performance oracle
|
||||
|
||||
---
|
||||
|
||||
## Enhancement Summary
|
||||
|
||||
### Key Improvements
|
||||
|
||||
1. **安全加固** - 添加文件类型验证、大小限制、API Key 管理
|
||||
2. **性能优化** - 流式文件上传、分页查询、连接池管理
|
||||
3. **架构分层** - 引入服务层和仓储模式,提高可测试性
|
||||
4. **代码简化** - 移除过度设计,遵循 YAGNI 原则
|
||||
5. **用户体验** - 完善空状态、加载状态、错误处理
|
||||
|
||||
### New Considerations Discovered
|
||||
|
||||
- RAGFlow 部署使用 HTTP(非 HTTPS),需要评估安全风险
|
||||
- 文件上传必须实现流式处理,避免内存溢出
|
||||
- 切片查询必须分页,否则大数据量会 OOM
|
||||
- API Key 应通过环境变量管理,不应硬编码
|
||||
|
||||
---
|
||||
|
||||
## 概述
|
||||
|
||||
在 qwen-client 项目上增加一个独立的知识库模块功能(与 bot 无关联),通过 RAGFlow SDK 实现知识库管理功能。
|
||||
|
||||
**架构设计:**
|
||||
```
|
||||
qwen-client (Vue 3) → qwen-agent (FastAPI) → RAGFlow (http://100.77.70.35:1080)
|
||||
```
|
||||
|
||||
## 需求背景
|
||||
|
||||
用户需要一个独立的知识库管理系统,可以:
|
||||
1. 创建和管理多个知识库(数据集)
|
||||
2. 向知识库上传文件
|
||||
3. 管理知识库内的文档切片
|
||||
4. 后续可与 bot 关联进行 RAG 检索
|
||||
|
||||
---
|
||||
|
||||
## 技术方案
|
||||
|
||||
### 后端实现 (qwen-agent)
|
||||
|
||||
#### 1. 环境配置
|
||||
|
||||
**文件:** `/utils/settings.py`
|
||||
|
||||
```python
|
||||
# ============================================================
|
||||
# RAGFlow Knowledge Base Configuration
|
||||
# ============================================================
|
||||
|
||||
# RAGFlow API 配置
|
||||
RAGFLOW_API_URL = os.getenv("RAGFLOW_API_URL", "http://100.77.70.35:1080")
|
||||
RAGFLOW_API_KEY = os.getenv("RAGFLOW_API_KEY", "") # 必须通过环境变量设置
|
||||
|
||||
# 文件上传配置
|
||||
RAGFLOW_MAX_UPLOAD_SIZE = int(os.getenv("RAGFLOW_MAX_UPLOAD_SIZE", str(100 * 1024 * 1024))) # 100MB
|
||||
RAGFLOW_ALLOWED_EXTENSIONS = os.getenv("RAGFLOW_ALLOWED_EXTENSIONS", "pdf,docx,txt,md,csv").split(",")
|
||||
|
||||
# 性能配置
|
||||
RAGFLOW_CONNECTION_TIMEOUT = int(os.getenv("RAGFLOW_CONNECTION_TIMEOUT", "30")) # 30秒
|
||||
RAGFLOW_MAX_CONCURRENT_UPLOADS = int(os.getenv("RAGFLOW_MAX_CONCURRENT_UPLOADS", "5"))
|
||||
```
|
||||
|
||||
#### 2. 依赖安装
|
||||
|
||||
**文件:** `/pyproject.toml`
|
||||
|
||||
在 `[tool.poetry.dependencies]` 添加:
|
||||
```toml
|
||||
ragflow-sdk = "^0.1.0"
|
||||
python-magic = "^0.4.27"
|
||||
aiofiles = "^24.1.0"
|
||||
```
|
||||
|
||||
执行:
|
||||
```bash
|
||||
poetry install
|
||||
poetry export -f requirements.txt -o requirements.txt --without-hashes
|
||||
```
|
||||
|
||||
#### 3. 项目结构
|
||||
|
||||
基于架构审查建议,采用分层设计:
|
||||
|
||||
```
|
||||
qwen-agent/
|
||||
├── routes/
|
||||
│ └── knowledge_base.py # API 路由层
|
||||
├── services/
|
||||
│ └── knowledge_base_service.py # 业务逻辑层(新增)
|
||||
├── repositories/
|
||||
│ ├── __init__.py
|
||||
│ └── ragflow_repository.py # RAGFlow 适配器(新增)
|
||||
└── utils/
|
||||
├── settings.py # 配置管理
|
||||
└── file_validator.py # 文件验证工具(新增)
|
||||
```
|
||||
|
||||
#### 4. API 路由设计
|
||||
|
||||
**文件:** `/routes/knowledge_base.py`
|
||||
|
||||
**路由前缀:** `/api/v1/knowledge-base`
|
||||
|
||||
| 端点 | 方法 | 功能 | 认证 | 优化 |
|
||||
|------|------|------|------|------|
|
||||
| `/datasets` | GET | 获取所有数据集列表(分页) | Admin Token | 缓存 |
|
||||
| `/datasets` | POST | 创建新数据集 | Admin Token | - |
|
||||
| `/datasets/{dataset_id}` | GET | 获取数据集详情 | Admin Token | 缓存 |
|
||||
| `/datasets/{dataset_id}` | PATCH | 更新数据集(部分更新) | Admin Token | - |
|
||||
| `/datasets/{dataset_id}` | DELETE | 删除数据集 | Admin Token | - |
|
||||
| `/datasets/{dataset_id}/files` | GET | 获取数据集内文件列表(分页) | Admin Token | 缓存 |
|
||||
| `/datasets/{dataset_id}/files` | POST | 上传文件到数据集(流式) | Admin Token | 限流 |
|
||||
| `/datasets/{dataset_id}/files/{document_id}` | DELETE | 删除文件 | Admin Token | - |
|
||||
| `/datasets/{dataset_id}/chunks` | GET | 获取数据集内切片列表(分页) | Admin Token | 游标分页 |
|
||||
| `/datasets/{dataset_id}/chunks/{chunk_id}` | DELETE | 删除切片 | Admin Token | - |
|
||||
|
||||
**代码结构:**
|
||||
|
||||
```python
|
||||
"""
|
||||
Knowledge Base API 路由
|
||||
通过 RAGFlow SDK 提供知识库管理功能
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, HTTPException, Header, UploadFile, File, Query, BackgroundTasks
|
||||
from pydantic import BaseModel, Field
|
||||
from pathlib import Path
|
||||
|
||||
from utils.settings import RAGFLOW_API_URL, RAGFLOW_API_KEY
|
||||
from utils.fastapi_utils import extract_api_key_from_auth
|
||||
from repositories.ragflow_repository import RAGFlowRepository
|
||||
from services.knowledge_base_service import KnowledgeBaseService
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ============== 依赖注入 ==============
|
||||
async def get_kb_service() -> KnowledgeBaseService:
|
||||
"""获取知识库服务实例"""
|
||||
return KnowledgeBaseService(RAGFlowRepository())
|
||||
|
||||
async def verify_admin(authorization: Optional[str] = Header(None)):
|
||||
"""验证管理员权限(复用现有认证)"""
|
||||
from routes.bot_manager import verify_admin_auth
|
||||
valid, username = await verify_admin_auth(authorization)
|
||||
if not valid:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
return username
|
||||
|
||||
# ============== Pydantic Models ==============
|
||||
|
||||
class DatasetCreate(BaseModel):
|
||||
"""创建数据集请求"""
|
||||
name: str = Field(..., min_length=1, max_length=128, description="数据集名称")
|
||||
description: Optional[str] = Field(None, max_length=500, description="描述信息")
|
||||
chunk_method: str = Field(default="naive", description="分块方法")
|
||||
# RAGFlow 支持的分块方法: naive, manual, qa, table, paper, book, laws, presentation, picture, one, email, knowledge-graph
|
||||
|
||||
class DatasetUpdate(BaseModel):
|
||||
"""更新数据集请求(部分更新)"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=128)
|
||||
description: Optional[str] = Field(None, max_length=500)
|
||||
chunk_method: Optional[str] = None
|
||||
|
||||
class DatasetListResponse(BaseModel):
|
||||
"""数据集列表响应(分页)"""
|
||||
items: List[dict]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
# ============== 数据集端点 ==============
|
||||
|
||||
@router.get("/datasets", response_model=DatasetListResponse)
|
||||
async def list_datasets(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
username: str = Depends(verify_admin),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""获取数据集列表(支持分页和搜索)"""
|
||||
return await kb_service.list_datasets(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
search=search
|
||||
)
|
||||
|
||||
@router.post("/datasets", status_code=201)
|
||||
async def create_dataset(
|
||||
data: DatasetCreate,
|
||||
username: str = Depends(verify_admin),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""创建数据集"""
|
||||
try:
|
||||
dataset = await kb_service.create_dataset(
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
chunk_method=data.chunk_method
|
||||
)
|
||||
return dataset
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create dataset: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"创建数据集失败: {str(e)}")
|
||||
|
||||
@router.get("/datasets/{dataset_id}")
|
||||
async def get_dataset(
|
||||
dataset_id: str,
|
||||
username: str = Depends(verify_admin),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""获取数据集详情"""
|
||||
dataset = await kb_service.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="数据集不存在")
|
||||
return dataset
|
||||
|
||||
@router.patch("/datasets/{dataset_id}")
|
||||
async def update_dataset(
|
||||
dataset_id: str,
|
||||
data: DatasetUpdate,
|
||||
username: str = Depends(verify_admin),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""更新数据集(部分更新)"""
|
||||
try:
|
||||
dataset = await kb_service.update_dataset(dataset_id, data.model_dump(exclude_unset=True))
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="数据集不存在")
|
||||
return dataset
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update dataset: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新数据集失败: {str(e)}")
|
||||
|
||||
@router.delete("/datasets/{dataset_id}", status_code=204)
|
||||
async def delete_dataset(
|
||||
dataset_id: str,
|
||||
username: str = Depends(verify_admin),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""删除数据集"""
|
||||
success = await kb_service.delete_dataset(dataset_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="数据集不存在")
|
||||
|
||||
# ============== 文件端点 ==============
|
||||
|
||||
@router.get("/datasets/{dataset_id}/files")
|
||||
async def list_dataset_files(
|
||||
dataset_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
username: str = Depends(verify_admin),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""获取数据集内文件列表(分页)"""
|
||||
return await kb_service.list_files(dataset_id, page=page, page_size=page_size)
|
||||
|
||||
@router.post("/datasets/{dataset_id}/files")
|
||||
async def upload_file(
|
||||
dataset_id: str,
|
||||
file: UploadFile = File(...),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
username: str = Depends(verify_admin),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""
|
||||
上传文件到数据集(流式处理)
|
||||
|
||||
支持的文件类型: PDF, DOCX, TXT, MD, CSV
|
||||
最大文件大小: 100MB
|
||||
"""
|
||||
# 文件验证在 service 层处理
|
||||
try:
|
||||
result = await kb_service.upload_file(dataset_id, file)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload file: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
|
||||
|
||||
@router.delete("/datasets/{dataset_id}/files/{document_id}")
|
||||
async def delete_file(
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
username: str = Depends(verify_admin),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""删除文件"""
|
||||
success = await kb_service.delete_file(dataset_id, document_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
return {"success": True}
|
||||
|
||||
# ============== 切片端点(可选,延后实现)=============
|
||||
# 根据简化建议,切片管理功能延后到明确需求时再实现
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 前端实现 (qwen-client)
|
||||
|
||||
#### 1. API 服务层
|
||||
|
||||
**文件:** `/src/api/index.js`
|
||||
|
||||
添加 `knowledgeBaseApi` 模块:
|
||||
|
||||
```javascript
|
||||
// ============== Knowledge Base API ==============
|
||||
const knowledgeBaseApi = {
|
||||
// 数据集管理
|
||||
getDatasets: async (params = {}) => {
|
||||
const qs = new URLSearchParams(params).toString()
|
||||
return request(`/api/v1/knowledge-base/datasets${qs ? '?' + qs : ''}`)
|
||||
},
|
||||
|
||||
createDataset: async (data) => {
|
||||
return request('/api/v1/knowledge-base/datasets', {
|
||||
method: 'POST',
|
||||
body: JSON.stringify(data)
|
||||
})
|
||||
},
|
||||
|
||||
updateDataset: async (datasetId, data) => {
|
||||
return request(`/api/v1/knowledge-base/datasets/${datasetId}`, {
|
||||
method: 'PATCH', // 使用 PATCH 支持部分更新
|
||||
body: JSON.stringify(data)
|
||||
})
|
||||
},
|
||||
|
||||
deleteDataset: async (datasetId) => {
|
||||
return request(`/api/v1/knowledge-base/datasets/${datasetId}`, {
|
||||
method: 'DELETE'
|
||||
})
|
||||
},
|
||||
|
||||
// 文件管理
|
||||
getDatasetFiles: async (datasetId, params = {}) => {
|
||||
const qs = new URLSearchParams(params).toString()
|
||||
return request(`/api/v1/knowledge-base/datasets/${datasetId}/files${qs ? '?' + qs : ''}`)
|
||||
},
|
||||
|
||||
uploadFile: async (datasetId, file, onProgress) => {
|
||||
const formData = new FormData()
|
||||
formData.append('file', file)
|
||||
|
||||
// 支持上传进度回调
|
||||
const xhr = new XMLHttpRequest()
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
xhr.upload.addEventListener('progress', (e) => {
|
||||
if (onProgress && e.lengthComputable) {
|
||||
onProgress(Math.round((e.loaded / e.total) * 100))
|
||||
}
|
||||
})
|
||||
|
||||
xhr.addEventListener('load', () => {
|
||||
if (xhr.status >= 200 && xhr.status < 300) {
|
||||
resolve(JSON.parse(xhr.responseText))
|
||||
} else {
|
||||
reject(new Error(xhr.statusText))
|
||||
}
|
||||
})
|
||||
|
||||
xhr.addEventListener('error', () => reject(new Error('上传失败')))
|
||||
xhr.addEventListener('abort', () => reject(new Error('上传已取消')))
|
||||
|
||||
xhr.open('POST', `${API_BASE}/api/v1/knowledge-base/datasets/${datasetId}/files`)
|
||||
xhr.setRequestHeader('Authorization', `Bearer ${localStorage.getItem('admin_token') || 'dummy-token'}`)
|
||||
xhr.send(formData)
|
||||
})
|
||||
},
|
||||
|
||||
deleteFile: async (datasetId, documentId) => {
|
||||
return request(`/api/v1/knowledge-base/datasets/${datasetId}/files/${documentId}`, {
|
||||
method: 'DELETE'
|
||||
})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. 简化的状态管理
|
||||
|
||||
**基于代码简洁性审查建议,直接在组件中管理状态,而不是创建独立的 composable**
|
||||
|
||||
```vue
|
||||
<!-- KnowledgeBaseView.vue -->
|
||||
<script setup>
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { knowledgeBaseApi } from '@/api'
|
||||
import DatasetList from '@/components/knowledge-base/DatasetList.vue'
|
||||
import FileList from '@/components/knowledge-base/FileList.vue'
|
||||
import DatasetFormModal from '@/components/knowledge-base/DatasetFormModal.vue'
|
||||
|
||||
// 状态
|
||||
const datasets = ref([])
|
||||
const currentDataset = ref(null)
|
||||
const files = ref([])
|
||||
const isLoading = ref(false)
|
||||
const error = ref(null)
|
||||
|
||||
// 分页
|
||||
const page = ref(1)
|
||||
const pageSize = ref(20)
|
||||
const total = ref(0)
|
||||
|
||||
// 加载数据集
|
||||
const loadDatasets = async () => {
|
||||
isLoading.value = true
|
||||
error.value = null
|
||||
try {
|
||||
const response = await knowledgeBaseApi.getDatasets({
|
||||
page: page.value,
|
||||
page_size: pageSize.value
|
||||
})
|
||||
datasets.value = response.items || []
|
||||
total.value = response.total
|
||||
} catch (err) {
|
||||
error.value = err.message
|
||||
} finally {
|
||||
isLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 选择数据集
|
||||
const selectDataset = async (dataset) => {
|
||||
currentDataset.value = dataset
|
||||
await loadFiles(dataset.dataset_id)
|
||||
}
|
||||
|
||||
// 加载文件
|
||||
const loadFiles = async (datasetId) => {
|
||||
isLoading.value = true
|
||||
try {
|
||||
const response = await knowledgeBaseApi.getDatasetFiles(datasetId)
|
||||
files.value = response.items || []
|
||||
} finally {
|
||||
isLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 创建数据集
|
||||
const createDataset = async (data) => {
|
||||
await knowledgeBaseApi.createDataset(data)
|
||||
await loadDatasets()
|
||||
}
|
||||
|
||||
// 删除数据集
|
||||
const deleteDataset = async (datasetId) => {
|
||||
await knowledgeBaseApi.deleteDataset(datasetId)
|
||||
if (currentDataset.value?.dataset_id === datasetId) {
|
||||
currentDataset.value = null
|
||||
files.value = []
|
||||
}
|
||||
await loadDatasets()
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadDatasets()
|
||||
})
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="knowledge-base-view">
|
||||
<!-- 数据集列表 -->
|
||||
<DatasetList
|
||||
:datasets="datasets"
|
||||
:loading="isLoading"
|
||||
:current="currentDataset"
|
||||
@select="selectDataset"
|
||||
@create="createDataset"
|
||||
@delete="deleteDataset"
|
||||
/>
|
||||
|
||||
<!-- 文件列表(选中数据集后显示) -->
|
||||
<FileList
|
||||
v-if="currentDataset"
|
||||
:dataset="currentDataset"
|
||||
:files="files"
|
||||
@upload="handleFileUpload"
|
||||
@delete="handleFileDelete"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
```
|
||||
|
||||
#### 3. 路由配置
|
||||
|
||||
**文件:** `/src/router/index.js`
|
||||
|
||||
添加知识库路由:
|
||||
|
||||
```javascript
|
||||
{
|
||||
path: '/knowledge-base',
|
||||
name: 'knowledge-base',
|
||||
component: () => import('@/views/KnowledgeBaseView.vue'),
|
||||
meta: { requiresAuth: true, title: '知识库管理' }
|
||||
}
|
||||
```
|
||||
|
||||
#### 4. 视图组件
|
||||
|
||||
**文件:** `/src/views/KnowledgeBaseView.vue`
|
||||
|
||||
主视图组件,包含:
|
||||
- 数据集列表(左侧或顶部)
|
||||
- 文件列表(选中数据集后显示)
|
||||
- 上传文件按钮
|
||||
- 创建数据集按钮
|
||||
|
||||
**子组件(简化后的结构):**
|
||||
|
||||
| 组件 | 文件 | 功能 |
|
||||
|------|------|------|
|
||||
| `DatasetList.vue` | `/src/components/knowledge-base/DatasetList.vue` | 数据集列表展示 + 创建/删除 |
|
||||
| `DatasetFormModal.vue` | `/src/components/knowledge-base/DatasetFormModal.vue` | 创建/编辑数据集弹窗(合并) |
|
||||
| `FileList.vue` | `/src/components/knowledge-base/FileList.vue` | 文件列表展示 + 上传 |
|
||||
| `FileUploadModal.vue` | `/src/components/knowledge-base/FileUploadModal.vue` | 文件上传弹窗 |
|
||||
|
||||
**目录结构:**
|
||||
```
|
||||
src/components/knowledge-base/
|
||||
├── DatasetList.vue # 数据集列表(含创建按钮)
|
||||
├── DatasetFormModal.vue # 创建/编辑数据集表单
|
||||
├── FileList.vue # 文件列表(含上传按钮)
|
||||
└── FileUploadModal.vue # 文件上传弹窗
|
||||
```
|
||||
|
||||
#### 5. 导航菜单
|
||||
|
||||
**文件:** `/src/views/AdminView.vue`
|
||||
|
||||
在导航菜单中添加知识库入口:
|
||||
|
||||
```vue
|
||||
<Button
|
||||
variant="ghost"
|
||||
@click="currentView = 'knowledge-base'"
|
||||
>
|
||||
<Database :size="20" />
|
||||
<span>知识库管理</span>
|
||||
</Button>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 实现阶段
|
||||
|
||||
### Phase 1: 后端基础 (qwen-agent) - 核心功能
|
||||
|
||||
- [ ] 添加 `ragflow-sdk` 依赖到 `pyproject.toml`
|
||||
- [ ] 在 `utils/settings.py` 添加 RAGFlow 配置(环境变量)
|
||||
- [ ] 创建 `repositories/ragflow_repository.py` - RAGFlow SDK 适配器
|
||||
- [ ] 创建 `services/knowledge_base_service.py` - 业务逻辑层
|
||||
- [ ] 创建 `routes/knowledge_base.py` - API 路由
|
||||
- [ ] 在 `fastapi_app.py` 注册路由
|
||||
- [ ] 测试 API 端点
|
||||
|
||||
### Phase 2: 前端 API 层 (qwen-client)
|
||||
|
||||
- [ ] 在 `src/api/index.js` 添加 `knowledgeBaseApi`
|
||||
- [ ] 添加知识库路由到 `src/router/index.js`
|
||||
- [ ] 在 AdminView 添加导航入口
|
||||
|
||||
### Phase 3: 前端 UI 组件 - 最小实现
|
||||
|
||||
- [ ] 创建 `src/components/knowledge-base/` 目录
|
||||
- [ ] 实现 `KnowledgeBaseView.vue` 主视图
|
||||
- [ ] 实现 `DatasetList.vue` 组件
|
||||
- [ ] 实现 `DatasetFormModal.vue` 组件
|
||||
- [ ] 实现 `FileList.vue` 组件
|
||||
- [ ] 实现 `FileUploadModal.vue` 组件
|
||||
|
||||
### Phase 4: 切片管理 (延后实现)
|
||||
|
||||
根据 YAGNI 原则,切片管理功能延后到有明确需求时再实现:
|
||||
- [ ] 后端实现切片列表/删除端点
|
||||
- [ ] 前端实现 `ChunkList.vue` 组件
|
||||
- [ ] 切片搜索功能
|
||||
|
||||
### Phase 5: 测试与优化
|
||||
|
||||
- [ ] 端到端测试
|
||||
- [ ] 错误处理优化
|
||||
- [ ] 加载状态优化
|
||||
- [ ] 添加性能监控
|
||||
|
||||
---
|
||||
|
||||
## 数据模型
|
||||
|
||||
### 数据集 (Dataset)
|
||||
|
||||
```typescript
|
||||
interface Dataset {
|
||||
dataset_id: string
|
||||
name: string
|
||||
description?: string
|
||||
chunk_method: string
|
||||
chunk_count?: number
|
||||
document_count?: number
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
```
|
||||
|
||||
### 文件 (Document)
|
||||
|
||||
```typescript
|
||||
interface Document {
|
||||
document_id: string
|
||||
dataset_id: string
|
||||
name: string
|
||||
size: number
|
||||
status: 'running' | 'success' | 'failed'
|
||||
progress: number // 0-100
|
||||
chunk_count?: number
|
||||
token_count?: number
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
```
|
||||
|
||||
### 切片 (Chunk) - 延后实现
|
||||
|
||||
```typescript
|
||||
interface Chunk {
|
||||
chunk_id: string
|
||||
document_id: string
|
||||
dataset_id: string
|
||||
content: string
|
||||
position: number
|
||||
important_keywords?: string[]
|
||||
available: boolean
|
||||
created_at: string
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## API 端点规范
|
||||
|
||||
### 1. 获取数据集列表(分页)
|
||||
|
||||
```
|
||||
GET /api/v1/knowledge-base/datasets?page=1&page_size=20&search=keyword
|
||||
Authorization: Bearer {admin_token}
|
||||
|
||||
Response:
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"dataset_id": "uuid",
|
||||
"name": "产品手册",
|
||||
"description": "公司产品相关文档",
|
||||
"chunk_method": "naive",
|
||||
"document_count": 5,
|
||||
"chunk_count": 120,
|
||||
"created_at": "2025-01-01T00:00:00Z",
|
||||
"updated_at": "2025-01-01T00:00:00Z"
|
||||
}
|
||||
],
|
||||
"total": 1,
|
||||
"page": 1,
|
||||
"page_size": 20
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 创建数据集
|
||||
|
||||
```
|
||||
POST /api/v1/knowledge-base/datasets
|
||||
Authorization: Bearer {admin_token}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"name": "产品手册",
|
||||
"description": "公司产品相关文档",
|
||||
"chunk_method": "naive"
|
||||
}
|
||||
|
||||
Response:
|
||||
{
|
||||
"dataset_id": "uuid",
|
||||
"name": "产品手册",
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 上传文件(流式)
|
||||
|
||||
```
|
||||
POST /api/v1/knowledge-base/datasets/{dataset_id}/files
|
||||
Authorization: Bearer {admin_token}
|
||||
Content-Type: multipart/form-data
|
||||
|
||||
file: <binary>
|
||||
|
||||
Response (异步):
|
||||
{
|
||||
"document_id": "uuid",
|
||||
"name": "document.pdf",
|
||||
"status": "running",
|
||||
"progress": 0,
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 安全考虑
|
||||
|
||||
### Research Insights
|
||||
|
||||
**文件上传安全:**
|
||||
- 实现文件类型白名单验证(扩展名 + MIME 类型 + 魔数)
|
||||
- 限制文件大小(最大 100MB)
|
||||
- 使用 UUID 重命名文件,防止路径遍历
|
||||
- 清理文件名中的危险字符
|
||||
|
||||
**API 认证:**
|
||||
- 复用现有的 `verify_admin_auth` 函数
|
||||
- 所有端点需要有效的 Admin Token
|
||||
- 集成现有的 RBAC 系统
|
||||
|
||||
**输入验证:**
|
||||
- 使用 Pydantic Field 进行输入验证
|
||||
- 限制字符串长度
|
||||
- 验证分页参数范围
|
||||
|
||||
**配置安全:**
|
||||
- API Key 必须通过环境变量设置
|
||||
- 不在代码中硬编码敏感信息
|
||||
|
||||
### 安全配置清单
|
||||
|
||||
| 措施 | 优先级 | 状态 |
|
||||
|------|--------|------|
|
||||
| 文件类型验证 | 高 | 待实现 |
|
||||
| 文件大小限制 | 高 | 待实现 |
|
||||
| API Key 环境变量 | 高 | 已规划 |
|
||||
| 路径遍历防护 | 高 | 待实现 |
|
||||
| 文件名清理 | 中 | 待实现 |
|
||||
| 病毒扫描 | 中 | 可选 |
|
||||
|
||||
---
|
||||
|
||||
## 性能优化
|
||||
|
||||
### Research Insights
|
||||
|
||||
**文件上传优化:**
|
||||
- 使用流式处理,避免一次性读取大文件到内存
|
||||
- 实现并发限制(最多 5 个并发上传)
|
||||
- 添加上传进度回调
|
||||
|
||||
**查询优化:**
|
||||
- 实现分页机制,避免返回大量数据
|
||||
- 使用游标分页优化深分页性能
|
||||
- 对数据集列表添加缓存
|
||||
|
||||
**连接池:**
|
||||
- 使用异步 HTTP 客户端连接池
|
||||
- 设置合理的超时时间
|
||||
|
||||
### 性能优化清单
|
||||
|
||||
| 优化项 | 优先级 | 预期效果 |
|
||||
|--------|--------|----------|
|
||||
| 流式文件上传 | 高 | 避免 OOM |
|
||||
| 分页查询 | 高 | 响应时间 < 100ms |
|
||||
| 数据集缓存 | 中 | 减少外部 API 调用 |
|
||||
| 连接池 | 中 | 提高并发能力 |
|
||||
| 限流 | 中 | 防止资源耗尽 |
|
||||
|
||||
---
|
||||
|
||||
## 用户体验优化
|
||||
|
||||
### Research Insights
|
||||
|
||||
**空状态设计:**
|
||||
- 为空列表提供友好的提示和操作引导
|
||||
- 区分不同场景的空状态(首次使用、搜索无结果等)
|
||||
|
||||
**加载状态:**
|
||||
- 使用骨架屏替代传统 loading 指示器
|
||||
- 显示加载进度(特别是文件上传)
|
||||
|
||||
**错误处理:**
|
||||
- 使用人类可读的错误消息
|
||||
- 提供具体的修复建议
|
||||
- 区分不同类型的错误(网络、验证、服务器)
|
||||
|
||||
**文件上传 UX:**
|
||||
- 支持拖拽上传
|
||||
- 显示上传进度
|
||||
- 支持批量上传
|
||||
- 显示文件大小和类型验证
|
||||
|
||||
---
|
||||
|
||||
## 配置清单
|
||||
|
||||
### 环境变量
|
||||
|
||||
| 变量名 | 默认值 | 说明 | 必填 |
|
||||
|--------|--------|------|------|
|
||||
| `RAGFLOW_API_URL` | `http://100.77.70.35:1080` | RAGFlow API 地址 | 是 |
|
||||
| `RAGFLOW_API_KEY` | - | RAGFlow API Key | 是 |
|
||||
| `RAGFLOW_MAX_UPLOAD_SIZE` | `104857600` | 最大上传文件大小(字节) | 否 |
|
||||
| `RAGFLOW_ALLOWED_EXTENSIONS` | `pdf,docx,txt,md,csv` | 允许的文件扩展名 | 否 |
|
||||
| `RAGFLOW_CONNECTION_TIMEOUT` | `30` | 连接超时(秒) | 否 |
|
||||
| `RAGFLOW_MAX_CONCURRENT_UPLOADS` | `5` | 最大并发上传数 | 否 |
|
||||
|
||||
### 依赖包
|
||||
|
||||
```toml
|
||||
[tool.poetry.dependencies]
|
||||
ragflow-sdk = "^0.1.0"
|
||||
python-magic = "^0.4.27"
|
||||
aiofiles = "^24.1.0"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 参考资料
|
||||
|
||||
- **RAGFlow 官方文档:** https://ragflow.com.cn/docs
|
||||
- **RAGFlow HTTP API:** https://ragflow.io/docs/http_api_reference
|
||||
- **RAGFlow GitHub:** https://github.com/infiniflow/ragflow
|
||||
- **RAGFlow Python SDK:** https://github.com/infiniflow/ragflow/blob/main/docs/references/python_api_reference.md
|
||||
- **qwen-client API 层:** `/src/api/index.js`
|
||||
- **qwen-agent 路由示例:** `/routes/bot_manager.py`
|
||||
|
||||
**研究来源:**
|
||||
- [Vue.js Official Composables Guide](https://vuejs.org/guide/reusability/composables.html)
|
||||
- [FastAPI Official Documentation](https://fastapi.tiangolo.com/)
|
||||
- [UX Best Practices for File Uploader - Uploadcare](https://uploadcare.com/blog/file-uploader-ux-best-practices/)
|
||||
- [Empty State UX Examples - Pencil & Paper](https://www.pencilandpaper.io/articles/empty-states)
|
||||
- [Error-Message Guidelines - Nielsen Norman Group](https://www.nngroup.com/articles/error-message-guidelines/)
|
||||
|
||||
---
|
||||
|
||||
## 后续扩展
|
||||
|
||||
1. **与 Bot 关联:** 在 Bot 设置中选择知识库
|
||||
2. **RAG 检索:** 实现基于知识库的问答功能
|
||||
3. **批量操作:** 批量上传、删除文件
|
||||
4. **知识库搜索:** 在知识库内搜索内容
|
||||
5. **访问统计:** 查看知识库使用情况
|
||||
6. **切片管理:** 前端切片查看和编辑(延后实现)
|
||||
40
poetry.lock
generated
40
poetry.lock
generated
@ -280,6 +280,26 @@ files = [
|
||||
{file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "beartype"
|
||||
version = "0.22.9"
|
||||
description = "Unbearably fast near-real-time pure-Python runtime-static type-checker."
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "beartype-0.22.9-py3-none-any.whl", hash = "sha256:d16c9bbc61ea14637596c5f6fbff2ee99cbe3573e46a716401734ef50c3060c2"},
|
||||
{file = "beartype-0.22.9.tar.gz", hash = "sha256:8f82b54aa723a2848a56008d18875f91c1db02c32ef6a62319a002e3e25a975f"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["autoapi (>=0.9.0)", "celery", "click", "coverage (>=5.5)", "docutils (>=0.22.0)", "equinox ; sys_platform == \"linux\" and python_version < \"3.15.0\"", "fastmcp ; python_version < \"3.14.0\"", "jax[cpu] ; sys_platform == \"linux\" and python_version < \"3.15.0\"", "jaxtyping ; sys_platform == \"linux\"", "langchain ; python_version < \"3.14.0\" and sys_platform != \"darwin\" and platform_python_implementation != \"PyPy\"", "mypy (>=0.800) ; platform_python_implementation != \"PyPy\"", "nuitka (>=1.2.6) ; sys_platform == \"linux\" and python_version < \"3.14.0\"", "numba ; python_version < \"3.14.0\"", "numpy ; python_version < \"3.15.0\" and sys_platform != \"darwin\" and platform_python_implementation != \"PyPy\"", "pandera (>=0.26.0) ; python_version < \"3.14.0\"", "poetry", "polars ; python_version < \"3.14.0\"", "pydata-sphinx-theme (<=0.7.2)", "pygments", "pyinstaller", "pyright (>=1.1.370)", "pytest (>=6.2.0)", "redis", "rich-click", "setuptools", "sphinx", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)", "sqlalchemy", "torch ; sys_platform == \"linux\" and python_version < \"3.14.0\"", "tox (>=3.20.1)", "typer", "typing-extensions (>=3.10.0.0)", "xarray ; python_version < \"3.15.0\""]
|
||||
doc-ghp = ["mkdocs-material[imaging] (>=9.6.0)", "mkdocstrings-python (>=1.16.0)", "mkdocstrings-python-xref (>=1.16.0)"]
|
||||
doc-rtd = ["autoapi (>=0.9.0)", "pydata-sphinx-theme (<=0.7.2)", "setuptools", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)"]
|
||||
test = ["celery", "click", "coverage (>=5.5)", "docutils (>=0.22.0)", "equinox ; sys_platform == \"linux\" and python_version < \"3.15.0\"", "fastmcp ; python_version < \"3.14.0\"", "jax[cpu] ; sys_platform == \"linux\" and python_version < \"3.15.0\"", "jaxtyping ; sys_platform == \"linux\"", "langchain ; python_version < \"3.14.0\" and sys_platform != \"darwin\" and platform_python_implementation != \"PyPy\"", "mypy (>=0.800) ; platform_python_implementation != \"PyPy\"", "nuitka (>=1.2.6) ; sys_platform == \"linux\" and python_version < \"3.14.0\"", "numba ; python_version < \"3.14.0\"", "numpy ; python_version < \"3.15.0\" and sys_platform != \"darwin\" and platform_python_implementation != \"PyPy\"", "pandera (>=0.26.0) ; python_version < \"3.14.0\"", "poetry", "polars ; python_version < \"3.14.0\"", "pygments", "pyinstaller", "pyright (>=1.1.370)", "pytest (>=6.2.0)", "redis", "rich-click", "sphinx", "sqlalchemy", "torch ; sys_platform == \"linux\" and python_version < \"3.14.0\"", "tox (>=3.20.1)", "typer", "typing-extensions (>=3.10.0.0)", "xarray ; python_version < \"3.15.0\""]
|
||||
test-tox = ["celery", "click", "docutils (>=0.22.0)", "equinox ; sys_platform == \"linux\" and python_version < \"3.15.0\"", "fastmcp ; python_version < \"3.14.0\"", "jax[cpu] ; sys_platform == \"linux\" and python_version < \"3.15.0\"", "jaxtyping ; sys_platform == \"linux\"", "langchain ; python_version < \"3.14.0\" and sys_platform != \"darwin\" and platform_python_implementation != \"PyPy\"", "mypy (>=0.800) ; platform_python_implementation != \"PyPy\"", "nuitka (>=1.2.6) ; sys_platform == \"linux\" and python_version < \"3.14.0\"", "numba ; python_version < \"3.14.0\"", "numpy ; python_version < \"3.15.0\" and sys_platform != \"darwin\" and platform_python_implementation != \"PyPy\"", "pandera (>=0.26.0) ; python_version < \"3.14.0\"", "poetry", "polars ; python_version < \"3.14.0\"", "pygments", "pyinstaller", "pyright (>=1.1.370)", "pytest (>=6.2.0)", "redis", "rich-click", "sphinx", "sqlalchemy", "torch ; sys_platform == \"linux\" and python_version < \"3.14.0\"", "typer", "typing-extensions (>=3.10.0.0)", "xarray ; python_version < \"3.15.0\""]
|
||||
test-tox-coverage = ["coverage (>=5.5)"]
|
||||
|
||||
[[package]]
|
||||
name = "beautifulsoup4"
|
||||
version = "4.14.3"
|
||||
@ -3888,6 +3908,22 @@ urllib3 = ">=1.26.14,<3"
|
||||
fastembed = ["fastembed (>=0.7,<0.8)"]
|
||||
fastembed-gpu = ["fastembed-gpu (>=0.7,<0.8)"]
|
||||
|
||||
[[package]]
|
||||
name = "ragflow-sdk"
|
||||
version = "0.23.1"
|
||||
description = "Python client sdk of [RAGFlow](https://github.com/infiniflow/ragflow). RAGFlow is an open-source RAG (Retrieval-Augmented Generation) engine based on deep document understanding."
|
||||
optional = false
|
||||
python-versions = "<3.15,>=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "ragflow_sdk-0.23.1-py3-none-any.whl", hash = "sha256:8bb2827f2696f84fc5cdbf980e2a74e2b18c712c07d08eca26ea52e13e2a4c51"},
|
||||
{file = "ragflow_sdk-0.23.1.tar.gz", hash = "sha256:dc358001bc8cad243e9aa879056c3f65bd7d687a9bff9863f6c79eaa4f43db09"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
beartype = ">=0.20.0,<1.0.0"
|
||||
requests = ">=2.30.0,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "referencing"
|
||||
version = "0.37.0"
|
||||
@ -6048,5 +6084,5 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.12,<4.0"
|
||||
content-hash = "abce2b9aba5a46841df8e6e4e4f12523ff9c4cd34dab7d180490ae36b2dee16e"
|
||||
python-versions = ">=3.12,<3.15"
|
||||
content-hash = "d930570328aea9211c1563538968847d8ba638963025e63a246559307e1d33ed"
|
||||
|
||||
@ -6,7 +6,7 @@ authors = [
|
||||
{name = "朱潮",email = "zhuchaowe@users.noreply.github.com"}
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12,<4.0"
|
||||
requires-python = ">=3.12,<3.15"
|
||||
dependencies = [
|
||||
"fastapi==0.116.1",
|
||||
"uvicorn==0.35.0",
|
||||
@ -36,6 +36,8 @@ dependencies = [
|
||||
"psycopg2-binary (>=2.9.11,<3.0.0)",
|
||||
"json-repair (>=0.29.0,<0.30.0)",
|
||||
"tiktoken (>=0.5.0,<1.0.0)",
|
||||
"ragflow-sdk (>=0.23.0,<0.24.0)",
|
||||
"httpx (>=0.28.1,<0.29.0)",
|
||||
]
|
||||
|
||||
[tool.poetry.requires-plugins]
|
||||
|
||||
6
repositories/__init__.py
Normal file
6
repositories/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""
|
||||
Repositories package for data access layer
|
||||
"""
|
||||
from .ragflow_repository import RAGFlowRepository
|
||||
|
||||
__all__ = ['RAGFlowRepository']
|
||||
559
repositories/ragflow_repository.py
Normal file
559
repositories/ragflow_repository.py
Normal file
@ -0,0 +1,559 @@
|
||||
"""
|
||||
RAGFlow Repository - 数据访问层
|
||||
封装 RAGFlow SDK 调用,提供统一的数据访问接口
|
||||
"""
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from ragflow_sdk import RAGFlow
|
||||
except ImportError:
|
||||
RAGFlow = None
|
||||
logging.warning("ragflow-sdk not installed")
|
||||
|
||||
from utils.settings import (
|
||||
RAGFLOW_API_URL,
|
||||
RAGFLOW_API_KEY,
|
||||
RAGFLOW_CONNECTION_TIMEOUT
|
||||
)
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
|
||||
class RAGFlowRepository:
|
||||
"""
|
||||
RAGFlow 数据仓储类
|
||||
|
||||
封装 RAGFlow SDK 的所有调用,提供:
|
||||
- 统一的错误处理
|
||||
- 连接管理
|
||||
- 数据转换
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str = None, base_url: str = None):
|
||||
"""
|
||||
初始化 RAGFlow 客户端
|
||||
|
||||
Args:
|
||||
api_key: RAGFlow API Key,默认从配置读取
|
||||
base_url: RAGFlow 服务地址,默认从配置读取
|
||||
"""
|
||||
self.api_key = api_key or RAGFLOW_API_KEY
|
||||
self.base_url = base_url or RAGFLOW_API_URL
|
||||
self._client: Optional[Any] = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _get_client(self):
|
||||
"""
|
||||
获取 RAGFlow 客户端实例(懒加载)
|
||||
|
||||
Returns:
|
||||
RAGFlow 客户端
|
||||
"""
|
||||
if RAGFlow is None:
|
||||
raise RuntimeError("ragflow-sdk is not installed. Run: poetry install")
|
||||
|
||||
if self._client is None:
|
||||
async with self._lock:
|
||||
# 双重检查
|
||||
if self._client is None:
|
||||
try:
|
||||
self._client = RAGFlow(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
logger.info(f"RAGFlow client initialized: {self.base_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize RAGFlow client: {e}")
|
||||
raise
|
||||
|
||||
return self._client
|
||||
|
||||
async def create_dataset(
|
||||
self,
|
||||
name: str,
|
||||
description: str = None,
|
||||
chunk_method: str = "naive",
|
||||
permission: str = "me"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
创建数据集
|
||||
|
||||
Args:
|
||||
name: 数据集名称
|
||||
description: 描述信息
|
||||
chunk_method: 分块方法 (naive, manual, qa, table, paper, book, etc.)
|
||||
permission: 权限 (me 或 team)
|
||||
|
||||
Returns:
|
||||
创建的数据集信息
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
dataset = client.create_dataset(
|
||||
name=name,
|
||||
avatar=None,
|
||||
description=description,
|
||||
chunk_method=chunk_method,
|
||||
permission=permission
|
||||
)
|
||||
|
||||
return {
|
||||
"dataset_id": getattr(dataset, 'id', None),
|
||||
"name": getattr(dataset, 'name', name),
|
||||
"description": getattr(dataset, 'description', description),
|
||||
"chunk_method": getattr(dataset, 'chunk_method', chunk_method),
|
||||
"permission": getattr(dataset, 'permission', permission),
|
||||
"created_at": getattr(dataset, 'created_at', None),
|
||||
"updated_at": getattr(dataset, 'updated_at', None),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create dataset: {e}")
|
||||
raise
|
||||
|
||||
async def list_datasets(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 30,
|
||||
search: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取数据集列表
|
||||
|
||||
Args:
|
||||
page: 页码
|
||||
page_size: 每页数量
|
||||
search: 搜索关键词
|
||||
|
||||
Returns:
|
||||
数据集列表和分页信息
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
# RAGFlow SDK 的 list_datasets 方法
|
||||
datasets = client.list_datasets(
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
items = []
|
||||
for dataset in datasets:
|
||||
dataset_info = {
|
||||
"dataset_id": getattr(dataset, 'id', None),
|
||||
"name": getattr(dataset, 'name', None),
|
||||
"description": getattr(dataset, 'description', None),
|
||||
"chunk_method": getattr(dataset, 'chunk_method', None),
|
||||
"avatar": getattr(dataset, 'avatar', None),
|
||||
"permission": getattr(dataset, 'permission', None),
|
||||
"created_at": getattr(dataset, 'created_at', None),
|
||||
"updated_at": getattr(dataset, 'updated_at', None),
|
||||
"metadata": getattr(dataset, 'metadata', {}),
|
||||
}
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
if (search_lower not in (dataset_info.get('name') or '').lower() and
|
||||
search_lower not in (dataset_info.get('description') or '').lower()):
|
||||
continue
|
||||
|
||||
items.append(dataset_info)
|
||||
|
||||
return {
|
||||
"items": items,
|
||||
"total": len(items), # RAGFlow 可能不返回总数,使用实际返回数量
|
||||
"page": page,
|
||||
"page_size": page_size
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list datasets: {e}")
|
||||
raise
|
||||
|
||||
async def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取数据集详情
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
|
||||
Returns:
|
||||
数据集详情,不存在返回 None
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
datasets = client.list_datasets(id=dataset_id)
|
||||
if datasets and len(datasets) > 0:
|
||||
dataset = datasets[0]
|
||||
return {
|
||||
"dataset_id": getattr(dataset, 'id', dataset_id),
|
||||
"name": getattr(dataset, 'name', None),
|
||||
"description": getattr(dataset, 'description', None),
|
||||
"chunk_method": getattr(dataset, 'chunk_method', None),
|
||||
"permission": getattr(dataset, 'permission', None),
|
||||
"created_at": getattr(dataset, 'created_at', None),
|
||||
"updated_at": getattr(dataset, 'updated_at', None),
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get dataset {dataset_id}: {e}")
|
||||
raise
|
||||
|
||||
async def update_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
**updates
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
更新数据集
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
**updates: 要更新的字段
|
||||
|
||||
Returns:
|
||||
更新后的数据集信息
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
datasets = client.list_datasets(id=dataset_id)
|
||||
if datasets and len(datasets) > 0:
|
||||
dataset = datasets[0]
|
||||
# 调用 update 方法
|
||||
dataset.update(updates)
|
||||
|
||||
return {
|
||||
"dataset_id": getattr(dataset, 'id', dataset_id),
|
||||
"name": getattr(dataset, 'name', None),
|
||||
"description": getattr(dataset, 'description', None),
|
||||
"chunk_method": getattr(dataset, 'chunk_method', None),
|
||||
"updated_at": getattr(dataset, 'updated_at', None),
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update dataset {dataset_id}: {e}")
|
||||
raise
|
||||
|
||||
async def delete_datasets(self, dataset_ids: List[str] = None) -> bool:
|
||||
"""
|
||||
删除数据集
|
||||
|
||||
Args:
|
||||
dataset_ids: 要删除的数据集 ID 列表
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
if dataset_ids:
|
||||
client.delete_datasets(ids=dataset_ids)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete datasets: {e}")
|
||||
raise
|
||||
|
||||
async def upload_document(
|
||||
self,
|
||||
dataset_id: str,
|
||||
file_name: str,
|
||||
file_content: bytes,
|
||||
display_name: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
上传文档到数据集
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
file_name: 文件名
|
||||
file_content: 文件内容
|
||||
display_name: 显示名称
|
||||
|
||||
Returns:
|
||||
上传的文档信息
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
# 获取数据集
|
||||
datasets = client.list_datasets(id=dataset_id)
|
||||
if not datasets or len(datasets) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} not found")
|
||||
|
||||
dataset = datasets[0]
|
||||
|
||||
# 上传文档
|
||||
display_name = display_name or file_name
|
||||
dataset.upload_documents([{
|
||||
"display_name": display_name,
|
||||
"blob": file_content
|
||||
}])
|
||||
|
||||
# 查找刚上传的文档
|
||||
documents = dataset.list_documents()
|
||||
for doc in documents:
|
||||
if getattr(doc, 'name', None) == display_name:
|
||||
return {
|
||||
"document_id": getattr(doc, 'id', None),
|
||||
"name": display_name,
|
||||
"dataset_id": dataset_id,
|
||||
"size": len(file_content),
|
||||
"status": "running",
|
||||
"chunk_count": getattr(doc, 'chunk_count', 0),
|
||||
"token_count": getattr(doc, 'token_count', 0),
|
||||
"created_at": getattr(doc, 'created_at', None),
|
||||
}
|
||||
|
||||
return {
|
||||
"document_id": None,
|
||||
"name": display_name,
|
||||
"dataset_id": dataset_id,
|
||||
"size": len(file_content),
|
||||
"status": "uploaded",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload document to {dataset_id}: {e}")
|
||||
raise
|
||||
|
||||
async def list_documents(
|
||||
self,
|
||||
dataset_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取数据集中的文档列表
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
page: 页码
|
||||
page_size: 每页数量
|
||||
|
||||
Returns:
|
||||
文档列表和分页信息
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
datasets = client.list_datasets(id=dataset_id)
|
||||
if not datasets or len(datasets) == 0:
|
||||
return {"items": [], "total": 0, "page": page, "page_size": page_size}
|
||||
|
||||
dataset = datasets[0]
|
||||
documents = dataset.list_documents(
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
items = []
|
||||
for doc in documents:
|
||||
items.append({
|
||||
"document_id": getattr(doc, 'id', None),
|
||||
"name": getattr(doc, 'name', None),
|
||||
"dataset_id": dataset_id,
|
||||
"size": getattr(doc, 'size', 0),
|
||||
"status": getattr(doc, 'run', 'unknown'),
|
||||
"progress": getattr(doc, 'progress', 0),
|
||||
"chunk_count": getattr(doc, 'chunk_count', 0),
|
||||
"token_count": getattr(doc, 'token_count', 0),
|
||||
"created_at": getattr(doc, 'created_at', None),
|
||||
"updated_at": getattr(doc, 'updated_at', None),
|
||||
})
|
||||
|
||||
return {
|
||||
"items": items,
|
||||
"total": len(items),
|
||||
"page": page,
|
||||
"page_size": page_size
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list documents for {dataset_id}: {e}")
|
||||
raise
|
||||
|
||||
async def delete_document(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
删除文档
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
document_id: 文档 ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
datasets = client.list_datasets(id=dataset_id)
|
||||
if datasets and len(datasets) > 0:
|
||||
dataset = datasets[0]
|
||||
dataset.delete_documents(ids=[document_id])
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document {document_id}: {e}")
|
||||
raise
|
||||
|
||||
async def list_chunks(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_id: str = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取切片列表
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
document_id: 文档 ID(可选)
|
||||
page: 页码
|
||||
page_size: 每页数量
|
||||
|
||||
Returns:
|
||||
切片列表和分页信息
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
datasets = client.list_datasets(id=dataset_id)
|
||||
if not datasets or len(datasets) == 0:
|
||||
return {"items": [], "total": 0, "page": page, "page_size": page_size}
|
||||
|
||||
dataset = datasets[0]
|
||||
|
||||
# 如果指定了文档 ID,先获取文档
|
||||
if document_id:
|
||||
documents = dataset.list_documents(id=document_id)
|
||||
if documents and len(documents) > 0:
|
||||
doc = documents[0]
|
||||
chunks = doc.list_chunks(page=page, page_size=page_size)
|
||||
else:
|
||||
chunks = []
|
||||
else:
|
||||
# 获取所有文档的所有切片
|
||||
chunks = []
|
||||
for doc in dataset.list_documents():
|
||||
chunks.extend(doc.list_chunks(page=page, page_size=page_size))
|
||||
|
||||
items = []
|
||||
for chunk in chunks:
|
||||
items.append({
|
||||
"chunk_id": getattr(chunk, 'id', None),
|
||||
"content": getattr(chunk, 'content', ''),
|
||||
"document_id": getattr(chunk, 'document_id', None),
|
||||
"dataset_id": dataset_id,
|
||||
"position": getattr(chunk, 'position', 0),
|
||||
"important_keywords": getattr(chunk, 'important_keywords', []),
|
||||
"available": getattr(chunk, 'available', True),
|
||||
"created_at": getattr(chunk, 'create_time', None),
|
||||
})
|
||||
|
||||
return {
|
||||
"items": items,
|
||||
"total": len(items),
|
||||
"page": page,
|
||||
"page_size": page_size
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list chunks for {dataset_id}: {e}")
|
||||
raise
|
||||
|
||||
async def delete_chunk(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
chunk_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
删除切片
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
document_id: 文档 ID
|
||||
chunk_id: 切片 ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
datasets = client.list_datasets(id=dataset_id)
|
||||
if datasets and len(datasets) > 0:
|
||||
dataset = datasets[0]
|
||||
documents = dataset.list_documents(id=document_id)
|
||||
if documents and len(documents) > 0:
|
||||
doc = documents[0]
|
||||
doc.delete_chunks(chunk_ids=[chunk_id])
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete chunk {chunk_id}: {e}")
|
||||
raise
|
||||
|
||||
async def parse_document(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
开始解析文档
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
document_id: 文档 ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
datasets = client.list_datasets(id=dataset_id)
|
||||
if not datasets or len(datasets) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} not found")
|
||||
|
||||
dataset = datasets[0]
|
||||
dataset.async_parse_documents([document_id])
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse document {document_id}: {e}")
|
||||
raise
|
||||
|
||||
async def cancel_parse_document(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
取消解析文档
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
document_id: 文档 ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
datasets = client.list_datasets(id=dataset_id)
|
||||
if not datasets or len(datasets) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} not found")
|
||||
|
||||
dataset = datasets[0]
|
||||
dataset.async_cancel_parse_documents([document_id])
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cancel parse document {document_id}: {e}")
|
||||
raise
|
||||
@ -425,7 +425,7 @@ class BotSettingsUpdate(BaseModel):
|
||||
avatar_url: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
suggestions: Optional[List[str]] = None
|
||||
dataset_ids: Optional[str] = None
|
||||
dataset_ids: Optional[List[str]] = None # 改为数组类型,支持多选知识库
|
||||
system_prompt: Optional[str] = None
|
||||
enable_memori: Optional[bool] = None
|
||||
enable_thinking: Optional[bool] = None
|
||||
@ -452,7 +452,7 @@ class BotSettingsResponse(BaseModel):
|
||||
avatar_url: Optional[str]
|
||||
description: Optional[str]
|
||||
suggestions: Optional[List[str]]
|
||||
dataset_ids: Optional[str]
|
||||
dataset_ids: Optional[List[str]] # 改为数组类型
|
||||
system_prompt: Optional[str]
|
||||
enable_memori: bool
|
||||
enable_thinking: bool
|
||||
@ -1544,6 +1544,13 @@ async def get_bot_settings(bot_uuid: str, authorization: Optional[str] = Header(
|
||||
api_key=mask_api_key(model_row[5])
|
||||
)
|
||||
|
||||
# 处理 dataset_ids:将字符串转换为数组
|
||||
dataset_ids = settings.get('dataset_ids')
|
||||
if dataset_ids and isinstance(dataset_ids, str):
|
||||
dataset_ids = [id.strip() for id in dataset_ids.split(',') if id.strip()]
|
||||
elif not dataset_ids:
|
||||
dataset_ids = None
|
||||
|
||||
return BotSettingsResponse(
|
||||
bot_id=str(bot_id),
|
||||
model_id=model_id,
|
||||
@ -1552,7 +1559,7 @@ async def get_bot_settings(bot_uuid: str, authorization: Optional[str] = Header(
|
||||
avatar_url=settings.get('avatar_url'),
|
||||
description=settings.get('description'),
|
||||
suggestions=settings.get('suggestions'),
|
||||
dataset_ids=settings.get('dataset_ids'),
|
||||
dataset_ids=dataset_ids,
|
||||
system_prompt=settings.get('system_prompt'),
|
||||
enable_memori=settings.get('enable_memori', False),
|
||||
enable_thinking=settings.get('enable_thinking', False),
|
||||
@ -1623,7 +1630,8 @@ async def update_bot_settings(
|
||||
if request.suggestions is not None:
|
||||
update_json['suggestions'] = request.suggestions
|
||||
if request.dataset_ids is not None:
|
||||
update_json['dataset_ids'] = request.dataset_ids
|
||||
# 将数组转换为逗号分隔的字符串存储
|
||||
update_json['dataset_ids'] = ','.join(request.dataset_ids) if request.dataset_ids else None
|
||||
if request.system_prompt is not None:
|
||||
update_json['system_prompt'] = request.system_prompt
|
||||
if request.enable_memori is not None:
|
||||
|
||||
369
routes/knowledge_base.py
Normal file
369
routes/knowledge_base.py
Normal file
@ -0,0 +1,369 @@
|
||||
"""
|
||||
Knowledge Base API 路由
|
||||
通过 RAGFlow SDK 提供知识库管理功能
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, Header, UploadFile, File, Query, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from agent.db_pool_manager import get_db_pool_manager
|
||||
from utils.fastapi_utils import extract_api_key_from_auth
|
||||
from repositories.ragflow_repository import RAGFlowRepository
|
||||
from services.knowledge_base_service import KnowledgeBaseService
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ============== 依赖注入 ==============
|
||||
async def get_kb_service() -> KnowledgeBaseService:
|
||||
"""获取知识库服务实例"""
|
||||
return KnowledgeBaseService(RAGFlowRepository())
|
||||
|
||||
|
||||
async def verify_user(authorization: Optional[str] = Header(None)) -> tuple:
|
||||
"""
|
||||
验证用户权限(检查 agent_user_tokens 表)
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: (user_id, username)
|
||||
"""
|
||||
from routes.bot_manager import verify_user_auth
|
||||
|
||||
valid, user_id, username = await verify_user_auth(authorization)
|
||||
if not valid:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
return user_id, username
|
||||
|
||||
|
||||
# ============== 数据库表初始化 ==============
|
||||
|
||||
async def init_knowledge_base_tables():
|
||||
"""
|
||||
初始化知识库相关的数据库表
|
||||
"""
|
||||
pool = get_db_pool_manager().pool
|
||||
|
||||
async with pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# 检查 user_datasets 表是否已存在
|
||||
await cursor.execute("""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = 'user_datasets'
|
||||
)
|
||||
""")
|
||||
table_exists = (await cursor.fetchone())[0]
|
||||
|
||||
if not table_exists:
|
||||
logger.info("Creating user_datasets table")
|
||||
|
||||
await cursor.execute("""
|
||||
CREATE TABLE user_datasets (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES agent_user(id) ON DELETE CASCADE,
|
||||
dataset_id VARCHAR(255) NOT NULL,
|
||||
dataset_name VARCHAR(255),
|
||||
owner BOOLEAN DEFAULT TRUE,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||||
UNIQUE(user_id, dataset_id)
|
||||
)
|
||||
""")
|
||||
|
||||
await cursor.execute("CREATE INDEX idx_user_datasets_user_id ON user_datasets(user_id)")
|
||||
await cursor.execute("CREATE INDEX idx_user_datasets_dataset_id ON user_datasets(dataset_id)")
|
||||
|
||||
logger.info("user_datasets table created successfully")
|
||||
|
||||
await conn.commit()
|
||||
logger.info("Knowledge base tables initialized successfully")
|
||||
|
||||
|
||||
# ============== Pydantic Models ==============
|
||||
|
||||
class DatasetCreate(BaseModel):
|
||||
"""创建数据集请求"""
|
||||
name: str = Field(..., min_length=1, max_length=128, description="数据集名称")
|
||||
description: Optional[str] = Field(None, max_length=500, description="描述信息")
|
||||
chunk_method: str = Field(
|
||||
default="naive",
|
||||
description="分块方法: naive, manual, qa, table, paper, book, laws, presentation, picture, one, email, knowledge-graph"
|
||||
)
|
||||
|
||||
|
||||
class DatasetUpdate(BaseModel):
|
||||
"""更新数据集请求(部分更新)"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=128)
|
||||
description: Optional[str] = Field(None, max_length=500)
|
||||
chunk_method: Optional[str] = None
|
||||
|
||||
|
||||
class DatasetListResponse(BaseModel):
|
||||
"""数据集列表响应(分页)"""
|
||||
items: list
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
class FileListResponse(BaseModel):
|
||||
"""文件列表响应(分页)"""
|
||||
items: list
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
class ChunkListResponse(BaseModel):
|
||||
"""切片列表响应(分页)"""
|
||||
items: list
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
# ============== 数据集端点 ==============
|
||||
|
||||
@router.get("/datasets", response_model=DatasetListResponse)
|
||||
async def list_datasets(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
user_info: tuple = Depends(verify_user),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""获取当前用户的数据集列表(支持分页和搜索)"""
|
||||
user_id, username = user_info
|
||||
return await kb_service.list_datasets(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
search=search
|
||||
)
|
||||
|
||||
|
||||
@router.post("/datasets", status_code=201)
|
||||
async def create_dataset(
|
||||
data: DatasetCreate,
|
||||
user_info: tuple = Depends(verify_user),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""创建数据集并关联到当前用户"""
|
||||
try:
|
||||
user_id, username = user_info
|
||||
dataset = await kb_service.create_dataset(
|
||||
user_id=user_id,
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
chunk_method=data.chunk_method
|
||||
)
|
||||
return dataset
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create dataset: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"创建数据集失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/datasets/{dataset_id}")
|
||||
async def get_dataset(
|
||||
dataset_id: str,
|
||||
user_info: tuple = Depends(verify_user),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""获取数据集详情(仅限自己的数据集)"""
|
||||
user_id, username = user_info
|
||||
dataset = await kb_service.get_dataset(dataset_id, user_id=user_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="数据集不存在")
|
||||
return dataset
|
||||
|
||||
|
||||
@router.patch("/datasets/{dataset_id}")
|
||||
async def update_dataset(
|
||||
dataset_id: str,
|
||||
data: DatasetUpdate,
|
||||
user_info: tuple = Depends(verify_user),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""更新数据集(部分更新)"""
|
||||
try:
|
||||
user_id, username = user_info
|
||||
# 只传递非 None 的字段
|
||||
updates = data.model_dump(exclude_unset=True)
|
||||
if not updates:
|
||||
raise HTTPException(status_code=400, detail="没有提供要更新的字段")
|
||||
|
||||
dataset = await kb_service.update_dataset(dataset_id, updates, user_id=user_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="数据集不存在")
|
||||
return dataset
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update dataset: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新数据集失败: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/datasets/{dataset_id}")
|
||||
async def delete_dataset(
|
||||
dataset_id: str,
|
||||
user_info: tuple = Depends(verify_user),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""删除数据集"""
|
||||
user_id, username = user_info
|
||||
success = await kb_service.delete_dataset(dataset_id, user_id=user_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="数据集不存在")
|
||||
return {"success": True, "message": "数据集已删除"}
|
||||
|
||||
|
||||
# ============== 文件端点 ==============
|
||||
|
||||
@router.get("/datasets/{dataset_id}/files", response_model=FileListResponse)
|
||||
async def list_dataset_files(
|
||||
dataset_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
user_info: tuple = Depends(verify_user),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""获取数据集内文件列表(分页,仅限自己的数据集)"""
|
||||
user_id, username = user_info
|
||||
return await kb_service.list_files(dataset_id, user_id=user_id, page=page, page_size=page_size)
|
||||
|
||||
|
||||
@router.post("/datasets/{dataset_id}/files")
|
||||
async def upload_file(
|
||||
dataset_id: str,
|
||||
file: UploadFile = File(...),
|
||||
user_info: tuple = Depends(verify_user),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""
|
||||
上传文件到数据集(流式处理)
|
||||
|
||||
支持的文件类型: PDF, DOCX, TXT, MD, CSV
|
||||
最大文件大小: 100MB
|
||||
"""
|
||||
try:
|
||||
user_id, username = user_info
|
||||
result = await kb_service.upload_file(dataset_id, user_id=user_id, file=file)
|
||||
return result
|
||||
except ValueError as e:
|
||||
if "File validation failed" in str(e) or "not belong to you" in str(e):
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
logger.error(f"Failed to upload file: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload file: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/datasets/{dataset_id}/files/{document_id}")
|
||||
async def delete_file(
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
user_info: tuple = Depends(verify_user),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""删除文件"""
|
||||
user_id, username = user_info
|
||||
success = await kb_service.delete_file(dataset_id, document_id, user_id=user_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
return {"success": True}
|
||||
|
||||
|
||||
# ============== 切片端点 ==============
|
||||
|
||||
@router.get("/datasets/{dataset_id}/chunks", response_model=ChunkListResponse)
|
||||
async def list_chunks(
|
||||
dataset_id: str,
|
||||
document_id: Optional[str] = Query(None, description="文档 ID(可选)"),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
user_info: tuple = Depends(verify_user),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""获取数据集内切片列表(分页,仅限自己的数据集)"""
|
||||
user_id, username = user_info
|
||||
return await kb_service.list_chunks(
|
||||
user_id=user_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/datasets/{dataset_id}/chunks/{chunk_id}")
|
||||
async def delete_chunk(
|
||||
dataset_id: str,
|
||||
chunk_id: str,
|
||||
document_id: str = Query(..., description="文档 ID"),
|
||||
user_info: tuple = Depends(verify_user),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""删除切片"""
|
||||
user_id, username = user_info
|
||||
success = await kb_service.delete_chunk(dataset_id, document_id, chunk_id, user_id=user_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="切片不存在")
|
||||
return {"success": True}
|
||||
|
||||
|
||||
# ============== 文档解析端点 ==============
|
||||
|
||||
@router.post("/datasets/{dataset_id}/documents/{document_id}/parse")
|
||||
async def parse_document(
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
user_info: tuple = Depends(verify_user),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""开始解析文档"""
|
||||
try:
|
||||
user_id, username = user_info
|
||||
result = await kb_service.parse_document(dataset_id, document_id, user_id=user_id)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse document: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"启动解析失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/datasets/{dataset_id}/documents/{document_id}/cancel-parse")
|
||||
async def cancel_parse_document(
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
user_info: tuple = Depends(verify_user),
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""取消解析文档"""
|
||||
try:
|
||||
user_id, username = user_info
|
||||
result = await kb_service.cancel_parse_document(dataset_id, document_id, user_id=user_id)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cancel parse: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"取消解析失败: {str(e)}")
|
||||
|
||||
|
||||
# ============== Bot 数据集关联端点 ==============
|
||||
|
||||
@router.get("/bots/{bot_id}/datasets")
|
||||
async def get_bot_datasets(
|
||||
bot_id: str,
|
||||
kb_service: KnowledgeBaseService = Depends(get_kb_service)
|
||||
):
|
||||
"""
|
||||
获取 bot 关联的数据集 ID 列表
|
||||
|
||||
用于 MCP 服务器通过 bot_id 获取对应的数据集 IDs
|
||||
"""
|
||||
try:
|
||||
dataset_ids = await kb_service.get_dataset_ids_by_bot(bot_id)
|
||||
return {"dataset_ids": dataset_ids}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get datasets for bot {bot_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取数据集失败: {str(e)}")
|
||||
6
services/__init__.py
Normal file
6
services/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""
|
||||
Services package for business logic layer
|
||||
"""
|
||||
from .knowledge_base_service import KnowledgeBaseService
|
||||
|
||||
__all__ = ['KnowledgeBaseService']
|
||||
627
services/knowledge_base_service.py
Normal file
627
services/knowledge_base_service.py
Normal file
@ -0,0 +1,627 @@
|
||||
"""
|
||||
Knowledge Base Service - 业务逻辑层
|
||||
提供知识库管理的业务逻辑,协调数据访问和业务规则
|
||||
"""
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.db_pool_manager import get_db_pool_manager
|
||||
from repositories.ragflow_repository import RAGFlowRepository
|
||||
from utils.settings import (
|
||||
RAGFLOW_MAX_UPLOAD_SIZE,
|
||||
RAGFLOW_ALLOWED_EXTENSIONS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
|
||||
class FileValidationError(Exception):
|
||||
"""文件验证错误"""
|
||||
pass
|
||||
|
||||
|
||||
class KnowledgeBaseService:
|
||||
"""
|
||||
知识库服务类
|
||||
|
||||
提供知识库管理的业务逻辑:
|
||||
- 数据集 CRUD
|
||||
- 文件上传和管理
|
||||
- 文件验证
|
||||
"""
|
||||
|
||||
def __init__(self, repository: RAGFlowRepository):
|
||||
"""
|
||||
初始化服务
|
||||
|
||||
Args:
|
||||
repository: RAGFlow 数据仓储实例
|
||||
"""
|
||||
self.repository = repository
|
||||
|
||||
def _validate_file(self, filename: str, content: bytes) -> None:
|
||||
"""
|
||||
验证文件
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
content: 文件内容
|
||||
|
||||
Raises:
|
||||
FileValidationError: 验证失败时抛出
|
||||
"""
|
||||
# 检查文件名
|
||||
if not filename or filename == "unknown":
|
||||
raise FileValidationError("无效的文件名")
|
||||
|
||||
# 检查路径遍历
|
||||
if '..' in filename or '/' in filename or '\\' in filename:
|
||||
raise FileValidationError("文件名包含非法字符")
|
||||
|
||||
# 检查文件扩展名(去掉点号进行比较)
|
||||
ext = Path(filename).suffix.lower().lstrip('.')
|
||||
if ext not in RAGFLOW_ALLOWED_EXTENSIONS:
|
||||
allowed = ', '.join(RAGFLOW_ALLOWED_EXTENSIONS)
|
||||
raise FileValidationError(f"不支持的文件类型: {ext}。支持的类型: {allowed}")
|
||||
|
||||
# 检查文件大小
|
||||
file_size = len(content)
|
||||
if file_size > RAGFLOW_MAX_UPLOAD_SIZE:
|
||||
size_mb = file_size / (1024 * 1024)
|
||||
max_mb = RAGFLOW_MAX_UPLOAD_SIZE / (1024 * 1024)
|
||||
raise FileValidationError(f"文件过大: {size_mb:.1f}MB (最大 {max_mb}MB)")
|
||||
|
||||
# 验证 MIME 类型(使用 mimetypes 标准库)
|
||||
detected_mime, _ = mimetypes.guess_type(filename)
|
||||
logger.info(f"File {filename} detected as {detected_mime}")
|
||||
|
||||
# ============== 数据集管理 ==============
|
||||
|
||||
async def _check_dataset_access(self, dataset_id: str, user_id: str) -> bool:
|
||||
"""
|
||||
检查用户是否有权访问该数据集
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
是否有权限
|
||||
"""
|
||||
pool = get_db_pool_manager().pool
|
||||
async with pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
SELECT id FROM user_datasets
|
||||
WHERE user_id = %s AND dataset_id = %s
|
||||
""", (user_id, dataset_id))
|
||||
return await cursor.fetchone() is not None
|
||||
|
||||
async def list_datasets(
|
||||
self,
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
search: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取用户的数据集列表(从本地数据库过滤)
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
page: 页码
|
||||
page_size: 每页数量
|
||||
search: 搜索关键词
|
||||
|
||||
Returns:
|
||||
数据集列表和分页信息
|
||||
"""
|
||||
logger.info(f"Listing datasets for user {user_id}: page={page}, page_size={page_size}, search={search}")
|
||||
|
||||
pool = get_db_pool_manager().pool
|
||||
|
||||
# 从本地数据库获取用户的数据集 ID 列表
|
||||
async with pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# 构建查询条件
|
||||
where_conditions = ["user_id = %s"]
|
||||
params = [user_id]
|
||||
|
||||
if search:
|
||||
where_conditions.append("dataset_name ILIKE %s")
|
||||
params.append(f"%{search}%")
|
||||
|
||||
where_clause = " AND ".join(where_conditions)
|
||||
|
||||
# 获取总数
|
||||
await cursor.execute(f"""
|
||||
SELECT COUNT(*) FROM user_datasets
|
||||
WHERE {where_clause}
|
||||
""", params)
|
||||
total = (await cursor.fetchone())[0]
|
||||
|
||||
# 获取分页数据
|
||||
offset = (page - 1) * page_size
|
||||
await cursor.execute(f"""
|
||||
SELECT dataset_id, dataset_name, created_at
|
||||
FROM user_datasets
|
||||
WHERE {where_clause}
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
""", params + [page_size, offset])
|
||||
|
||||
user_datasets = await cursor.fetchall()
|
||||
|
||||
if not user_datasets:
|
||||
return {
|
||||
"items": [],
|
||||
"total": 0,
|
||||
"page": page,
|
||||
"page_size": page_size
|
||||
}
|
||||
|
||||
# 获取数据集 ID 列表,从 RAGFlow 获取详情
|
||||
dataset_ids = [row[0] for row in user_datasets]
|
||||
dataset_names = {row[0]: row[1] for row in user_datasets}
|
||||
|
||||
# 从 RAGFlow 获取完整的数据集信息
|
||||
ragflow_result = await self.repository.list_datasets(
|
||||
page=1,
|
||||
page_size=1000 # 获取所有数据集,然后在本地过滤
|
||||
)
|
||||
|
||||
# 过滤出属于该用户的数据集
|
||||
user_dataset_ids_set = set(dataset_ids)
|
||||
items = []
|
||||
for item in ragflow_result["items"]:
|
||||
if item.get("dataset_id") in user_dataset_ids_set:
|
||||
items.append(item)
|
||||
|
||||
return {
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size
|
||||
}
|
||||
|
||||
async def create_dataset(
|
||||
self,
|
||||
user_id: str,
|
||||
name: str,
|
||||
description: str = None,
|
||||
chunk_method: str = "naive"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
创建数据集并关联到用户
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
name: 数据集名称
|
||||
description: 描述信息
|
||||
chunk_method: 分块方法
|
||||
|
||||
Returns:
|
||||
创建的数据集信息
|
||||
"""
|
||||
logger.info(f"Creating dataset for user {user_id}: name={name}, chunk_method={chunk_method}")
|
||||
|
||||
# 验证分块方法
|
||||
valid_methods = [
|
||||
"naive", "manual", "qa", "table", "paper",
|
||||
"book", "laws", "presentation", "picture", "one", "email", "knowledge-graph"
|
||||
]
|
||||
if chunk_method not in valid_methods:
|
||||
raise ValueError(f"无效的分块方法: {chunk_method}。支持的方法: {', '.join(valid_methods)}")
|
||||
|
||||
# 先在 RAGFlow 创建数据集
|
||||
result = await self.repository.create_dataset(
|
||||
name=name,
|
||||
description=description,
|
||||
chunk_method=chunk_method,
|
||||
permission="me"
|
||||
)
|
||||
|
||||
# 记录到本地数据库
|
||||
dataset_id = result.get("dataset_id")
|
||||
if dataset_id:
|
||||
pool = get_db_pool_manager().pool
|
||||
async with pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
INSERT INTO user_datasets (user_id, dataset_id, dataset_name, owner)
|
||||
VALUES (%s, %s, %s, TRUE)
|
||||
""", (user_id, dataset_id, name))
|
||||
await conn.commit()
|
||||
logger.info(f"Dataset {dataset_id} associated with user {user_id}")
|
||||
|
||||
return result
|
||||
|
||||
async def get_dataset(self, dataset_id: str, user_id: str = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取数据集详情
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
user_id: 用户 ID(可选,用于权限验证)
|
||||
|
||||
Returns:
|
||||
数据集详情,不存在或无权限返回 None
|
||||
"""
|
||||
logger.info(f"Getting dataset: {dataset_id} for user: {user_id}")
|
||||
|
||||
# 如果提供了 user_id,先检查权限
|
||||
if user_id:
|
||||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||||
if not has_access:
|
||||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||||
return None
|
||||
|
||||
return await self.repository.get_dataset(dataset_id)
|
||||
|
||||
async def update_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
updates: Dict[str, Any],
|
||||
user_id: str = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
更新数据集
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
updates: 要更新的字段
|
||||
user_id: 用户 ID(可选,用于权限验证)
|
||||
|
||||
Returns:
|
||||
更新后的数据集信息
|
||||
"""
|
||||
logger.info(f"Updating dataset {dataset_id}: {updates}")
|
||||
|
||||
# 如果提供了 user_id,先检查权限
|
||||
if user_id:
|
||||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||||
if not has_access:
|
||||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||||
return None
|
||||
|
||||
result = await self.repository.update_dataset(dataset_id, **updates)
|
||||
|
||||
# 如果更新了名称,同步更新本地数据库
|
||||
if result and user_id and 'name' in updates:
|
||||
pool = get_db_pool_manager().pool
|
||||
async with pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
UPDATE user_datasets
|
||||
SET dataset_name = %s
|
||||
WHERE user_id = %s AND dataset_id = %s
|
||||
""", (updates['name'], user_id, dataset_id))
|
||||
await conn.commit()
|
||||
|
||||
return result
|
||||
|
||||
async def delete_dataset(self, dataset_id: str, user_id: str = None) -> bool:
|
||||
"""
|
||||
删除数据集
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
user_id: 用户 ID(可选,用于权限验证)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
logger.info(f"Deleting dataset: {dataset_id}")
|
||||
|
||||
# 如果提供了 user_id,先检查权限
|
||||
if user_id:
|
||||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||||
if not has_access:
|
||||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||||
return False
|
||||
|
||||
# 从 RAGFlow 删除
|
||||
result = await self.repository.delete_datasets([dataset_id])
|
||||
|
||||
# 从本地数据库删除关联记录
|
||||
if result and user_id:
|
||||
pool = get_db_pool_manager().pool
|
||||
async with pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
DELETE FROM user_datasets
|
||||
WHERE user_id = %s AND dataset_id = %s
|
||||
""", (user_id, dataset_id))
|
||||
await conn.commit()
|
||||
logger.info(f"Dataset {dataset_id} unlinked from user {user_id}")
|
||||
|
||||
return result
|
||||
|
||||
# ============== 文件管理 ==============
|
||||
|
||||
async def list_files(
|
||||
self,
|
||||
dataset_id: str,
|
||||
user_id: str = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取数据集中的文件列表
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
user_id: 用户 ID(可选,用于权限验证)
|
||||
page: 页码
|
||||
page_size: 每页数量
|
||||
|
||||
Returns:
|
||||
文件列表和分页信息
|
||||
"""
|
||||
logger.info(f"Listing files for dataset {dataset_id}: page={page}, page_size={page_size}")
|
||||
|
||||
# 如果提供了 user_id,先检查权限
|
||||
if user_id:
|
||||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||||
if not has_access:
|
||||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||||
raise ValueError("Dataset not found or does not belong to you")
|
||||
|
||||
return await self.repository.list_documents(
|
||||
dataset_id=dataset_id,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
dataset_id: str,
|
||||
user_id: str = None,
|
||||
file=None,
|
||||
chunk_size: int = 1024 * 1024 # 1MB chunks
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
上传文件到数据集(流式处理)
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
user_id: 用户 ID(可选,用于权限验证)
|
||||
file: FastAPI UploadFile 对象
|
||||
chunk_size: 分块大小
|
||||
|
||||
Returns:
|
||||
上传的文档信息
|
||||
"""
|
||||
filename = file.filename or "unknown"
|
||||
|
||||
logger.info(f"Uploading file {filename} to dataset {dataset_id}")
|
||||
|
||||
# 如果提供了 user_id,先检查权限
|
||||
if user_id:
|
||||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||||
if not has_access:
|
||||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||||
raise ValueError("Dataset not found or does not belong to you")
|
||||
|
||||
# 流式读取文件内容
|
||||
content = await file.read()
|
||||
|
||||
# 验证文件
|
||||
try:
|
||||
self._validate_file(filename, content)
|
||||
except FileValidationError as e:
|
||||
logger.warning(f"File validation failed: {e}")
|
||||
raise
|
||||
|
||||
# 上传到 RAGFlow
|
||||
result = await self.repository.upload_document(
|
||||
dataset_id=dataset_id,
|
||||
file_name=filename,
|
||||
file_content=content,
|
||||
display_name=filename
|
||||
)
|
||||
|
||||
logger.info(f"File {filename} uploaded successfully")
|
||||
return result
|
||||
|
||||
async def delete_file(self, dataset_id: str, document_id: str, user_id: str = None) -> bool:
|
||||
"""
|
||||
删除文件
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
document_id: 文档 ID
|
||||
user_id: 用户 ID(可选,用于权限验证)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
logger.info(f"Deleting file {document_id} from dataset {dataset_id}")
|
||||
|
||||
# 如果提供了 user_id,先检查权限
|
||||
if user_id:
|
||||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||||
if not has_access:
|
||||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||||
return False
|
||||
|
||||
return await self.repository.delete_document(dataset_id, document_id)
|
||||
|
||||
# ============== 切片管理 ==============
|
||||
|
||||
async def list_chunks(
|
||||
self,
|
||||
dataset_id: str,
|
||||
user_id: str = None,
|
||||
document_id: str = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取切片列表
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
user_id: 用户 ID(可选,用于权限验证)
|
||||
document_id: 文档 ID(可选)
|
||||
page: 页码
|
||||
page_size: 每页数量
|
||||
|
||||
Returns:
|
||||
切片列表和分页信息
|
||||
"""
|
||||
logger.info(f"Listing chunks for dataset {dataset_id}, document {document_id}")
|
||||
|
||||
# 如果提供了 user_id,先检查权限
|
||||
if user_id:
|
||||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||||
if not has_access:
|
||||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||||
raise ValueError("Dataset not found or does not belong to you")
|
||||
|
||||
return await self.repository.list_chunks(
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
async def delete_chunk(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
chunk_id: str,
|
||||
user_id: str = None
|
||||
) -> bool:
|
||||
"""
|
||||
删除切片
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
document_id: 文档 ID
|
||||
chunk_id: 切片 ID
|
||||
user_id: 用户 ID(可选,用于权限验证)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
logger.info(f"Deleting chunk {chunk_id} from document {document_id}")
|
||||
|
||||
# 如果提供了 user_id,先检查权限
|
||||
if user_id:
|
||||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||||
if not has_access:
|
||||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||||
return False
|
||||
|
||||
return await self.repository.delete_chunk(dataset_id, document_id, chunk_id)
|
||||
|
||||
async def parse_document(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
user_id: str = None
|
||||
) -> dict:
|
||||
"""
|
||||
开始解析文档
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
document_id: 文档 ID
|
||||
user_id: 用户 ID(可选,用于权限验证)
|
||||
|
||||
Returns:
|
||||
操作结果
|
||||
"""
|
||||
logger.info(f"Parsing document {document_id} in dataset {dataset_id}")
|
||||
|
||||
# 如果提供了 user_id,先检查权限
|
||||
if user_id:
|
||||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||||
if not has_access:
|
||||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||||
raise ValueError("Dataset not found or does not belong to you")
|
||||
|
||||
success = await self.repository.parse_document(dataset_id, document_id)
|
||||
return {"success": success, "message": "解析任务已启动"}
|
||||
|
||||
async def cancel_parse_document(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
user_id: str = None
|
||||
) -> dict:
|
||||
"""
|
||||
取消解析文档
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集 ID
|
||||
document_id: 文档 ID
|
||||
user_id: 用户 ID(可选,用于权限验证)
|
||||
|
||||
Returns:
|
||||
操作结果
|
||||
"""
|
||||
logger.info(f"Cancelling parse for document {document_id} in dataset {dataset_id}")
|
||||
|
||||
# 如果提供了 user_id,先检查权限
|
||||
if user_id:
|
||||
has_access = await self._check_dataset_access(dataset_id, user_id)
|
||||
if not has_access:
|
||||
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
|
||||
raise ValueError("Dataset not found or does not belong to you")
|
||||
|
||||
success = await self.repository.cancel_parse_document(dataset_id, document_id)
|
||||
return {"success": success, "message": "解析任务已取消"}
|
||||
|
||||
# ============== Bot 数据集关联管理 ==============
|
||||
|
||||
async def get_dataset_ids_by_bot(self, bot_id: str) -> list[str]:
|
||||
"""
|
||||
根据 bot_id 获取关联的数据集 ID 列表
|
||||
|
||||
Args:
|
||||
bot_id: Bot ID (agent_bots 表中的 bot_id 字段)
|
||||
|
||||
Returns:
|
||||
数据集 ID 列表
|
||||
"""
|
||||
logger.info(f"Getting dataset_ids for bot_id: {bot_id}")
|
||||
|
||||
pool = get_db_pool_manager().pool
|
||||
|
||||
async with pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# 查询 bot 的 settings 字段中的 dataset_ids
|
||||
await cursor.execute("""
|
||||
SELECT settings
|
||||
FROM agent_bots
|
||||
WHERE bot_id = %s
|
||||
""", (bot_id,))
|
||||
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
logger.warning(f"Bot not found: {bot_id}")
|
||||
return []
|
||||
|
||||
settings = row[0]
|
||||
|
||||
# dataset_ids 在 settings 中存储为逗号分隔的字符串
|
||||
dataset_ids_str = settings.get('dataset_ids') if settings else None
|
||||
|
||||
if not dataset_ids_str:
|
||||
return []
|
||||
|
||||
# 如果是字符串,按逗号分割
|
||||
if isinstance(dataset_ids_str, str):
|
||||
dataset_ids = [ds_id.strip() for ds_id in dataset_ids_str.split(',') if ds_id.strip()]
|
||||
elif isinstance(dataset_ids_str, list):
|
||||
dataset_ids = dataset_ids_str
|
||||
else:
|
||||
dataset_ids = []
|
||||
|
||||
logger.info(f"Found {len(dataset_ids)} datasets for bot {bot_id}")
|
||||
return dataset_ids
|
||||
85
test_knowledge_base.sh
Executable file
85
test_knowledge_base.sh
Executable file
@ -0,0 +1,85 @@
|
||||
#!/bin/bash
|
||||
# 知识库 API 测试脚本
|
||||
|
||||
API_BASE="http://localhost:8001"
|
||||
TOKEN="a21c99620a8ef61d69563afe05ccce89"
|
||||
DATASET_ID="3c3c671205c911f1a37efedd444ada7f"
|
||||
|
||||
echo "=========================================="
|
||||
echo "知识库 API 测试"
|
||||
echo "=========================================="
|
||||
|
||||
# 1. 获取数据集列表
|
||||
echo ""
|
||||
echo "1. 获取数据集列表"
|
||||
echo "GET /api/v1/knowledge-base/datasets"
|
||||
curl --silent --request GET \
|
||||
"$API_BASE/api/v1/knowledge-base/datasets" \
|
||||
--header "authorization: Bearer $TOKEN" \
|
||||
--header 'content-type: application/json' | python3 -m json.tool
|
||||
echo ""
|
||||
|
||||
# 2. 获取数据集详情
|
||||
echo "2. 获取数据集详情"
|
||||
echo "GET /api/v1/knowledge-base/datasets/{dataset_id}"
|
||||
curl --silent --request GET \
|
||||
"$API_BASE/api/v1/knowledge-base/datasets/$DATASET_ID" \
|
||||
--header "authorization: Bearer $TOKEN" | python3 -m json.tool
|
||||
echo ""
|
||||
|
||||
# 3. 创建数据集
|
||||
echo "3. 创建数据集"
|
||||
echo "POST /api/v1/knowledge-base/datasets"
|
||||
curl --silent --request POST \
|
||||
"$API_BASE/api/v1/knowledge-base/datasets" \
|
||||
--header "authorization: Bearer $TOKEN" \
|
||||
--header 'content-type: application/json' \
|
||||
--data '{
|
||||
"name": "API测试知识库",
|
||||
"description": "通过API创建的测试知识库",
|
||||
"chunk_method": "naive"
|
||||
}' | python3 -m json.tool
|
||||
echo ""
|
||||
|
||||
# 4. 获取文件列表
|
||||
echo "4. 获取文件列表"
|
||||
echo "GET /api/v1/knowledge-base/datasets/{dataset_id}/files"
|
||||
curl --silent --request GET \
|
||||
"$API_BASE/api/v1/knowledge-base/datasets/$DATASET_ID/files" \
|
||||
--header "authorization: Bearer $TOKEN" | python3 -m json.tool
|
||||
echo ""
|
||||
|
||||
# 5. 上传文件
|
||||
echo "5. 上传文件"
|
||||
echo "POST /api/v1/knowledge-base/datasets/{dataset_id}/files"
|
||||
echo "测试文档内容,用于文件上传测试。" > /tmp/test_doc.txt
|
||||
curl --silent --request POST \
|
||||
"$API_BASE/api/v1/knowledge-base/datasets/$DATASET_ID/files" \
|
||||
--header "authorization: Bearer $TOKEN" \
|
||||
-F "file=@/tmp/test_doc.txt" | python3 -m json.tool
|
||||
echo ""
|
||||
|
||||
# 6. 获取切片列表
|
||||
echo "6. 获取切片列表"
|
||||
echo "GET /api/v1/knowledge-base/datasets/{dataset_id}/chunks"
|
||||
curl --silent --request GET \
|
||||
"$API_BASE/api/v1/knowledge-base/datasets/$DATASET_ID/chunks" \
|
||||
--header "authorization: Bearer $TOKEN" | python3 -m json.tool
|
||||
echo ""
|
||||
|
||||
# 7. 更新数据集
|
||||
echo "7. 更新数据集"
|
||||
echo "PATCH /api/v1/knowledge-base/datasets/{dataset_id}"
|
||||
curl --silent --request PATCH \
|
||||
"$API_BASE/api/v1/knowledge-base/datasets/$DATASET_ID" \
|
||||
--header "authorization: Bearer $TOKEN" \
|
||||
--header 'content-type: application/json' \
|
||||
--data '{
|
||||
"name": "更新后的知识库名称",
|
||||
"description": "更新后的描述"
|
||||
}' | python3 -m json.tool
|
||||
echo ""
|
||||
|
||||
echo "=========================================="
|
||||
echo "测试完成"
|
||||
echo "=========================================="
|
||||
@ -49,8 +49,8 @@ MCP_SSE_READ_TIMEOUT = int(os.getenv("MCP_SSE_READ_TIMEOUT", 300)) # SSE 读取
|
||||
|
||||
# PostgreSQL 连接字符串
|
||||
# 格式: postgresql://user:password@host:port/database
|
||||
#CHECKPOINT_DB_URL = os.getenv("CHECKPOINT_DB_URL", "postgresql://postgres:AeEGDB0b7Z5GK0E2tblt@dev-circleo-pg.celp3nik7oaq.ap-northeast-1.rds.amazonaws.com:5432/gptbase")
|
||||
CHECKPOINT_DB_URL = os.getenv("CHECKPOINT_DB_URL", "postgresql://postgres:E5ACJo6zJub4QS@192.168.102.5:5432/agent_db")
|
||||
CHECKPOINT_DB_URL = os.getenv("CHECKPOINT_DB_URL", "postgresql://postgres:AeEGDB0b7Z5GK0E2tblt@dev-circleo-pg.celp3nik7oaq.ap-northeast-1.rds.amazonaws.com:5432/gptbase")
|
||||
#CHECKPOINT_DB_URL = os.getenv("CHECKPOINT_DB_URL", "postgresql://postgres:E5ACJo6zJub4QS@192.168.102.5:5432/agent_db")
|
||||
|
||||
# 连接池大小
|
||||
# 同时可以持有的最大连接数
|
||||
@ -81,3 +81,19 @@ MEM0_ENABLED = os.getenv("MEM0_ENABLED", "true") == "true"
|
||||
MEM0_SEMANTIC_SEARCH_TOP_K = int(os.getenv("MEM0_SEMANTIC_SEARCH_TOP_K", "20"))
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "your_api_key"
|
||||
|
||||
# ============================================================
|
||||
# RAGFlow Knowledge Base Configuration
|
||||
# ============================================================
|
||||
|
||||
# RAGFlow API 配置
|
||||
RAGFLOW_API_URL = os.getenv("RAGFLOW_API_URL", "http://100.77.70.35:1080")
|
||||
RAGFLOW_API_KEY = os.getenv("RAGFLOW_API_KEY", "ragflow-MRqxnDnYZ1yp5kklDMIlKH4f1qezvXIngSMGPhu1AG8")
|
||||
|
||||
# 文件上传配置
|
||||
RAGFLOW_MAX_UPLOAD_SIZE = int(os.getenv("RAGFLOW_MAX_UPLOAD_SIZE", str(100 * 1024 * 1024))) # 100MB
|
||||
RAGFLOW_ALLOWED_EXTENSIONS = os.getenv("RAGFLOW_ALLOWED_EXTENSIONS", "pdf,docx,txt,md,csv").split(",")
|
||||
|
||||
# 性能配置
|
||||
RAGFLOW_CONNECTION_TIMEOUT = int(os.getenv("RAGFLOW_CONNECTION_TIMEOUT", "30")) # 30秒
|
||||
RAGFLOW_MAX_CONCURRENT_UPLOADS = int(os.getenv("RAGFLOW_MAX_CONCURRENT_UPLOADS", "5"))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user