From f8a44e8d6d0b40c82fc92cdd11cb482287019b92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Fri, 23 Jan 2026 17:44:34 +0800 Subject: [PATCH] execute sql --- fastapi_app.py | 3 +- routes/database.py | 278 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 280 insertions(+), 1 deletion(-) create mode 100644 routes/database.py diff --git a/fastapi_app.py b/fastapi_app.py index aadea2d..5e88bb7 100644 --- a/fastapi_app.py +++ b/fastapi_app.py @@ -71,7 +71,7 @@ from utils.log_util.logger import init_with_fastapi logger = logging.getLogger('app') # Import route modules -from routes import chat, files, projects, system, skill_manager +from routes import chat, files, projects, system, skill_manager, database @asynccontextmanager @@ -174,6 +174,7 @@ app.include_router(files.router) app.include_router(projects.router) app.include_router(system.router) app.include_router(skill_manager.router) +app.include_router(database.router) # 注册文件管理API路由 app.include_router(file_manager_router) diff --git a/routes/database.py b/routes/database.py new file mode 100644 index 0000000..ab8eeda --- /dev/null +++ b/routes/database.py @@ -0,0 +1,278 @@ +""" +数据库操作 API 路由 +提供数据库迁移、表结构变更等功能 +""" +import logging +from typing import Optional +from fastapi import APIRouter, HTTPException, Header +from pydantic import BaseModel + +from agent.db_pool_manager import get_db_pool_manager +from utils.settings import MASTERKEY +from utils.fastapi_utils import extract_api_key_from_auth + +logger = logging.getLogger('app') + +router = APIRouter() + + +def verify_database_auth(authorization: Optional[str]) -> None: + """ + 验证数据库操作 API 的认证 + + 直接使用 MASTERKEY 进行验证 + + Args: + authorization: Authorization header 值 + + Raises: + HTTPException: 认证失败时抛出 401/403 错误 + """ + # 提取提供的 token + provided_token = extract_api_key_from_auth(authorization) + + if not provided_token: + raise HTTPException( + status_code=401, + detail="Authorization header is required" + ) + + if provided_token != MASTERKEY: + raise HTTPException( + status_code=403, + detail="Invalid authorization token" + ) + + +class DatabaseMigrationResponse(BaseModel): + """数据库迁移响应""" + success: bool + message: str + steps_completed: list[str] + error: Optional[str] = None + + +class ExecuteSQLRequest(BaseModel): + """执行 SQL 请求""" + sql: str + autocommit: bool = True + + +class ExecuteSQLResponse(BaseModel): + """执行 SQL 响应""" + success: bool + rows_affected: Optional[int] = None + message: str + error: Optional[str] = None + + +@router.post("/api/v1/database/migrate-pgvector", response_model=DatabaseMigrationResponse) +async def migrate_pgvector(authorization: Optional[str] = Header(None)): + """ + 执行 pgvector 扩展安装迁移 + + 执行步骤: + 1. 将 public.vector 表重命名为 public.vector_legacy + 2. 创建 pgvector 扩展 (CREATE EXTENSION vector) + + 注意:此操作会修改数据库结构,请确保在执行前已做好备份。 + + Authentication: + - Authorization header should contain: Bearer {MASTERKEY} + + Returns: + DatabaseMigrationResponse: 迁移结果 + """ + # 验证认证 + verify_database_auth(authorization) + + steps_completed = [] + pool_manager = get_db_pool_manager() + + try: + # 获取异步连接 + pool = pool_manager.pool + + # 步骤 1: 重命名 vector 表为 vector_legacy + async with pool.connection() as conn: + async with conn.cursor() as cursor: + # 检查 vector 表是否存在 + await cursor.execute(""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_schema = 'public' AND table_name = 'vector' + ) + """) + vector_exists = (await cursor.fetchone())[0] + + if vector_exists: + # 检查 vector_legacy 是否已存在 + await cursor.execute(""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_schema = 'public' AND table_name = 'vector_legacy' + ) + """) + legacy_exists = (await cursor.fetchone())[0] + + if legacy_exists: + steps_completed.append("vector_legacy 表已存在,跳过重命名") + else: + # 执行重命名 + await cursor.execute("ALTER TABLE public.vector RENAME TO vector_legacy") + steps_completed.append("已将 public.vector 表重命名为 public.vector_legacy") + logger.info("Renamed public.vector to public.vector_legacy") + else: + steps_completed.append("public.vector 表不存在,跳过重命名") + + # 提交事务 + await conn.commit() + + # 步骤 2: 创建 pgvector 扩展 + async with pool.connection() as conn: + async with conn.cursor() as cursor: + # 检查 pgvector 扩展是否已安装 + await cursor.execute(""" + SELECT EXISTS ( + SELECT 1 FROM pg_extension WHERE extname = 'vector' + ) + """) + extension_exists = (await cursor.fetchone())[0] + + if extension_exists: + steps_completed.append("pgvector 扩展已安装") + else: + # 创建 pgvector 扩展 + await cursor.execute("CREATE EXTENSION vector") + steps_completed.append("已成功安装 pgvector 扩展") + logger.info("Created pgvector extension") + + # 提交事务 + await conn.commit() + + return DatabaseMigrationResponse( + success=True, + message="pgvector 迁移完成", + steps_completed=steps_completed + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"pgvector 迁移失败: {e}") + return DatabaseMigrationResponse( + success=False, + message="pgvector 迁移失败", + steps_completed=steps_completed, + error=str(e) + ) + + +@router.post("/api/v1/database/execute-sql", response_model=ExecuteSQLResponse) +async def execute_sql(request: ExecuteSQLRequest, authorization: Optional[str] = Header(None)): + """ + 执行自定义 SQL 语句 + + 注意:此接口具有较高权限,请谨慎使用。 + + Authentication: + - Authorization header should contain: Bearer {MASTERKEY} + + Args: + request: 包含 SQL 语句和是否自动提交的请求 + + Returns: + ExecuteSQLResponse: 执行结果 + """ + # 验证认证 + verify_database_auth(authorization) + + pool_manager = get_db_pool_manager() + + try: + pool = pool_manager.pool + + async with pool.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute(request.sql) + rows_affected = cursor.rowcount + + if request.autocommit: + await conn.commit() + + return ExecuteSQLResponse( + success=True, + rows_affected=rows_affected, + message=f"SQL 执行成功,影响行数: {rows_affected}" + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"SQL 执行失败: {e}") + raise HTTPException( + status_code=500, + detail=f"SQL 执行失败: {str(e)}" + ) + + +@router.get("/api/v1/database/check-pgvector") +async def check_pgvector(authorization: Optional[str] = Header(None)): + """ + 检查 pgvector 扩展安装状态 + + Authentication: + - Authorization header should contain: Bearer {MASTERKEY} + + Returns: + pgvector 扩展的状态信息 + """ + # 验证认证 + verify_database_auth(authorization) + + pool_manager = get_db_pool_manager() + + try: + pool = pool_manager.pool + + async with pool.connection() as conn: + async with conn.cursor() as cursor: + # 检查扩展是否存在 + await cursor.execute(""" + SELECT + extname, + extversion + FROM pg_extension + WHERE extname = 'vector' + """) + extension_result = await cursor.fetchone() + + # 检查 vector 相关表 + await cursor.execute(""" + SELECT + table_name, + table_type + FROM information_schema.tables + WHERE table_schema = 'public' + AND table_name LIKE '%vector%' + ORDER BY table_name + """) + tables = await cursor.fetchall() + + return { + "extension_installed": extension_result is not None, + "extension_version": extension_result[1] if extension_result else None, + "vector_tables": [ + {"name": row[0], "type": row[1]} + for row in tables + ] + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"检查 pgvector 状态失败: {e}") + raise HTTPException( + status_code=500, + detail=f"检查失败: {str(e)}" + )