add rag flow

This commit is contained in:
朱潮 2026-02-10 18:59:10 +08:00
parent 742eaf0a1c
commit bd39a53507
15 changed files with 3499 additions and 17 deletions

View File

@ -120,6 +120,21 @@ CREATE INDEX IF NOT EXISTS idx_bot_shares_bot_id ON bot_shares(bot_id);
CREATE INDEX IF NOT EXISTS idx_bot_shares_user_id ON bot_shares(user_id);
CREATE INDEX IF NOT EXISTS idx_bot_shares_shared_by ON bot_shares(shared_by);
-- 9. 创建 user_datasets 表(用户与 RAGFlow 数据集的关联表)
CREATE TABLE IF NOT EXISTS user_datasets (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES agent_user(id) ON DELETE CASCADE,
dataset_id VARCHAR(255) NOT NULL, -- RAGFlow 返回的 dataset_id
dataset_name VARCHAR(255), -- 冗余存储数据集名称,方便查询
owner BOOLEAN DEFAULT TRUE, -- 是否为所有者(预留分享功能)
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
UNIQUE(user_id, dataset_id)
);
-- user_datasets 索引
CREATE INDEX IF NOT EXISTS idx_user_datasets_user_id ON user_datasets(user_id);
CREATE INDEX IF NOT EXISTS idx_user_datasets_dataset_id ON user_datasets(dataset_id);
-- ===========================
-- 默认 Admin 账号
-- 用户名: admin

View File

@ -71,7 +71,7 @@ from utils.log_util.logger import init_with_fastapi
logger = logging.getLogger('app')
# Import route modules
from routes import chat, files, projects, system, skill_manager, database, bot_manager
from routes import chat, files, projects, system, skill_manager, database, bot_manager, knowledge_base
@asynccontextmanager
@ -125,7 +125,14 @@ async def lifespan(app: FastAPI):
except Exception as e:
logger.warning(f"Bot Manager table initialization failed (non-fatal): {e}")
# 6. 启动 checkpoint 清理调度器
# 6. 初始化 Knowledge Base 表
try:
await knowledge_base.init_knowledge_base_tables()
logger.info("Knowledge Base tables initialized")
except Exception as e:
logger.warning(f"Knowledge Base table initialization failed (non-fatal): {e}")
# 7. 启动 checkpoint 清理调度器
if CHECKPOINT_CLEANUP_ENABLED:
# 启动时立即执行一次清理
try:
@ -187,6 +194,9 @@ app.include_router(bot_manager.router)
# 注册文件管理API路由
app.include_router(file_manager_router)
# 注册知识库API路由
app.include_router(knowledge_base.router, prefix="/api/v1/knowledge-base", tags=["knowledge-base"])
if __name__ == "__main__":
# 启动 FastAPI 应用

View File

@ -2,12 +2,12 @@
{
"mcpServers": {
"rag_retrieve": {
"transport": "stdio",
"command": "python",
"args": [
"./mcp/rag_retrieve_server.py",
"{bot_id}"
]
"transport": "http",
"url": "http://100.77.70.35:9382/mcp/",
"headers": {
"api_key": "ragflow-MRqxnDnYZ1yp5kklDMIlKH4f1qezvXIngSMGPhu1AG8",
"X-Dataset-Ids": "{dataset_ids}"
}
}
}
}

874
mcp/rag_flow_server.py Normal file
View File

