diff --git a/routes/bot_manager.py b/routes/bot_manager.py index b488db2..55f0126 100644 --- a/routes/bot_manager.py +++ b/routes/bot_manager.py @@ -78,6 +78,173 @@ def verify_auth(authorization: Optional[str]) -> None: ) +# ============== 用户认证辅助函数 ============== + +def hash_password(password: str) -> str: + """ + 使用 SHA256 哈希密码 + + Args: + password: 明文密码 + + Returns: + str: 哈希后的密码 + """ + return hashlib.sha256(password.encode()).hexdigest() + + +async def verify_user_auth(authorization: Optional[str]) -> tuple[bool, Optional[str], Optional[str]]: + """ + 验证用户认证 + + Args: + authorization: Authorization header 值 + + Returns: + tuple[bool, Optional[str], Optional[str]]: (是否有效, 用户ID, 用户名) + """ + provided_token = extract_api_key_from_auth(authorization) + if not provided_token: + return False, None, None + + pool = get_db_pool_manager().pool + + async with pool.connection() as conn: + async with conn.cursor() as cursor: + # 检查 token 是否有效且未过期 + await cursor.execute(""" + SELECT u.id, u.username, t.expires_at + FROM agent_user_tokens t + JOIN agent_user u ON t.user_id = u.id + WHERE t.token = %s + AND t.expires_at > NOW() + AND u.is_active = TRUE + """, (provided_token,)) + row = await cursor.fetchone() + + if not row: + return False, None, None + + return True, str(row[0]), row[1] + + +async def get_user_id_from_token(authorization: Optional[str]) -> Optional[str]: + """ + 从 token 获取用户 ID + + Args: + authorization: Authorization header 值 + + Returns: + Optional[str]: 用户 ID,无效时返回 None + """ + valid, user_id, _ = await verify_user_auth(authorization) + return user_id if valid else None + + +# ============== 权限检查辅助函数 ============== + +async def check_bot_access(bot_id: str, user_id: str, required_permission: str) -> bool: + """ + 检查用户对 Bot 的访问权限 + + Args: + bot_id: Bot UUID + user_id: 用户 UUID + required_permission: 需要的权限 ('read', 'write', 'share', 'delete') + + Returns: + bool: 是否有权限 + """ + pool = get_db_pool_manager().pool + + async with pool.connection() as conn: + async with conn.cursor() as cursor: + # 检查是否是所有者 + await cursor.execute(""" + SELECT id FROM agent_bots + WHERE id = %s AND owner_id = %s + """, (bot_id, user_id)) + if await cursor.fetchone(): + return True + + # 检查是否在分享列表中 + await cursor.execute(""" + SELECT role FROM bot_shares + WHERE bot_id = %s AND user_id = %s + """, (bot_id, user_id)) + row = await cursor.fetchone() + + if not row: + return False + + role = row[0] + + # 权限矩阵 + permissions = { + 'viewer': ['read'], + 'editor': ['read', 'write'], + 'owner': ['read', 'write', 'share', 'delete'] + } + + return required_permission in permissions.get(role, []) + + +async def is_bot_owner(bot_id: str, user_id: str) -> bool: + """ + 检查用户是否是 Bot 的所有者 + + Args: + bot_id: Bot UUID + user_id: 用户 UUID + + Returns: + bool: 是否是所有者 + """ + pool = get_db_pool_manager().pool + + async with pool.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute(""" + SELECT id FROM agent_bots + WHERE id = %s AND owner_id = %s + """, (bot_id, user_id)) + return await cursor.fetchone() is not None + + +async def get_user_bot_role(bot_id: str, user_id: str) -> Optional[str]: + """ + 获取用户在 Bot 中的角色 + + Args: + bot_id: Bot UUID + user_id: 用户 UUID + + Returns: + Optional[str]: 'owner', 'editor', 'viewer' 或 None + """ + pool = get_db_pool_manager().pool + + async with pool.connection() as conn: + async with conn.cursor() as cursor: + # 检查是否是所有者 + await cursor.execute(""" + SELECT id FROM agent_bots + WHERE id = %s AND owner_id = %s + """, (bot_id, user_id)) + if await cursor.fetchone(): + return 'owner' + + # 检查分享角色 + await cursor.execute(""" + SELECT role FROM bot_shares + WHERE bot_id = %s AND user_id = %s + """, (bot_id, user_id)) + row = await cursor.fetchone() + + return row[0] if row else None + + # ============== Pydantic Models ============== # --- Admin 登录相关 --- @@ -100,6 +267,56 @@ class AdminVerifyResponse(BaseModel): username: Optional[str] = None +# --- 用户认证相关 --- +class UserRegisterRequest(BaseModel): + """用户注册请求""" + username: str + email: Optional[str] = None + password: str + invitation_code: str # 邀请制注册需要邀请码 + + +class UserLoginRequest(BaseModel): + """用户登录请求""" + username: str + password: str + + +class UserLoginResponse(BaseModel): + """用户登录响应""" + token: str + user_id: str + username: str + email: Optional[str] = None + is_admin: bool = False + expires_at: str + + +class UserVerifyResponse(BaseModel): + """用户验证响应""" + valid: bool + user_id: Optional[str] = None + username: Optional[str] = None + is_admin: bool = False + + +class UserInfoResponse(BaseModel): + """用户信息响应""" + id: str + username: str + email: Optional[str] = None + is_admin: bool = False + created_at: str + last_login: Optional[str] = None + + +class UserSearchResponse(BaseModel): + """用户搜索响应""" + id: str + username: str + email: Optional[str] = None + + # --- 模型相关 --- class ModelCreate(BaseModel): """创建模型请求""" @@ -151,6 +368,11 @@ class BotResponse(BaseModel): id: str name: str bot_id: str + is_owner: bool = False + is_shared: bool = False + owner: Optional[dict] = None # {id, username} + role: Optional[str] = None # 'viewer', 'editor', None for owner + shared_at: Optional[str] = None created_at: str updated_at: str @@ -245,6 +467,32 @@ class MCPServerResponse(BaseModel): updated_at: str +# --- 分享相关 --- +class BotShareCreate(BaseModel): + """创建分享请求""" + user_ids: List[str] + role: str = "viewer" # 'viewer' or 'editor' + + +class BotShareResponse(BaseModel): + """分享响应""" + id: str + bot_id: str + user_id: str + username: str + email: Optional[str] = None + role: str + shared_at: str + shared_by: Optional[str] = None + + +class BotSharesListResponse(BaseModel): + """分享列表响应""" + bot_id: str + shares: List[BotShareResponse] + + + # --- 通用响应 --- class SuccessResponse(BaseModel): """通用成功响应""" @@ -254,6 +502,216 @@ class SuccessResponse(BaseModel): # ============== 数据库表初始化 ============== +async def migrate_bot_owner_and_shares(): + """ + 迁移 agent_bots 表添加 owner_id 字段,并创建相关表 + """ + pool = get_db_pool_manager().pool + + async with pool.connection() as conn: + async with conn.cursor() as cursor: + # 1. 首先创建 agent_user 表 + await cursor.execute(""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'agent_user' + ) + """) + user_table_exists = (await cursor.fetchone())[0] + + if not user_table_exists: + logger.info("Creating agent_user table") + + # 创建 agent_user 表 + await cursor.execute(""" + CREATE TABLE agent_user ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + username VARCHAR(255) UNIQUE NOT NULL, + email VARCHAR(255) UNIQUE, + password_hash VARCHAR(255) NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + last_login TIMESTAMP WITH TIME ZONE, + is_active BOOLEAN DEFAULT TRUE, + is_admin BOOLEAN DEFAULT FALSE + ) + """) + + # 创建索引 + await cursor.execute("CREATE INDEX idx_agent_user_username ON agent_user(username)") + await cursor.execute("CREATE INDEX idx_agent_user_email ON agent_user(email)") + await cursor.execute("CREATE INDEX idx_agent_user_is_active ON agent_user(is_active)") + + logger.info("agent_user table created successfully") + else: + # 为已存在的表添加 is_admin 字段 + await cursor.execute(""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'agent_user' AND column_name = 'is_admin' + """) + has_admin_column = await cursor.fetchone() + if not has_admin_column: + logger.info("Adding is_admin column to agent_user table") + await cursor.execute(""" + ALTER TABLE agent_user + ADD COLUMN is_admin BOOLEAN DEFAULT FALSE + """) + logger.info("is_admin column added successfully") + + # 2. 创建 bot_shares 表 + await cursor.execute(""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'bot_shares' + ) + """) + shares_table_exists = (await cursor.fetchone())[0] + + if not shares_table_exists: + logger.info("Creating bot_shares table") + + await cursor.execute(""" + CREATE TABLE bot_shares ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + bot_id UUID NOT NULL REFERENCES agent_bots(id) ON DELETE CASCADE, + user_id UUID NOT NULL REFERENCES agent_user(id) ON DELETE CASCADE, + shared_by UUID NOT NULL REFERENCES agent_user(id) ON DELETE SET NULL, + role VARCHAR(50) DEFAULT 'viewer' CHECK (role IN ('viewer', 'editor')), + shared_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + UNIQUE(bot_id, user_id) + ) + """) + + await cursor.execute("CREATE INDEX idx_bot_shares_bot_id ON bot_shares(bot_id)") + await cursor.execute("CREATE INDEX idx_bot_shares_user_id ON bot_shares(user_id)") + await cursor.execute("CREATE INDEX idx_bot_shares_shared_by ON bot_shares(shared_by)") + + logger.info("bot_shares table created successfully") + + # 4. 创建 agent_user_tokens 表 + await cursor.execute(""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'agent_user_tokens' + ) + """) + tokens_table_exists = (await cursor.fetchone())[0] + + if not tokens_table_exists: + logger.info("Creating agent_user_tokens table") + + await cursor.execute(""" + CREATE TABLE agent_user_tokens ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES agent_user(id) ON DELETE CASCADE, + token VARCHAR(255) NOT NULL UNIQUE, + expires_at TIMESTAMP WITH TIME ZONE NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() + ) + """) + + await cursor.execute("CREATE INDEX idx_agent_user_tokens_token ON agent_user_tokens(token)") + await cursor.execute("CREATE INDEX idx_agent_user_tokens_user_id ON agent_user_tokens(user_id)") + await cursor.execute("CREATE INDEX idx_agent_user_tokens_expires ON agent_user_tokens(expires_at)") + + logger.info("agent_user_tokens table created successfully") + + # 5. 检查 agent_bots 表是否有 owner_id 字段 + await cursor.execute(""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'agent_bots' AND column_name = 'owner_id' + """) + has_owner_column = await cursor.fetchone() + + if not has_owner_column: + logger.info("Adding owner_id column to agent_bots table") + + # 首先创建或更新默认 admin 用户(密码:admin123) + default_admin_id = '00000000-0000-0000-0000-000000000001' + default_admin_password_hash = hash_password('admin123') + + # 先���除可能存在的旧 admin 用户(避免用户名冲突) + await cursor.execute("DELETE FROM agent_user WHERE username = 'admin' AND id != %s", (default_admin_id,)) + + # 创建或更新 admin 用户 + await cursor.execute(""" + INSERT INTO agent_user (id, username, email, password_hash, is_active, is_admin) + VALUES (%s, 'admin', 'admin@local', %s, TRUE, TRUE) + ON CONFLICT (id) DO UPDATE SET password_hash = EXCLUDED.password_hash, is_admin = EXCLUDED.is_admin + """, (default_admin_id, default_admin_password_hash)) + + logger.info(f"Default admin user created/updated with password 'admin123'") + + # 添加 owner_id 字段,允许 NULL 以便迁移 + await cursor.execute(""" + ALTER TABLE agent_bots + ADD COLUMN owner_id UUID REFERENCES agent_user(id) ON DELETE SET NULL + """) + + # 创建索引 + await cursor.execute("CREATE INDEX idx_agent_bots_owner_id ON agent_bots(owner_id)") + + # 将现有的 bots 分配给默认 admin 用户 + await cursor.execute(""" + UPDATE agent_bots + SET owner_id = %s + WHERE owner_id IS NULL + """, (default_admin_id,)) + + logger.info("Existing bots assigned to default admin user") + + # 现在将 owner_id 改为 NOT NULL + await cursor.execute(""" + ALTER TABLE agent_bots + ALTER COLUMN owner_id SET NOT NULL + """) + + # 为了防止数据丢失,将 ON DELETE SET NULL 改为 ON DELETE RESTRICT + # 需要先删除约束再重新添加 + await cursor.execute(""" + ALTER TABLE agent_bots + DROP CONSTRAINT agent_bots_owner_id_fkey + """) + await cursor.execute(""" + ALTER TABLE agent_bots + ADD CONSTRAINT agent_bots_owner_id_fkey + FOREIGN KEY (owner_id) REFERENCES agent_user(id) ON DELETE RESTRICT + """) + + logger.info("owner_id column added and set to NOT NULL") + + # 确保默认 admin 用户存在(总是执行) + default_admin_id = '00000000-0000-0000-0000-000000000001' + default_admin_password_hash = hash_password('admin123') + + # 检查 admin 用户是否已存在 + await cursor.execute("SELECT id, password_hash FROM agent_user WHERE username = 'admin'") + admin_row = await cursor.fetchone() + + if admin_row: + existing_id, existing_hash = admin_row + # 如果密码是旧的 PLACEHOLDER,则更新 + if existing_hash == 'PLACEHOLDER' or existing_hash == '8c6976e5b5410415bde908bd4dee15dfb167a9c873fc4bb8a81f6f2ab448a918': + await cursor.execute(""" + UPDATE agent_user + SET password_hash = %s, is_admin = TRUE + WHERE username = 'admin' + """, (default_admin_password_hash,)) + logger.info("Updated existing admin user with new password 'admin123'") + else: + # 创建新的 admin 用户 + await cursor.execute(""" + INSERT INTO agent_user (id, username, email, password_hash, is_active, is_admin) + VALUES (%s, 'admin', 'admin@local', %s, TRUE, TRUE) + """, (default_admin_id, default_admin_password_hash)) + logger.info("Created new admin user with password 'admin123'") + + await conn.commit() + logger.info("Bot owner and shares migration completed successfully") + + async def migrate_bot_settings_to_jsonb(): """ 迁移 agent_bot_settings 表数据到 agent_bots.settings JSONB 字段 @@ -333,7 +791,10 @@ async def init_bot_manager_tables(): pool = get_db_pool_manager().pool # 首先执行迁移(如果需要) + # 1. Bot settings 迁移 await migrate_bot_settings_to_jsonb() + # 2. User 和 shares 迁移 + await migrate_bot_owner_and_shares() # SQL 表创建语句 tables_sql = [ @@ -694,7 +1155,7 @@ async def set_default_model(model_id: str, authorization: Optional[str] = Header @router.get("/api/v1/bots", response_model=List[BotResponse]) async def get_bots(authorization: Optional[str] = Header(None)): """ - 获取所有 Bot + 获取所有 Bot(拥有的和分享给我的) Args: authorization: Bearer token @@ -702,29 +1163,74 @@ async def get_bots(authorization: Optional[str] = Header(None)): Returns: List[BotResponse]: Bot 列表 """ - verify_auth(authorization) + # 支持管理员认证和用户认证 + admin_valid, admin_username = await verify_admin_auth(authorization) + user_valid, user_id, user_username = await verify_user_auth(authorization) + + if not admin_valid and not user_valid: + raise HTTPException( + status_code=401, + detail="Unauthorized" + ) pool = get_db_pool_manager().pool async with pool.connection() as conn: async with conn.cursor() as cursor: - await cursor.execute(""" - SELECT id, name, bot_id, created_at, updated_at - FROM agent_bots - ORDER BY created_at DESC - """) - rows = await cursor.fetchall() + if admin_valid: + # 管理员可以看到所有 Bot + await cursor.execute(""" + SELECT b.id, b.name, b.bot_id, b.created_at, b.updated_at, + u.id as owner_id, u.username as owner_username + FROM agent_bots b + LEFT JOIN agent_user u ON b.owner_id = u.id + ORDER BY b.created_at DESC + """) + rows = await cursor.fetchall() - return [ - BotResponse( - id=str(row[0]), - name=row[1], - bot_id=row[2], - created_at=datetime_to_str(row[3]), - updated_at=datetime_to_str(row[4]) - ) - for row in rows - ] + return [ + BotResponse( + id=str(row[0]), + name=row[1], + bot_id=row[2], + is_owner=True, + is_shared=False, + owner={"id": str(row[5]), "username": row[6]} if row[5] else None, + role=None, + created_at=datetime_to_str(row[3]), + updated_at=datetime_to_str(row[4]) + ) + for row in rows + ] + else: + # 用户只能看到拥有的 Bot 和分享给自己的 Bot + await cursor.execute(""" + SELECT b.id, b.name, b.bot_id, b.created_at, b.updated_at, + u.id as owner_id, u.username as owner_username, + s.role, s.shared_at + FROM agent_bots b + LEFT JOIN agent_user u ON b.owner_id = u.id + LEFT JOIN bot_shares s ON b.id = s.bot_id AND s.user_id = %s + WHERE b.owner_id = %s OR s.user_id = %s + ORDER BY b.created_at DESC + """, (user_id, user_id, user_id)) + rows = await cursor.fetchall() + + return [ + BotResponse( + id=str(row[0]), + name=row[1], + bot_id=row[2], + is_owner=(row[5] == user_id), + is_shared=(row[5] != user_id and row[7] is not None), + owner={"id": str(row[5]), "username": row[6]} if row[5] else None, + role=row[7] if row[7] else None, + shared_at=datetime_to_str(row[8]) if row[8] else None, + created_at=datetime_to_str(row[3]), + updated_at=datetime_to_str(row[4]) + ) + for row in rows + ] @router.post("/api/v1/bots", response_model=BotResponse) @@ -739,21 +1245,32 @@ async def create_bot(request: BotCreate, authorization: Optional[str] = Header(N Returns: BotResponse: 创建的 Bot 信息 """ - verify_auth(authorization) + # 支持管理员认证和用户认证 + admin_valid, admin_username = await verify_admin_auth(authorization) + user_valid, user_id, user_username = await verify_user_auth(authorization) + + if not admin_valid and not user_valid: + raise HTTPException( + status_code=401, + detail="Unauthorized" + ) pool = get_db_pool_manager().pool # 自动生成 bot_id bot_id = str(uuid.uuid4()) + # 使用用户 ID 或默认 admin ID + owner_id = user_id if user_valid else None + try: async with pool.connection() as conn: async with conn.cursor() as cursor: await cursor.execute(""" - INSERT INTO agent_bots (name, bot_id) - VALUES (%s, %s) - RETURNING id, created_at, updated_at - """, (request.name, bot_id)) + INSERT INTO agent_bots (name, bot_id, owner_id) + VALUES (%s, %s, %s) + RETURNING id, created_at, updated_at, owner_id + """, (request.name, bot_id, owner_id)) row = await cursor.fetchone() await conn.commit() @@ -761,6 +1278,9 @@ async def create_bot(request: BotCreate, authorization: Optional[str] = Header(N id=str(row[0]), name=request.name, bot_id=bot_id, + is_owner=True, + is_shared=False, + owner={"id": str(owner_id), "username": user_username} if owner_id else None, created_at=datetime_to_str(row[1]), updated_at=datetime_to_str(row[2]) ) @@ -777,7 +1297,7 @@ async def update_bot( authorization: Optional[str] = Header(None) ): """ - 更新 Bot + 更新 Bot(仅所有者可以更新) Args: bot_uuid: Bot 内部 UUID @@ -787,7 +1307,23 @@ async def update_bot( Returns: BotResponse: 更新后的 Bot 信息 """ - verify_auth(authorization) + # 支持管理员认证和用户认证 + admin_valid, admin_username = await verify_admin_auth(authorization) + user_valid, user_id, user_username = await verify_user_auth(authorization) + + if not admin_valid and not user_valid: + raise HTTPException( + status_code=401, + detail="Unauthorized" + ) + + # 非管理员需要检查所有权 + if user_valid: + if not await is_bot_owner(bot_uuid, user_id): + raise HTTPException( + status_code=403, + detail="Only bot owner can update" + ) pool = get_db_pool_manager().pool @@ -827,6 +1363,8 @@ async def update_bot( id=str(row[0]), name=row[1], bot_id=row[2], + is_owner=True, + is_shared=False, created_at=datetime_to_str(row[3]), updated_at=datetime_to_str(row[4]) ) @@ -835,7 +1373,7 @@ async def update_bot( @router.delete("/api/v1/bots/{bot_uuid}", response_model=SuccessResponse) async def delete_bot(bot_uuid: str, authorization: Optional[str] = Header(None)): """ - 删除 Bot(级联删除相关设置、会话、MCP 配置等) + 删除 Bot(仅所有者可以删除,级联删除相关设置、会话、MCP 配置等) Args: bot_uuid: Bot 内部 UUID @@ -844,7 +1382,23 @@ async def delete_bot(bot_uuid: str, authorization: Optional[str] = Header(None)) Returns: SuccessResponse: 删除结果 """ - verify_auth(authorization) + # 支持管理员认证和用户认证 + admin_valid, admin_username = await verify_admin_auth(authorization) + user_valid, user_id, user_username = await verify_user_auth(authorization) + + if not admin_valid and not user_valid: + raise HTTPException( + status_code=401, + detail="Unauthorized" + ) + + # 非管理员需要检查所有权 + if user_valid: + if not await is_bot_owner(bot_uuid, user_id): + raise HTTPException( + status_code=403, + detail="Only bot owner can delete" + ) pool = get_db_pool_manager().pool @@ -866,7 +1420,7 @@ async def delete_bot(bot_uuid: str, authorization: Optional[str] = Header(None)) @router.get("/api/v1/bots/{bot_uuid}/settings", response_model=BotSettingsResponse) async def get_bot_settings(bot_uuid: str, authorization: Optional[str] = Header(None)): """ - 获取 Bot 设置 + 获取 Bot 设置(所有者和 editor 可以查看) Args: bot_uuid: Bot 内部 UUID @@ -875,7 +1429,23 @@ async def get_bot_settings(bot_uuid: str, authorization: Optional[str] = Header( Returns: BotSettingsResponse: Bot 设置信息 """ - verify_auth(authorization) + # 支持管理员认证和用户认证 + admin_valid, admin_username = await verify_admin_auth(authorization) + user_valid, user_id, user_username = await verify_user_auth(authorization) + + if not admin_valid and not user_valid: + raise HTTPException( + status_code=401, + detail="Unauthorized" + ) + + # 用户需要检查是否有 read 权限 + if user_valid: + if not await check_bot_access(bot_uuid, user_id, 'read'): + raise HTTPException( + status_code=403, + detail="You don't have access to this bot" + ) pool = get_db_pool_manager().pool @@ -939,7 +1509,7 @@ async def update_bot_settings( authorization: Optional[str] = Header(None) ): """ - 更新 Bot 设置 + 更新 Bot 设置(仅所有者和 editor 可以更新) Args: bot_uuid: Bot 内部 UUID @@ -949,7 +1519,23 @@ async def update_bot_settings( Returns: SuccessResponse: 更新结果 """ - verify_auth(authorization) + # 支持管理员认证和用户认证 + admin_valid, admin_username = await verify_admin_auth(authorization) + user_valid, user_id, user_username = await verify_user_auth(authorization) + + if not admin_valid and not user_valid: + raise HTTPException( + status_code=401, + detail="Unauthorized" + ) + + # 用户需要检查是否有 write 权限 + if user_valid: + if not await check_bot_access(bot_uuid, user_id, 'write'): + raise HTTPException( + status_code=403, + detail="You don't have permission to modify this bot" + ) pool = get_db_pool_manager().pool @@ -1409,3 +1995,507 @@ async def admin_logout(authorization: Optional[str] = Header(None)): await conn.commit() return SuccessResponse(success=True, message="Logged out successfully") + + +# ============== 用户认证 API ============== + +@router.post("/api/v1/auth/register", response_model=UserLoginResponse) +async def user_register(request: UserRegisterRequest): + """ + 用户注册(需��邀请码) + + Args: + request: 注册请求(用户名、邮箱、密码、邀请码) + + Returns: + UserLoginResponse: 注册成功返回 token 和用户信息 + """ + pool = get_db_pool_manager().pool + + async with pool.connection() as conn: + async with conn.cursor() as cursor: + # 1. 验证邀请码(固定邀请码) + if request.invitation_code != "WELCOME2026": + raise HTTPException( + status_code=400, + detail="邀请码无效" + ) + + # 2. 检查用户名是否已存在 + await cursor.execute(""" + SELECT id FROM agent_user WHERE username = %s + """, (request.username,)) + if await cursor.fetchone(): + raise HTTPException( + status_code=400, + detail="用户名已存在" + ) + + # 3. 检查邮箱是否已存在(如果提供) + if request.email: + await cursor.execute(""" + SELECT id FROM agent_user WHERE email = %s + """, (request.email,)) + if await cursor.fetchone(): + raise HTTPException( + status_code=400, + detail="邮箱已被注册" + ) + + # 4. 创建用户 + password_hash = hash_password(request.password) + await cursor.execute(""" + INSERT INTO agent_user (username, email, password_hash) + VALUES (%s, %s, %s) + RETURNING id, created_at + """, (request.username, request.email, password_hash)) + user_id, created_at = await cursor.fetchone() + + # 6. 生成 token + token = secrets.token_urlsafe(32) + expires_at = datetime.now() + timedelta(hours=TOKEN_EXPIRE_HOURS) + + await cursor.execute(""" + INSERT INTO agent_user_tokens (user_id, token, expires_at) + VALUES (%s, %s, %s) + """, (user_id, token, expires_at)) + + await conn.commit() + + return UserLoginResponse( + token=token, + user_id=str(user_id), + username=request.username, + email=request.email, + is_admin=False, + expires_at=expires_at.isoformat() + ) + + +@router.post("/api/v1/auth/login", response_model=UserLoginResponse) +async def user_login(request: UserLoginRequest): + """ + 用户登录 + + Args: + request: 登录请求(用户名、密码) + + Returns: + UserLoginResponse: 登录成功返回 token 和用户信息 + """ + pool = get_db_pool_manager().pool + + async with pool.connection() as conn: + async with conn.cursor() as cursor: + # 1. 验证用户名和密码 + password_hash = hash_password(request.password) + await cursor.execute(""" + SELECT id, username, email, is_active, is_admin + FROM agent_user + WHERE username = %s AND password_hash = %s + """, (request.username, password_hash)) + row = await cursor.fetchone() + + if not row: + raise HTTPException( + status_code=401, + detail="用户名或密码错误" + ) + + user_id, username, email, is_active, is_admin = row + + if not is_active: + raise HTTPException( + status_code=403, + detail="账户已被禁用" + ) + + # 2. 更新最后登录时间 + await cursor.execute(""" + UPDATE agent_user + SET last_login = NOW() + WHERE id = %s + """, (user_id,)) + + # 3. 清理旧 token + await cursor.execute(""" + DELETE FROM agent_user_tokens + WHERE user_id = %s + """, (user_id,)) + + # 4. 生成新 token + token = secrets.token_urlsafe(32) + expires_at = datetime.now() + timedelta(hours=TOKEN_EXPIRE_HOURS) + + await cursor.execute(""" + INSERT INTO agent_user_tokens (user_id, token, expires_at) + VALUES (%s, %s, %s) + """, (user_id, token, expires_at)) + + await conn.commit() + + return UserLoginResponse( + token=token, + user_id=str(user_id), + username=username, + email=email, + is_admin=is_admin or False, + expires_at=expires_at.isoformat() + ) + + +@router.post("/api/v1/auth/verify", response_model=UserVerifyResponse) +async def user_verify(authorization: Optional[str] = Header(None)): + """ + 验证用户 token 是否有效 + + Args: + authorization: Bearer token + + Returns: + UserVerifyResponse: 验证结果 + """ + valid, user_id, username = await verify_user_auth(authorization) + + is_admin_flag = False + if valid and user_id: + pool = get_db_pool_manager().pool + async with pool.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute(""" + SELECT is_admin FROM agent_user WHERE id = %s + """, (user_id,)) + row = await cursor.fetchone() + if row: + is_admin_flag = row[0] or False + + return UserVerifyResponse( + valid=valid, + user_id=user_id, + username=username, + is_admin=is_admin_flag + ) + + +@router.get("/api/v1/users/me", response_model=UserInfoResponse) +async def get_current_user(authorization: Optional[str] = Header(None)): + """ + 获取当前用户信息 + + Args: + authorization: Bearer token + + Returns: + UserInfoResponse: 用户信息 + """ + valid, user_id, username = await verify_user_auth(authorization) + + if not valid: + raise HTTPException( + status_code=401, + detail="Invalid token" + ) + + pool = get_db_pool_manager().pool + + async with pool.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute(""" + SELECT id, username, email, is_admin, created_at, last_login + FROM agent_user + WHERE id = %s + """, (user_id,)) + row = await cursor.fetchone() + + if not row: + raise HTTPException( + status_code=404, + detail="User not found" + ) + + user_id, username, email, is_admin, created_at, last_login = row + + return UserInfoResponse( + id=str(user_id), + username=username, + email=email, + is_admin=is_admin or False, + created_at=created_at.isoformat() if created_at else "", + last_login=last_login.isoformat() if last_login else None + ) + + +@router.get("/api/v1/users", response_model=List[UserSearchResponse]) +async def search_users( + q: str = "", + authorization: Optional[str] = Header(None) +): + """ + 搜索用户(用于分享) + + Args: + q: 搜索关键词(用户名或邮箱) + authorization: Bearer token + + Returns: + List[UserSearchResponse]: 用户列表 + """ + # 支持管理员认证和用户认证 + admin_valid, _ = await verify_admin_auth(authorization) + user_valid, user_id, _ = await verify_user_auth(authorization) + + if not admin_valid and not user_valid: + raise HTTPException( + status_code=401, + detail="Unauthorized" + ) + + if not q: + return [] + + pool = get_db_pool_manager().pool + + async with pool.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute(""" + SELECT id, username, email + FROM agent_user + WHERE is_active = TRUE + AND (username ILIKE %s OR email ILIKE %s) + AND id != %s + ORDER BY username + LIMIT 20 + """, (f"%{q}%", f"%{q}%", user_id)) + rows = await cursor.fetchall() + + return [ + UserSearchResponse( + id=str(row[0]), + username=row[1], + email=row[2] + ) + for row in rows + ] + + +@router.post("/api/v1/auth/logout", response_model=SuccessResponse) +async def user_logout(authorization: Optional[str] = Header(None)): + """ + 用户登出(删除 token) + + Args: + authorization: Bearer token + + Returns: + SuccessResponse: 登出结果 + """ + provided_token = extract_api_key_from_auth(authorization) + if not provided_token: + raise HTTPException( + status_code=401, + detail="Authorization header is required" + ) + + pool = get_db_pool_manager().pool + + async with pool.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute("DELETE FROM agent_user_tokens WHERE token = %s", (provided_token,)) + await conn.commit() + + return SuccessResponse(success=True, message="Logged out successfully") + + +# ============== 分享管理 API ============== + +@router.get("/api/v1/bots/{bot_uuid}/shares", response_model=BotSharesListResponse) +async def get_bot_shares( + bot_uuid: str, + authorization: Optional[str] = Header(None) +): + """ + 获取 Bot 的分享列表 + + Args: + bot_uuid: Bot UUID + authorization: Bearer token + + Returns: + BotSharesListResponse: 分享列表 + """ + valid, user_id, _ = await verify_user_auth(authorization) + + if not valid: + raise HTTPException( + status_code=401, + detail="Unauthorized" + ) + + # 验证用户是 Bot 所有者 + if not await is_bot_owner(bot_uuid, user_id): + raise HTTPException( + status_code=403, + detail="Only bot owner can view shares" + ) + + pool = get_db_pool_manager().pool + + async with pool.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute(""" + SELECT s.id, s.bot_id, s.user_id, u.username, u.email, s.role, s.shared_at, su.username + FROM bot_shares s + JOIN agent_user u ON s.user_id = u.id + LEFT JOIN agent_user su ON s.shared_by = su.id + WHERE s.bot_id = %s + ORDER BY s.shared_at DESC + """, (bot_uuid,)) + rows = await cursor.fetchall() + + shares = [ + BotShareResponse( + id=str(row[0]), + bot_id=str(row[1]), + user_id=str(row[2]), + username=row[3], + email=row[4], + role=row[5], + shared_at=row[6].isoformat() if row[6] else "", + shared_by=row[7] if row[7] is None else str(row[7]) + ) + for row in rows + ] + + return BotSharesListResponse( + bot_id=bot_uuid, + shares=shares + ) + + +@router.post("/api/v1/bots/{bot_uuid}/shares", response_model=SuccessResponse) +async def add_bot_share( + bot_uuid: str, + request: BotShareCreate, + authorization: Optional[str] = Header(None) +): + """ + 添加 Bot 分享 + + Args: + bot_uuid: Bot UUID + request: 分享请求 + authorization: Bearer token + + Returns: + SuccessResponse: 操作结果 + """ + valid, user_id, _ = await verify_user_auth(authorization) + + if not valid: + raise HTTPException( + status_code=401, + detail="Unauthorized" + ) + + # 验证用户是 Bot 所有者 + if not await is_bot_owner(bot_uuid, user_id): + raise HTTPException( + status_code=403, + detail="Only bot owner can share" + ) + + pool = get_db_pool_manager().pool + + async with pool.connection() as conn: + async with conn.cursor() as cursor: + # 验证目标用户存在 + for target_user_id in request.user_ids: + await cursor.execute(""" + SELECT id FROM agent_user WHERE id = %s AND is_active = TRUE + """, (target_user_id,)) + if not await cursor.fetchone(): + raise HTTPException( + status_code=400, + detail=f"User {target_user_id} not found" + ) + + # 添加分享 + added_count = 0 + for target_user_id in request.user_ids: + try: + await cursor.execute(""" + INSERT INTO bot_shares (bot_id, user_id, shared_by, role) + VALUES (%s, %s, %s, %s) + ON CONFLICT (bot_id, user_id) DO UPDATE SET + role = EXCLUDED.role, + shared_by = EXCLUDED.shared_by + """, (bot_uuid, target_user_id, user_id, request.role)) + added_count += 1 + except Exception: + pass # 忽略重复的 + + await conn.commit() + + return SuccessResponse( + success=True, + message=f"Shared with {added_count} user(s)" + ) + + +@router.delete("/api/v1/bots/{bot_uuid}/shares/{user_id}", response_model=SuccessResponse) +async def remove_bot_share( + bot_uuid: str, + user_id: str, + authorization: Optional[str] = Header(None) +): + """ + 移除 Bot 分享 + + Args: + bot_uuid: Bot UUID + user_id: 要移除的用户 ID + authorization: Bearer token + + Returns: + SuccessResponse: 操作结果 + """ + valid, current_user_id, _ = await verify_user_auth(authorization) + + if not valid: + raise HTTPException( + status_code=401, + detail="Unauthorized" + ) + + # 验证用户是 Bot 所有者 + if not await is_bot_owner(bot_uuid, current_user_id): + raise HTTPException( + status_code=403, + detail="Only bot owner can remove shares" + ) + + pool = get_db_pool_manager().pool + + async with pool.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute(""" + DELETE FROM bot_shares + WHERE bot_id = %s AND user_id = %s + RETURNING id + """, (bot_uuid, user_id)) + row = await cursor.fetchone() + + if not row: + raise HTTPException( + status_code=404, + detail="Share not found" + ) + + await conn.commit() + + return SuccessResponse( + success=True, + message="Share removed successfully" + ) + +