execute sql

This commit is contained in:
朱潮 2026-01-23 17:44:34 +08:00
parent 44b4295a87
commit f8a44e8d6d
2 changed files with 280 additions and 1 deletions

View File

@ -71,7 +71,7 @@ from utils.log_util.logger import init_with_fastapi
logger = logging.getLogger('app') logger = logging.getLogger('app')
# Import route modules # Import route modules
from routes import chat, files, projects, system, skill_manager from routes import chat, files, projects, system, skill_manager, database
@asynccontextmanager @asynccontextmanager
@ -174,6 +174,7 @@ app.include_router(files.router)
app.include_router(projects.router) app.include_router(projects.router)
app.include_router(system.router) app.include_router(system.router)
app.include_router(skill_manager.router) app.include_router(skill_manager.router)
app.include_router(database.router)
# 注册文件管理API路由 # 注册文件管理API路由
app.include_router(file_manager_router) app.include_router(file_manager_router)

278
routes/database.py Normal file
View File

@ -0,0 +1,278 @@
"""
数据库操作 API 路由
提供数据库迁移表结构变更等功能
"""
import logging
from typing import Optional
from fastapi import APIRouter, HTTPException, Header
from pydantic import BaseModel
from agent.db_pool_manager import get_db_pool_manager
from utils.settings import MASTERKEY
from utils.fastapi_utils import extract_api_key_from_auth
logger = logging.getLogger('app')
router = APIRouter()
def verify_database_auth(authorization: Optional[str]) -> None:
"""
验证数据库操作 API 的认证
直接使用 MASTERKEY 进行验证
Args:
authorization: Authorization header
Raises:
HTTPException: 认证失败时抛出 401/403 错误
"""
# 提取提供的 token
provided_token = extract_api_key_from_auth(authorization)
if not provided_token:
raise HTTPException(
status_code=401,
detail="Authorization header is required"
)
if provided_token != MASTERKEY:
raise HTTPException(
status_code=403,
detail="Invalid authorization token"
)
class DatabaseMigrationResponse(BaseModel):
"""数据库迁移响应"""
success: bool
message: str
steps_completed: list[str]
error: Optional[str] = None
class ExecuteSQLRequest(BaseModel):
"""执行 SQL 请求"""
sql: str
autocommit: bool = True
class ExecuteSQLResponse(BaseModel):
"""执行 SQL 响应"""
success: bool
rows_affected: Optional[int] = None
message: str
error: Optional[str] = None
@router.post("/api/v1/database/migrate-pgvector", response_model=DatabaseMigrationResponse)
async def migrate_pgvector(authorization: Optional[str] = Header(None)):
"""
执行 pgvector 扩展安装迁移
执行步骤
1. public.vector 表重命名为 public.vector_legacy
2. 创建 pgvector 扩展 (CREATE EXTENSION vector)
注意此操作会修改数据库结构请确保在执行前已做好备份
Authentication:
- Authorization header should contain: Bearer {MASTERKEY}
Returns:
DatabaseMigrationResponse: 迁移结果
"""
# 验证认证
verify_database_auth(authorization)
steps_completed = []
pool_manager = get_db_pool_manager()
try:
# 获取异步连接
pool = pool_manager.pool
# 步骤 1: 重命名 vector 表为 vector_legacy
async with pool.connection() as conn:
async with conn.cursor() as cursor:
# 检查 vector 表是否存在
await cursor.execute("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = 'vector'
)
""")
vector_exists = (await cursor.fetchone())[0]
if vector_exists:
# 检查 vector_legacy 是否已存在
await cursor.execute("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = 'vector_legacy'
)
""")
legacy_exists = (await cursor.fetchone())[0]
if legacy_exists:
steps_completed.append("vector_legacy 表已存在,跳过重命名")
else:
# 执行重命名
await cursor.execute("ALTER TABLE public.vector RENAME TO vector_legacy")
steps_completed.append("已将 public.vector 表重命名为 public.vector_legacy")
logger.info("Renamed public.vector to public.vector_legacy")
else:
steps_completed.append("public.vector 表不存在,跳过重命名")
# 提交事务
await conn.commit()
# 步骤 2: 创建 pgvector 扩展
async with pool.connection() as conn:
async with conn.cursor() as cursor:
# 检查 pgvector 扩展是否已安装
await cursor.execute("""
SELECT EXISTS (
SELECT 1 FROM pg_extension WHERE extname = 'vector'
)
""")
extension_exists = (await cursor.fetchone())[0]
if extension_exists:
steps_completed.append("pgvector 扩展已安装")
else:
# 创建 pgvector 扩展
await cursor.execute("CREATE EXTENSION vector")
steps_completed.append("已成功安装 pgvector 扩展")
logger.info("Created pgvector extension")
# 提交事务
await conn.commit()
return DatabaseMigrationResponse(
success=True,
message="pgvector 迁移完成",
steps_completed=steps_completed
)
except HTTPException:
raise
except Exception as e:
logger.error(f"pgvector 迁移失败: {e}")
return DatabaseMigrationResponse(
success=False,
message="pgvector 迁移失败",
steps_completed=steps_completed,
error=str(e)
)
@router.post("/api/v1/database/execute-sql", response_model=ExecuteSQLResponse)
async def execute_sql(request: ExecuteSQLRequest, authorization: Optional[str] = Header(None)):
"""
执行自定义 SQL 语句
注意此接口具有较高权限请谨慎使用
Authentication:
- Authorization header should contain: Bearer {MASTERKEY}
Args:
request: 包含 SQL 语句和是否自动提交的请求
Returns:
ExecuteSQLResponse: 执行结果
"""
# 验证认证
verify_database_auth(authorization)
pool_manager = get_db_pool_manager()
try:
pool = pool_manager.pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(request.sql)
rows_affected = cursor.rowcount
if request.autocommit:
await conn.commit()
return ExecuteSQLResponse(
success=True,
rows_affected=rows_affected,
message=f"SQL 执行成功,影响行数: {rows_affected}"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"SQL 执行失败: {e}")
raise HTTPException(
status_code=500,
detail=f"SQL 执行失败: {str(e)}"
)
@router.get("/api/v1/database/check-pgvector")
async def check_pgvector(authorization: Optional[str] = Header(None)):
"""
检查 pgvector 扩展安装状态
Authentication:
- Authorization header should contain: Bearer {MASTERKEY}
Returns:
pgvector 扩展的状态信息
"""
# 验证认证
verify_database_auth(authorization)
pool_manager = get_db_pool_manager()
try:
pool = pool_manager.pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
# 检查扩展是否存在
await cursor.execute("""
SELECT
extname,
extversion
FROM pg_extension
WHERE extname = 'vector'
""")
extension_result = await cursor.fetchone()
# 检查 vector 相关表
await cursor.execute("""
SELECT
table_name,
table_type
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name LIKE '%vector%'
ORDER BY table_name
""")
tables = await cursor.fetchall()
return {
"extension_installed": extension_result is not None,
"extension_version": extension_result[1] if extension_result else None,
"vector_tables": [
{"name": row[0], "type": row[1]}
for row in tables
]
}
except HTTPException:
raise
except Exception as e:
logger.error(f"检查 pgvector 状态失败: {e}")
raise HTTPException(
status_code=500,
detail=f"检查失败: {str(e)}"
)