@ -0,0 +1,874 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import random
import time
from collections import OrderedDict
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from functools import wraps
from typing import Any
import click
import httpx
import mcp.types as types
from mcp.server.lowlevel import Server
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.responses import JSONResponse, Response
from starlette.routing import Mount, Route
from strenum import StrEnum
class LaunchMode(StrEnum):
SELF_HOST = "self-host"
HOST = "host"
class Transport(StrEnum):
SSE = "sse"
STEAMABLE_HTTP = "streamable-http"
BASE_URL = "http://127.0.0.1:9380"
HOST = "127.0.0.1"
PORT = "9382"
HOST_API_KEY = ""
MODE = ""
TRANSPORT_SSE_ENABLED = True
TRANSPORT_STREAMABLE_HTTP_ENABLED = True
JSON_RESPONSE = True
class RAGFlowConnector:
_MAX_DATASET_CACHE = 32
_CACHE_TTL = 300
_dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts)
_document_metadata_cache: OrderedDict[str, tuple[list[tuple[str, dict]], float | int]] = OrderedDict() # "dataset_id" -> ([(document_id, doc_metadata)], expiry_ts)
def __init__(self, base_url: str, version="v1"):
self.base_url = base_url
self.version = version
self.api_url = f"{self.base_url}/api/{self.version}"
self._async_client = None
async def _get_client(self):
if self._async_client is None:
self._async_client = httpx.AsyncClient(timeout=httpx.Timeout(60.0))
return self._async_client
async def close(self):
if self._async_client is not None:
await self._async_client.aclose()
self._async_client = None
async def _post(self, path, json=None, stream=False, files=None, api_key: str = ""):
if not api_key:
return None
client = await self._get_client()
res = await client.post(url=self.api_url + path, json=json, headers={"Authorization": f"Bearer {api_key}"})
return res
async def _get(self, path, params=None, api_key: str = ""):
if not api_key:
return None
client = await self._get_client()
res = await client.get(url=self.api_url + path, params=params, headers={"Authorization": f"Bearer {api_key}"})
return res
def _is_cache_valid(self, ts):
return time.time() < ts
def _get_expiry_timestamp(self):
offset = random.randint(-30, 30)
return time.time() + self._CACHE_TTL + offset
def _get_cached_dataset_metadata(self, dataset_id):
entry = self._dataset_metadata_cache.get(dataset_id)
if entry:
data, ts = entry
if self._is_cache_valid(ts):
self._dataset_metadata_cache.move_to_end(dataset_id)
return data
return None
def _set_cached_dataset_metadata(self, dataset_id, metadata):
self._dataset_metadata_cache[dataset_id] = (metadata, self._get_expiry_timestamp())
self._dataset_metadata_cache.move_to_end(dataset_id)
if len(self._dataset_metadata_cache) > self._MAX_DATASET_CACHE:
self._dataset_metadata_cache.popitem(last=False)
def _get_cached_document_metadata_by_dataset(self, dataset_id):
entry = self._document_metadata_cache.get(dataset_id)
if entry:
data_list, ts = entry
if self._is_cache_valid(ts):
self._document_metadata_cache.move_to_end(dataset_id)
return {doc_id: doc_meta for doc_id, doc_meta in data_list}
return None
def _set_cached_document_metadata_by_dataset(self, dataset_id, doc_id_meta_list):
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
self._document_metadata_cache.move_to_end(dataset_id)
async def list_datasets(
self,
*,
api_key: str,
page: int = 1,
page_size: int = 1000,
orderby: str = "create_time",
desc: bool = True,
id: str | None = None,
name: str | None = None,
):
res = await self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}, api_key=api_key)
if not res or res.status_code != 200:
raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])
res = res.json()
if res.get("code") == 0:
result_list = []
for data in res["data"]:
d = {"description": data["description"], "id": data["id"]}
result_list.append(json.dumps(d, ensure_ascii=False))
return "\n".join(result_list)
return ""
async def retrieval(
self,
*,
api_key: str,
dataset_ids,
document_ids=None,
question="",
page=1,
page_size=30,
similarity_threshold=0.2,
vector_similarity_weight=0.3,
top_k=1024,
rerank_id: str | None = None,
keyword: bool = False,
force_refresh: bool = False,
):
if document_ids is None:
document_ids = []
# If no dataset_ids provided or empty list, get all available dataset IDs
if not dataset_ids:
dataset_list_str = await self.list_datasets(api_key=api_key)
dataset_ids = []
# Parse the dataset list to extract IDs
if dataset_list_str:
for line in dataset_list_str.strip().split("\n"):
if line.strip():
try:
dataset_info = json.loads(line.strip())
dataset_ids.append(dataset_info["id"])
except (json.JSONDecodeError, KeyError):
# Skip malformed lines
continue
data_json = {
"page": page,
"page_size": page_size,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight,
"top_k": top_k,
"rerank_id": rerank_id,
"keyword": keyword,
"question": question,
"dataset_ids": dataset_ids,
"document_ids": document_ids,
}
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
res = await self._post("/retrieval", json=data_json, api_key=api_key)
if not res or res.status_code != 200:
raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])
res = res.json()
if res.get("code") == 0:
data = res["data"]
chunks = []
# Cache document metadata and dataset information
document_cache, dataset_cache = await self._get_document_metadata_cache(dataset_ids, api_key=api_key, force_refresh=force_refresh)
# Process chunks with enhanced field mapping including per-chunk metadata
for chunk_data in data.get("chunks", []):
enhanced_chunk = self._map_chunk_fields(chunk_data, dataset_cache, document_cache)
chunks.append(enhanced_chunk)
# Build structured response (no longer need response-level document_metadata)
response = {
"chunks": chunks,
"pagination": {
"page": data.get("page", page),
"page_size": data.get("page_size", page_size),
"total_chunks": data.get("total", len(chunks)),
"total_pages": (data.get("total", len(chunks)) + page_size - 1) // page_size,
},
"query_info": {
"question": question,
"similarity_threshold": similarity_threshold,
"vector_weight": vector_similarity_weight,
"keyword_search": keyword,
"dataset_count": len(dataset_ids),
},
}
return [types.TextContent(type="text", text=json.dumps(response, ensure_ascii=False))]
raise Exception([types.TextContent(type="text", text=res.get("message"))])
async def _get_document_metadata_cache(self, dataset_ids, *, api_key: str, force_refresh=False):
"""Cache document metadata for all documents in the specified datasets"""
document_cache = {}
dataset_cache = {}
try:
for dataset_id in dataset_ids:
dataset_meta = None if force_refresh else self._get_cached_dataset_metadata(dataset_id)
if not dataset_meta:
# First get dataset info for name
dataset_res = await self._get("/datasets", {"id": dataset_id, "page_size": 1}, api_key=api_key)
if dataset_res and dataset_res.status_code == 200:
dataset_data = dataset_res.json()
if dataset_data.get("code") == 0 and dataset_data.get("data"):
dataset_info = dataset_data["data"][0]
dataset_meta = {"name": dataset_info.get("name", "Unknown"), "description": dataset_info.get("description", "")}
self._set_cached_dataset_metadata(dataset_id, dataset_meta)
if dataset_meta:
dataset_cache[dataset_id] = dataset_meta
docs = None if force_refresh else self._get_cached_document_metadata_by_dataset(dataset_id)
if docs is None:
page = 1
page_size = 30
doc_id_meta_list = []
docs = {}
while page:
docs_res = await self._get(f"/datasets/{dataset_id}/documents?page={page}", api_key=api_key)
if not docs_res:
break
docs_data = docs_res.json()
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
for doc in docs_data["data"]["docs"]:
doc_id = doc.get("id")
if not doc_id:
continue
doc_meta = {
"document_id": doc_id,
"name": doc.get("name", ""),
"location": doc.get("location", ""),
"type": doc.get("type", ""),
"size": doc.get("size"),
"chunk_count": doc.get("chunk_count"),
"create_date": doc.get("create_date", ""),
"update_date": doc.get("update_date", ""),
"token_count": doc.get("token_count"),
"thumbnail": doc.get("thumbnail", ""),
"dataset_id": doc.get("dataset_id", dataset_id),
"meta_fields": doc.get("meta_fields", {}),
}
doc_id_meta_list.append((doc_id, doc_meta))
docs[doc_id] = doc_meta
page += 1
if docs_data.get("data", {}).get("total", 0) - page * page_size <= 0:
page = None
self._set_cached_document_metadata_by_dataset(dataset_id, doc_id_meta_list)
if docs:
document_cache.update(docs)
except Exception as e:
# Gracefully handle metadata cache failures
logging.error(f"Problem building the document metadata cache: {str(e)}")
pass
return document_cache, dataset_cache
def _map_chunk_fields(self, chunk_data, dataset_cache, document_cache):
"""Preserve all original API fields and add per-chunk document metadata"""
# Start with ALL raw data from API (preserve everything like original version)
mapped = dict(chunk_data)
# Add dataset name enhancement
dataset_id = chunk_data.get("dataset_id") or chunk_data.get("kb_id")
if dataset_id and dataset_id in dataset_cache:
mapped["dataset_name"] = dataset_cache[dataset_id]["name"]
else:
mapped["dataset_name"] = "Unknown"
# Add document name convenience field
mapped["document_name"] = chunk_data.get("document_keyword", "")
# Add per-chunk document metadata
document_id = chunk_data.get("document_id")
if document_id and document_id in document_cache:
mapped["document_metadata"] = document_cache[document_id]
return mapped
class RAGFlowCtx:
def __init__(self, connector: RAGFlowConnector):
self.conn = connector
@asynccontextmanager
async def sse_lifespan(server: Server) -> AsyncIterator[dict]:
ctx = RAGFlowCtx(RAGFlowConnector(base_url=BASE_URL))
logging.info("Legacy SSE application started with StreamableHTTP session manager!")
try:
yield {"ragflow_ctx": ctx}
finally:
await ctx.conn.close()
logging.info("Legacy SSE application shutting down...")
app = Server("ragflow-mcp-server", lifespan=sse_lifespan)
AUTH_TOKEN_STATE_KEY = "ragflow_auth_token"
def _to_text(value: Any) -> str:
if isinstance(value, bytes):
return value.decode(errors="ignore")
return str(value)
def _extract_token_from_headers(headers: Any) -> str | None:
if not headers or not hasattr(headers, "get"):
return None
auth_keys = ("authorization", "Authorization", b"authorization", b"Authorization")
for key in auth_keys:
auth = headers.get(key)
if not auth:
continue
auth_text = _to_text(auth).strip()
if auth_text.lower().startswith("bearer "):
token = auth_text[7:].strip()
if token:
return token
api_key_keys = ("api_key", "x-api-key", "Api-Key", "X-API-Key", b"api_key", b"x-api-key", b"Api-Key", b"X-API-Key")
for key in api_key_keys:
token = headers.get(key)
if token:
token_text = _to_text(token).strip()
if token_text:
return token_text
return None
def _extract_dataset_ids_from_headers(headers: Any) -> list[str]:
"""Extract dataset_ids from request headers.
Supports multiple header formats:
- Comma-separated string: "dataset1,dataset2,dataset3"
- JSON array string: '["dataset1","dataset2","dataset3"]'
- Repeated headers (x-dataset-id)
Returns:
List of dataset IDs. Empty list if none found or invalid format.
"""
if not headers or not hasattr(headers, "get"):
return []
# Try various header key variations
header_keys = (
"x-dataset-ids", "X-Dataset-Ids", "X-DATASET-IDS",
"dataset_ids", "Dataset-Ids", "DATASET_IDS",
"x-datasets", "X-Datasets", "X-DATASETS",
b"x-dataset-ids", b"X-Dataset-Ids", b"X-DATASET-IDS",
b"dataset_ids", b"Dataset-Ids", b"DATASET_IDS",
b"x-datasets", b"X-Datasets", b"X-DATASETS",
)
for key in header_keys:
value = headers.get(key)
if not value:
continue
value_text = _to_text(value).strip()
if not value_text:
continue
# Try parsing as JSON array first
if value_text.startswith("["):
try:
dataset_ids = json.loads(value_text)
if isinstance(dataset_ids, list):
return [str(ds_id).strip() for ds_id in dataset_ids if str(ds_id).strip()]
except json.JSONDecodeError:
pass
# Try parsing as comma-separated string
dataset_ids = [ds_id.strip() for ds_id in value_text.split(",") if ds_id.strip()]
if dataset_ids:
return dataset_ids
# Try repeated header format (x-dataset-id)
single_header_keys = (
"x-dataset-id", "X-Dataset-Id", "X-DATASET-ID",
"dataset_id", "Dataset-Id", "DATASET_ID",
b"x-dataset-id", b"X-Dataset-Id", b"X-DATASET-ID",
b"dataset_id", b"Dataset-Id", b"DATASET_ID",
)
dataset_ids = []
for key in single_header_keys:
value = headers.get(key)
if value:
value_text = _to_text(value).strip()
if value_text:
dataset_ids.append(value_text)
return dataset_ids
def _extract_token_from_request(request: Any) -> str | None:
if request is None:
return None
state = getattr(request, "state", None)
if state is not None:
token = getattr(state, AUTH_TOKEN_STATE_KEY, None)
if token:
return token
token = _extract_token_from_headers(getattr(request, "headers", None))
if token and state is not None:
setattr(state, AUTH_TOKEN_STATE_KEY, token)
return token
def _extract_dataset_ids_from_request(request: Any) -> list[str]:
"""Extract dataset_ids from a request object.
First checks state for cached dataset_ids, then extracts from headers.
Returns:
List of dataset IDs. Empty list if none found.
"""
if request is None:
return []
state = getattr(request, "state", None)
if state is not None:
dataset_ids = getattr(state, "ragflow_dataset_ids", None)
if dataset_ids:
return dataset_ids
headers = getattr(request, "headers", None)
dataset_ids = _extract_dataset_ids_from_headers(headers)
if dataset_ids and state is not None:
setattr(state, "ragflow_dataset_ids", dataset_ids)
return dataset_ids
def with_api_key(required: bool = True):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
ctx = app.request_context
ragflow_ctx = ctx.lifespan_context.get("ragflow_ctx")
if not ragflow_ctx:
raise ValueError("Get RAGFlow Context failed")
connector = ragflow_ctx.conn
api_key = HOST_API_KEY
request = getattr(ctx, "request", None)
if MODE == LaunchMode.HOST:
api_key = _extract_token_from_request(request) or ""
if required and not api_key:
raise ValueError("RAGFlow API key or Bearer token is required.")
return await func(*args, connector=connector, api_key=api_key, request=request, **kwargs)
return wrapper
return decorator
@app.list_tools()
@with_api_key(required=True)
async def list_tools(*, connector: RAGFlowConnector, api_key: str) -> list[types.Tool]:
dataset_description = await connector.list_datasets(api_key=api_key)
return [
types.Tool(
name="ragflow_retrieval",
description="Retrieve relevant chunks from the RAGFlow retrieve interface based on the question. You can optionally specify dataset_ids to search only specific datasets, or omit dataset_ids entirely to search across ALL available datasets. You can also optionally specify document_ids to search within specific documents. When dataset_ids is not provided or is empty, the system will automatically search across all available datasets. Below is the list of all available datasets, including their descriptions and IDs:"
+ dataset_description,
inputSchema={
"type": "object",
"properties": {
"dataset_ids": {"type": "array", "items": {"type": "string"}, "description": "Optional array of dataset IDs to search. If not provided or empty, all datasets will be searched."},
"document_ids": {"type": "array", "items": {"type": "string"}, "description": "Optional array of document IDs to search within."},
"question": {"type": "string", "description": "The question or query to search for."},
"page": {
"type": "integer",
"description": "Page number for pagination",
"default": 1,
"minimum": 1,
},
"page_size": {
"type": "integer",
"description": "Number of results to return per page (default: 10, max recommended: 50 to avoid token limits)",
"default": 10,
"minimum": 1,
"maximum": 100,
},
"similarity_threshold": {
"type": "number",
"description": "Minimum similarity threshold for results",
"default": 0.2,
"minimum": 0.0,
"maximum": 1.0,
},
"vector_similarity_weight": {
"type": "number",
"description": "Weight for vector similarity vs term similarity",
"default": 0.3,
"minimum": 0.0,
"maximum": 1.0,
},
"keyword": {
"type": "boolean",
"description": "Enable keyword-based search",
"default": False,
},
"top_k": {
"type": "integer",
"description": "Maximum results to consider before ranking",
"default": 1024,
"minimum": 1,
"maximum": 1024,
},
"rerank_id": {
"type": "string",
"description": "Optional reranking model identifier",
},
"force_refresh": {
"type": "boolean",
"description": "Set to true only if fresh dataset and document metadata is explicitly required. Otherwise, cached metadata is used (default: false).",
"default": False,
},
},
"required": ["question"],
},
),
]
@app.call_tool()
@with_api_key(required=True)
async def call_tool(
name: str,
arguments: dict,
*,
connector: RAGFlowConnector,
api_key: str,
request: Any = None,
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
if name == "ragflow_retrieval":
document_ids = arguments.get("document_ids", [])
dataset_ids = arguments.get("dataset_ids", [])
question = arguments.get("question", "")
page = arguments.get("page", 1)
page_size = arguments.get("page_size", 10)
similarity_threshold = arguments.get("similarity_threshold", 0.2)
vector_similarity_weight = arguments.get("vector_similarity_weight", 0.3)
keyword = arguments.get("keyword", False)
top_k = arguments.get("top_k", 1024)
rerank_id = arguments.get("rerank_id")
force_refresh = arguments.get("force_refresh", False)
# If no dataset_ids provided or empty list, try to extract from request headers
if not dataset_ids:
dataset_ids = _extract_dataset_ids_from_request(request)
# If still no dataset_ids, get all available dataset IDs
if not dataset_ids:
dataset_list_str = await connector.list_datasets(api_key=api_key)
dataset_ids = []
# Parse the dataset list to extract IDs
if dataset_list_str:
for line in dataset_list_str.strip().split("\n"):
if line.strip():
try:
dataset_info = json.loads(line.strip())
dataset_ids.append(dataset_info["id"])
except (json.JSONDecodeError, KeyError):
# Skip malformed lines
continue
return await connector.retrieval(
api_key=api_key,
dataset_ids=dataset_ids,
document_ids=document_ids,
question=question,
page=page,
page_size=page_size,
similarity_threshold=similarity_threshold,
vector_similarity_weight=vector_similarity_weight,
keyword=keyword,
top_k=top_k,
rerank_id=rerank_id,
force_refresh=force_refresh,
)
raise ValueError(f"Tool not found: {name}")
def create_starlette_app():
routes = []
middleware = None
if MODE == LaunchMode.HOST:
from starlette.types import ASGIApp, Receive, Scope, Send
class AuthMiddleware:
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
path = scope["path"]
if path.startswith("/messages/") or path.startswith("/sse") or path.startswith("/mcp"):
headers = dict(scope["headers"])
token = _extract_token_from_headers(headers)
if not token:
response = JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401)
await response(scope, receive, send)
return
scope.setdefault("state", {})[AUTH_TOKEN_STATE_KEY] = token
# Extract and cache dataset_ids from headers
dataset_ids = _extract_dataset_ids_from_headers(headers)
if dataset_ids:
scope.setdefault("state", {})["ragflow_dataset_ids"] = dataset_ids
await self.app(scope, receive, send)
middleware = [Middleware(AuthMiddleware)]
# Add SSE routes if enabled
if TRANSPORT_SSE_ENABLED:
from mcp.server.sse import SseServerTransport
sse = SseServerTransport("/messages/")
async def handle_sse(request):
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await app.run(streams[0], streams[1], app.create_initialization_options(experimental_capabilities={"headers": dict(request.headers)}))
return Response()
routes.extend(
[
Route("/sse", endpoint=handle_sse, methods=["GET"]),
Mount("/messages/", app=sse.handle_post_message),
]
)
# Add streamable HTTP route if enabled
streamablehttp_lifespan = None
if TRANSPORT_STREAMABLE_HTTP_ENABLED:
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.types import Receive, Scope, Send
session_manager = StreamableHTTPSessionManager(
app=app,
event_store=None,
json_response=JSON_RESPONSE,
stateless=True,
)
class StreamableHTTPEntry:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await session_manager.handle_request(scope, receive, send)
streamable_http_entry = StreamableHTTPEntry()
@asynccontextmanager
async def streamablehttp_lifespan(app: Starlette) -> AsyncIterator[None]:
async with session_manager.run():
logging.info("StreamableHTTP application started with StreamableHTTP session manager!")
try:
yield
finally:
logging.info("StreamableHTTP application shutting down...")
routes.extend(
[
Route("/mcp", endpoint=streamable_http_entry, methods=["GET", "POST", "DELETE"]),
Mount("/mcp", app=streamable_http_entry),
]
)
return Starlette(
debug=True,
routes=routes,
middleware=middleware,
lifespan=streamablehttp_lifespan,
)
@click.command()
@click.option("--base-url", type=str, default="http://127.0.0.1:9380", help="API base URL for RAGFlow backend")
@click.option("--host", type=str, default="127.0.0.1", help="Host to bind the RAGFlow MCP server")
@click.option("--port", type=int, default=9382, help="Port to bind the RAGFlow MCP server")
@click.option(
"--mode",
type=click.Choice(["self-host", "host"]),
default="self-host",
help=("Launch mode:\n self-host: run MCP for a single tenant (requires --api-key)\n host: multi-tenant mode, users must provide Authorization headers"),
)
@click.option("--api-key", type=str, default="", help="API key to use when in self-host mode")
@click.option(
"--transport-sse-enabled/--no-transport-sse-enabled",
default=True,
help="Enable or disable legacy SSE transport mode (default: enabled)",
)
@click.option(
"--transport-streamable-http-enabled/--no-transport-streamable-http-enabled",
default=True,
help="Enable or disable streamable-http transport mode (default: enabled)",
)
@click.option(
"--json-response/--no-json-response",
default=True,
help="Enable or disable JSON response mode for streamable-http (default: enabled)",
)
def main(base_url, host, port, mode, api_key, transport_sse_enabled, transport_streamable_http_enabled, json_response):
import os
import uvicorn
from dotenv import load_dotenv
load_dotenv()
def parse_bool_flag(key: str, default: bool) -> bool:
val = os.environ.get(key, str(default))
return str(val).strip().lower() in ("1", "true", "yes", "on")
global BASE_URL, HOST, PORT, MODE, HOST_API_KEY, TRANSPORT_SSE_ENABLED, TRANSPORT_STREAMABLE_HTTP_ENABLED, JSON_RESPONSE
BASE_URL = os.environ.get("RAGFLOW_MCP_BASE_URL", base_url)
HOST = os.environ.get("RAGFLOW_MCP_HOST", host)
PORT = os.environ.get("RAGFLOW_MCP_PORT", str(port))
MODE = os.environ.get("RAGFLOW_MCP_LAUNCH_MODE", mode)
HOST_API_KEY = os.environ.get("RAGFLOW_MCP_HOST_API_KEY", api_key)
TRANSPORT_SSE_ENABLED = parse_bool_flag("RAGFLOW_MCP_TRANSPORT_SSE_ENABLED", transport_sse_enabled)
TRANSPORT_STREAMABLE_HTTP_ENABLED = parse_bool_flag("RAGFLOW_MCP_TRANSPORT_STREAMABLE_ENABLED", transport_streamable_http_enabled)
JSON_RESPONSE = parse_bool_flag("RAGFLOW_MCP_JSON_RESPONSE", json_response)
if MODE == LaunchMode.SELF_HOST and not HOST_API_KEY:
raise click.UsageError("--api-key is required when --mode is 'self-host'")
if not TRANSPORT_STREAMABLE_HTTP_ENABLED and JSON_RESPONSE:
JSON_RESPONSE = False
print(
r"""
__ __ ____ ____ ____ _____ ______ _______ ____
| \/ |/ ___| _ \ / ___|| ____| _ \ \ / / ____| _ \
| |\/| | | | |_) | \___ \| _| | |_) \ \ / /| _| | |_) |
| | | | |___| __/ ___) | |___| _ < \ V / | |___| _ <
|_| |_|\____|_| |____/|_____|_| \_\ \_/ |_____|_| \_\
""",
flush=True,
)
print(f"MCP launch mode: {MODE}", flush=True)
print(f"MCP host: {HOST}", flush=True)
print(f"MCP port: {PORT}", flush=True)
print(f"MCP base_url: {BASE_URL}", flush=True)
if not any([TRANSPORT_SSE_ENABLED, TRANSPORT_STREAMABLE_HTTP_ENABLED]):
print("At least one transport should be enabled, enable streamable-http automatically", flush=True)
TRANSPORT_STREAMABLE_HTTP_ENABLED = True
if TRANSPORT_SSE_ENABLED:
print("SSE transport enabled: yes", flush=True)
print("SSE endpoint available at /sse", flush=True)
else:
print("SSE transport enabled: no", flush=True)
if TRANSPORT_STREAMABLE_HTTP_ENABLED:
print("Streamable HTTP transport enabled: yes", flush=True)
print("Streamable HTTP endpoint available at /mcp", flush=True)
if JSON_RESPONSE:
print("Streamable HTTP mode: JSON response enabled", flush=True)
else:
print("Streamable HTTP mode: SSE over HTTP enabled", flush=True)
else:
print("Streamable HTTP transport enabled: no", flush=True)
if JSON_RESPONSE:
print("Warning: --json-response ignored because streamable transport is disabled.", flush=True)
uvicorn.run(
create_starlette_app(),
host=HOST,
port=int(PORT),
)
if __name__ == "__main__":
"""
Launch examples:
1. Self-host mode with both SSE and Streamable HTTP (in JSON response mode) enabled (default):
uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
--base-url=http://127.0.0.1:9380 \
--mode=self-host --api-key=ragflow-xxxxx
2. Host mode (multi-tenant, clients must provide Authorization headers):
uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
--base-url=http://127.0.0.1:9380 \
--mode=host
3. Disable legacy SSE (only streamable HTTP will be active):
uv run mcp/server/server.py --no-transport-sse-enabled \
--mode=self-host --api-key=ragflow-xxxxx
4. Disable streamable HTTP (only legacy SSE will be active):
uv run mcp/server/server.py --no-transport-streamable-http-enabled \
--mode=self-host --api-key=ragflow-xxxxx
5. Use streamable HTTP with SSE-style events (disable JSON response):
uv run mcp/server/server.py --transport-streamable-http-enabled --no-json-response \
--mode=self-host --api-key=ragflow-xxxxx
6. Disable both transports (for testing):
uv run mcp/server/server.py --no-transport-sse-enabled --no-transport-streamable-http-enabled \
--mode=self-host --api-key=ragflow-xxxxx
"""
main()

