This commit is contained in:
朱潮 2026-02-12 00:41:03 +08:00
parent 4a8fffaf7d
commit 2fbf249a8d

View File

@ -27,61 +27,6 @@ TOKEN_EXPIRE_HOURS = 24
# ============== 认证函数 ============== # ============== 认证函数 ==============
async def verify_admin_auth(authorization: Optional[str]) -> tuple[bool, Optional[str], Optional[str]]:
"""
验证管理员认证
Args:
authorization: Authorization header
Returns:
tuple[bool, Optional[str], Optional[str]]: (是否有效, 用户名, 用户ID)
"""
provided_token = extract_api_key_from_auth(authorization)
if not provided_token:
return False, None, None
pool = get_db_pool_manager().pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
# 先检查 admin token 表
await cursor.execute("""
SELECT username, expires_at
FROM agent_admin_tokens
WHERE token = %s
AND expires_at > NOW()
""", (provided_token,))
row = await cursor.fetchone()
if row:
# admin token 有效,返回 admin 用户信息
username = row[0]
# 获取 admin 用户在 agent_user 表中的 ID
await cursor.execute("""
SELECT id FROM agent_user WHERE username = %s
""", (username,))
user_row = await cursor.fetchone()
user_id = str(user_row[0]) if user_row else None
return True, username, user_id
# 如果 admin token 无效,再检查普通用户 token
await cursor.execute("""
SELECT u.id, u.username, u.is_admin, t.expires_at
FROM agent_user_tokens t
JOIN agent_user u ON t.user_id = u.id
WHERE t.token = %s
AND t.expires_at > NOW()
AND u.is_active = TRUE
""", (provided_token,))
user_row = await cursor.fetchone()
if user_row:
return True, user_row[1], str(user_row[0])
return False, None, None
def verify_auth(authorization: Optional[str]) -> None: def verify_auth(authorization: Optional[str]) -> None:
""" """
验证请求认证 验证请求认证
@ -168,7 +113,7 @@ async def get_user_id_from_token(authorization: Optional[str]) -> Optional[str]:
async def is_admin_user(authorization: Optional[str]) -> bool: async def is_admin_user(authorization: Optional[str]) -> bool:
""" """
检查当前请求是否来自管理员admin token is_admin=True 的用户 检查当前请求是否来自管理员is_admin=True 的用户
Args: Args:
authorization: Authorization header authorization: Authorization header
@ -176,10 +121,6 @@ async def is_admin_user(authorization: Optional[str]) -> bool:
Returns: Returns:
bool: 是否是管理员 bool: 是否是管理员
""" """
admin_valid, _, admin_user_id = await verify_admin_auth(authorization)
if admin_valid:
return True
user_valid, user_id, _ = await verify_user_auth(authorization) user_valid, user_id, _ = await verify_user_auth(authorization)
if not user_valid or not user_id: if not user_valid or not user_id:
return False return False
@ -1240,21 +1181,27 @@ async def get_bots(authorization: Optional[str] = Header(None)):
Returns: Returns:
List[BotResponse]: Bot 列表 List[BotResponse]: Bot 列表
""" """
# 支持管理员认证和用户认证
admin_valid, admin_username, admin_user_id = await verify_admin_auth(authorization)
user_valid, user_id, user_username = await verify_user_auth(authorization) user_valid, user_id, user_username = await verify_user_auth(authorization)
if not admin_valid and not user_valid: if not user_valid:
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="Unauthorized" detail="Unauthorized"
) )
# 检查是否是管理员
pool = get_db_pool_manager().pool pool = get_db_pool_manager().pool
async with pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute("""
SELECT is_admin FROM agent_user WHERE id = %s
""", (user_id,))
row = await cursor.fetchone()
is_admin = row and row[0]
async with pool.connection() as conn: async with pool.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
if admin_valid: if is_admin:
# 管理员可以看到所有 Bot # 管理员可以看到所有 Bot
await cursor.execute(""" await cursor.execute("""
SELECT b.id, b.name, b.bot_id, b.created_at, b.updated_at, b.settings, SELECT b.id, b.name, b.bot_id, b.created_at, b.updated_at, b.settings,
@ -1283,15 +1230,18 @@ async def get_bots(authorization: Optional[str] = Header(None)):
] ]
else: else:
# 用户只能看到拥有的 Bot 和分享给自己的 Bot且未过期 # 用户只能看到拥有的 Bot 和分享给自己的 Bot且未过期
# 使用子查询确保正确过滤,避免 LEFT JOIN 导致的 NULL 值比较问题
await cursor.execute(""" await cursor.execute("""
SELECT b.id, b.name, b.bot_id, b.created_at, b.updated_at, b.settings, SELECT DISTINCT b.id, b.name, b.bot_id, b.created_at, b.updated_at, b.settings,
u.id as owner_id, u.username as owner_username, u.id as owner_id, u.username as owner_username,
s.role, s.shared_at, s.expires_at s.role, s.shared_at, s.expires_at
FROM agent_bots b FROM agent_bots b
LEFT JOIN agent_user u ON b.owner_id = u.id LEFT JOIN agent_user u ON b.owner_id = u.id
LEFT JOIN bot_shares s ON b.id = s.bot_id AND s.user_id = %s LEFT JOIN bot_shares s ON b.id = s.bot_id AND s.user_id = %s
WHERE b.owner_id = %s WHERE b.owner_id = %s
OR (s.user_id = %s AND (s.expires_at IS NULL OR s.expires_at > NOW())) OR (s.user_id IS NOT NULL
AND s.user_id = %s
AND (s.expires_at IS NULL OR s.expires_at > NOW()))
ORDER BY b.created_at DESC ORDER BY b.created_at DESC
""", (user_id, user_id, user_id)) """, (user_id, user_id, user_id))
rows = await cursor.fetchall() rows = await cursor.fetchall()
@ -1301,12 +1251,12 @@ async def get_bots(authorization: Optional[str] = Header(None)):
id=str(row[0]), # 使用 UUID 主键 id=str(row[0]), # 使用 UUID 主键
name=row[1], name=row[1],
bot_id=str(row[0]), # bot_id 也指向主键 id bot_id=str(row[0]), # bot_id 也指向主键 id
is_owner=(str(row[6]) == user_id if row[6] else False), is_owner=(row[6] is not None and str(row[6]) == user_id),
is_shared=(str(row[6]) != user_id and row[8] is not None) if row[6] else False, is_shared=(row[6] is not None and str(row[6]) != user_id and row[8] is not None),
owner={"id": str(row[6]), "username": row[7]} if row[6] else None, owner={"id": str(row[6]), "username": row[7]} if row[6] is not None else None,
role=row[8] if row[8] else None, role=row[8] if row[8] is not None else None,
shared_at=datetime_to_str(row[9]) if row[9] else None, shared_at=datetime_to_str(row[9]) if row[9] is not None else None,
expires_at=row[10].isoformat() if row[10] else None, expires_at=row[10].isoformat() if row[10] is not None else None,
description=row[5].get('description') if row[5] else None, description=row[5].get('description') if row[5] else None,
avatar_url=row[5].get('avatar_url') if row[5] else None, avatar_url=row[5].get('avatar_url') if row[5] else None,
created_at=datetime_to_str(row[3]), created_at=datetime_to_str(row[3]),
@ -1328,11 +1278,9 @@ async def create_bot(request: BotCreate, authorization: Optional[str] = Header(N
Returns: Returns:
BotResponse: 创建的 Bot 信息 BotResponse: 创建的 Bot 信息
""" """
# 支持管理员认证和用户认证
admin_valid, admin_username, admin_user_id = await verify_admin_auth(authorization)
user_valid, user_id, user_username = await verify_user_auth(authorization) user_valid, user_id, user_username = await verify_user_auth(authorization)
if not admin_valid and not user_valid: if not user_valid:
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="Unauthorized" detail="Unauthorized"
@ -1343,8 +1291,8 @@ async def create_bot(request: BotCreate, authorization: Optional[str] = Header(N
# 自动生成 bot_id # 自动生成 bot_id
bot_id = str(uuid.uuid4()) bot_id = str(uuid.uuid4())
# 使用用户 ID 或默认 admin ID # 使用用户 ID
owner_id = user_id if user_valid else None owner_id = user_id
try: try:
async with pool.connection() as conn: async with pool.connection() as conn:
@ -1390,18 +1338,19 @@ async def update_bot(
Returns: Returns:
BotResponse: 更新后的 Bot 信息 BotResponse: 更新后的 Bot 信息
""" """
# 支持管理员认证和用户认证
admin_valid, admin_username, admin_user_id = await verify_admin_auth(authorization)
user_valid, user_id, user_username = await verify_user_auth(authorization) user_valid, user_id, user_username = await verify_user_auth(authorization)
if not admin_valid and not user_valid: if not user_valid:
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="Unauthorized" detail="Unauthorized"
) )
# 检查是否是管理员
is_admin = await is_admin_user(authorization)
# 非管理员需要检查所有权 # 非管理员需要检查所有权
if user_valid: if not is_admin:
if not await is_bot_owner(bot_uuid, user_id): if not await is_bot_owner(bot_uuid, user_id):
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
@ -1465,18 +1414,19 @@ async def delete_bot(bot_uuid: str, authorization: Optional[str] = Header(None))
Returns: Returns:
SuccessResponse: 删除结果 SuccessResponse: 删除结果
""" """
# 支持管理员认证和用户认证
admin_valid, admin_username, admin_user_id = await verify_admin_auth(authorization)
user_valid, user_id, user_username = await verify_user_auth(authorization) user_valid, user_id, user_username = await verify_user_auth(authorization)
if not admin_valid and not user_valid: if not user_valid:
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="Unauthorized" detail="Unauthorized"
) )
# 检查是否是管理员
is_admin = await is_admin_user(authorization)
# 非管理员需要检查所有权 # 非管理员需要检查所有权
if user_valid: if not is_admin:
if not await is_bot_owner(bot_uuid, user_id): if not await is_bot_owner(bot_uuid, user_id):
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
@ -1512,21 +1462,19 @@ async def get_bot_settings(bot_uuid: str, authorization: Optional[str] = Header(
Returns: Returns:
BotSettingsResponse: Bot 设置信息 BotSettingsResponse: Bot 设置信息
""" """
# 支持管理员认证和用户认证
admin_valid, admin_username, admin_user_id = await verify_admin_auth(authorization)
user_valid, user_id, user_username = await verify_user_auth(authorization) user_valid, user_id, user_username = await verify_user_auth(authorization)
if not admin_valid and not user_valid: if not user_valid:
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="Unauthorized" detail="Unauthorized"
) )
# 获取实际的用户ID优先使用 admin 的 user_id # 检查是否是管理员
actual_user_id = admin_user_id if admin_user_id else user_id is_admin = await is_admin_user(authorization)
# 如果是普通用户(非 admin检查是否有 read 权限 # 如果是普通用户(非 admin检查是否有 read 权限
if user_valid and not admin_user_id: if not is_admin:
if not await check_bot_access(bot_uuid, user_id, 'read'): if not await check_bot_access(bot_uuid, user_id, 'read'):
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
@ -1611,18 +1559,19 @@ async def update_bot_settings(
Returns: Returns:
SuccessResponse: 更新结果 SuccessResponse: 更新结果
""" """
# 支持管理员认证和用户认证
admin_valid, admin_username, admin_user_id = await verify_admin_auth(authorization)
user_valid, user_id, user_username = await verify_user_auth(authorization) user_valid, user_id, user_username = await verify_user_auth(authorization)
if not admin_valid and not user_valid: if not user_valid:
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="Unauthorized" detail="Unauthorized"
) )
# 检查是否是管理员
is_admin = await is_admin_user(authorization)
# 用户需要检查是否有 write 权限 # 用户需要检查是否有 write 权限
if user_valid: if not is_admin:
if not await check_bot_access(bot_uuid, user_id, 'write'): if not await check_bot_access(bot_uuid, user_id, 'write'):
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
@ -2051,11 +2000,12 @@ async def admin_verify(authorization: Optional[str] = Header(None)):
Returns: Returns:
AdminVerifyResponse: 验证结果 AdminVerifyResponse: 验证结果
""" """
valid, username, admin_user_id = await verify_admin_auth(authorization) is_admin = await is_admin_user(authorization)
user_valid, _, username = await verify_user_auth(authorization)
return AdminVerifyResponse( return AdminVerifyResponse(
valid=valid, valid=is_admin,
username=username username=username if is_admin else None
) )
@ -2330,11 +2280,9 @@ async def search_users(
Returns: Returns:
List[UserSearchResponse]: 用户列表 List[UserSearchResponse]: 用户列表
""" """
# 支持管理员认证<E8AEA4><E8AF81>用户认证
admin_valid, _, admin_user_id = await verify_admin_auth(authorization)
user_valid, user_id, _ = await verify_user_auth(authorization) user_valid, user_id, _ = await verify_user_auth(authorization)
if not admin_valid and not user_valid: if not user_valid:
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="Unauthorized" detail="Unauthorized"