3014 lines
95 KiB
Python
3014 lines
95 KiB
Python
"""
|
||
Bot Manager API 路由
|
||
提供模型配置、Bot 管理、设置管理、MCP 服务器等功能的 API
|
||
"""
|
||
import json
|
||
import logging
|
||
import uuid
|
||
import hashlib
|
||
import secrets
|
||
from datetime import datetime, timedelta
|
||
from typing import Optional, List
|
||
from fastapi import APIRouter, HTTPException, Header
|
||
from pydantic import BaseModel
|
||
|
||
from agent.db_pool_manager import get_db_pool_manager
|
||
from utils.fastapi_utils import extract_api_key_from_auth
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
router = APIRouter()
|
||
|
||
# ============== Admin 配置 ==============
|
||
ADMIN_USERNAME = "admin"
|
||
ADMIN_PASSWORD = "Admin123" # 生产环境应使用环境变量
|
||
TOKEN_EXPIRE_HOURS = 24
|
||
|
||
|
||
# ============== 认证函数 ==============
|
||
|
||
async def verify_admin_auth(authorization: Optional[str]) -> tuple[bool, Optional[str]]:
|
||
"""
|
||
验证管理员认证
|
||
|
||
Args:
|
||
authorization: Authorization header 值
|
||
|
||
Returns:
|
||
tuple[bool, Optional[str]]: (是否有效, 用户名)
|
||
"""
|
||
provided_token = extract_api_key_from_auth(authorization)
|
||
if not provided_token:
|
||
return False, None
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 检查 token 是否有效且未过期
|
||
await cursor.execute("""
|
||
SELECT username, expires_at
|
||
FROM agent_admin_tokens
|
||
WHERE token = %s
|
||
AND expires_at > NOW()
|
||
""", (provided_token,))
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
return False, None
|
||
|
||
return True, row[0]
|
||
|
||
|
||
def verify_auth(authorization: Optional[str]) -> None:
|
||
"""
|
||
验证请求认证
|
||
|
||
Args:
|
||
authorization: Authorization header 值
|
||
|
||
Raises:
|
||
HTTPException: 认证失败时抛出 401 错误
|
||
"""
|
||
provided_token = extract_api_key_from_auth(authorization)
|
||
if not provided_token:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Authorization header is required"
|
||
)
|
||
|
||
|
||
# ============== 用户认证辅助函数 ==============
|
||
|
||
def hash_password(password: str) -> str:
|
||
"""
|
||
使用 SHA256 哈希密码
|
||
|
||
Args:
|
||
password: 明文密码
|
||
|
||
Returns:
|
||
str: 哈希后的密码
|
||
"""
|
||
return hashlib.sha256(password.encode()).hexdigest()
|
||
|
||
|
||
async def verify_user_auth(authorization: Optional[str]) -> tuple[bool, Optional[str], Optional[str]]:
|
||
"""
|
||
验证用户认证
|
||
|
||
Args:
|
||
authorization: Authorization header 值
|
||
|
||
Returns:
|
||
tuple[bool, Optional[str], Optional[str]]: (是否有效, 用户ID, 用户名)
|
||
"""
|
||
provided_token = extract_api_key_from_auth(authorization)
|
||
if not provided_token:
|
||
return False, None, None
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 检查 token 是否有效且未过期
|
||
await cursor.execute("""
|
||
SELECT u.id, u.username, t.expires_at
|
||
FROM agent_user_tokens t
|
||
JOIN agent_user u ON t.user_id = u.id
|
||
WHERE t.token = %s
|
||
AND t.expires_at > NOW()
|
||
AND u.is_active = TRUE
|
||
""", (provided_token,))
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
return False, None, None
|
||
|
||
return True, str(row[0]), row[1]
|
||
|
||
|
||
async def get_user_id_from_token(authorization: Optional[str]) -> Optional[str]:
|
||
"""
|
||
从 token 获取用户 ID
|
||
|
||
Args:
|
||
authorization: Authorization header 值
|
||
|
||
Returns:
|
||
Optional[str]: 用户 ID,无效时返回 None
|
||
"""
|
||
valid, user_id, _ = await verify_user_auth(authorization)
|
||
return user_id if valid else None
|
||
|
||
|
||
# ============== 权限检查辅助函数 ==============
|
||
|
||
async def is_admin_user(authorization: Optional[str]) -> bool:
|
||
"""
|
||
检查当前请求是否来自管理员(admin token 或 is_admin=True 的用户)
|
||
|
||
Args:
|
||
authorization: Authorization header 值
|
||
|
||
Returns:
|
||
bool: 是否是管理员
|
||
"""
|
||
admin_valid, _ = await verify_admin_auth(authorization)
|
||
if admin_valid:
|
||
return True
|
||
|
||
user_valid, user_id, _ = await verify_user_auth(authorization)
|
||
if not user_valid or not user_id:
|
||
return False
|
||
|
||
pool = get_db_pool_manager().pool
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
SELECT is_admin FROM agent_user WHERE id = %s
|
||
""", (user_id,))
|
||
row = await cursor.fetchone()
|
||
return row and row[0]
|
||
|
||
|
||
async def check_bot_access(bot_id: str, user_id: str, required_permission: str) -> bool:
|
||
"""
|
||
检查用户对 Bot 的访问权限
|
||
|
||
Args:
|
||
bot_id: Bot UUID
|
||
user_id: 用户 UUID
|
||
required_permission: 需要的权限 ('read', 'write', 'share', 'delete')
|
||
|
||
Returns:
|
||
bool: 是否有权限
|
||
"""
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 检查是否是所有者
|
||
await cursor.execute("""
|
||
SELECT id FROM agent_bots
|
||
WHERE id = %s AND owner_id = %s
|
||
""", (bot_id, user_id))
|
||
if await cursor.fetchone():
|
||
return True
|
||
|
||
# 检查是否在分享列表中(同<EFBC88><E5908C>检查过期时间)
|
||
await cursor.execute("""
|
||
SELECT role, expires_at FROM bot_shares
|
||
WHERE bot_id = %s AND user_id = %s
|
||
""", (bot_id, user_id))
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
return False
|
||
|
||
role, expires_at = row
|
||
|
||
# 检查是否已过期
|
||
if expires_at is not None:
|
||
# 获取当前时间(考虑时区)
|
||
from datetime import datetime, timezone
|
||
now = datetime.now(timezone.utc)
|
||
if expires_at < now:
|
||
# 分享已过期,拒绝访问
|
||
return False
|
||
|
||
# 权限矩阵
|
||
permissions = {
|
||
'viewer': ['read'],
|
||
'editor': ['read', 'write'],
|
||
'owner': ['read', 'write', 'share', 'delete']
|
||
}
|
||
|
||
return required_permission in permissions.get(role, [])
|
||
|
||
|
||
async def is_bot_owner(bot_id: str, user_id: str) -> bool:
|
||
"""
|
||
检查用户是否是 Bot 的所有者
|
||
|
||
Args:
|
||
bot_id: Bot UUID
|
||
user_id: 用户 UUID
|
||
|
||
Returns:
|
||
bool: 是否是所有者
|
||
"""
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
SELECT id FROM agent_bots
|
||
WHERE id = %s AND owner_id = %s
|
||
""", (bot_id, user_id))
|
||
return await cursor.fetchone() is not None
|
||
|
||
|
||
async def get_user_bot_role(bot_id: str, user_id: str) -> Optional[str]:
|
||
"""
|
||
获取用户在 Bot 中的角色
|
||
|
||
Args:
|
||
bot_id: Bot UUID
|
||
user_id: 用户 UUID
|
||
|
||
Returns:
|
||
Optional[str]: 'owner', 'editor', 'viewer' 或 None
|
||
"""
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 检查是否是所有者
|
||
await cursor.execute("""
|
||
SELECT id FROM agent_bots
|
||
WHERE id = %s AND owner_id = %s
|
||
""", (bot_id, user_id))
|
||
if await cursor.fetchone():
|
||
return 'owner'
|
||
|
||
# 检查分享角色
|
||
await cursor.execute("""
|
||
SELECT role FROM bot_shares
|
||
WHERE bot_id = %s AND user_id = %s
|
||
""", (bot_id, user_id))
|
||
row = await cursor.fetchone()
|
||
|
||
return row[0] if row else None
|
||
|
||
|
||
# ============== Pydantic Models ==============
|
||
|
||
# --- Admin 登录相关 ---
|
||
class AdminLoginRequest(BaseModel):
|
||
"""管理员登录请求"""
|
||
username: str
|
||
password: str
|
||
|
||
|
||
class AdminLoginResponse(BaseModel):
|
||
"""管理员登录响应"""
|
||
token: str
|
||
username: str
|
||
expires_at: str
|
||
|
||
|
||
class AdminVerifyResponse(BaseModel):
|
||
"""管理员验证响应"""
|
||
valid: bool
|
||
username: Optional[str] = None
|
||
|
||
|
||
# --- 用户认证相关 ---
|
||
class UserRegisterRequest(BaseModel):
|
||
"""用户注册请求"""
|
||
username: str
|
||
email: Optional[str] = None
|
||
password: str
|
||
invitation_code: str # 邀请制注册需要邀请码
|
||
|
||
|
||
class UserLoginRequest(BaseModel):
|
||
"""用户登录请求"""
|
||
username: str
|
||
password: str
|
||
|
||
|
||
class UserLoginResponse(BaseModel):
|
||
"""用户登录响应"""
|
||
token: str
|
||
user_id: str
|
||
username: str
|
||
email: Optional[str] = None
|
||
is_admin: bool = False
|
||
expires_at: str
|
||
|
||
|
||
class UserVerifyResponse(BaseModel):
|
||
"""用户验证响应"""
|
||
valid: bool
|
||
user_id: Optional[str] = None
|
||
username: Optional[str] = None
|
||
is_admin: bool = False
|
||
|
||
|
||
class UserInfoResponse(BaseModel):
|
||
"""用户信息响应"""
|
||
id: str
|
||
username: str
|
||
email: Optional[str] = None
|
||
is_admin: bool = False
|
||
created_at: str
|
||
last_login: Optional[str] = None
|
||
|
||
|
||
class UserSearchResponse(BaseModel):
|
||
"""用户搜索响应"""
|
||
id: str
|
||
username: str
|
||
email: Optional[str] = None
|
||
|
||
|
||
# --- 模型相关 ---
|
||
class ModelCreate(BaseModel):
|
||
"""创建模型请求"""
|
||
name: str
|
||
provider: str
|
||
model: str
|
||
server: Optional[str] = None
|
||
api_key: Optional[str] = None
|
||
is_default: bool = False
|
||
|
||
|
||
class ModelUpdate(BaseModel):
|
||
"""更新模型请求"""
|
||
name: Optional[str] = None
|
||
provider: Optional[str] = None
|
||
model: Optional[str] = None
|
||
server: Optional[str] = None
|
||
api_key: Optional[str] = None
|
||
is_default: Optional[bool] = None
|
||
|
||
|
||
class ModelResponse(BaseModel):
|
||
"""模型响应"""
|
||
id: str
|
||
name: str
|
||
provider: str
|
||
model: str
|
||
server: Optional[str]
|
||
api_key: Optional[str] # 掩码显示
|
||
is_default: bool
|
||
created_at: str
|
||
updated_at: str
|
||
|
||
|
||
# --- Bot 相关 ---
|
||
class BotCreate(BaseModel):
|
||
"""创建 Bot 请求"""
|
||
name: str
|
||
|
||
|
||
class BotUpdate(BaseModel):
|
||
"""更新 Bot 请求"""
|
||
name: Optional[str] = None
|
||
bot_id: Optional[str] = None
|
||
|
||
|
||
class BotResponse(BaseModel):
|
||
"""Bot 响应"""
|
||
id: str
|
||
name: str
|
||
bot_id: str
|
||
is_owner: bool = False
|
||
is_shared: bool = False
|
||
owner: Optional[dict] = None # {id, username}
|
||
role: Optional[str] = None # 'viewer', 'editor', None for owner
|
||
shared_at: Optional[str] = None
|
||
expires_at: Optional[str] = None # 分享过期时间
|
||
description: Optional[str] = None # 从 settings 中提取
|
||
avatar_url: Optional[str] = None # 从 settings 中提取
|
||
created_at: str
|
||
updated_at: str
|
||
|
||
|
||
# --- Bot 设置相关 ---
|
||
class BotSettingsUpdate(BaseModel):
|
||
"""更新 Bot 设置请求"""
|
||
model_id: Optional[str] = None
|
||
language: Optional[str] = None
|
||
robot_type: Optional[str] = None
|
||
avatar_url: Optional[str] = None
|
||
description: Optional[str] = None
|
||
suggestions: Optional[List[str]] = None
|
||
dataset_ids: Optional[str] = None
|
||
system_prompt: Optional[str] = None
|
||
enable_memori: Optional[bool] = None
|
||
enable_thinking: Optional[bool] = None
|
||
tool_response: Optional[bool] = None
|
||
skills: Optional[str] = None
|
||
|
||
|
||
class ModelInfo(BaseModel):
|
||
"""模型信息"""
|
||
id: str
|
||
name: str
|
||
provider: str
|
||
model: str
|
||
server: Optional[str]
|
||
api_key: Optional[str] # 掩码显示
|
||
|
||
|
||
class BotSettingsResponse(BaseModel):
|
||
"""Bot 设置响应"""
|
||
bot_id: str
|
||
model_id: Optional[str]
|
||
model: Optional[ModelInfo] # 关联的模型信息
|
||
language: str
|
||
robot_type: Optional[str]
|
||
avatar_url: Optional[str]
|
||
description: Optional[str]
|
||
suggestions: Optional[List[str]]
|
||
dataset_ids: Optional[str]
|
||
system_prompt: Optional[str]
|
||
enable_memori: bool
|
||
enable_thinking: bool
|
||
tool_response: bool
|
||
skills: Optional[str]
|
||
updated_at: str
|
||
|
||
|
||
# --- 会话相关 ---
|
||
class SessionCreate(BaseModel):
|
||
"""创建会话请求"""
|
||
title: Optional[str] = None
|
||
|
||
|
||
class SessionResponse(BaseModel):
|
||
"""会话响应"""
|
||
id: str
|
||
bot_id: str
|
||
title: Optional[str]
|
||
created_at: str
|
||
updated_at: str
|
||
|
||
|
||
# --- MCP 相关 ---
|
||
class MCPServerCreate(BaseModel):
|
||
"""创建 MCP 服务器请求"""
|
||
name: str
|
||
type: str
|
||
config: dict
|
||
enabled: bool = True
|
||
|
||
|
||
class MCPServerUpdate(BaseModel):
|
||
"""更新 MCP 服务器请求"""
|
||
name: Optional[str] = None
|
||
type: Optional[str] = None
|
||
config: Optional[dict] = None
|
||
enabled: Optional[bool] = None
|
||
|
||
|
||
class MCPServerResponse(BaseModel):
|
||
"""MCP 服务器响应"""
|
||
id: str
|
||
bot_id: str
|
||
name: str
|
||
type: str
|
||
config: dict
|
||
enabled: bool
|
||
created_at: str
|
||
updated_at: str
|
||
|
||
|
||
# --- 分享相关 ---
|
||
class BotShareCreate(BaseModel):
|
||
"""创建分享请求"""
|
||
user_ids: List[str]
|
||
role: str = "viewer" # 'viewer' or 'editor'
|
||
expires_at: Optional[str] = None # ISO 8601 格式的过期时间,None 表示永不过期
|
||
|
||
|
||
class BotShareResponse(BaseModel):
|
||
"""分享响应"""
|
||
id: str
|
||
bot_id: str
|
||
user_id: str
|
||
username: str
|
||
email: Optional[str] = None
|
||
role: str
|
||
shared_at: str
|
||
shared_by: Optional[str] = None
|
||
expires_at: Optional[str] = None # 过期时间
|
||
|
||
|
||
class BotSharesListResponse(BaseModel):
|
||
"""分享列表响应"""
|
||
bot_id: str
|
||
shares: List[BotShareResponse]
|
||
|
||
|
||
|
||
# --- 通用响应 ---
|
||
class SuccessResponse(BaseModel):
|
||
"""通用成功响应"""
|
||
success: bool
|
||
message: str
|
||
|
||
|
||
# ============== 数据库表初始化 ==============
|
||
|
||
async def migrate_bot_owner_and_shares():
|
||
"""
|
||
迁移 agent_bots 表添加 owner_id 字段,并创建相关表
|
||
"""
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 1. 首先创建 agent_user 表
|
||
await cursor.execute("""
|
||
SELECT EXISTS (
|
||
SELECT FROM information_schema.tables
|
||
WHERE table_name = 'agent_user'
|
||
)
|
||
""")
|
||
user_table_exists = (await cursor.fetchone())[0]
|
||
|
||
if not user_table_exists:
|
||
logger.info("Creating agent_user table")
|
||
|
||
# 创建 agent_user 表
|
||
await cursor.execute("""
|
||
CREATE TABLE agent_user (
|
||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||
username VARCHAR(255) UNIQUE NOT NULL,
|
||
email VARCHAR(255) UNIQUE,
|
||
password_hash VARCHAR(255) NOT NULL,
|
||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||
last_login TIMESTAMP WITH TIME ZONE,
|
||
is_active BOOLEAN DEFAULT TRUE,
|
||
is_admin BOOLEAN DEFAULT FALSE
|
||
)
|
||
""")
|
||
|
||
# 创建索引
|
||
await cursor.execute("CREATE INDEX idx_agent_user_username ON agent_user(username)")
|
||
await cursor.execute("CREATE INDEX idx_agent_user_email ON agent_user(email)")
|
||
await cursor.execute("CREATE INDEX idx_agent_user_is_active ON agent_user(is_active)")
|
||
|
||
logger.info("agent_user table created successfully")
|
||
else:
|
||
# 为已存在的表添加 is_admin 字段
|
||
await cursor.execute("""
|
||
SELECT column_name
|
||
FROM information_schema.columns
|
||
WHERE table_name = 'agent_user' AND column_name = 'is_admin'
|
||
""")
|
||
has_admin_column = await cursor.fetchone()
|
||
if not has_admin_column:
|
||
logger.info("Adding is_admin column to agent_user table")
|
||
await cursor.execute("""
|
||
ALTER TABLE agent_user
|
||
ADD COLUMN is_admin BOOLEAN DEFAULT FALSE
|
||
""")
|
||
logger.info("is_admin column added successfully")
|
||
|
||
# 2. 创建 bot_shares 表
|
||
await cursor.execute("""
|
||
SELECT EXISTS (
|
||
SELECT FROM information_schema.tables
|
||
WHERE table_name = 'bot_shares'
|
||
)
|
||
""")
|
||
shares_table_exists = (await cursor.fetchone())[0]
|
||
|
||
if not shares_table_exists:
|
||
logger.info("Creating bot_shares table")
|
||
|
||
await cursor.execute("""
|
||
CREATE TABLE bot_shares (
|
||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||
bot_id UUID NOT NULL REFERENCES agent_bots(id) ON DELETE CASCADE,
|
||
user_id UUID NOT NULL REFERENCES agent_user(id) ON DELETE CASCADE,
|
||
shared_by UUID NOT NULL REFERENCES agent_user(id) ON DELETE SET NULL,
|
||
role VARCHAR(50) DEFAULT 'viewer' CHECK (role IN ('viewer', 'editor')),
|
||
shared_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||
expires_at TIMESTAMP WITH TIME ZONE,
|
||
UNIQUE(bot_id, user_id)
|
||
)
|
||
""")
|
||
|
||
await cursor.execute("CREATE INDEX idx_bot_shares_bot_id ON bot_shares(bot_id)")
|
||
await cursor.execute("CREATE INDEX idx_bot_shares_user_id ON bot_shares(user_id)")
|
||
await cursor.execute("CREATE INDEX idx_bot_shares_shared_by ON bot_shares(shared_by)")
|
||
|
||
logger.info("bot_shares table created successfully")
|
||
else:
|
||
# 为已存在的表添加 expires_at 字段
|
||
await cursor.execute("""
|
||
SELECT column_name
|
||
FROM information_schema.columns
|
||
WHERE table_name = 'bot_shares' AND column_name = 'expires_at'
|
||
""")
|
||
has_expires_column = await cursor.fetchone()
|
||
if not has_expires_column:
|
||
logger.info("Adding expires_at column to bot_shares table")
|
||
await cursor.execute("""
|
||
ALTER TABLE bot_shares
|
||
ADD COLUMN expires_at TIMESTAMP WITH TIME ZONE
|
||
""")
|
||
logger.info("expires_at column added successfully")
|
||
|
||
# 4. 创建 agent_user_tokens 表
|
||
await cursor.execute("""
|
||
SELECT EXISTS (
|
||
SELECT FROM information_schema.tables
|
||
WHERE table_name = 'agent_user_tokens'
|
||
)
|
||
""")
|
||
tokens_table_exists = (await cursor.fetchone())[0]
|
||
|
||
if not tokens_table_exists:
|
||
logger.info("Creating agent_user_tokens table")
|
||
|
||
await cursor.execute("""
|
||
CREATE TABLE agent_user_tokens (
|
||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||
user_id UUID NOT NULL REFERENCES agent_user(id) ON DELETE CASCADE,
|
||
token VARCHAR(255) NOT NULL UNIQUE,
|
||
expires_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||
)
|
||
""")
|
||
|
||
await cursor.execute("CREATE INDEX idx_agent_user_tokens_token ON agent_user_tokens(token)")
|
||
await cursor.execute("CREATE INDEX idx_agent_user_tokens_user_id ON agent_user_tokens(user_id)")
|
||
await cursor.execute("CREATE INDEX idx_agent_user_tokens_expires ON agent_user_tokens(expires_at)")
|
||
|
||
logger.info("agent_user_tokens table created successfully")
|
||
|
||
# 5. 检查 agent_bots 表是否有 owner_id 字段
|
||
await cursor.execute("""
|
||
SELECT column_name
|
||
FROM information_schema.columns
|
||
WHERE table_name = 'agent_bots' AND column_name = 'owner_id'
|
||
""")
|
||
has_owner_column = await cursor.fetchone()
|
||
|
||
if not has_owner_column:
|
||
logger.info("Adding owner_id column to agent_bots table")
|
||
|
||
# 首先创建或更新默认 admin 用户(密码:admin123)
|
||
default_admin_id = '00000000-0000-0000-0000-000000000001'
|
||
default_admin_password_hash = hash_password('admin123')
|
||
|
||
# 先<><E58588><EFBFBD>除可能存在的旧 admin 用户(避免用户名冲突)
|
||
await cursor.execute("DELETE FROM agent_user WHERE username = 'admin' AND id != %s", (default_admin_id,))
|
||
|
||
# 创建或更新 admin 用户
|
||
await cursor.execute("""
|
||
INSERT INTO agent_user (id, username, email, password_hash, is_active, is_admin)
|
||
VALUES (%s, 'admin', 'admin@local', %s, TRUE, TRUE)
|
||
ON CONFLICT (id) DO UPDATE SET password_hash = EXCLUDED.password_hash, is_admin = EXCLUDED.is_admin
|
||
""", (default_admin_id, default_admin_password_hash))
|
||
|
||
logger.info(f"Default admin user created/updated with password 'admin123'")
|
||
|
||
# 添加 owner_id 字段,允许 NULL 以便迁移
|
||
await cursor.execute("""
|
||
ALTER TABLE agent_bots
|
||
ADD COLUMN owner_id UUID REFERENCES agent_user(id) ON DELETE SET NULL
|
||
""")
|
||
|
||
# 创建索引
|
||
await cursor.execute("CREATE INDEX idx_agent_bots_owner_id ON agent_bots(owner_id)")
|
||
|
||
# 将现有的 bots 分配给默认 admin 用户
|
||
await cursor.execute("""
|
||
UPDATE agent_bots
|
||
SET owner_id = %s
|
||
WHERE owner_id IS NULL
|
||
""", (default_admin_id,))
|
||
|
||
logger.info("Existing bots assigned to default admin user")
|
||
|
||
# 现在将 owner_id 改为 NOT NULL
|
||
await cursor.execute("""
|
||
ALTER TABLE agent_bots
|
||
ALTER COLUMN owner_id SET NOT NULL
|
||
""")
|
||
|
||
# 为了防止数据丢失,将 ON DELETE SET NULL 改为 ON DELETE RESTRICT
|
||
# 需要先删除约束再重新添加
|
||
await cursor.execute("""
|
||
ALTER TABLE agent_bots
|
||
DROP CONSTRAINT agent_bots_owner_id_fkey
|
||
""")
|
||
await cursor.execute("""
|
||
ALTER TABLE agent_bots
|
||
ADD CONSTRAINT agent_bots_owner_id_fkey
|
||
FOREIGN KEY (owner_id) REFERENCES agent_user(id) ON DELETE RESTRICT
|
||
""")
|
||
|
||
logger.info("owner_id column added and set to NOT NULL")
|
||
|
||
# 确保默认 admin 用户存在(总是执行)
|
||
default_admin_id = '00000000-0000-0000-0000-000000000001'
|
||
default_admin_password_hash = hash_password('admin123')
|
||
|
||
# 检查 admin 用户是否已存在
|
||
await cursor.execute("SELECT id, password_hash FROM agent_user WHERE username = 'admin'")
|
||
admin_row = await cursor.fetchone()
|
||
|
||
if admin_row:
|
||
existing_id, existing_hash = admin_row
|
||
# 如果密码是旧的 PLACEHOLDER,则更新
|
||
if existing_hash == 'PLACEHOLDER' or existing_hash == '8c6976e5b5410415bde908bd4dee15dfb167a9c873fc4bb8a81f6f2ab448a918':
|
||
await cursor.execute("""
|
||
UPDATE agent_user
|
||
SET password_hash = %s, is_admin = TRUE
|
||
WHERE username = 'admin'
|
||
""", (default_admin_password_hash,))
|
||
logger.info("Updated existing admin user with new password 'admin123'")
|
||
else:
|
||
# 创建新的 admin 用户
|
||
await cursor.execute("""
|
||
INSERT INTO agent_user (id, username, email, password_hash, is_active, is_admin)
|
||
VALUES (%s, 'admin', 'admin@local', %s, TRUE, TRUE)
|
||
""", (default_admin_id, default_admin_password_hash))
|
||
logger.info("Created new admin user with password 'admin123'")
|
||
|
||
await conn.commit()
|
||
logger.info("Bot owner and shares migration completed successfully")
|
||
|
||
|
||
async def migrate_bot_settings_to_jsonb():
|
||
"""
|
||
迁移 agent_bot_settings 表数据到 agent_bots.settings JSONB 字段
|
||
这是一个向后兼容的迁移函数
|
||
"""
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 1. 检查 agent_bots 表是否有 settings 列
|
||
await cursor.execute("""
|
||
SELECT column_name
|
||
FROM information_schema.columns
|
||
WHERE table_name = 'agent_bots' AND column_name = 'settings'
|
||
""")
|
||
has_settings_column = await cursor.fetchone()
|
||
|
||
if not has_settings_column:
|
||
logger.info("Migrating agent_bots table: adding settings column")
|
||
|
||
# 添加 settings 列
|
||
await cursor.execute("""
|
||
ALTER TABLE agent_bots
|
||
ADD COLUMN settings JSONB DEFAULT '{
|
||
"language": "zh",
|
||
"enable_memori": false,
|
||
"enable_thinking": false,
|
||
"tool_response": false
|
||
}'::jsonb
|
||
""")
|
||
|
||
# 2. 检查旧的 agent_bot_settings 表是否存在
|
||
await cursor.execute("""
|
||
SELECT EXISTS (
|
||
SELECT FROM information_schema.tables
|
||
WHERE table_name = 'agent_bot_settings'
|
||
)
|
||
""")
|
||
old_table_exists = (await cursor.fetchone())[0]
|
||
|
||
if old_table_exists:
|
||
logger.info("Migrating data from agent_bot_settings to agent_bots.settings")
|
||
|
||
# 3. 迁移旧数据到新字段
|
||
await cursor.execute("""
|
||
UPDATE agent_bots b
|
||
SET settings = jsonb_build_object(
|
||
'model_id', s.model_id,
|
||
'language', COALESCE(s.language, 'zh'),
|
||
'robot_type', s.robot_type,
|
||
'dataset_ids', s.dataset_ids,
|
||
'system_prompt', s.system_prompt,
|
||
'enable_memori', COALESCE(s.enable_memori, false),
|
||
'enable_thinking', false,
|
||
'tool_response', COALESCE(s.tool_response, false),
|
||
'skills', s.skills
|
||
)
|
||
FROM agent_bot_settings s
|
||
WHERE b.id = s.bot_id
|
||
""")
|
||
|
||
logger.info("Data migration completed, dropping old table")
|
||
|
||
# 4. 删除旧的 agent_bot_settings 表
|
||
await cursor.execute("DROP TABLE IF EXISTS agent_bot_settings CASCADE")
|
||
|
||
await conn.commit()
|
||
logger.info("Bot settings migration completed successfully")
|
||
else:
|
||
logger.info("Settings column already exists, skipping migration")
|
||
|
||
|
||
async def init_bot_manager_tables():
|
||
"""
|
||
初始化 Bot Manager 相关的所有数据库表
|
||
"""
|
||
pool = get_db_pool_manager().pool
|
||
|
||
# 首先执行迁移(如果需要)
|
||
# 1. Bot settings 迁移
|
||
await migrate_bot_settings_to_jsonb()
|
||
# 2. User 和 shares 迁移
|
||
await migrate_bot_owner_and_shares()
|
||
|
||
# SQL 表创建语句
|
||
tables_sql = [
|
||
# admin_tokens 表(用于存储登录 token)
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS agent_admin_tokens (
|
||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||
username VARCHAR(255) NOT NULL,
|
||
token VARCHAR(255) NOT NULL UNIQUE,
|
||
expires_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||
)
|
||
""",
|
||
# admin_tokens 索引
|
||
"CREATE INDEX IF NOT EXISTS idx_agent_admin_tokens_token ON agent_admin_tokens(token)",
|
||
"CREATE INDEX IF NOT EXISTS idx_agent_admin_tokens_expires ON agent_admin_tokens(expires_at)",
|
||
|
||
# models 表
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS agent_models (
|
||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||
name VARCHAR(255) NOT NULL,
|
||
provider VARCHAR(100) NOT NULL,
|
||
model VARCHAR(255) NOT NULL,
|
||
server VARCHAR(500),
|
||
api_key VARCHAR(500),
|
||
is_default BOOLEAN DEFAULT FALSE,
|
||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||
)
|
||
""",
|
||
# models 索引
|
||
"CREATE INDEX IF NOT EXISTS idx_agent_models_is_default ON agent_models(is_default)",
|
||
|
||
# bots 表(合并 settings 为 JSONB 字段)
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS agent_bots (
|
||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||
name VARCHAR(255) NOT NULL,
|
||
bot_id VARCHAR(255) NOT NULL UNIQUE,
|
||
settings JSONB DEFAULT '{
|
||
"language": "zh",
|
||
"enable_memori": false,
|
||
"enable_thinking": false,
|
||
"tool_response": false
|
||
}'::jsonb,
|
||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||
)
|
||
""",
|
||
# bots 索引
|
||
"CREATE INDEX IF NOT EXISTS idx_agent_bots_bot_id ON agent_bots(bot_id)",
|
||
|
||
# mcp_servers 表
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS agent_mcp_servers (
|
||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||
bot_id UUID REFERENCES agent_bots(id) ON DELETE CASCADE,
|
||
name VARCHAR(255) NOT NULL,
|
||
type VARCHAR(50) NOT NULL,
|
||
config JSONB NOT NULL,
|
||
enabled BOOLEAN DEFAULT TRUE,
|
||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||
)
|
||
""",
|
||
# mcp_servers 索引
|
||
"CREATE INDEX IF NOT EXISTS idx_agent_mcp_servers_bot_id ON agent_mcp_servers(bot_id)",
|
||
"CREATE INDEX IF NOT EXISTS idx_agent_mcp_servers_enabled ON agent_mcp_servers(enabled)",
|
||
|
||
# chat_sessions 表
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS agent_chat_sessions (
|
||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||
bot_id UUID REFERENCES agent_bots(id) ON DELETE CASCADE,
|
||
title VARCHAR(500),
|
||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||
)
|
||
""",
|
||
# chat_sessions 索引
|
||
"CREATE INDEX IF NOT EXISTS idx_agent_chat_sessions_bot_id ON agent_chat_sessions(bot_id)",
|
||
"CREATE INDEX IF NOT EXISTS idx_agent_chat_sessions_created ON agent_chat_sessions(created_at DESC)",
|
||
]
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
for sql in tables_sql:
|
||
await cursor.execute(sql)
|
||
await conn.commit()
|
||
|
||
logger.info("Bot Manager tables initialized successfully")
|
||
|
||
|
||
# ============== 辅助函数 ==============
|
||
|
||
def mask_api_key(api_key: Optional[str]) -> Optional[str]:
|
||
"""对 API Key 进行掩码处理"""
|
||
if not api_key:
|
||
return None
|
||
if len(api_key) <= 8:
|
||
return "****"
|
||
return api_key[:4] + "****" + api_key[-4:]
|
||
|
||
|
||
def datetime_to_str(dt: datetime) -> str:
|
||
"""将 datetime 转换为 ISO 格式字符串"""
|
||
return dt.isoformat() if dt else ""
|
||
|
||
|
||
# ============== 模型管理 API ==============
|
||
|
||
@router.get("/api/v1/models", response_model=List[ModelResponse])
|
||
async def get_models(authorization: Optional[str] = Header(None)):
|
||
"""
|
||
获取所有模型配置
|
||
|
||
Args:
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
List[ModelResponse]: 模型列表
|
||
"""
|
||
verify_auth(authorization)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
SELECT id, name, provider, model, server, api_key, is_default, created_at, updated_at
|
||
FROM agent_models
|
||
ORDER BY is_default DESC, created_at DESC
|
||
""")
|
||
rows = await cursor.fetchall()
|
||
|
||
return [
|
||
ModelResponse(
|
||
id=str(row[0]),
|
||
name=row[1],
|
||
provider=row[2],
|
||
model=row[3],
|
||
server=row[4],
|
||
api_key=mask_api_key(row[5]),
|
||
is_default=row[6],
|
||
created_at=datetime_to_str(row[7]),
|
||
updated_at=datetime_to_str(row[8])
|
||
)
|
||
for row in rows
|
||
]
|
||
|
||
|
||
@router.post("/api/v1/models", response_model=ModelResponse)
|
||
async def create_model(request: ModelCreate, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
创建新模型
|
||
|
||
Args:
|
||
request: 模型创建请求
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
ModelResponse: 创建的模型信息
|
||
"""
|
||
verify_auth(authorization)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
# 如果设置为默认,需要先取消其他默认模型
|
||
if request.is_default:
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("UPDATE agent_models SET is_default = FALSE WHERE is_default = TRUE")
|
||
await conn.commit()
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
INSERT INTO agent_models (name, provider, model, server, api_key, is_default)
|
||
VALUES (%s, %s, %s, %s, %s, %s)
|
||
RETURNING id, created_at, updated_at
|
||
""", (
|
||
request.name,
|
||
request.provider,
|
||
request.model,
|
||
request.server,
|
||
request.api_key,
|
||
request.is_default
|
||
))
|
||
row = await cursor.fetchone()
|
||
await conn.commit()
|
||
|
||
return ModelResponse(
|
||
id=str(row[0]),
|
||
name=request.name,
|
||
provider=request.provider,
|
||
model=request.model,
|
||
server=request.server,
|
||
api_key=mask_api_key(request.api_key),
|
||
is_default=request.is_default,
|
||
created_at=datetime_to_str(row[1]),
|
||
updated_at=datetime_to_str(row[2])
|
||
)
|
||
|
||
|
||
@router.put("/api/v1/models/{model_id}", response_model=ModelResponse)
|
||
async def update_model(
|
||
model_id: str,
|
||
request: ModelUpdate,
|
||
authorization: Optional[str] = Header(None)
|
||
):
|
||
"""
|
||
更新模型
|
||
|
||
Args:
|
||
model_id: 模型 ID
|
||
request: 模型更新请求
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
ModelResponse: 更新后的模型信息
|
||
"""
|
||
verify_auth(authorization)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
# 构建更新字段
|
||
update_fields = []
|
||
values = []
|
||
|
||
if request.name is not None:
|
||
update_fields.append("name = %s")
|
||
values.append(request.name)
|
||
if request.provider is not None:
|
||
update_fields.append("provider = %s")
|
||
values.append(request.provider)
|
||
if request.model is not None:
|
||
update_fields.append("model = %s")
|
||
values.append(request.model)
|
||
if request.server is not None:
|
||
update_fields.append("server = %s")
|
||
values.append(request.server)
|
||
if request.api_key is not None:
|
||
update_fields.append("api_key = %s")
|
||
values.append(request.api_key)
|
||
if request.is_default is not None:
|
||
update_fields.append("is_default = %s")
|
||
values.append(request.is_default)
|
||
|
||
if not update_fields:
|
||
raise HTTPException(status_code=400, detail="No fields to update")
|
||
|
||
update_fields.append("updated_at = NOW()")
|
||
values.append(model_id)
|
||
|
||
# 如果设置为默认,需要先取消其他默认模型
|
||
if request.is_default is True:
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("UPDATE agent_models SET is_default = FALSE WHERE is_default = TRUE")
|
||
await conn.commit()
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute(f"""
|
||
UPDATE agent_models
|
||
SET {', '.join(update_fields)}
|
||
WHERE id = %s
|
||
RETURNING id, name, provider, model, server, api_key, is_default, created_at, updated_at
|
||
""", values)
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(status_code=404, detail="Model not found")
|
||
|
||
await conn.commit()
|
||
|
||
return ModelResponse(
|
||
id=str(row[0]),
|
||
name=row[1],
|
||
provider=row[2],
|
||
model=row[3],
|
||
server=row[4],
|
||
api_key=mask_api_key(row[5]),
|
||
is_default=row[6],
|
||
created_at=datetime_to_str(row[7]),
|
||
updated_at=datetime_to_str(row[8])
|
||
)
|
||
|
||
|
||
@router.delete("/api/v1/models/{model_id}", response_model=SuccessResponse)
|
||
async def delete_model(model_id: str, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
删除模型
|
||
|
||
Args:
|
||
model_id: 模型 ID
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
SuccessResponse: 删除结果
|
||
"""
|
||
verify_auth(authorization)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("DELETE FROM agent_models WHERE id = %s RETURNING id", (model_id,))
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(status_code=404, detail="Model not found")
|
||
|
||
await conn.commit()
|
||
|
||
return SuccessResponse(success=True, message="Model deleted successfully")
|
||
|
||
|
||
@router.patch("/api/v1/models/{model_id}/default", response_model=SuccessResponse)
|
||
async def set_default_model(model_id: str, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
设置默认模型
|
||
|
||
Args:
|
||
model_id: 模型 ID
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
SuccessResponse: 设置结果
|
||
"""
|
||
verify_auth(authorization)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
# 首先检查模型是否存在
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("SELECT id FROM agent_models WHERE id = %s", (model_id,))
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(status_code=404, detail="Model not found")
|
||
|
||
# 取消所有默认设置
|
||
await cursor.execute("UPDATE agent_models SET is_default = FALSE WHERE is_default = TRUE")
|
||
|
||
# 设置新的默认模型
|
||
await cursor.execute("UPDATE agent_models SET is_default = TRUE WHERE id = %s", (model_id,))
|
||
|
||
await conn.commit()
|
||
|
||
return SuccessResponse(success=True, message="Default model updated successfully")
|
||
|
||
|
||
# ============== Bot 管理 API ==============
|
||
|
||
@router.get("/api/v1/bots", response_model=List[BotResponse])
|
||
async def get_bots(authorization: Optional[str] = Header(None)):
|
||
"""
|
||
获取所有 Bot(拥有的和分享给我的)
|
||
|
||
Args:
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
List[BotResponse]: Bot 列表
|
||
"""
|
||
# 支持管理员认证和用户认证
|
||
admin_valid, admin_username = await verify_admin_auth(authorization)
|
||
user_valid, user_id, user_username = await verify_user_auth(authorization)
|
||
|
||
if not admin_valid and not user_valid:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Unauthorized"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
if admin_valid:
|
||
# 管理员可以看到所有 Bot
|
||
await cursor.execute("""
|
||
SELECT b.id, b.name, b.bot_id, b.created_at, b.updated_at, b.settings,
|
||
u.id as owner_id, u.username as owner_username
|
||
FROM agent_bots b
|
||
LEFT JOIN agent_user u ON b.owner_id = u.id
|
||
ORDER BY b.created_at DESC
|
||
""")
|
||
rows = await cursor.fetchall()
|
||
|
||
return [
|
||
BotResponse(
|
||
id=str(row[0]),
|
||
name=row[1],
|
||
bot_id=row[2],
|
||
is_owner=True,
|
||
is_shared=False,
|
||
owner={"id": str(row[6]), "username": row[7]} if row[6] else None,
|
||
role=None,
|
||
description=row[5].get('description') if row[5] else None,
|
||
avatar_url=row[5].get('avatar_url') if row[5] else None,
|
||
created_at=datetime_to_str(row[3]),
|
||
updated_at=datetime_to_str(row[4])
|
||
)
|
||
for row in rows
|
||
]
|
||
else:
|
||
# 用户只能看到拥有的 Bot 和分享给自己的 Bot(且未过期)
|
||
await cursor.execute("""
|
||
SELECT b.id, b.name, b.bot_id, b.created_at, b.updated_at, b.settings,
|
||
u.id as owner_id, u.username as owner_username,
|
||
s.role, s.shared_at, s.expires_at
|
||
FROM agent_bots b
|
||
LEFT JOIN agent_user u ON b.owner_id = u.id
|
||
LEFT JOIN bot_shares s ON b.id = s.bot_id AND s.user_id = %s
|
||
WHERE b.owner_id = %s
|
||
OR (s.user_id = %s AND (s.expires_at IS NULL OR s.expires_at > NOW()))
|
||
ORDER BY b.created_at DESC
|
||
""", (user_id, user_id, user_id))
|
||
rows = await cursor.fetchall()
|
||
|
||
return [
|
||
BotResponse(
|
||
id=str(row[0]),
|
||
name=row[1],
|
||
bot_id=row[2],
|
||
is_owner=(str(row[6]) == user_id if row[6] else False),
|
||
is_shared=(str(row[6]) != user_id and row[8] is not None) if row[6] else False,
|
||
owner={"id": str(row[6]), "username": row[7]} if row[6] else None,
|
||
role=row[8] if row[8] else None,
|
||
shared_at=datetime_to_str(row[9]) if row[9] else None,
|
||
expires_at=row[10].isoformat() if row[10] else None,
|
||
description=row[5].get('description') if row[5] else None,
|
||
avatar_url=row[5].get('avatar_url') if row[5] else None,
|
||
created_at=datetime_to_str(row[3]),
|
||
updated_at=datetime_to_str(row[4])
|
||
)
|
||
for row in rows
|
||
]
|
||
|
||
|
||
@router.post("/api/v1/bots", response_model=BotResponse)
|
||
async def create_bot(request: BotCreate, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
创建新 Bot
|
||
|
||
Args:
|
||
request: Bot 创建请求
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
BotResponse: 创建的 Bot 信息
|
||
"""
|
||
# 支持管理员认证和用户认证
|
||
admin_valid, admin_username = await verify_admin_auth(authorization)
|
||
user_valid, user_id, user_username = await verify_user_auth(authorization)
|
||
|
||
if not admin_valid and not user_valid:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Unauthorized"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
# 自动生成 bot_id
|
||
bot_id = str(uuid.uuid4())
|
||
|
||
# 使用用户 ID 或默认 admin ID
|
||
owner_id = user_id if user_valid else None
|
||
|
||
try:
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
INSERT INTO agent_bots (name, bot_id, owner_id)
|
||
VALUES (%s, %s, %s)
|
||
RETURNING id, created_at, updated_at, owner_id
|
||
""", (request.name, bot_id, owner_id))
|
||
row = await cursor.fetchone()
|
||
await conn.commit()
|
||
|
||
return BotResponse(
|
||
id=str(row[0]),
|
||
name=request.name,
|
||
bot_id=bot_id,
|
||
is_owner=True,
|
||
is_shared=False,
|
||
owner={"id": str(owner_id), "username": user_username} if owner_id else None,
|
||
created_at=datetime_to_str(row[1]),
|
||
updated_at=datetime_to_str(row[2])
|
||
)
|
||
except Exception as e:
|
||
if "duplicate key" in str(e):
|
||
raise HTTPException(status_code=400, detail="Bot ID already exists")
|
||
raise
|
||
|
||
|
||
@router.put("/api/v1/bots/{bot_uuid}", response_model=BotResponse)
|
||
async def update_bot(
|
||
bot_uuid: str,
|
||
request: BotUpdate,
|
||
authorization: Optional[str] = Header(None)
|
||
):
|
||
"""
|
||
更新 Bot(仅所有者可以更新)
|
||
|
||
Args:
|
||
bot_uuid: Bot 内部 UUID
|
||
request: Bot 更新请求
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
BotResponse: 更新后的 Bot 信息
|
||
"""
|
||
# 支持管理员认证和用户认证
|
||
admin_valid, admin_username = await verify_admin_auth(authorization)
|
||
user_valid, user_id, user_username = await verify_user_auth(authorization)
|
||
|
||
if not admin_valid and not user_valid:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Unauthorized"
|
||
)
|
||
|
||
# 非管理员需要检查所有权
|
||
if user_valid:
|
||
if not await is_bot_owner(bot_uuid, user_id):
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail="Only bot owner can update"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
# 构建更新字段
|
||
update_fields = []
|
||
values = []
|
||
|
||
if request.name is not None:
|
||
update_fields.append("name = %s")
|
||
values.append(request.name)
|
||
if request.bot_id is not None:
|
||
update_fields.append("bot_id = %s")
|
||
values.append(request.bot_id)
|
||
|
||
if not update_fields:
|
||
raise HTTPException(status_code=400, detail="No fields to update")
|
||
|
||
update_fields.append("updated_at = NOW()")
|
||
values.append(bot_uuid)
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute(f"""
|
||
UPDATE agent_bots
|
||
SET {', '.join(update_fields)}
|
||
WHERE id = %s
|
||
RETURNING id, name, bot_id, created_at, updated_at
|
||
""", values)
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(status_code=404, detail="Bot not found")
|
||
|
||
await conn.commit()
|
||
|
||
return BotResponse(
|
||
id=str(row[0]),
|
||
name=row[1],
|
||
bot_id=row[2],
|
||
is_owner=True,
|
||
is_shared=False,
|
||
created_at=datetime_to_str(row[3]),
|
||
updated_at=datetime_to_str(row[4])
|
||
)
|
||
|
||
|
||
@router.delete("/api/v1/bots/{bot_uuid}", response_model=SuccessResponse)
|
||
async def delete_bot(bot_uuid: str, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
删除 Bot(仅所有者可以删除,级联删除相关设置、会话、MCP 配置等)
|
||
|
||
Args:
|
||
bot_uuid: Bot 内部 UUID
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
SuccessResponse: 删除结果
|
||
"""
|
||
# 支持管理员认证和用户认证
|
||
admin_valid, admin_username = await verify_admin_auth(authorization)
|
||
user_valid, user_id, user_username = await verify_user_auth(authorization)
|
||
|
||
if not admin_valid and not user_valid:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Unauthorized"
|
||
)
|
||
|
||
# 非管理员需要检查所有权
|
||
if user_valid:
|
||
if not await is_bot_owner(bot_uuid, user_id):
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail="Only bot owner can delete"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("DELETE FROM agent_bots WHERE id = %s RETURNING id", (bot_uuid,))
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(status_code=404, detail="Bot not found")
|
||
|
||
await conn.commit()
|
||
|
||
return SuccessResponse(success=True, message="Bot deleted successfully")
|
||
|
||
|
||
# ============== Bot 设置 API ==============
|
||
|
||
@router.get("/api/v1/bots/{bot_uuid}/settings", response_model=BotSettingsResponse)
|
||
async def get_bot_settings(bot_uuid: str, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
获取 Bot 设置(所有者和 editor 可以查看)
|
||
|
||
Args:
|
||
bot_uuid: Bot 内部 UUID
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
BotSettingsResponse: Bot 设置信息
|
||
"""
|
||
# 支持管理员认证和用户认证
|
||
admin_valid, admin_username = await verify_admin_auth(authorization)
|
||
user_valid, user_id, user_username = await verify_user_auth(authorization)
|
||
|
||
if not admin_valid and not user_valid:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Unauthorized"
|
||
)
|
||
|
||
# 用户需要检查是否有 read 权限
|
||
if user_valid:
|
||
if not await check_bot_access(bot_uuid, user_id, 'read'):
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail="You don't have access to this bot"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
SELECT id, settings, updated_at
|
||
FROM agent_bots
|
||
WHERE id = %s
|
||
""", (bot_uuid,))
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(status_code=404, detail="Bot not found")
|
||
|
||
bot_id, settings_json, updated_at = row
|
||
settings = settings_json if settings_json else {}
|
||
|
||
# 获取关联的模型信息
|
||
model_info = None
|
||
model_id = settings.get('model_id')
|
||
if model_id:
|
||
await cursor.execute("""
|
||
SELECT id, name, provider, model, server, api_key
|
||
FROM agent_models WHERE id = %s
|
||
""", (model_id,))
|
||
model_row = await cursor.fetchone()
|
||
if model_row:
|
||
model_info = ModelInfo(
|
||
id=str(model_row[0]),
|
||
name=model_row[1],
|
||
provider=model_row[2],
|
||
model=model_row[3],
|
||
server=model_row[4],
|
||
api_key=mask_api_key(model_row[5])
|
||
)
|
||
|
||
return BotSettingsResponse(
|
||
bot_id=str(bot_id),
|
||
model_id=model_id,
|
||
model=model_info,
|
||
language=settings.get('language', 'zh'),
|
||
robot_type=settings.get('robot_type'),
|
||
avatar_url=settings.get('avatar_url'),
|
||
description=settings.get('description'),
|
||
suggestions=settings.get('suggestions'),
|
||
dataset_ids=settings.get('dataset_ids'),
|
||
system_prompt=settings.get('system_prompt'),
|
||
enable_memori=settings.get('enable_memori', False),
|
||
enable_thinking=settings.get('enable_thinking', False),
|
||
tool_response=settings.get('tool_response', False),
|
||
skills=settings.get('skills'),
|
||
updated_at=datetime_to_str(updated_at)
|
||
)
|
||
|
||
|
||
@router.put("/api/v1/bots/{bot_uuid}/settings", response_model=SuccessResponse)
|
||
async def update_bot_settings(
|
||
bot_uuid: str,
|
||
request: BotSettingsUpdate,
|
||
authorization: Optional[str] = Header(None)
|
||
):
|
||
"""
|
||
更新 Bot 设置(仅所有者和 editor 可以更新)
|
||
|
||
Args:
|
||
bot_uuid: Bot 内部 UUID
|
||
request: 设置更新请求
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
SuccessResponse: 更新结果
|
||
"""
|
||
# 支持管理员认证和用户认证
|
||
admin_valid, admin_username = await verify_admin_auth(authorization)
|
||
user_valid, user_id, user_username = await verify_user_auth(authorization)
|
||
|
||
if not admin_valid and not user_valid:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Unauthorized"
|
||
)
|
||
|
||
# 用户需要检查是否有 write 权限
|
||
if user_valid:
|
||
if not await check_bot_access(bot_uuid, user_id, 'write'):
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail="You don't have permission to modify this bot"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
# 处理 model_id:将空字符串转换为 None
|
||
model_id_value = request.model_id.strip() if request.model_id else None
|
||
|
||
# 验证 model_id 是否存在
|
||
if model_id_value:
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("SELECT id FROM agent_models WHERE id = %s", (model_id_value,))
|
||
if not await cursor.fetchone():
|
||
raise HTTPException(status_code=400, detail=f"Model with id '{request.model_id}' not found")
|
||
|
||
# 构建 JSONB 更新对象
|
||
update_json = {}
|
||
if request.model_id is not None:
|
||
update_json['model_id'] = model_id_value if model_id_value else None
|
||
if request.language is not None:
|
||
update_json['language'] = request.language
|
||
if request.robot_type is not None:
|
||
update_json['robot_type'] = request.robot_type
|
||
if request.avatar_url is not None:
|
||
update_json['avatar_url'] = request.avatar_url
|
||
if request.description is not None:
|
||
update_json['description'] = request.description
|
||
if request.suggestions is not None:
|
||
update_json['suggestions'] = request.suggestions
|
||
if request.dataset_ids is not None:
|
||
update_json['dataset_ids'] = request.dataset_ids
|
||
if request.system_prompt is not None:
|
||
update_json['system_prompt'] = request.system_prompt
|
||
if request.enable_memori is not None:
|
||
update_json['enable_memori'] = request.enable_memori
|
||
if request.enable_thinking is not None:
|
||
update_json['enable_thinking'] = request.enable_thinking
|
||
if request.tool_response is not None:
|
||
update_json['tool_response'] = request.tool_response
|
||
if request.skills is not None:
|
||
update_json['skills'] = request.skills
|
||
|
||
if not update_json:
|
||
raise HTTPException(status_code=400, detail="No fields to update")
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 检查 Bot 是否存在
|
||
await cursor.execute("SELECT id, settings FROM agent_bots WHERE id = %s", (bot_uuid,))
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(status_code=404, detail="Bot not found")
|
||
|
||
# 合并现有设置和新设置
|
||
existing_settings = row[1] if row[1] else {}
|
||
existing_settings.update(update_json)
|
||
|
||
# 更新设置
|
||
await cursor.execute("""
|
||
UPDATE agent_bots
|
||
SET settings = %s, updated_at = NOW()
|
||
WHERE id = %s
|
||
""", (json.dumps(existing_settings), bot_uuid))
|
||
|
||
await conn.commit()
|
||
|
||
return SuccessResponse(success=True, message="Bot settings updated successfully")
|
||
|
||
|
||
# ============== 会话管理 API ==============
|
||
|
||
@router.get("/api/v1/bots/{bot_uuid}/sessions", response_model=List[SessionResponse])
|
||
async def get_bot_sessions(bot_uuid: str, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
获取 Bot 的会话列表
|
||
|
||
Args:
|
||
bot_uuid: Bot 内部 UUID
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
List[SessionResponse]: 会话列表
|
||
"""
|
||
verify_auth(authorization)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
SELECT id, bot_id, title, created_at, updated_at
|
||
FROM agent_chat_sessions
|
||
WHERE bot_id = %s
|
||
ORDER BY updated_at DESC
|
||
""", (bot_uuid,))
|
||
rows = await cursor.fetchall()
|
||
|
||
return [
|
||
SessionResponse(
|
||
id=str(row[0]),
|
||
bot_id=str(row[1]),
|
||
title=row[2],
|
||
created_at=datetime_to_str(row[3]),
|
||
updated_at=datetime_to_str(row[4])
|
||
)
|
||
for row in rows
|
||
]
|
||
|
||
|
||
@router.post("/api/v1/bots/{bot_uuid}/sessions", response_model=SessionResponse)
|
||
async def create_session(
|
||
bot_uuid: str,
|
||
request: SessionCreate,
|
||
authorization: Optional[str] = Header(None)
|
||
):
|
||
"""
|
||
创建新会话
|
||
|
||
Args:
|
||
bot_uuid: Bot 内部 UUID
|
||
request: 会话创建请求
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
SessionResponse: 创建的会话信息
|
||
"""
|
||
verify_auth(authorization)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
# 验证 Bot 是否存在
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("SELECT id FROM agent_bots WHERE id = %s", (bot_uuid,))
|
||
if not await cursor.fetchone():
|
||
raise HTTPException(status_code=404, detail="Bot not found")
|
||
|
||
# 创建会话
|
||
await cursor.execute("""
|
||
INSERT INTO agent_chat_sessions (bot_id, title)
|
||
VALUES (%s, %s)
|
||
RETURNING id, created_at, updated_at
|
||
""", (bot_uuid, request.title))
|
||
row = await cursor.fetchone()
|
||
|
||
await conn.commit()
|
||
|
||
return SessionResponse(
|
||
id=str(row[0]),
|
||
bot_id=bot_uuid,
|
||
title=request.title,
|
||
created_at=datetime_to_str(row[1]),
|
||
updated_at=datetime_to_str(row[2])
|
||
)
|
||
|
||
|
||
@router.delete("/api/v1/sessions/{session_id}", response_model=SuccessResponse)
|
||
async def delete_session(session_id: str, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
删除会话
|
||
|
||
Args:
|
||
session_id: 会话 ID
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
SuccessResponse: 删除结果
|
||
"""
|
||
verify_auth(authorization)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("DELETE FROM agent_chat_sessions WHERE id = %s RETURNING id", (session_id,))
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
await conn.commit()
|
||
|
||
return SuccessResponse(success=True, message="Session deleted successfully")
|
||
|
||
|
||
# ============== MCP 服务器 API ==============
|
||
|
||
@router.get("/api/v1/bots/{bot_uuid}/mcp", response_model=List[MCPServerResponse])
|
||
async def get_mcp_servers(bot_uuid: str, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
获取 Bot 的 MCP 服务器配置
|
||
|
||
Args:
|
||
bot_uuid: Bot 内部 UUID
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
List[MCPServerResponse]: MCP 服务器列表
|
||
"""
|
||
verify_auth(authorization)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
SELECT id, bot_id, name, type, config, enabled, created_at, updated_at
|
||
FROM agent_mcp_servers
|
||
WHERE bot_id = %s
|
||
ORDER BY created_at DESC
|
||
""", (bot_uuid,))
|
||
rows = await cursor.fetchall()
|
||
|
||
return [
|
||
MCPServerResponse(
|
||
id=str(row[0]),
|
||
bot_id=str(row[1]),
|
||
name=row[2],
|
||
type=row[3],
|
||
config=row[4],
|
||
enabled=row[5],
|
||
created_at=datetime_to_str(row[6]),
|
||
updated_at=datetime_to_str(row[7])
|
||
)
|
||
for row in rows
|
||
]
|
||
|
||
|
||
@router.put("/api/v1/bots/{bot_uuid}/mcp", response_model=SuccessResponse)
|
||
async def update_mcp_servers(
|
||
bot_uuid: str,
|
||
servers: List[MCPServerCreate],
|
||
authorization: Optional[str] = Header(None)
|
||
):
|
||
"""
|
||
更新 Bot 的 MCP 服务器配置
|
||
|
||
Args:
|
||
bot_uuid: Bot 内部 UUID
|
||
servers: MCP 服务器列表
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
SuccessResponse: 更新结果
|
||
"""
|
||
verify_auth(authorization)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 删除旧的 MCP 配置
|
||
await cursor.execute("DELETE FROM agent_mcp_servers WHERE bot_id = %s", (bot_uuid,))
|
||
|
||
# 插入新的 MCP 配置
|
||
for server in servers:
|
||
await cursor.execute("""
|
||
INSERT INTO agent_mcp_servers (bot_id, name, type, config, enabled)
|
||
VALUES (%s, %s, %s, %s, %s)
|
||
""", (
|
||
bot_uuid,
|
||
server.name,
|
||
server.type,
|
||
server.config,
|
||
server.enabled
|
||
))
|
||
|
||
await conn.commit()
|
||
|
||
return SuccessResponse(
|
||
success=True,
|
||
message=f"MCP servers updated successfully ({len(servers)} servers)"
|
||
)
|
||
|
||
|
||
@router.post("/api/v1/bots/{bot_uuid}/mcp", response_model=MCPServerResponse)
|
||
async def add_mcp_server(
|
||
bot_uuid: str,
|
||
request: MCPServerCreate,
|
||
authorization: Optional[str] = Header(None)
|
||
):
|
||
"""
|
||
添加单个 MCP 服务器
|
||
|
||
Args:
|
||
bot_uuid: Bot 内部 UUID
|
||
request: MCP 服务器创建请求
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
MCPServerResponse: 创建的 MCP 服务器信息
|
||
"""
|
||
verify_auth(authorization)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
# 验证 Bot 是否存在
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("SELECT id FROM agent_bots WHERE id = %s", (bot_uuid,))
|
||
if not await cursor.fetchone():
|
||
raise HTTPException(status_code=404, detail="Bot not found")
|
||
|
||
# 创建 MCP 服务器
|
||
await cursor.execute("""
|
||
INSERT INTO agent_mcp_servers (bot_id, name, type, config, enabled)
|
||
VALUES (%s, %s, %s, %s, %s)
|
||
RETURNING id, created_at, updated_at
|
||
""", (
|
||
bot_uuid,
|
||
request.name,
|
||
request.type,
|
||
request.config,
|
||
request.enabled
|
||
))
|
||
row = await cursor.fetchone()
|
||
|
||
await conn.commit()
|
||
|
||
return MCPServerResponse(
|
||
id=str(row[0]),
|
||
bot_id=bot_uuid,
|
||
name=request.name,
|
||
type=request.type,
|
||
config=request.config,
|
||
enabled=request.enabled,
|
||
created_at=datetime_to_str(row[1]),
|
||
updated_at=datetime_to_str(row[2])
|
||
)
|
||
|
||
|
||
@router.delete("/api/v1/bots/{bot_uuid}/mcp/{mcp_id}", response_model=SuccessResponse)
|
||
async def delete_mcp_server(
|
||
bot_uuid: str,
|
||
mcp_id: str,
|
||
authorization: Optional[str] = Header(None)
|
||
):
|
||
"""
|
||
删除 MCP 服务器
|
||
|
||
Args:
|
||
bot_uuid: Bot 内部 UUID
|
||
mcp_id: MCP 服务器 ID
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
SuccessResponse: 删除结果
|
||
"""
|
||
verify_auth(authorization)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute(
|
||
"DELETE FROM agent_mcp_servers WHERE id = %s AND bot_id = %s RETURNING id",
|
||
(mcp_id, bot_uuid)
|
||
)
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(status_code=404, detail="MCP server not found")
|
||
|
||
await conn.commit()
|
||
|
||
return SuccessResponse(success=True, message="MCP server deleted successfully")
|
||
|
||
|
||
# ============== Admin 登录 API ==============
|
||
|
||
@router.post("/api/v1/admin/login", response_model=AdminLoginResponse)
|
||
async def admin_login(request: AdminLoginRequest):
|
||
"""
|
||
管理员登录
|
||
|
||
Args:
|
||
request: 登录请求(用户名和密码)
|
||
|
||
Returns:
|
||
AdminLoginResponse: 登录成功返回 token
|
||
"""
|
||
# 硬编码验证账号密码
|
||
if request.username != ADMIN_USERNAME or request.password != ADMIN_PASSWORD:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="用户名或密码错误"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
# 生成 token
|
||
token = secrets.token_urlsafe(32)
|
||
expires_at = datetime.now() + timedelta(hours=TOKEN_EXPIRE_HOURS)
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 清理该用户的旧 token
|
||
await cursor.execute("DELETE FROM agent_admin_tokens WHERE username = %s", (request.username,))
|
||
|
||
# 保存新 token
|
||
await cursor.execute("""
|
||
INSERT INTO agent_admin_tokens (username, token, expires_at)
|
||
VALUES (%s, %s, %s)
|
||
""", (request.username, token, expires_at))
|
||
|
||
await conn.commit()
|
||
|
||
return AdminLoginResponse(
|
||
token=token,
|
||
username=request.username,
|
||
expires_at=expires_at.isoformat()
|
||
)
|
||
|
||
|
||
@router.post("/api/v1/admin/verify", response_model=AdminVerifyResponse)
|
||
async def admin_verify(authorization: Optional[str] = Header(None)):
|
||
"""
|
||
验证管理员 token 是否有效
|
||
|
||
Args:
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
AdminVerifyResponse: 验证结果
|
||
"""
|
||
valid, username = await verify_admin_auth(authorization)
|
||
|
||
return AdminVerifyResponse(
|
||
valid=valid,
|
||
username=username
|
||
)
|
||
|
||
|
||
@router.post("/api/v1/admin/logout", response_model=SuccessResponse)
|
||
async def admin_logout(authorization: Optional[str] = Header(None)):
|
||
"""
|
||
管理员登出(删除 token)
|
||
|
||
Args:
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
SuccessResponse: 登出结果
|
||
"""
|
||
provided_token = extract_api_key_from_auth(authorization)
|
||
if not provided_token:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Authorization header is required"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("DELETE FROM agent_admin_tokens WHERE token = %s", (provided_token,))
|
||
await conn.commit()
|
||
|
||
return SuccessResponse(success=True, message="Logged out successfully")
|
||
|
||
|
||
# ============== 用户认证 API ==============
|
||
|
||
@router.post("/api/v1/auth/register", response_model=UserLoginResponse)
|
||
async def user_register(request: UserRegisterRequest):
|
||
"""
|
||
用户注册(需<EFBC88><E99C80>邀请码)
|
||
|
||
Args:
|
||
request: 注册请求(用户名、邮箱、密码、邀请码)
|
||
|
||
Returns:
|
||
UserLoginResponse: 注册成功返回 token 和用户信息
|
||
"""
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 1. 验证邀请码(固定邀请码)
|
||
if request.invitation_code != "WELCOME2026":
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="邀请码无效"
|
||
)
|
||
|
||
# 2. 检查用户名是否已存在
|
||
await cursor.execute("""
|
||
SELECT id FROM agent_user WHERE username = %s
|
||
""", (request.username,))
|
||
if await cursor.fetchone():
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="用户名已存在"
|
||
)
|
||
|
||
# 3. 检查邮箱是否已存在(如果提供)
|
||
if request.email:
|
||
await cursor.execute("""
|
||
SELECT id FROM agent_user WHERE email = %s
|
||
""", (request.email,))
|
||
if await cursor.fetchone():
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="邮箱已被注册"
|
||
)
|
||
|
||
# 4. 创建用户
|
||
password_hash = hash_password(request.password)
|
||
await cursor.execute("""
|
||
INSERT INTO agent_user (username, email, password_hash)
|
||
VALUES (%s, %s, %s)
|
||
RETURNING id, created_at
|
||
""", (request.username, request.email, password_hash))
|
||
user_id, created_at = await cursor.fetchone()
|
||
|
||
# 6. 生成 token
|
||
token = secrets.token_urlsafe(32)
|
||
expires_at = datetime.now() + timedelta(hours=TOKEN_EXPIRE_HOURS)
|
||
|
||
await cursor.execute("""
|
||
INSERT INTO agent_user_tokens (user_id, token, expires_at)
|
||
VALUES (%s, %s, %s)
|
||
""", (user_id, token, expires_at))
|
||
|
||
await conn.commit()
|
||
|
||
return UserLoginResponse(
|
||
token=token,
|
||
user_id=str(user_id),
|
||
username=request.username,
|
||
email=request.email,
|
||
is_admin=False,
|
||
expires_at=expires_at.isoformat()
|
||
)
|
||
|
||
|
||
@router.post("/api/v1/auth/login", response_model=UserLoginResponse)
|
||
async def user_login(request: UserLoginRequest):
|
||
"""
|
||
用户登录
|
||
|
||
Args:
|
||
request: 登录请求(用户名、密码)
|
||
|
||
Returns:
|
||
UserLoginResponse: 登录成功返回 token 和用户信息
|
||
"""
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 1. 验证用户名和密码
|
||
password_hash = hash_password(request.password)
|
||
await cursor.execute("""
|
||
SELECT id, username, email, is_active, is_admin
|
||
FROM agent_user
|
||
WHERE username = %s AND password_hash = %s
|
||
""", (request.username, password_hash))
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="用户名或密码错误"
|
||
)
|
||
|
||
user_id, username, email, is_active, is_admin = row
|
||
|
||
if not is_active:
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail="账户已被禁用"
|
||
)
|
||
|
||
# 2. 更新最后登录时间
|
||
await cursor.execute("""
|
||
UPDATE agent_user
|
||
SET last_login = NOW()
|
||
WHERE id = %s
|
||
""", (user_id,))
|
||
|
||
# 3. 清理旧 token
|
||
await cursor.execute("""
|
||
DELETE FROM agent_user_tokens
|
||
WHERE user_id = %s
|
||
""", (user_id,))
|
||
|
||
# 4. 生成新 token
|
||
token = secrets.token_urlsafe(32)
|
||
expires_at = datetime.now() + timedelta(hours=TOKEN_EXPIRE_HOURS)
|
||
|
||
await cursor.execute("""
|
||
INSERT INTO agent_user_tokens (user_id, token, expires_at)
|
||
VALUES (%s, %s, %s)
|
||
""", (user_id, token, expires_at))
|
||
|
||
await conn.commit()
|
||
|
||
return UserLoginResponse(
|
||
token=token,
|
||
user_id=str(user_id),
|
||
username=username,
|
||
email=email,
|
||
is_admin=is_admin or False,
|
||
expires_at=expires_at.isoformat()
|
||
)
|
||
|
||
|
||
@router.post("/api/v1/auth/verify", response_model=UserVerifyResponse)
|
||
async def user_verify(authorization: Optional[str] = Header(None)):
|
||
"""
|
||
验证用户 token 是否有效
|
||
|
||
Args:
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
UserVerifyResponse: 验证结果
|
||
"""
|
||
valid, user_id, username = await verify_user_auth(authorization)
|
||
|
||
is_admin_flag = False
|
||
if valid and user_id:
|
||
pool = get_db_pool_manager().pool
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
SELECT is_admin FROM agent_user WHERE id = %s
|
||
""", (user_id,))
|
||
row = await cursor.fetchone()
|
||
if row:
|
||
is_admin_flag = row[0] or False
|
||
|
||
return UserVerifyResponse(
|
||
valid=valid,
|
||
user_id=user_id,
|
||
username=username,
|
||
is_admin=is_admin_flag
|
||
)
|
||
|
||
|
||
@router.get("/api/v1/users/me", response_model=UserInfoResponse)
|
||
async def get_current_user(authorization: Optional[str] = Header(None)):
|
||
"""
|
||
获取当前用户信息
|
||
|
||
Args:
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
UserInfoResponse: 用户信息
|
||
"""
|
||
valid, user_id, username = await verify_user_auth(authorization)
|
||
|
||
if not valid:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Invalid token"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
SELECT id, username, email, is_admin, created_at, last_login
|
||
FROM agent_user
|
||
WHERE id = %s
|
||
""", (user_id,))
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail="User not found"
|
||
)
|
||
|
||
user_id, username, email, is_admin, created_at, last_login = row
|
||
|
||
return UserInfoResponse(
|
||
id=str(user_id),
|
||
username=username,
|
||
email=email,
|
||
is_admin=is_admin or False,
|
||
created_at=created_at.isoformat() if created_at else "",
|
||
last_login=last_login.isoformat() if last_login else None
|
||
)
|
||
|
||
|
||
@router.get("/api/v1/users", response_model=List[UserSearchResponse])
|
||
async def search_users(
|
||
q: str = "",
|
||
authorization: Optional[str] = Header(None)
|
||
):
|
||
"""
|
||
搜索用户(用于分享)
|
||
|
||
Args:
|
||
q: 搜索关键词(用户名或邮箱)
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
List[UserSearchResponse]: 用户列表
|
||
"""
|
||
# 支持管理员认证<E8AEA4><E8AF81>用户认证
|
||
admin_valid, _ = await verify_admin_auth(authorization)
|
||
user_valid, user_id, _ = await verify_user_auth(authorization)
|
||
|
||
if not admin_valid and not user_valid:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Unauthorized"
|
||
)
|
||
|
||
if not q:
|
||
return []
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 动态构建查询:如果有 user_id 则排除当前用户
|
||
if user_id:
|
||
await cursor.execute("""
|
||
SELECT id, username, email
|
||
FROM agent_user
|
||
WHERE is_active = TRUE
|
||
AND (username ILIKE %s OR email ILIKE %s)
|
||
AND id != %s
|
||
ORDER BY username
|
||
LIMIT 20
|
||
""", (f"%{q}%", f"%{q}%", user_id))
|
||
else:
|
||
await cursor.execute("""
|
||
SELECT id, username, email
|
||
FROM agent_user
|
||
WHERE is_active = TRUE
|
||
AND (username ILIKE %s OR email ILIKE %s)
|
||
ORDER BY username
|
||
LIMIT 20
|
||
""", (f"%{q}%", f"%{q}%"))
|
||
rows = await cursor.fetchall()
|
||
|
||
return [
|
||
UserSearchResponse(
|
||
id=str(row[0]),
|
||
username=row[1],
|
||
email=row[2]
|
||
)
|
||
for row in rows
|
||
]
|
||
|
||
|
||
# ============== 用户管理 API(仅管理员)=============
|
||
|
||
class UserListResponse(BaseModel):
|
||
"""用户列表响应"""
|
||
id: str
|
||
username: str
|
||
email: Optional[str] = None
|
||
is_admin: bool = False
|
||
is_active: bool = True
|
||
created_at: str
|
||
last_login: Optional[str] = None
|
||
|
||
|
||
class UserCreateRequest(BaseModel):
|
||
"""创建用户请求"""
|
||
email: str
|
||
username: Optional[str] = None
|
||
password: str
|
||
is_admin: bool = False
|
||
|
||
|
||
class UserUpdateRequest(BaseModel):
|
||
"""更新用户请求"""
|
||
email: Optional[str] = None
|
||
username: Optional[str] = None
|
||
|
||
|
||
class ResetPasswordRequest(BaseModel):
|
||
"""重置密码请求"""
|
||
new_password: str
|
||
|
||
|
||
@router.get("/api/v1/users/list", response_model=List[UserListResponse])
|
||
async def get_all_users(authorization: Optional[str] = Header(None)):
|
||
"""
|
||
获取所有用户列表(<E8A1A8><EFBC88>管理员)
|
||
|
||
Args:
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
List[UserListResponse]: 用户列表
|
||
"""
|
||
if not await is_admin_user(authorization):
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail="Admin access required"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
SELECT id, username, email, is_admin, is_active, created_at, last_login
|
||
FROM agent_user
|
||
ORDER BY created_at DESC
|
||
""")
|
||
rows = await cursor.fetchall()
|
||
|
||
return [
|
||
UserListResponse(
|
||
id=str(row[0]),
|
||
username=row[1],
|
||
email=row[2],
|
||
is_admin=row[3] or False,
|
||
is_active=row[4],
|
||
created_at=row[5].isoformat() if row[5] else "",
|
||
last_login=row[6].isoformat() if row[6] else None
|
||
)
|
||
for row in rows
|
||
]
|
||
|
||
|
||
@router.get("/api/v1/users/{user_id}", response_model=UserListResponse)
|
||
async def get_user(user_id: str, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
获取单个用户信息(仅管理员)
|
||
|
||
Args:
|
||
user_id: 用户 ID
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
UserListResponse: 用户信息
|
||
"""
|
||
if not await is_admin_user(authorization):
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail="Admin access required"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
SELECT id, username, email, is_admin, is_active, created_at, last_login
|
||
FROM agent_user
|
||
WHERE id = %s
|
||
""", (user_id,))
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail="User not found"
|
||
)
|
||
|
||
return UserListResponse(
|
||
id=str(row[0]),
|
||
username=row[1],
|
||
email=row[2],
|
||
is_admin=row[3] or False,
|
||
is_active=row[4],
|
||
created_at=row[5].isoformat() if row[5] else "",
|
||
last_login=row[6].isoformat() if row[6] else None
|
||
)
|
||
|
||
|
||
@router.post("/api/v1/users", response_model=UserListResponse)
|
||
async def create_user(
|
||
request: UserCreateRequest,
|
||
authorization: Optional[str] = Header(None)
|
||
):
|
||
"""
|
||
创建新用户(仅管理员)
|
||
|
||
Args:
|
||
request: 创建用户请求
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
UserListResponse: 创建的用户信息
|
||
"""
|
||
if not await is_admin_user(authorization):
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail="Admin access required"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 如果提供了用户名,检查是否已存在
|
||
if request.username:
|
||
await cursor.execute("""
|
||
SELECT id FROM agent_user WHERE username = %s
|
||
""", (request.username,))
|
||
if await cursor.fetchone():
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="用户名已存在"
|
||
)
|
||
|
||
# 检查邮箱是否已存在
|
||
await cursor.execute("""
|
||
SELECT id FROM agent_user WHERE email = %s
|
||
""", (request.email,))
|
||
if await cursor.fetchone():
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="邮箱已被注册"
|
||
)
|
||
|
||
# 创建用户
|
||
password_hash = hash_password(request.password)
|
||
username = request.username or request.email.split('@')[0]
|
||
|
||
await cursor.execute("""
|
||
INSERT INTO agent_user (username, email, password_hash, is_admin)
|
||
VALUES (%s, %s, %s, %s)
|
||
RETURNING id, created_at
|
||
""", (username, request.email, password_hash, request.is_admin))
|
||
user_id, created_at = await cursor.fetchone()
|
||
|
||
await conn.commit()
|
||
|
||
return UserListResponse(
|
||
id=str(user_id),
|
||
username=username,
|
||
email=request.email,
|
||
is_admin=request.is_admin,
|
||
is_active=True,
|
||
created_at=created_at.isoformat() if created_at else "",
|
||
last_login=None
|
||
)
|
||
|
||
|
||
@router.put("/api/v1/users/{user_id}", response_model=UserListResponse)
|
||
async def update_user(
|
||
user_id: str,
|
||
request: UserUpdateRequest,
|
||
authorization: Optional[str] = Header(None)
|
||
):
|
||
"""
|
||
更新用户信息(仅管理员)
|
||
|
||
Args:
|
||
user_id: 用户 ID
|
||
request: 更新请求
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
UserListResponse: 更新后的用户信息
|
||
"""
|
||
if not await is_admin_user(authorization):
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail="Admin access required"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 构建更新字段
|
||
update_fields = []
|
||
values = []
|
||
|
||
if request.username is not None:
|
||
# 检查用户名是否已被其他用户使用
|
||
await cursor.execute("""
|
||
SELECT id FROM agent_user WHERE username = %s AND id != %s
|
||
""", (request.username, user_id))
|
||
if await cursor.fetchone():
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="用户名已存在"
|
||
)
|
||
update_fields.append("username = %s")
|
||
values.append(request.username)
|
||
|
||
if request.email is not None:
|
||
# 检查邮箱是否已被其他用户使用
|
||
await cursor.execute("""
|
||
SELECT id FROM agent_user WHERE email = %s AND id != %s
|
||
""", (request.email, user_id))
|
||
if await cursor.fetchone():
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="邮箱已被使用"
|
||
)
|
||
update_fields.append("email = %s")
|
||
values.append(request.email)
|
||
|
||
if not update_fields:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="No fields to update"
|
||
)
|
||
|
||
update_fields.append("updated_at = NOW()")
|
||
values.append(user_id)
|
||
|
||
await cursor.execute(f"""
|
||
UPDATE agent_user
|
||
SET {', '.join(update_fields)}
|
||
WHERE id = %s
|
||
RETURNING id, username, email, is_admin, is_active, created_at, last_login
|
||
""", values)
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail="User not found"
|
||
)
|
||
|
||
await conn.commit()
|
||
|
||
return UserListResponse(
|
||
id=str(row[0]),
|
||
username=row[1],
|
||
email=row[2],
|
||
is_admin=row[3] or False,
|
||
is_active=row[4],
|
||
created_at=row[5].isoformat() if row[5] else "",
|
||
last_login=row[6].isoformat() if row[6] else None
|
||
)
|
||
|
||
|
||
@router.delete("/api/v1/users/{user_id}", response_model=SuccessResponse)
|
||
async def delete_user(user_id: str, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
删除用户(仅管理员)
|
||
|
||
Args:
|
||
user_id: 用户 ID
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
SuccessResponse: 删除结果
|
||
"""
|
||
if not await is_admin_user(authorization):
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail="Admin access required"
|
||
)
|
||
|
||
# 不允许删除默认 admin 用户
|
||
if user_id == '00000000-0000-0000-0000-000000000001':
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="Cannot delete default admin user"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("DELETE FROM agent_user WHERE id = %s RETURNING username", (user_id,))
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail="User not found"
|
||
)
|
||
|
||
await conn.commit()
|
||
|
||
return SuccessResponse(
|
||
success=True,
|
||
message=f"User '{row[0]}' deleted successfully"
|
||
)
|
||
|
||
|
||
@router.patch("/api/v1/users/{user_id}/toggle-admin", response_model=SuccessResponse)
|
||
async def toggle_user_admin(user_id: str, authorization: Optional[str] = Header(None)):
|
||
"""
|
||
切换用户管理员状态(仅管理员)
|
||
|
||
Args:
|
||
user_id: 用户 ID
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
SuccessResponse: 操作结果
|
||
"""
|
||
if not await is_admin_user(authorization):
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail="Admin access required"
|
||
)
|
||
|
||
# 不允许取消默认 admin 的管理员权限
|
||
if user_id == '00000000-0000-0000-0000-000000000001':
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="Cannot modify default admin user"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
UPDATE agent_user
|
||
SET is_admin = NOT is_admin,
|
||
updated_at = NOW()
|
||
WHERE id = %s
|
||
RETURNING is_admin, username
|
||
""", (user_id,))
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail="User not found"
|
||
)
|
||
|
||
new_status, username = row
|
||
|
||
await conn.commit()
|
||
|
||
return SuccessResponse(
|
||
success=True,
|
||
message=f"User '{username}' is now {'an admin' if new_status else 'a regular user'}"
|
||
)
|
||
|
||
|
||
@router.post("/api/v1/users/{user_id}/reset-password", response_model=SuccessResponse)
|
||
async def reset_user_password(
|
||
user_id: str,
|
||
request: ResetPasswordRequest,
|
||
authorization: Optional[str] = Header(None)
|
||
):
|
||
"""
|
||
重置用户密码(仅管理员)
|
||
|
||
Args:
|
||
user_id: 用户 ID
|
||
request: 重置密码请求
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
SuccessResponse: 操作结果
|
||
"""
|
||
if not await is_admin_user(authorization):
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail="Admin access required"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
password_hash = hash_password(request.new_password)
|
||
|
||
await cursor.execute("""
|
||
UPDATE agent_user
|
||
SET password_hash = %s,
|
||
updated_at = NOW()
|
||
WHERE id = %s
|
||
RETURNING username
|
||
""", (password_hash, user_id))
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail="User not found"
|
||
)
|
||
|
||
await conn.commit()
|
||
|
||
return SuccessResponse(
|
||
success=True,
|
||
message=f"Password reset for user '{row[0]}'"
|
||
)
|
||
|
||
|
||
@router.post("/api/v1/auth/logout", response_model=SuccessResponse)
|
||
async def user_logout(authorization: Optional[str] = Header(None)):
|
||
"""
|
||
用户登出(删除 token)
|
||
|
||
Args:
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
SuccessResponse: 登出结果
|
||
"""
|
||
provided_token = extract_api_key_from_auth(authorization)
|
||
if not provided_token:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Authorization header is required"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("DELETE FROM agent_user_tokens WHERE token = %s", (provided_token,))
|
||
await conn.commit()
|
||
|
||
return SuccessResponse(success=True, message="Logged out successfully")
|
||
|
||
|
||
# ============== 分享管理 API ==============
|
||
|
||
@router.get("/api/v1/bots/{bot_uuid}/shares", response_model=BotSharesListResponse)
|
||
async def get_bot_shares(
|
||
bot_uuid: str,
|
||
authorization: Optional[str] = Header(None)
|
||
):
|
||
"""
|
||
获取 Bot 的分享列表
|
||
|
||
Args:
|
||
bot_uuid: Bot UUID
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
BotSharesListResponse: 分享列表
|
||
"""
|
||
valid, user_id, _ = await verify_user_auth(authorization)
|
||
|
||
if not valid:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Unauthorized"
|
||
)
|
||
|
||
# 验证用户是 Bot 所有者
|
||
if not await is_bot_owner(bot_uuid, user_id):
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail="Only bot owner can view shares"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
SELECT s.id, s.bot_id, s.user_id, u.username, u.email, s.role, s.shared_at, s.expires_at, su.username
|
||
FROM bot_shares s
|
||
JOIN agent_user u ON s.user_id = u.id
|
||
LEFT JOIN agent_user su ON s.shared_by = su.id
|
||
WHERE s.bot_id = %s
|
||
ORDER BY s.shared_at DESC
|
||
""", (bot_uuid,))
|
||
rows = await cursor.fetchall()
|
||
|
||
shares = [
|
||
BotShareResponse(
|
||
id=str(row[0]),
|
||
bot_id=str(row[1]),
|
||
user_id=str(row[2]),
|
||
username=row[3],
|
||
email=row[4],
|
||
role=row[5],
|
||
shared_at=row[6].isoformat() if row[6] else "",
|
||
expires_at=row[7].isoformat() if row[7] else None,
|
||
shared_by=row[8] if row[8] is None else str(row[8])
|
||
)
|
||
for row in rows
|
||
]
|
||
|
||
return BotSharesListResponse(
|
||
bot_id=bot_uuid,
|
||
shares=shares
|
||
)
|
||
|
||
|
||
@router.post("/api/v1/bots/{bot_uuid}/shares", response_model=SuccessResponse)
|
||
async def add_bot_share(
|
||
bot_uuid: str,
|
||
request: BotShareCreate,
|
||
authorization: Optional[str] = Header(None)
|
||
):
|
||
"""
|
||
添加 Bot 分享
|
||
|
||
Args:
|
||
bot_uuid: Bot UUID
|
||
request: 分享请求
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
SuccessResponse: 操作结果
|
||
"""
|
||
valid, user_id, _ = await verify_user_auth(authorization)
|
||
|
||
if not valid:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Unauthorized"
|
||
)
|
||
|
||
# 验证用户是 Bot 所有者
|
||
if not await is_bot_owner(bot_uuid, user_id):
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail="Only bot owner can share"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 验证目标用户存在
|
||
for target_user_id in request.user_ids:
|
||
await cursor.execute("""
|
||
SELECT id FROM agent_user WHERE id = %s AND is_active = TRUE
|
||
""", (target_user_id,))
|
||
if not await cursor.fetchone():
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"User {target_user_id} not found"
|
||
)
|
||
|
||
# 添加分享
|
||
added_count = 0
|
||
for target_user_id in request.user_ids:
|
||
try:
|
||
await cursor.execute("""
|
||
INSERT INTO bot_shares (bot_id, user_id, shared_by, role, expires_at)
|
||
VALUES (%s, %s, %s, %s, %s)
|
||
ON CONFLICT (bot_id, user_id) DO UPDATE SET
|
||
role = EXCLUDED.role,
|
||
shared_by = EXCLUDED.shared_by,
|
||
expires_at = EXCLUDED.expires_at
|
||
""", (bot_uuid, target_user_id, user_id, request.role, request.expires_at))
|
||
added_count += 1
|
||
except Exception:
|
||
pass # 忽略重复的
|
||
|
||
await conn.commit()
|
||
|
||
return SuccessResponse(
|
||
success=True,
|
||
message=f"Shared with {added_count} user(s)"
|
||
)
|
||
|
||
|
||
@router.delete("/api/v1/bots/{bot_uuid}/shares/{user_id}", response_model=SuccessResponse)
|
||
async def remove_bot_share(
|
||
bot_uuid: str,
|
||
user_id: str,
|
||
authorization: Optional[str] = Header(None)
|
||
):
|
||
"""
|
||
移除 Bot 分享
|
||
|
||
Args:
|
||
bot_uuid: Bot UUID
|
||
user_id: 要移除的用户 ID
|
||
authorization: Bearer token
|
||
|
||
Returns:
|
||
SuccessResponse: 操作结果
|
||
"""
|
||
valid, current_user_id, _ = await verify_user_auth(authorization)
|
||
|
||
if not valid:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Unauthorized"
|
||
)
|
||
|
||
# 验证用户是 Bot 所有者
|
||
if not await is_bot_owner(bot_uuid, current_user_id):
|
||
raise HTTPException(
|
||
status_code=403,
|
||
detail="Only bot owner can remove shares"
|
||
)
|
||
|
||
pool = get_db_pool_manager().pool
|
||
|
||
async with pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
DELETE FROM bot_shares
|
||
WHERE bot_id = %s AND user_id = %s
|
||
RETURNING id
|
||
""", (bot_uuid, user_id))
|
||
row = await cursor.fetchone()
|
||
|
||
if not row:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail="Share not found"
|
||
)
|
||
|
||
await conn.commit()
|
||
|
||
return SuccessResponse(
|
||
success=True,
|
||
message="Share removed successfully"
|
||
)
|
||
|
||
|