View File

@ -0,0 +1,869 @@
# 知识库模块功能实现计划
> **Enhanced on:** 2025-02-10
> **Sections enhanced:** 10
> **Research agents used:** FastAPI best practices, Vue 3 composables, UI/UX patterns, RAGFlow SDK, File upload security, Architecture strategy, Code simplicity, Security sentinel, Performance oracle
---
## Enhancement Summary
### Key Improvements
1. **安全加固** - 添加文件类型验证、大小限制、API Key 管理
2. **性能优化** - 流式文件上传、分页查询、连接池管理
3. **架构分层** - 引入服务层和仓储模式,提高可测试性
4. **代码简化** - 移除过度设计,遵循 YAGNI 原则
5. **用户体验** - 完善空状态、加载状态、错误处理
### New Considerations Discovered
- RAGFlow 部署使用 HTTP非 HTTPS需要评估安全风险
- 文件上传必须实现流式处理,避免内存溢出
- 切片查询必须分页,否则大数据量会 OOM
- API Key 应通过环境变量管理,不应硬编码
---
## 概述
在 qwen-client 项目上增加一个独立的知识库模块功能(与 bot 无关联),通过 RAGFlow SDK 实现知识库管理功能。
**架构设计:**
```
qwen-client (Vue 3) → qwen-agent (FastAPI) → RAGFlow (http://100.77.70.35:1080)
```
## 需求背景
用户需要一个独立的知识库管理系统,可以:
1. 创建和管理多个知识库(数据集)
2. 向知识库上传文件
3. 管理知识库内的文档切片
4. 后续可与 bot 关联进行 RAG 检索
---
## 技术方案
### 后端实现 (qwen-agent)
#### 1. 环境配置
**文件:** `/utils/settings.py`
```python
# ============================================================
# RAGFlow Knowledge Base Configuration
# ============================================================
# RAGFlow API 配置
RAGFLOW_API_URL = os.getenv("RAGFLOW_API_URL", "http://100.77.70.35:1080")
RAGFLOW_API_KEY = os.getenv("RAGFLOW_API_KEY", "") # 必须通过环境变量设置
# 文件上传配置
RAGFLOW_MAX_UPLOAD_SIZE = int(os.getenv("RAGFLOW_MAX_UPLOAD_SIZE", str(100 * 1024 * 1024))) # 100MB
RAGFLOW_ALLOWED_EXTENSIONS = os.getenv("RAGFLOW_ALLOWED_EXTENSIONS", "pdf,docx,txt,md,csv").split(",")
# 性能配置
RAGFLOW_CONNECTION_TIMEOUT = int(os.getenv("RAGFLOW_CONNECTION_TIMEOUT", "30")) # 30秒
RAGFLOW_MAX_CONCURRENT_UPLOADS = int(os.getenv("RAGFLOW_MAX_CONCURRENT_UPLOADS", "5"))
```
#### 2. 依赖安装
**文件:** `/pyproject.toml`
`[tool.poetry.dependencies]` 添加:
```toml
ragflow-sdk = "^0.1.0"
python-magic = "^0.4.27"
aiofiles = "^24.1.0"
```
执行:
```bash
poetry install
poetry export -f requirements.txt -o requirements.txt --without-hashes
```
#### 3. 项目结构
基于架构审查建议,采用分层设计:
```
qwen-agent/
├── routes/
│ └── knowledge_base.py # API 路由层
├── services/
│ └── knowledge_base_service.py # 业务逻辑层(新增)
├── repositories/
│ ├── __init__.py
│ └── ragflow_repository.py # RAGFlow 适配器(新增)
└── utils/
├── settings.py # 配置管理
└── file_validator.py # 文件验证工具(新增)
```
#### 4. API 路由设计
**文件:** `/routes/knowledge_base.py`
**路由前缀:** `/api/v1/knowledge-base`
| 端点 | 方法 | 功能 | 认证 | 优化 |
|------|------|------|------|------|
| `/datasets` | GET | 获取所有数据集列表(分页) | Admin Token | 缓存 |
| `/datasets` | POST | 创建新数据集 | Admin Token | - |
| `/datasets/{dataset_id}` | GET | 获取数据集详情 | Admin Token | 缓存 |
| `/datasets/{dataset_id}` | PATCH | 更新数据集(部分更新) | Admin Token | - |
| `/datasets/{dataset_id}` | DELETE | 删除数据集 | Admin Token | - |
| `/datasets/{dataset_id}/files` | GET | 获取数据集内文件列表(分页) | Admin Token | 缓存 |
| `/datasets/{dataset_id}/files` | POST | 上传文件到数据集(流式) | Admin Token | 限流 |
| `/datasets/{dataset_id}/files/{document_id}` | DELETE | 删除文件 | Admin Token | - |
| `/datasets/{dataset_id}/chunks` | GET | 获取数据集内切片列表(分页) | Admin Token | 游标分页 |
| `/datasets/{dataset_id}/chunks/{chunk_id}` | DELETE | 删除切片 | Admin Token | - |
**代码结构:**
```python
"""
Knowledge Base API 路由
通过 RAGFlow SDK 提供知识库管理功能
"""
import logging
import os
from typing import Optional, List
from fastapi import APIRouter, HTTPException, Header, UploadFile, File, Query, BackgroundTasks
from pydantic import BaseModel, Field
from pathlib import Path
from utils.settings import RAGFLOW_API_URL, RAGFLOW_API_KEY
from utils.fastapi_utils import extract_api_key_from_auth
from repositories.ragflow_repository import RAGFlowRepository
from services.knowledge_base_service import KnowledgeBaseService
logger = logging.getLogger('app')
router = APIRouter()
# ============== 依赖注入 ==============
async def get_kb_service() -> KnowledgeBaseService:
"""获取知识库服务实例"""
return KnowledgeBaseService(RAGFlowRepository())
async def verify_admin(authorization: Optional[str] = Header(None)):
"""验证管理员权限(复用现有认证)"""
from routes.bot_manager import verify_admin_auth
valid, username = await verify_admin_auth(authorization)
if not valid:
raise HTTPException(status_code=401, detail="Unauthorized")
return username
# ============== Pydantic Models ==============
class DatasetCreate(BaseModel):
"""创建数据集请求"""
name: str = Field(..., min_length=1, max_length=128, description="数据集名称")
description: Optional[str] = Field(None, max_length=500, description="描述信息")
chunk_method: str = Field(default="naive", description="分块方法")
# RAGFlow 支持的分块方法: naive, manual, qa, table, paper, book, laws, presentation, picture, one, email, knowledge-graph
class DatasetUpdate(BaseModel):
"""更新数据集请求(部分更新)"""
name: Optional[str] = Field(None, min_length=1, max_length=128)
description: Optional[str] = Field(None, max_length=500)
chunk_method: Optional[str] = None
class DatasetListResponse(BaseModel):
"""数据集列表响应(分页)"""
items: List[dict]
total: int
page: int
page_size: int
# ============== 数据集端点 ==============
@router.get("/datasets", response_model=DatasetListResponse)
async def list_datasets(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
username: str = Depends(verify_admin),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""获取数据集列表(支持分页和搜索)"""
return await kb_service.list_datasets(
page=page,
page_size=page_size,
search=search
)
@router.post("/datasets", status_code=201)
async def create_dataset(
data: DatasetCreate,
username: str = Depends(verify_admin),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""创建数据集"""
try:
dataset = await kb_service.create_dataset(
name=data.name,
description=data.description,
chunk_method=data.chunk_method
)
return dataset
except Exception as e:
logger.error(f"Failed to create dataset: {e}")
raise HTTPException(status_code=500, detail=f"创建数据集失败: {str(e)}")
@router.get("/datasets/{dataset_id}")
async def get_dataset(
dataset_id: str,
username: str = Depends(verify_admin),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""获取数据集详情"""
dataset = await kb_service.get_dataset(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail="数据集不存在")
return dataset
@router.patch("/datasets/{dataset_id}")
async def update_dataset(
dataset_id: str,
data: DatasetUpdate,
username: str = Depends(verify_admin),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""更新数据集(部分更新)"""
try:
dataset = await kb_service.update_dataset(dataset_id, data.model_dump(exclude_unset=True))
if not dataset:
raise HTTPException(status_code=404, detail="数据集不存在")
return dataset
except Exception as e:
logger.error(f"Failed to update dataset: {e}")
raise HTTPException(status_code=500, detail=f"更新数据集失败: {str(e)}")
@router.delete("/datasets/{dataset_id}", status_code=204)
async def delete_dataset(
dataset_id: str,
username: str = Depends(verify_admin),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""删除数据集"""
success = await kb_service.delete_dataset(dataset_id)
if not success:
raise HTTPException(status_code=404, detail="数据集不存在")
# ============== 文件端点 ==============
@router.get("/datasets/{dataset_id}/files")
async def list_dataset_files(
dataset_id: str,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
username: str = Depends(verify_admin),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""获取数据集内文件列表(分页)"""
return await kb_service.list_files(dataset_id, page=page, page_size=page_size)
@router.post("/datasets/{dataset_id}/files")
async def upload_file(
dataset_id: str,
file: UploadFile = File(...),
background_tasks: BackgroundTasks = None,
username: str = Depends(verify_admin),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""
上传文件到数据集(流式处理)
支持的文件类型: PDF, DOCX, TXT, MD, CSV
最大文件大小: 100MB
"""
# 文件验证在 service 层处理
try:
result = await kb_service.upload_file(dataset_id, file)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Failed to upload file: {e}")
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
@router.delete("/datasets/{dataset_id}/files/{document_id}")
async def delete_file(
dataset_id: str,
document_id: str,
username: str = Depends(verify_admin),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""删除文件"""
success = await kb_service.delete_file(dataset_id, document_id)
if not success:
raise HTTPException(status_code=404, detail="文件不存在")
return {"success": True}
# ============== 切片端点(可选,延后实现)=============
# 根据简化建议,切片管理功能延后到明确需求时再实现
```
---
### 前端实现 (qwen-client)
#### 1. API 服务层
**文件:** `/src/api/index.js`
添加 `knowledgeBaseApi` 模块:
```javascript
// ============== Knowledge Base API ==============
const knowledgeBaseApi = {
// 数据集管理
getDatasets: async (params = {}) => {
const qs = new URLSearchParams(params).toString()
return request(`/api/v1/knowledge-base/datasets${qs ? '?' + qs : ''}`)
},
createDataset: async (data) => {
return request('/api/v1/knowledge-base/datasets', {
method: 'POST',
body: JSON.stringify(data)
})
},
updateDataset: async (datasetId, data) => {
return request(`/api/v1/knowledge-base/datasets/${datasetId}`, {
method: 'PATCH', // 使用 PATCH 支持部分更新
body: JSON.stringify(data)
})
},
deleteDataset: async (datasetId) => {
return request(`/api/v1/knowledge-base/datasets/${datasetId}`, {
method: 'DELETE'
})
},
// 文件管理
getDatasetFiles: async (datasetId, params = {}) => {
const qs = new URLSearchParams(params).toString()
return request(`/api/v1/knowledge-base/datasets/${datasetId}/files${qs ? '?' + qs : ''}`)
},
uploadFile: async (datasetId, file, onProgress) => {
const formData = new FormData()
formData.append('file', file)
// 支持上传进度回调
const xhr = new XMLHttpRequest()
return new Promise((resolve, reject) => {
xhr.upload.addEventListener('progress', (e) => {
if (onProgress && e.lengthComputable) {
onProgress(Math.round((e.loaded / e.total) * 100))
}
})
xhr.addEventListener('load', () => {
if (xhr.status >= 200 && xhr.status < 300) {
resolve(JSON.parse(xhr.responseText))
} else {
reject(new Error(xhr.statusText))
}
})
xhr.addEventListener('error', () => reject(new Error('上传失败')))
xhr.addEventListener('abort', () => reject(new Error('上传已取消')))
xhr.open('POST', `${API_BASE}/api/v1/knowledge-base/datasets/${datasetId}/files`)
xhr.setRequestHeader('Authorization', `Bearer ${localStorage.getItem('admin_token') || 'dummy-token'}`)
xhr.send(formData)
})
},
deleteFile: async (datasetId, documentId) => {
return request(`/api/v1/knowledge-base/datasets/${datasetId}/files/${documentId}`, {
method: 'DELETE'
})
}
}
```
#### 2. 简化的状态管理
**基于代码简洁性审查建议,直接在组件中管理状态,而不是创建独立的 composable**
```vue
<!-- KnowledgeBaseView.vue -->
<script setup>
import { ref, onMounted } from 'vue'
import { knowledgeBaseApi } from '@/api'
import DatasetList from '@/components/knowledge-base/DatasetList.vue'
import FileList from '@/components/knowledge-base/FileList.vue'
import DatasetFormModal from '@/components/knowledge-base/DatasetFormModal.vue'
// 状态
const datasets = ref([])
const currentDataset = ref(null)
const files = ref([])
const isLoading = ref(false)
const error = ref(null)
// 分页
const page = ref(1)
const pageSize = ref(20)
const total = ref(0)
// 加载数据集
const loadDatasets = async () => {
isLoading.value = true
error.value = null
try {
const response = await knowledgeBaseApi.getDatasets({
page: page.value,
page_size: pageSize.value
})
datasets.value = response.items || []
total.value = response.total
} catch (err) {
error.value = err.message
} finally {
isLoading.value = false
}
}
// 选择数据集
const selectDataset = async (dataset) => {
currentDataset.value = dataset
await loadFiles(dataset.dataset_id)
}
// 加载文件
const loadFiles = async (datasetId) => {
isLoading.value = true
try {
const response = await knowledgeBaseApi.getDatasetFiles(datasetId)
files.value = response.items || []
} finally {
isLoading.value = false
}
}
// 创建数据集
const createDataset = async (data) => {
await knowledgeBaseApi.createDataset(data)
await loadDatasets()
}
// 删除数据集
const deleteDataset = async (datasetId) => {
await knowledgeBaseApi.deleteDataset(datasetId)
if (currentDataset.value?.dataset_id === datasetId) {
currentDataset.value = null
files.value = []
}
await loadDatasets()
}
onMounted(() => {
loadDatasets()
})
</script>
<template>
<div class="knowledge-base-view">
<!-- 数据集列表 -->
<DatasetList
:datasets="datasets"
:loading="isLoading"
:current="currentDataset"
@select="selectDataset"
@create="createDataset"
@delete="deleteDataset"
/>
<!-- 文件列表(选中数据集后显示) -->
<FileList
v-if="currentDataset"
:dataset="currentDataset"
:files="files"
@upload="handleFileUpload"
@delete="handleFileDelete"
/>
</div>
</template>
```
#### 3. 路由配置
**文件:** `/src/router/index.js`
添加知识库路由:
```javascript
{
path: '/knowledge-base',
name: 'knowledge-base',
component: () => import('@/views/KnowledgeBaseView.vue'),
meta: { requiresAuth: true, title: '知识库管理' }
}
```
#### 4. 视图组件
**文件:** `/src/views/KnowledgeBaseView.vue`
主视图组件,包含:
- 数据集列表(左侧或顶部)
- 文件列表(选中数据集后显示)
- 上传文件按钮
- 创建数据集按钮
**子组件(简化后的结构):**
| 组件 | 文件 | 功能 |
|------|------|------|
| `DatasetList.vue` | `/src/components/knowledge-base/DatasetList.vue` | 数据集列表展示 + 创建/删除 |
| `DatasetFormModal.vue` | `/src/components/knowledge-base/DatasetFormModal.vue` | 创建/编辑数据集弹窗(合并) |
| `FileList.vue` | `/src/components/knowledge-base/FileList.vue` | 文件列表展示 + 上传 |
| `FileUploadModal.vue` | `/src/components/knowledge-base/FileUploadModal.vue` | 文件上传弹窗 |
**目录结构:**
```
src/components/knowledge-base/
├── DatasetList.vue # 数据集列表(含创建按钮)
├── DatasetFormModal.vue # 创建/编辑数据集表单
├── FileList.vue # 文件列表(含上传按钮)
└── FileUploadModal.vue # 文件上传弹窗
```
#### 5. 导航菜单
**文件:** `/src/views/AdminView.vue`
在导航菜单中添加知识库入口:
```vue
<Button
variant="ghost"
@click="currentView = 'knowledge-base'"
>
<Database :size="20" />
<span>知识库管理</span>
</Button>
```
---
## 实现阶段
### Phase 1: 后端基础 (qwen-agent) - 核心功能
- [ ] 添加 `ragflow-sdk` 依赖到 `pyproject.toml`
- [ ] 在 `utils/settings.py` 添加 RAGFlow 配置(环境变量)
- [ ] 创建 `repositories/ragflow_repository.py` - RAGFlow SDK 适配器
- [ ] 创建 `services/knowledge_base_service.py` - 业务逻辑层
- [ ] 创建 `routes/knowledge_base.py` - API 路由
- [ ] 在 `fastapi_app.py` 注册路由
- [ ] 测试 API 端点
### Phase 2: 前端 API 层 (qwen-client)
- [ ] 在 `src/api/index.js` 添加 `knowledgeBaseApi`
- [ ] 添加知识库路由到 `src/router/index.js`
- [ ] 在 AdminView 添加导航入口
### Phase 3: 前端 UI 组件 - 最小实现
- [ ] 创建 `src/components/knowledge-base/` 目录
- [ ] 实现 `KnowledgeBaseView.vue` 主视图
- [ ] 实现 `DatasetList.vue` 组件
- [ ] 实现 `DatasetFormModal.vue` 组件
- [ ] 实现 `FileList.vue` 组件
- [ ] 实现 `FileUploadModal.vue` 组件
### Phase 4: 切片管理 (延后实现)
根据 YAGNI 原则,切片管理功能延后到有明确需求时再实现:
- [ ] 后端实现切片列表/删除端点
- [ ] 前端实现 `ChunkList.vue` 组件
- [ ] 切片搜索功能
### Phase 5: 测试与优化
- [ ] 端到端测试
- [ ] 错误处理优化
- [ ] 加载状态优化
- [ ] 添加性能监控
---
## 数据模型
### 数据集 (Dataset)
```typescript
interface Dataset {
dataset_id: string
name: string
description?: string
chunk_method: string
chunk_count?: number
document_count?: number
created_at: string
updated_at: string
}
```
### 文件 (Document)
```typescript
interface Document {
document_id: string
dataset_id: string
name: string
size: number
status: 'running' | 'success' | 'failed'
progress: number // 0-100
chunk_count?: number
token_count?: number
created_at: string
updated_at: string
}
```
### 切片 (Chunk) - 延后实现
```typescript
interface Chunk {
chunk_id: string
document_id: string
dataset_id: string
content: string
position: number
important_keywords?: string[]
available: boolean
created_at: string
}
```
---
## API 端点规范
### 1. 获取数据集列表(分页)
```
GET /api/v1/knowledge-base/datasets?page=1&page_size=20&search=keyword
Authorization: Bearer {admin_token}
Response:
{
"items": [
{
"dataset_id": "uuid",
"name": "产品手册",
"description": "公司产品相关文档",
"chunk_method": "naive",
"document_count": 5,
"chunk_count": 120,
"created_at": "2025-01-01T00:00:00Z",
"updated_at": "2025-01-01T00:00:00Z"
}
],
"total": 1,
"page": 1,
"page_size": 20
}
```
### 2. 创建数据集
```
POST /api/v1/knowledge-base/datasets
Authorization: Bearer {admin_token}
Content-Type: application/json
{
"name": "产品手册",
"description": "公司产品相关文档",
"chunk_method": "naive"
}
Response:
{
"dataset_id": "uuid",
"name": "产品手册",
...
}
```
### 3. 上传文件(流式)
```
POST /api/v1/knowledge-base/datasets/{dataset_id}/files
Authorization: Bearer {admin_token}
Content-Type: multipart/form-data
file: <binary>
Response (异步):
{
"document_id": "uuid",
"name": "document.pdf",
"status": "running",
"progress": 0,
...
}
```
---
## 安全考虑
### Research Insights
**文件上传安全:**
- 实现文件类型白名单验证(扩展名 + MIME 类型 + 魔数)
- 限制文件大小(最大 100MB
- 使用 UUID 重命名文件,防止路径遍历
- 清理文件名中的危险字符
**API 认证:**
- 复用现有的 `verify_admin_auth` 函数
- 所有端点需要有效的 Admin Token
- 集成现有的 RBAC 系统
**输入验证:**
- 使用 Pydantic Field 进行输入验证
- 限制字符串长度
- 验证分页参数范围
**配置安全:**
- API Key 必须通过环境变量设置
- 不在代码中硬编码敏感信息
### 安全配置清单
| 措施 | 优先级 | 状态 |
|------|--------|------|
| 文件类型验证 | 高 | 待实现 |
| 文件大小限制 | 高 | 待实现 |
| API Key 环境变量 | 高 | 已规划 |
| 路径遍历防护 | 高 | 待实现 |
| 文件名清理 | 中 | 待实现 |
| 病毒扫描 | 中 | 可选 |
---
## 性能优化
### Research Insights
**文件上传优化:**
- 使用流式处理,避免一次性读取大文件到内存
- 实现并发限制(最多 5 个并发上传)
- 添加上传进度回调
**查询优化:**
- 实现分页机制,避免返回大量数据
- 使用游标分页优化深分页性能
- 对数据集列表添加缓存
**连接池:**
- 使用异步 HTTP 客户端连接池
- 设置合理的超时时间
### 性能优化清单
| 优化项 | 优先级 | 预期效果 |
|--------|--------|----------|
| 流式文件上传 | 高 | 避免 OOM |
| 分页查询 | 高 | 响应时间 < 100ms |
| 数据集缓存 | 中 | 减少外部 API 调用 |
| 连接池 | 中 | 提高并发能力 |
| 限流 | 中 | 防止资源耗尽 |
---
## 用户体验优化
### Research Insights
**空状态设计:**
- 为空列表提供友好的提示和操作引导
- 区分不同场景的空状态(首次使用、搜索无结果等)
**加载状态:**
- 使用骨架屏替代传统 loading 指示器
- 显示加载进度(特别是文件上传)
**错误处理:**
- 使用人类可读的错误消息
- 提供具体的修复建议
- 区分不同类型的错误(网络、验证、服务器)
**文件上传 UX**
- 支持拖拽上传
- 显示上传进度
- 支持批量上传
- 显示文件大小和类型验证
---
## 配置清单
### 环境变量
| 变量名 | 默认值 | 说明 | 必填 |
|--------|--------|------|------|
| `RAGFLOW_API_URL` | `http://100.77.70.35:1080` | RAGFlow API 地址 | 是 |
| `RAGFLOW_API_KEY` | - | RAGFlow API Key | 是 |
| `RAGFLOW_MAX_UPLOAD_SIZE` | `104857600` | 最大上传文件大小(字节) | 否 |
| `RAGFLOW_ALLOWED_EXTENSIONS` | `pdf,docx,txt,md,csv` | 允许的文件扩展名 | 否 |
| `RAGFLOW_CONNECTION_TIMEOUT` | `30` | 连接超时(秒) | 否 |
| `RAGFLOW_MAX_CONCURRENT_UPLOADS` | `5` | 最大并发上传数 | 否 |
### 依赖包
```toml
[tool.poetry.dependencies]
ragflow-sdk = "^0.1.0"
python-magic = "^0.4.27"
aiofiles = "^24.1.0"
```
---
## 参考资料
- **RAGFlow 官方文档:** https://ragflow.com.cn/docs
- **RAGFlow HTTP API:** https://ragflow.io/docs/http_api_reference
- **RAGFlow GitHub:** https://github.com/infiniflow/ragflow
- **RAGFlow Python SDK:** https://github.com/infiniflow/ragflow/blob/main/docs/references/python_api_reference.md
- **qwen-client API 层:** `/src/api/index.js`
- **qwen-agent 路由示例:** `/routes/bot_manager.py`
**研究来源:**
- [Vue.js Official Composables Guide](https://vuejs.org/guide/reusability/composables.html)
- [FastAPI Official Documentation](https://fastapi.tiangolo.com/)
- [UX Best Practices for File Uploader - Uploadcare](https://uploadcare.com/blog/file-uploader-ux-best-practices/)
- [Empty State UX Examples - Pencil & Paper](https://www.pencilandpaper.io/articles/empty-states)
- [Error-Message Guidelines - Nielsen Norman Group](https://www.nngroup.com/articles/error-message-guidelines/)
---
## 后续扩展
1. **与 Bot 关联:** 在 Bot 设置中选择知识库
2. **RAG 检索:** 实现基于知识库的问答功能
3. **批量操作:** 批量上传、删除文件
4. **知识库搜索:** 在知识库内搜索内容
5. **访问统计:** 查看知识库使用情况
6. **切片管理:** 前端切片查看和编辑(延后实现)

40
poetry.lock generated
View File

@ -280,6 +280,26 @@ files = [
{file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"},
]
[[package]]
name = "beartype"
version = "0.22.9"
description = "Unbearably fast near-real-time pure-Python runtime-static type-checker."
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "beartype-0.22.9-py3-none-any.whl", hash = "sha256:d16c9bbc61ea14637596c5f6fbff2ee99cbe3573e46a716401734ef50c3060c2"},
{file = "beartype-0.22.9.tar.gz", hash = "sha256:8f82b54aa723a2848a56008d18875f91c1db02c32ef6a62319a002e3e25a975f"},
]
[package.extras]
dev = ["autoapi (>=0.9.0)", "celery", "click", "coverage (>=5.5)", "docutils (>=0.22.0)", "equinox ; sys_platform == \"linux\" and python_version < \"3.15.0\"", "fastmcp ; python_version < \"3.14.0\"", "jax[cpu] ; sys_platform == \"linux\" and python_version < \"3.15.0\"", "jaxtyping ; sys_platform == \"linux\"", "langchain ; python_version < \"3.14.0\" and sys_platform != \"darwin\" and platform_python_implementation != \"PyPy\"", "mypy (>=0.800) ; platform_python_implementation != \"PyPy\"", "nuitka (>=1.2.6) ; sys_platform == \"linux\" and python_version < \"3.14.0\"", "numba ; python_version < \"3.14.0\"", "numpy ; python_version < \"3.15.0\" and sys_platform != \"darwin\" and platform_python_implementation != \"PyPy\"", "pandera (>=0.26.0) ; python_version < \"3.14.0\"", "poetry", "polars ; python_version < \"3.14.0\"", "pydata-sphinx-theme (<=0.7.2)", "pygments", "pyinstaller", "pyright (>=1.1.370)", "pytest (>=6.2.0)", "redis", "rich-click", "setuptools", "sphinx", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)", "sqlalchemy", "torch ; sys_platform == \"linux\" and python_version < \"3.14.0\"", "tox (>=3.20.1)", "typer", "typing-extensions (>=3.10.0.0)", "xarray ; python_version < \"3.15.0\""]
doc-ghp = ["mkdocs-material[imaging] (>=9.6.0)", "mkdocstrings-python (>=1.16.0)", "mkdocstrings-python-xref (>=1.16.0)"]
doc-rtd = ["autoapi (>=0.9.0)", "pydata-sphinx-theme (<=0.7.2)", "setuptools", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)"]
test = ["celery", "click", "coverage (>=5.5)", "docutils (>=0.22.0)", "equinox ; sys_platform == \"linux\" and python_version < \"3.15.0\"", "fastmcp ; python_version < \"3.14.0\"", "jax[cpu] ; sys_platform == \"linux\" and python_version < \"3.15.0\"", "jaxtyping ; sys_platform == \"linux\"", "langchain ; python_version < \"3.14.0\" and sys_platform != \"darwin\" and platform_python_implementation != \"PyPy\"", "mypy (>=0.800) ; platform_python_implementation != \"PyPy\"", "nuitka (>=1.2.6) ; sys_platform == \"linux\" and python_version < \"3.14.0\"", "numba ; python_version < \"3.14.0\"", "numpy ; python_version < \"3.15.0\" and sys_platform != \"darwin\" and platform_python_implementation != \"PyPy\"", "pandera (>=0.26.0) ; python_version < \"3.14.0\"", "poetry", "polars ; python_version < \"3.14.0\"", "pygments", "pyinstaller", "pyright (>=1.1.370)", "pytest (>=6.2.0)", "redis", "rich-click", "sphinx", "sqlalchemy", "torch ; sys_platform == \"linux\" and python_version < \"3.14.0\"", "tox (>=3.20.1)", "typer", "typing-extensions (>=3.10.0.0)", "xarray ; python_version < \"3.15.0\""]
test-tox = ["celery", "click", "docutils (>=0.22.0)", "equinox ; sys_platform == \"linux\" and python_version < \"3.15.0\"", "fastmcp ; python_version < \"3.14.0\"", "jax[cpu] ; sys_platform == \"linux\" and python_version < \"3.15.0\"", "jaxtyping ; sys_platform == \"linux\"", "langchain ; python_version < \"3.14.0\" and sys_platform != \"darwin\" and platform_python_implementation != \"PyPy\"", "mypy (>=0.800) ; platform_python_implementation != \"PyPy\"", "nuitka (>=1.2.6) ; sys_platform == \"linux\" and python_version < \"3.14.0\"", "numba ; python_version < \"3.14.0\"", "numpy ; python_version < \"3.15.0\" and sys_platform != \"darwin\" and platform_python_implementation != \"PyPy\"", "pandera (>=0.26.0) ; python_version < \"3.14.0\"", "poetry", "polars ; python_version < \"3.14.0\"", "pygments", "pyinstaller", "pyright (>=1.1.370)", "pytest (>=6.2.0)", "redis", "rich-click", "sphinx", "sqlalchemy", "torch ; sys_platform == \"linux\" and python_version < \"3.14.0\"", "typer", "typing-extensions (>=3.10.0.0)", "xarray ; python_version < \"3.15.0\""]
test-tox-coverage = ["coverage (>=5.5)"]
[[package]]
name = "beautifulsoup4"
version = "4.14.3"
@ -3888,6 +3908,22 @@ urllib3 = ">=1.26.14,<3"
fastembed = ["fastembed (>=0.7,<0.8)"]
fastembed-gpu = ["fastembed-gpu (>=0.7,<0.8)"]
[[package]]
name = "ragflow-sdk"
version = "0.23.1"
description = "Python client sdk of [RAGFlow](https://github.com/infiniflow/ragflow). RAGFlow is an open-source RAG (Retrieval-Augmented Generation) engine based on deep document understanding."
optional = false
python-versions = "<3.15,>=3.12"
groups = ["main"]
files = [
{file = "ragflow_sdk-0.23.1-py3-none-any.whl", hash = "sha256:8bb2827f2696f84fc5cdbf980e2a74e2b18c712c07d08eca26ea52e13e2a4c51"},
{file = "ragflow_sdk-0.23.1.tar.gz", hash = "sha256:dc358001bc8cad243e9aa879056c3f65bd7d687a9bff9863f6c79eaa4f43db09"},
]
[package.dependencies]
beartype = ">=0.20.0,<1.0.0"
requests = ">=2.30.0,<3.0.0"
[[package]]
name = "referencing"
version = "0.37.0"
@ -6048,5 +6084,5 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata]
lock-version = "2.1"
python-versions = ">=3.12,<4.0"
content-hash = "abce2b9aba5a46841df8e6e4e4f12523ff9c4cd34dab7d180490ae36b2dee16e"
python-versions = ">=3.12,<3.15"
content-hash = "d930570328aea9211c1563538968847d8ba638963025e63a246559307e1d33ed"

View File

@ -6,7 +6,7 @@ authors = [
{name = "朱潮",email = "zhuchaowe@users.noreply.github.com"}
]
readme = "README.md"
requires-python = ">=3.12,<4.0"
requires-python = ">=3.12,<3.15"
dependencies = [
"fastapi==0.116.1",
"uvicorn==0.35.0",
@ -36,6 +36,8 @@ dependencies = [
"psycopg2-binary (>=2.9.11,<3.0.0)",
"json-repair (>=0.29.0,<0.30.0)",
"tiktoken (>=0.5.0,<1.0.0)",
"ragflow-sdk (>=0.23.0,<0.24.0)",
"httpx (>=0.28.1,<0.29.0)",
]
[tool.poetry.requires-plugins]

6
repositories/__init__.py Normal file
View File

@ -0,0 +1,6 @@
"""
Repositories package for data access layer
"""
from .ragflow_repository import RAGFlowRepository
__all__ = ['RAGFlowRepository']

View File

@ -0,0 +1,559 @@
"""
RAGFlow Repository - 数据访问层
封装 RAGFlow SDK 调用提供统一的数据访问接口
"""
import logging
import asyncio
from typing import Optional, List, Dict, Any
from pathlib import Path
try:
from ragflow_sdk import RAGFlow
except ImportError:
RAGFlow = None
logging.warning("ragflow-sdk not installed")
from utils.settings import (
RAGFLOW_API_URL,
RAGFLOW_API_KEY,
RAGFLOW_CONNECTION_TIMEOUT
)
logger = logging.getLogger('app')
class RAGFlowRepository:
"""
RAGFlow 数据仓储类
封装 RAGFlow SDK 的所有调用提供
- 统一的错误处理
- 连接管理
- 数据转换
"""
def __init__(self, api_key: str = None, base_url: str = None):
"""
初始化 RAGFlow 客户端
Args:
api_key: RAGFlow API Key默认从配置读取
base_url: RAGFlow 服务地址默认从配置读取
"""
self.api_key = api_key or RAGFLOW_API_KEY
self.base_url = base_url or RAGFLOW_API_URL
self._client: Optional[Any] = None
self._lock = asyncio.Lock()
async def _get_client(self):
"""
获取 RAGFlow 客户端实例懒加载
Returns:
RAGFlow 客户端
"""
if RAGFlow is None:
raise RuntimeError("ragflow-sdk is not installed. Run: poetry install")
if self._client is None:
async with self._lock:
# 双重检查
if self._client is None:
try:
self._client = RAGFlow(
api_key=self.api_key,
base_url=self.base_url
)
logger.info(f"RAGFlow client initialized: {self.base_url}")
except Exception as e:
logger.error(f"Failed to initialize RAGFlow client: {e}")
raise
return self._client
async def create_dataset(
self,
name: str,
description: str = None,
chunk_method: str = "naive",
permission: str = "me"
) -> Dict[str, Any]:
"""
创建数据集
Args:
name: 数据集名称
description: 描述信息
chunk_method: 分块方法 (naive, manual, qa, table, paper, book, etc.)
permission: 权限 (me team)
Returns:
创建的数据集信息
"""
client = await self._get_client()
try:
dataset = client.create_dataset(
name=name,
avatar=None,
description=description,
chunk_method=chunk_method,
permission=permission
)
return {
"dataset_id": getattr(dataset, 'id', None),
"name": getattr(dataset, 'name', name),
"description": getattr(dataset, 'description', description),
"chunk_method": getattr(dataset, 'chunk_method', chunk_method),
"permission": getattr(dataset, 'permission', permission),
"created_at": getattr(dataset, 'created_at', None),
"updated_at": getattr(dataset, 'updated_at', None),
}
except Exception as e:
logger.error(f"Failed to create dataset: {e}")
raise
async def list_datasets(
self,
page: int = 1,
page_size: int = 30,
search: str = None
) -> Dict[str, Any]:
"""
获取数据集列表
Args:
page: 页码
page_size: 每页数量
search: 搜索关键词
Returns:
数据集列表和分页信息
"""
client = await self._get_client()
try:
# RAGFlow SDK 的 list_datasets 方法
datasets = client.list_datasets(
page=page,
page_size=page_size
)
items = []
for dataset in datasets:
dataset_info = {
"dataset_id": getattr(dataset, 'id', None),
"name": getattr(dataset, 'name', None),
"description": getattr(dataset, 'description', None),
"chunk_method": getattr(dataset, 'chunk_method', None),
"avatar": getattr(dataset, 'avatar', None),
"permission": getattr(dataset, 'permission', None),
"created_at": getattr(dataset, 'created_at', None),
"updated_at": getattr(dataset, 'updated_at', None),
"metadata": getattr(dataset, 'metadata', {}),
}
# 搜索过滤
if search:
search_lower = search.lower()
if (search_lower not in (dataset_info.get('name') or '').lower() and
search_lower not in (dataset_info.get('description') or '').lower()):
continue
items.append(dataset_info)
return {
"items": items,
"total": len(items), # RAGFlow 可能不返回总数,使用实际返回数量
"page": page,
"page_size": page_size
}
except Exception as e:
logger.error(f"Failed to list datasets: {e}")
raise
async def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]:
"""
获取数据集详情
Args:
dataset_id: 数据集 ID
Returns:
数据集详情不存在返回 None
"""
client = await self._get_client()
try:
datasets = client.list_datasets(id=dataset_id)
if datasets and len(datasets) > 0:
dataset = datasets[0]
return {
"dataset_id": getattr(dataset, 'id', dataset_id),
"name": getattr(dataset, 'name', None),
"description": getattr(dataset, 'description', None),
"chunk_method": getattr(dataset, 'chunk_method', None),
"permission": getattr(dataset, 'permission', None),
"created_at": getattr(dataset, 'created_at', None),
"updated_at": getattr(dataset, 'updated_at', None),
}
return None
except Exception as e:
logger.error(f"Failed to get dataset {dataset_id}: {e}")
raise
async def update_dataset(
self,
dataset_id: str,
**updates
) -> Optional[Dict[str, Any]]:
"""
更新数据集
Args:
dataset_id: 数据集 ID
**updates: 要更新的字段
Returns:
更新后的数据集信息
"""
client = await self._get_client()
try:
datasets = client.list_datasets(id=dataset_id)
if datasets and len(datasets) > 0:
dataset = datasets[0]
# 调用 update 方法
dataset.update(updates)
return {
"dataset_id": getattr(dataset, 'id', dataset_id),
"name": getattr(dataset, 'name', None),
"description": getattr(dataset, 'description', None),
"chunk_method": getattr(dataset, 'chunk_method', None),
"updated_at": getattr(dataset, 'updated_at', None),
}
return None
except Exception as e:
logger.error(f"Failed to update dataset {dataset_id}: {e}")
raise
async def delete_datasets(self, dataset_ids: List[str] = None) -> bool:
"""
删除数据集
Args:
dataset_ids: 要删除的数据集 ID 列表
Returns:
是否成功
"""
client = await self._get_client()
try:
if dataset_ids:
client.delete_datasets(ids=dataset_ids)
return True
except Exception as e:
logger.error(f"Failed to delete datasets: {e}")
raise
async def upload_document(
self,
dataset_id: str,
file_name: str,
file_content: bytes,
display_name: str = None
) -> Dict[str, Any]:
"""
上传文档到数据集
Args:
dataset_id: 数据集 ID
file_name: 文件名
file_content: 文件内容
display_name: 显示名称
Returns:
上传的文档信息
"""
client = await self._get_client()
try:
# 获取数据集
datasets = client.list_datasets(id=dataset_id)
if not datasets or len(datasets) == 0:
raise ValueError(f"Dataset {dataset_id} not found")
dataset = datasets[0]
# 上传文档
display_name = display_name or file_name
dataset.upload_documents([{
"display_name": display_name,
"blob": file_content
}])
# 查找刚上传的文档
documents = dataset.list_documents()
for doc in documents:
if getattr(doc, 'name', None) == display_name:
return {
"document_id": getattr(doc, 'id', None),
"name": display_name,
"dataset_id": dataset_id,
"size": len(file_content),
"status": "running",
"chunk_count": getattr(doc, 'chunk_count', 0),
"token_count": getattr(doc, 'token_count', 0),
"created_at": getattr(doc, 'created_at', None),
}
return {
"document_id": None,
"name": display_name,
"dataset_id": dataset_id,
"size": len(file_content),
"status": "uploaded",
}
except Exception as e:
logger.error(f"Failed to upload document to {dataset_id}: {e}")
raise
async def list_documents(
self,
dataset_id: str,
page: int = 1,
page_size: int = 20
) -> Dict[str, Any]:
"""
获取数据集中的文档列表
Args:
dataset_id: 数据集 ID
page: 页码
page_size: 每页数量
Returns:
文档列表和分页信息
"""
client = await self._get_client()
try:
datasets = client.list_datasets(id=dataset_id)
if not datasets or len(datasets) == 0:
return {"items": [], "total": 0, "page": page, "page_size": page_size}
dataset = datasets[0]
documents = dataset.list_documents(
page=page,
page_size=page_size
)
items = []
for doc in documents:
items.append({
"document_id": getattr(doc, 'id', None),
"name": getattr(doc, 'name', None),
"dataset_id": dataset_id,
"size": getattr(doc, 'size', 0),
"status": getattr(doc, 'run', 'unknown'),
"progress": getattr(doc, 'progress', 0),
"chunk_count": getattr(doc, 'chunk_count', 0),
"token_count": getattr(doc, 'token_count', 0),
"created_at": getattr(doc, 'created_at', None),
"updated_at": getattr(doc, 'updated_at', None),
})
return {
"items": items,
"total": len(items),
"page": page,
"page_size": page_size
}
except Exception as e:
logger.error(f"Failed to list documents for {dataset_id}: {e}")
raise
async def delete_document(
self,
dataset_id: str,
document_id: str
) -> bool:
"""
删除文档
Args:
dataset_id: 数据集 ID
document_id: 文档 ID
Returns:
是否成功
"""
client = await self._get_client()
try:
datasets = client.list_datasets(id=dataset_id)
if datasets and len(datasets) > 0:
dataset = datasets[0]
dataset.delete_documents(ids=[document_id])
return True
except Exception as e:
logger.error(f"Failed to delete document {document_id}: {e}")
raise
async def list_chunks(
self,
dataset_id: str,
document_id: str = None,
page: int = 1,
page_size: int = 50
) -> Dict[str, Any]:
"""
获取切片列表
Args:
dataset_id: 数据集 ID
document_id: 文档 ID可选
page: 页码
page_size: 每页数量
Returns:
切片列表和分页信息
"""
client = await self._get_client()
try:
datasets = client.list_datasets(id=dataset_id)
if not datasets or len(datasets) == 0:
return {"items": [], "total": 0, "page": page, "page_size": page_size}
dataset = datasets[0]
# 如果指定了文档 ID先获取文档
if document_id:
documents = dataset.list_documents(id=document_id)
if documents and len(documents) > 0:
doc = documents[0]
chunks = doc.list_chunks(page=page, page_size=page_size)
else:
chunks = []
else:
# 获取所有文档的所有切片
chunks = []
for doc in dataset.list_documents():
chunks.extend(doc.list_chunks(page=page, page_size=page_size))
items = []
for chunk in chunks:
items.append({
"chunk_id": getattr(chunk, 'id', None),
"content": getattr(chunk, 'content', ''),
"document_id": getattr(chunk, 'document_id', None),
"dataset_id": dataset_id,
"position": getattr(chunk, 'position', 0),
"important_keywords": getattr(chunk, 'important_keywords', []),
"available": getattr(chunk, 'available', True),
"created_at": getattr(chunk, 'create_time', None),
})
return {
"items": items,
"total": len(items),
"page": page,
"page_size": page_size
}
except Exception as e:
logger.error(f"Failed to list chunks for {dataset_id}: {e}")
raise
async def delete_chunk(
self,
dataset_id: str,
document_id: str,
chunk_id: str
) -> bool:
"""
删除切片
Args:
dataset_id: 数据集 ID
document_id: 文档 ID
chunk_id: 切片 ID
Returns:
是否成功
"""
client = await self._get_client()
try:
datasets = client.list_datasets(id=dataset_id)
if datasets and len(datasets) > 0:
dataset = datasets[0]
documents = dataset.list_documents(id=document_id)
if documents and len(documents) > 0:
doc = documents[0]
doc.delete_chunks(chunk_ids=[chunk_id])
return True
except Exception as e:
logger.error(f"Failed to delete chunk {chunk_id}: {e}")
raise
async def parse_document(
self,
dataset_id: str,
document_id: str
) -> bool:
"""
开始解析文档
Args:
dataset_id: 数据集 ID
document_id: 文档 ID
Returns:
是否成功
"""
client = await self._get_client()
try:
datasets = client.list_datasets(id=dataset_id)
if not datasets or len(datasets) == 0:
raise ValueError(f"Dataset {dataset_id} not found")
dataset = datasets[0]
dataset.async_parse_documents([document_id])
return True
except Exception as e:
logger.error(f"Failed to parse document {document_id}: {e}")
raise
async def cancel_parse_document(
self,
dataset_id: str,
document_id: str
) -> bool:
"""
取消解析文档
Args:
dataset_id: 数据集 ID
document_id: 文档 ID
Returns:
是否成功
"""
client = await self._get_client()
try:
datasets = client.list_datasets(id=dataset_id)
if not datasets or len(datasets) == 0:
raise ValueError(f"Dataset {dataset_id} not found")
dataset = datasets[0]
dataset.async_cancel_parse_documents([document_id])
return True
except Exception as e:
logger.error(f"Failed to cancel parse document {document_id}: {e}")
raise

View File

@ -425,7 +425,7 @@ class BotSettingsUpdate(BaseModel):
avatar_url: Optional[str] = None
description: Optional[str] = None
suggestions: Optional[List[str]] = None
dataset_ids: Optional[str] = None
dataset_ids: Optional[List[str]] = None # 改为数组类型,支持多选知识库
system_prompt: Optional[str] = None
enable_memori: Optional[bool] = None
enable_thinking: Optional[bool] = None
@ -452,7 +452,7 @@ class BotSettingsResponse(BaseModel):
avatar_url: Optional[str]
description: Optional[str]
suggestions: Optional[List[str]]
dataset_ids: Optional[str]
dataset_ids: Optional[List[str]] # 改为数组类型
system_prompt: Optional[str]
enable_memori: bool
enable_thinking: bool
@ -1544,6 +1544,13 @@ async def get_bot_settings(bot_uuid: str, authorization: Optional[str] = Header(
api_key=mask_api_key(model_row[5])
)
# 处理 dataset_ids将字符串转换为数组
dataset_ids = settings.get('dataset_ids')
if dataset_ids and isinstance(dataset_ids, str):
dataset_ids = [id.strip() for id in dataset_ids.split(',') if id.strip()]
elif not dataset_ids:
dataset_ids = None
return BotSettingsResponse(
bot_id=str(bot_id),
model_id=model_id,
@ -1552,7 +1559,7 @@ async def get_bot_settings(bot_uuid: str, authorization: Optional[str] = Header(
avatar_url=settings.get('avatar_url'),
description=settings.get('description'),
suggestions=settings.get('suggestions'),
dataset_ids=settings.get('dataset_ids'),
dataset_ids=dataset_ids,
system_prompt=settings.get('system_prompt'),
enable_memori=settings.get('enable_memori', False),
enable_thinking=settings.get('enable_thinking', False),
@ -1623,7 +1630,8 @@ async def update_bot_settings(
if request.suggestions is not None:
update_json['suggestions'] = request.suggestions
if request.dataset_ids is not None:
update_json['dataset_ids'] = request.dataset_ids
# 将数组转换为逗号分隔的字符串存储
update_json['dataset_ids'] = ','.join(request.dataset_ids) if request.dataset_ids else None
if request.system_prompt is not None:
update_json['system_prompt'] = request.system_prompt
if request.enable_memori is not None:

369
routes/knowledge_base.py Normal file
View File

@ -0,0 +1,369 @@
"""
Knowledge Base API 路由
通过 RAGFlow SDK 提供知识库管理功能
"""
import logging
from typing import Optional
from fastapi import APIRouter, HTTPException, Header, UploadFile, File, Query, Depends
from pydantic import BaseModel, Field
from agent.db_pool_manager import get_db_pool_manager
from utils.fastapi_utils import extract_api_key_from_auth
from repositories.ragflow_repository import RAGFlowRepository
from services.knowledge_base_service import KnowledgeBaseService
logger = logging.getLogger('app')
router = APIRouter()
# ============== 依赖注入 ==============
async def get_kb_service() -> KnowledgeBaseService:
"""获取知识库服务实例"""
return KnowledgeBaseService(RAGFlowRepository())
async def verify_user(authorization: Optional[str] = Header(None)) -> tuple:
"""
验证用户权限检查 agent_user_tokens
Returns:
tuple[str, str]: (user_id, username)
"""
from routes.bot_manager import verify_user_auth
valid, user_id, username = await verify_user_auth(authorization)
if not valid:
raise HTTPException(status_code=401, detail="Unauthorized")
return user_id, username
# ============== 数据库表初始化 ==============
async def init_knowledge_base_tables():
"""
初始化知识库相关的数据库表
"""
pool = get_db_pool_manager().pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
# 检查 user_datasets 表是否已存在
await cursor.execute("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'user_datasets'
)
""")
table_exists = (await cursor.fetchone())[0]
if not table_exists:
logger.info("Creating user_datasets table")
await cursor.execute("""
CREATE TABLE user_datasets (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES agent_user(id) ON DELETE CASCADE,
dataset_id VARCHAR(255) NOT NULL,
dataset_name VARCHAR(255),
owner BOOLEAN DEFAULT TRUE,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
UNIQUE(user_id, dataset_id)
)
""")
await cursor.execute("CREATE INDEX idx_user_datasets_user_id ON user_datasets(user_id)")
await cursor.execute("CREATE INDEX idx_user_datasets_dataset_id ON user_datasets(dataset_id)")
logger.info("user_datasets table created successfully")
await conn.commit()
logger.info("Knowledge base tables initialized successfully")
# ============== Pydantic Models ==============
class DatasetCreate(BaseModel):
"""创建数据集请求"""
name: str = Field(..., min_length=1, max_length=128, description="数据集名称")
description: Optional[str] = Field(None, max_length=500, description="描述信息")
chunk_method: str = Field(
default="naive",
description="分块方法: naive, manual, qa, table, paper, book, laws, presentation, picture, one, email, knowledge-graph"
)
class DatasetUpdate(BaseModel):
"""更新数据集请求(部分更新)"""
name: Optional[str] = Field(None, min_length=1, max_length=128)
description: Optional[str] = Field(None, max_length=500)
chunk_method: Optional[str] = None
class DatasetListResponse(BaseModel):
"""数据集列表响应(分页)"""
items: list
total: int
page: int
page_size: int
class FileListResponse(BaseModel):
"""文件列表响应(分页)"""
items: list
total: int
page: int
page_size: int
class ChunkListResponse(BaseModel):
"""切片列表响应(分页)"""
items: list
total: int
page: int
page_size: int
# ============== 数据集端点 ==============
@router.get("/datasets", response_model=DatasetListResponse)
async def list_datasets(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""获取当前用户的数据集列表(支持分页和搜索)"""
user_id, username = user_info
return await kb_service.list_datasets(
user_id=user_id,
page=page,
page_size=page_size,
search=search
)
@router.post("/datasets", status_code=201)
async def create_dataset(
data: DatasetCreate,
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""创建数据集并关联到当前用户"""
try:
user_id, username = user_info
dataset = await kb_service.create_dataset(
user_id=user_id,
name=data.name,
description=data.description,
chunk_method=data.chunk_method
)
return dataset
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Failed to create dataset: {e}")
raise HTTPException(status_code=500, detail=f"创建数据集失败: {str(e)}")
@router.get("/datasets/{dataset_id}")
async def get_dataset(
dataset_id: str,
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""获取数据集详情(仅限自己的数据集)"""
user_id, username = user_info
dataset = await kb_service.get_dataset(dataset_id, user_id=user_id)
if not dataset:
raise HTTPException(status_code=404, detail="数据集不存在")
return dataset
@router.patch("/datasets/{dataset_id}")
async def update_dataset(
dataset_id: str,
data: DatasetUpdate,
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""更新数据集(部分更新)"""
try:
user_id, username = user_info
# 只传递非 None 的字段
updates = data.model_dump(exclude_unset=True)
if not updates:
raise HTTPException(status_code=400, detail="没有提供要更新的字段")
dataset = await kb_service.update_dataset(dataset_id, updates, user_id=user_id)
if not dataset:
raise HTTPException(status_code=404, detail="数据集不存在")
return dataset
except Exception as e:
logger.error(f"Failed to update dataset: {e}")
raise HTTPException(status_code=500, detail=f"更新数据集失败: {str(e)}")
@router.delete("/datasets/{dataset_id}")
async def delete_dataset(
dataset_id: str,
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""删除数据集"""
user_id, username = user_info
success = await kb_service.delete_dataset(dataset_id, user_id=user_id)
if not success:
raise HTTPException(status_code=404, detail="数据集不存在")
return {"success": True, "message": "数据集已删除"}
# ============== 文件端点 ==============
@router.get("/datasets/{dataset_id}/files", response_model=FileListResponse)
async def list_dataset_files(
dataset_id: str,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""获取数据集内文件列表(分页,仅限自己的数据集)"""
user_id, username = user_info
return await kb_service.list_files(dataset_id, user_id=user_id, page=page, page_size=page_size)
@router.post("/datasets/{dataset_id}/files")
async def upload_file(
dataset_id: str,
file: UploadFile = File(...),
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""
上传文件到数据集流式处理
支持的文件类型: PDF, DOCX, TXT, MD, CSV
最大文件大小: 100MB
"""
try:
user_id, username = user_info
result = await kb_service.upload_file(dataset_id, user_id=user_id, file=file)
return result
except ValueError as e:
if "File validation failed" in str(e) or "not belong to you" in str(e):
raise HTTPException(status_code=400, detail=str(e))
logger.error(f"Failed to upload file: {e}")
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
except Exception as e:
logger.error(f"Failed to upload file: {e}")
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
@router.delete("/datasets/{dataset_id}/files/{document_id}")
async def delete_file(
dataset_id: str,
document_id: str,
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""删除文件"""
user_id, username = user_info
success = await kb_service.delete_file(dataset_id, document_id, user_id=user_id)
if not success:
raise HTTPException(status_code=404, detail="文件不存在")
return {"success": True}
# ============== 切片端点 ==============
@router.get("/datasets/{dataset_id}/chunks", response_model=ChunkListResponse)
async def list_chunks(
dataset_id: str,
document_id: Optional[str] = Query(None, description="文档 ID可选"),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""获取数据集内切片列表(分页,仅限自己的数据集)"""
user_id, username = user_info
return await kb_service.list_chunks(
user_id=user_id,
dataset_id=dataset_id,
document_id=document_id,
page=page,
page_size=page_size
)
@router.delete("/datasets/{dataset_id}/chunks/{chunk_id}")
async def delete_chunk(
dataset_id: str,
chunk_id: str,
document_id: str = Query(..., description="文档 ID"),
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""删除切片"""
user_id, username = user_info
success = await kb_service.delete_chunk(dataset_id, document_id, chunk_id, user_id=user_id)
if not success:
raise HTTPException(status_code=404, detail="切片不存在")
return {"success": True}
# ============== 文档解析端点 ==============
@router.post("/datasets/{dataset_id}/documents/{document_id}/parse")
async def parse_document(
dataset_id: str,
document_id: str,
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""开始解析文档"""
try:
user_id, username = user_info
result = await kb_service.parse_document(dataset_id, document_id, user_id=user_id)
return result
except Exception as e:
logger.error(f"Failed to parse document: {e}")
raise HTTPException(status_code=500, detail=f"启动解析失败: {str(e)}")
@router.post("/datasets/{dataset_id}/documents/{document_id}/cancel-parse")
async def cancel_parse_document(
dataset_id: str,
document_id: str,
user_info: tuple = Depends(verify_user),
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""取消解析文档"""
try:
user_id, username = user_info
result = await kb_service.cancel_parse_document(dataset_id, document_id, user_id=user_id)
return result
except Exception as e:
logger.error(f"Failed to cancel parse: {e}")
raise HTTPException(status_code=500, detail=f"取消解析失败: {str(e)}")
# ============== Bot 数据集关联端点 ==============
@router.get("/bots/{bot_id}/datasets")
async def get_bot_datasets(
bot_id: str,
kb_service: KnowledgeBaseService = Depends(get_kb_service)
):
"""
获取 bot 关联的数据集 ID 列表
用于 MCP 服务器通过 bot_id 获取对应的数据集 IDs
"""
try:
dataset_ids = await kb_service.get_dataset_ids_by_bot(bot_id)
return {"dataset_ids": dataset_ids}
except Exception as e:
logger.error(f"Failed to get datasets for bot {bot_id}: {e}")
raise HTTPException(status_code=500, detail=f"获取数据集失败: {str(e)}")

6
services/__init__.py Normal file
View File

@ -0,0 +1,6 @@
"""
Services package for business logic layer
"""
from .knowledge_base_service import KnowledgeBaseService
__all__ = ['KnowledgeBaseService']

View File

@ -0,0 +1,627 @@
"""
Knowledge Base Service - 业务逻辑层
提供知识库管理的业务逻辑协调数据访问和业务规则
"""
import logging
import mimetypes
import os
from typing import Optional, List, Dict, Any
from pathlib import Path
from agent.db_pool_manager import get_db_pool_manager
from repositories.ragflow_repository import RAGFlowRepository
from utils.settings import (
RAGFLOW_MAX_UPLOAD_SIZE,
RAGFLOW_ALLOWED_EXTENSIONS,
)
logger = logging.getLogger('app')
class FileValidationError(Exception):
"""文件验证错误"""
pass
class KnowledgeBaseService:
"""
知识库服务类
提供知识库管理的业务逻辑
- 数据集 CRUD
- 文件上传和管理
- 文件验证
"""
def __init__(self, repository: RAGFlowRepository):
"""
初始化服务
Args:
repository: RAGFlow 数据仓储实例
"""
self.repository = repository
def _validate_file(self, filename: str, content: bytes) -> None:
"""
验证文件
Args:
filename: 文件名
content: 文件内容
Raises:
FileValidationError: 验证失败时抛出
"""
# 检查文件名
if not filename or filename == "unknown":
raise FileValidationError("无效的文件名")
# 检查路径遍历
if '..' in filename or '/' in filename or '\\' in filename:
raise FileValidationError("文件名包含非法字符")
# 检查文件扩展名(去掉点号进行比较)
ext = Path(filename).suffix.lower().lstrip('.')
if ext not in RAGFLOW_ALLOWED_EXTENSIONS:
allowed = ', '.join(RAGFLOW_ALLOWED_EXTENSIONS)
raise FileValidationError(f"不支持的文件类型: {ext}。支持的类型: {allowed}")
# 检查文件大小
file_size = len(content)
if file_size > RAGFLOW_MAX_UPLOAD_SIZE:
size_mb = file_size / (1024 * 1024)
max_mb = RAGFLOW_MAX_UPLOAD_SIZE / (1024 * 1024)
raise FileValidationError(f"文件过大: {size_mb:.1f}MB (最大 {max_mb}MB)")
# 验证 MIME 类型(使用 mimetypes 标准库)
detected_mime, _ = mimetypes.guess_type(filename)
logger.info(f"File {filename} detected as {detected_mime}")
# ============== 数据集管理 ==============
async def _check_dataset_access(self, dataset_id: str, user_id: str) -> bool:
"""
检查用户是否有权访问该数据集
Args:
dataset_id: 数据集 ID
user_id: 用户 ID
Returns:
是否有权限
"""
pool = get_db_pool_manager().pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute("""
SELECT id FROM user_datasets
WHERE user_id = %s AND dataset_id = %s
""", (user_id, dataset_id))
return await cursor.fetchone() is not None
async def list_datasets(
self,
user_id: str,
page: int = 1,
page_size: int = 20,
search: str = None
) -> Dict[str, Any]:
"""
获取用户的数据集列表从本地数据库过滤
Args:
user_id: 用户 ID
page: 页码
page_size: 每页数量
search: 搜索关键词
Returns:
数据集列表和分页信息
"""
logger.info(f"Listing datasets for user {user_id}: page={page}, page_size={page_size}, search={search}")
pool = get_db_pool_manager().pool
# 从本地数据库获取用户的数据集 ID 列表
async with pool.connection() as conn:
async with conn.cursor() as cursor:
# 构建查询条件
where_conditions = ["user_id = %s"]
params = [user_id]
if search:
where_conditions.append("dataset_name ILIKE %s")
params.append(f"%{search}%")
where_clause = " AND ".join(where_conditions)
# 获取总数
await cursor.execute(f"""
SELECT COUNT(*) FROM user_datasets
WHERE {where_clause}
""", params)
total = (await cursor.fetchone())[0]
# 获取分页数据
offset = (page - 1) * page_size
await cursor.execute(f"""
SELECT dataset_id, dataset_name, created_at
FROM user_datasets
WHERE {where_clause}
ORDER BY created_at DESC
LIMIT %s OFFSET %s
""", params + [page_size, offset])
user_datasets = await cursor.fetchall()
if not user_datasets:
return {
"items": [],
"total": 0,
"page": page,
"page_size": page_size
}
# 获取数据集 ID 列表,从 RAGFlow 获取详情
dataset_ids = [row[0] for row in user_datasets]
dataset_names = {row[0]: row[1] for row in user_datasets}
# 从 RAGFlow 获取完整的数据集信息
ragflow_result = await self.repository.list_datasets(
page=1,
page_size=1000 # 获取所有数据集,然后在本地过滤
)
# 过滤出属于该用户的数据集
user_dataset_ids_set = set(dataset_ids)
items = []
for item in ragflow_result["items"]:
if item.get("dataset_id") in user_dataset_ids_set:
items.append(item)
return {
"items": items,
"total": total,
"page": page,
"page_size": page_size
}
async def create_dataset(
self,
user_id: str,
name: str,
description: str = None,
chunk_method: str = "naive"
) -> Dict[str, Any]:
"""
创建数据集并关联到用户
Args:
user_id: 用户 ID
name: 数据集名称
description: 描述信息
chunk_method: 分块方法
Returns:
创建的数据集信息
"""
logger.info(f"Creating dataset for user {user_id}: name={name}, chunk_method={chunk_method}")
# 验证分块方法
valid_methods = [
"naive", "manual", "qa", "table", "paper",
"book", "laws", "presentation", "picture", "one", "email", "knowledge-graph"
]
if chunk_method not in valid_methods:
raise ValueError(f"无效的分块方法: {chunk_method}。支持的方法: {', '.join(valid_methods)}")
# 先在 RAGFlow 创建数据集
result = await self.repository.create_dataset(
name=name,
description=description,
chunk_method=chunk_method,
permission="me"
)
# 记录到本地数据库
dataset_id = result.get("dataset_id")
if dataset_id:
pool = get_db_pool_manager().pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute("""
INSERT INTO user_datasets (user_id, dataset_id, dataset_name, owner)
VALUES (%s, %s, %s, TRUE)
""", (user_id, dataset_id, name))
await conn.commit()
logger.info(f"Dataset {dataset_id} associated with user {user_id}")
return result
async def get_dataset(self, dataset_id: str, user_id: str = None) -> Optional[Dict[str, Any]]:
"""
获取数据集详情
Args:
dataset_id: 数据集 ID
user_id: 用户 ID可选用于权限验证
Returns:
数据集详情不存在或无权限返回 None
"""
logger.info(f"Getting dataset: {dataset_id} for user: {user_id}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
return None
return await self.repository.get_dataset(dataset_id)
async def update_dataset(
self,
dataset_id: str,
updates: Dict[str, Any],
user_id: str = None
) -> Optional[Dict[str, Any]]:
"""
更新数据集
Args:
dataset_id: 数据集 ID
updates: 要更新的字段
user_id: 用户 ID可选用于权限验证
Returns:
更新后的数据集信息
"""
logger.info(f"Updating dataset {dataset_id}: {updates}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
return None
result = await self.repository.update_dataset(dataset_id, **updates)
# 如果更新了名称,同步更新本地数据库
if result and user_id and 'name' in updates:
pool = get_db_pool_manager().pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute("""
UPDATE user_datasets
SET dataset_name = %s
WHERE user_id = %s AND dataset_id = %s
""", (updates['name'], user_id, dataset_id))
await conn.commit()
return result
async def delete_dataset(self, dataset_id: str, user_id: str = None) -> bool:
"""
删除数据集
Args:
dataset_id: 数据集 ID
user_id: 用户 ID可选用于权限验证
Returns:
是否成功
"""
logger.info(f"Deleting dataset: {dataset_id}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
return False
# 从 RAGFlow 删除
result = await self.repository.delete_datasets([dataset_id])
# 从本地数据库删除关联记录
if result and user_id:
pool = get_db_pool_manager().pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute("""
DELETE FROM user_datasets
WHERE user_id = %s AND dataset_id = %s
""", (user_id, dataset_id))
await conn.commit()
logger.info(f"Dataset {dataset_id} unlinked from user {user_id}")
return result
# ============== 文件管理 ==============
async def list_files(
self,
dataset_id: str,
user_id: str = None,
page: int = 1,
page_size: int = 20
) -> Dict[str, Any]:
"""
获取数据集中的文件列表
Args:
dataset_id: 数据集 ID
user_id: 用户 ID可选用于权限验证
page: 页码
page_size: 每页数量
Returns:
文件列表和分页信息
"""
logger.info(f"Listing files for dataset {dataset_id}: page={page}, page_size={page_size}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
raise ValueError("Dataset not found or does not belong to you")
return await self.repository.list_documents(
dataset_id=dataset_id,
page=page,
page_size=page_size
)
async def upload_file(
self,
dataset_id: str,
user_id: str = None,
file=None,
chunk_size: int = 1024 * 1024 # 1MB chunks
) -> Dict[str, Any]:
"""
上传文件到数据集流式处理
Args:
dataset_id: 数据集 ID
user_id: 用户 ID可选用于权限验证
file: FastAPI UploadFile 对象
chunk_size: 分块大小
Returns:
上传的文档信息
"""
filename = file.filename or "unknown"
logger.info(f"Uploading file {filename} to dataset {dataset_id}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
raise ValueError("Dataset not found or does not belong to you")
# 流式读取文件内容
content = await file.read()
# 验证文件
try:
self._validate_file(filename, content)
except FileValidationError as e:
logger.warning(f"File validation failed: {e}")
raise
# 上传到 RAGFlow
result = await self.repository.upload_document(
dataset_id=dataset_id,
file_name=filename,
file_content=content,
display_name=filename
)
logger.info(f"File {filename} uploaded successfully")
return result
async def delete_file(self, dataset_id: str, document_id: str, user_id: str = None) -> bool:
"""
删除文件
Args:
dataset_id: 数据集 ID
document_id: 文档 ID
user_id: 用户 ID可选用于权限验证
Returns:
是否成功
"""
logger.info(f"Deleting file {document_id} from dataset {dataset_id}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
return False
return await self.repository.delete_document(dataset_id, document_id)
# ============== 切片管理 ==============
async def list_chunks(
self,
dataset_id: str,
user_id: str = None,
document_id: str = None,
page: int = 1,
page_size: int = 50
) -> Dict[str, Any]:
"""
获取切片列表
Args:
dataset_id: 数据集 ID
user_id: 用户 ID可选用于权限验证
document_id: 文档 ID可选
page: 页码
page_size: 每页数量
Returns:
切片列表和分页信息
"""
logger.info(f"Listing chunks for dataset {dataset_id}, document {document_id}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
raise ValueError("Dataset not found or does not belong to you")
return await self.repository.list_chunks(
dataset_id=dataset_id,
document_id=document_id,
page=page,
page_size=page_size
)
async def delete_chunk(
self,
dataset_id: str,
document_id: str,
chunk_id: str,
user_id: str = None
) -> bool:
"""
删除切片
Args:
dataset_id: 数据集 ID
document_id: 文档 ID
chunk_id: 切片 ID
user_id: 用户 ID可选用于权限验证
Returns:
是否成功
"""
logger.info(f"Deleting chunk {chunk_id} from document {document_id}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
return False
return await self.repository.delete_chunk(dataset_id, document_id, chunk_id)
async def parse_document(
self,
dataset_id: str,
document_id: str,
user_id: str = None
) -> dict:
"""
开始解析文档
Args:
dataset_id: 数据集 ID
document_id: 文档 ID
user_id: 用户 ID可选用于权限验证
Returns:
操作结果
"""
logger.info(f"Parsing document {document_id} in dataset {dataset_id}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
raise ValueError("Dataset not found or does not belong to you")
success = await self.repository.parse_document(dataset_id, document_id)
return {"success": success, "message": "解析任务已启动"}
async def cancel_parse_document(
self,
dataset_id: str,
document_id: str,
user_id: str = None
) -> dict:
"""
取消解析文档
Args:
dataset_id: 数据集 ID
document_id: 文档 ID
user_id: 用户 ID可选用于权限验证
Returns:
操作结果
"""
logger.info(f"Cancelling parse for document {document_id} in dataset {dataset_id}")
# 如果提供了 user_id先检查权限
if user_id:
has_access = await self._check_dataset_access(dataset_id, user_id)
if not has_access:
logger.warning(f"User {user_id} has no access to dataset {dataset_id}")
raise ValueError("Dataset not found or does not belong to you")
success = await self.repository.cancel_parse_document(dataset_id, document_id)
return {"success": success, "message": "解析任务已取消"}
# ============== Bot 数据集关联管理 ==============
async def get_dataset_ids_by_bot(self, bot_id: str) -> list[str]:
"""
根据 bot_id 获取关联的数据集 ID 列表
Args:
bot_id: Bot ID (agent_bots 表中的 bot_id 字段)
Returns:
数据集 ID 列表
"""
logger.info(f"Getting dataset_ids for bot_id: {bot_id}")
pool = get_db_pool_manager().pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
# 查询 bot 的 settings 字段中的 dataset_ids
await cursor.execute("""
SELECT settings
FROM agent_bots
WHERE bot_id = %s
""", (bot_id,))
row = await cursor.fetchone()
if not row:
logger.warning(f"Bot not found: {bot_id}")
return []
settings = row[0]
# dataset_ids 在 settings 中存储为逗号分隔的字符串
dataset_ids_str = settings.get('dataset_ids') if settings else None
if not dataset_ids_str:
return []
# 如果是字符串,按逗号分割
if isinstance(dataset_ids_str, str):
dataset_ids = [ds_id.strip() for ds_id in dataset_ids_str.split(',') if ds_id.strip()]
elif isinstance(dataset_ids_str, list):
dataset_ids = dataset_ids_str
else:
dataset_ids = []
logger.info(f"Found {len(dataset_ids)} datasets for bot {bot_id}")
return dataset_ids

85
test_knowledge_base.sh Executable file
View File

@ -0,0 +1,85 @@
#!/bin/bash
# 知识库 API 测试脚本
API_BASE="http://localhost:8001"
TOKEN="a21c99620a8ef61d69563afe05ccce89"
DATASET_ID="3c3c671205c911f1a37efedd444ada7f"
echo "=========================================="
echo "知识库 API 测试"
echo "=========================================="
# 1. 获取数据集列表
echo ""
echo "1. 获取数据集列表"
echo "GET /api/v1/knowledge-base/datasets"
curl --silent --request GET \
"$API_BASE/api/v1/knowledge-base/datasets" \
--header "authorization: Bearer $TOKEN" \
--header 'content-type: application/json' | python3 -m json.tool
echo ""
# 2. 获取数据集详情
echo "2. 获取数据集详情"
echo "GET /api/v1/knowledge-base/datasets/{dataset_id}"
curl --silent --request GET \
"$API_BASE/api/v1/knowledge-base/datasets/$DATASET_ID" \
--header "authorization: Bearer $TOKEN" | python3 -m json.tool
echo ""
# 3. 创建数据集
echo "3. 创建数据集"
echo "POST /api/v1/knowledge-base/datasets"
curl --silent --request POST \
"$API_BASE/api/v1/knowledge-base/datasets" \
--header "authorization: Bearer $TOKEN" \
--header 'content-type: application/json' \
--data '{
"name": "API测试知识库",
"description": "通过API创建的测试知识库",
"chunk_method": "naive"
}' | python3 -m json.tool
echo ""
# 4. 获取文件列表
echo "4. 获取文件列表"
echo "GET /api/v1/knowledge-base/datasets/{dataset_id}/files"
curl --silent --request GET \
"$API_BASE/api/v1/knowledge-base/datasets/$DATASET_ID/files" \
--header "authorization: Bearer $TOKEN" | python3 -m json.tool
echo ""
# 5. 上传文件
echo "5. 上传文件"
echo "POST /api/v1/knowledge-base/datasets/{dataset_id}/files"
echo "测试文档内容,用于文件上传测试。" > /tmp/test_doc.txt
curl --silent --request POST \
"$API_BASE/api/v1/knowledge-base/datasets/$DATASET_ID/files" \
--header "authorization: Bearer $TOKEN" \
-F "file=@/tmp/test_doc.txt" | python3 -m json.tool
echo ""
# 6. 获取切片列表
echo "6. 获取切片列表"
echo "GET /api/v1/knowledge-base/datasets/{dataset_id}/chunks"
curl --silent --request GET \
"$API_BASE/api/v1/knowledge-base/datasets/$DATASET_ID/chunks" \
--header "authorization: Bearer $TOKEN" | python3 -m json.tool
echo ""
# 7. 更新数据集
echo "7. 更新数据集"
echo "PATCH /api/v1/knowledge-base/datasets/{dataset_id}"
curl --silent --request PATCH \
"$API_BASE/api/v1/knowledge-base/datasets/$DATASET_ID" \
--header "authorization: Bearer $TOKEN" \
--header 'content-type: application/json' \
--data '{
"name": "更新后的知识库名称",
"description": "更新后的描述"
}' | python3 -m json.tool
echo ""
echo "=========================================="
echo "测试完成"
echo "=========================================="

View File

@ -49,8 +49,8 @@ MCP_SSE_READ_TIMEOUT = int(os.getenv("MCP_SSE_READ_TIMEOUT", 300)) # SSE 读取
# PostgreSQL 连接字符串
# 格式: postgresql://user:password@host:port/database
#CHECKPOINT_DB_URL = os.getenv("CHECKPOINT_DB_URL", "postgresql://postgres:AeEGDB0b7Z5GK0E2tblt@dev-circleo-pg.celp3nik7oaq.ap-northeast-1.rds.amazonaws.com:5432/gptbase")
CHECKPOINT_DB_URL = os.getenv("CHECKPOINT_DB_URL", "postgresql://postgres:E5ACJo6zJub4QS@192.168.102.5:5432/agent_db")
CHECKPOINT_DB_URL = os.getenv("CHECKPOINT_DB_URL", "postgresql://postgres:AeEGDB0b7Z5GK0E2tblt@dev-circleo-pg.celp3nik7oaq.ap-northeast-1.rds.amazonaws.com:5432/gptbase")
#CHECKPOINT_DB_URL = os.getenv("CHECKPOINT_DB_URL", "postgresql://postgres:E5ACJo6zJub4QS@192.168.102.5:5432/agent_db")
# 连接池大小
# 同时可以持有的最大连接数
@ -81,3 +81,19 @@ MEM0_ENABLED = os.getenv("MEM0_ENABLED", "true") == "true"
MEM0_SEMANTIC_SEARCH_TOP_K = int(os.getenv("MEM0_SEMANTIC_SEARCH_TOP_K", "20"))
os.environ["OPENAI_API_KEY"] = "your_api_key"
# ============================================================
# RAGFlow Knowledge Base Configuration
# ============================================================
# RAGFlow API 配置
RAGFLOW_API_URL = os.getenv("RAGFLOW_API_URL", "http://100.77.70.35:1080")
RAGFLOW_API_KEY = os.getenv("RAGFLOW_API_KEY", "ragflow-MRqxnDnYZ1yp5kklDMIlKH4f1qezvXIngSMGPhu1AG8")
# 文件上传配置
RAGFLOW_MAX_UPLOAD_SIZE = int(os.getenv("RAGFLOW_MAX_UPLOAD_SIZE", str(100 * 1024 * 1024))) # 100MB
RAGFLOW_ALLOWED_EXTENSIONS = os.getenv("RAGFLOW_ALLOWED_EXTENSIONS", "pdf,docx,txt,md,csv").split(",")
# 性能配置
RAGFLOW_CONNECTION_TIMEOUT = int(os.getenv("RAGFLOW_CONNECTION_TIMEOUT", "30")) # 30秒
RAGFLOW_MAX_CONCURRENT_UPLOADS = int(os.getenv("RAGFLOW_MAX_CONCURRENT_UPLOADS", "5"))