207 lines
7.7 KiB
Python
207 lines
7.7 KiB
Python
#!/usr/bin/env python3
|
||
# /// script
|
||
# requires-python = ">=3.10"
|
||
# dependencies = [
|
||
# "requests>=2.31.0",
|
||
# ]
|
||
# ///
|
||
"""使用 Agnes AI 生成视频(异步任务:创建 -> 轮询 -> 取结果)。
|
||
支持文生视频、图生视频、多图视频、关键帧动画。
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import shutil
|
||
import sys
|
||
import time
|
||
import urllib.parse
|
||
import urllib.request
|
||
|
||
import requests
|
||
|
||
CREATE_URL = "https://apihub.agnes-ai.com/v1/videos"
|
||
QUERY_URL = "https://apihub.agnes-ai.com/agnesapi"
|
||
DEFAULT_MODEL = "agnes-video-v2.0"
|
||
|
||
|
||
def log(msg):
|
||
"""进度信息打到 stderr,保持 stdout 只输出最终结果。"""
|
||
print(msg, file=sys.stderr, flush=True)
|
||
|
||
|
||
def normalize_frames(nf):
|
||
"""num_frames 必须 <=441 且满足 8n+1,规整到最接近的合法值。"""
|
||
nf = max(9, min(441, int(nf)))
|
||
n = round((nf - 1) / 8)
|
||
nf = 8 * n + 1
|
||
return max(9, min(441, nf))
|
||
|
||
|
||
def create_task(payload, api_key, retries=3):
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}",
|
||
}
|
||
last_err = None
|
||
for attempt in range(1, retries + 1):
|
||
try:
|
||
resp = requests.post(CREATE_URL, headers=headers, json=payload, timeout=120)
|
||
resp.raise_for_status()
|
||
return resp.json()
|
||
except (requests.exceptions.SSLError,
|
||
requests.exceptions.ConnectionError,
|
||
requests.exceptions.Timeout) as e:
|
||
last_err = e
|
||
if attempt < retries:
|
||
time.sleep(attempt * 2)
|
||
continue
|
||
raise
|
||
raise last_err
|
||
|
||
|
||
def query_result(video_id, api_key, model=None):
|
||
"""用 video_id 查询任务结果(推荐方式)。"""
|
||
params = f"?video_id={urllib.parse.quote(video_id)}"
|
||
if model:
|
||
params += f"&model_name={urllib.parse.quote(model)}"
|
||
req = urllib.request.Request(
|
||
QUERY_URL + params,
|
||
headers={"Authorization": f"Bearer {api_key}"},
|
||
)
|
||
with urllib.request.urlopen(req, timeout=60) as r:
|
||
return json.loads(r.read())
|
||
|
||
|
||
def download(url, path):
|
||
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
||
with urllib.request.urlopen(url, timeout=300) as r, open(path, "wb") as f:
|
||
shutil.copyfileobj(r, f)
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="使用 Agnes AI 生成视频。")
|
||
parser.add_argument("--prompt", required=True, help="视频内容的文本描述(必选)")
|
||
parser.add_argument("--model", default=DEFAULT_MODEL, help=f"模型 ID,默认 {DEFAULT_MODEL}")
|
||
parser.add_argument("--image", action="append", default=None,
|
||
help="参考图 URL(图生视频/多图/关键帧,可多次指定)")
|
||
parser.add_argument("--keyframes", action="store_true",
|
||
help="关键帧动画模式(配合多张 --image 使用)")
|
||
parser.add_argument("--width", type=int, default=1152, help="视频宽度,默认 1152")
|
||
parser.add_argument("--height", type=int, default=768, help="视频高度,默认 768")
|
||
parser.add_argument("--num-frames", type=int, default=121,
|
||
help="帧数,<=441 且满足 8n+1,默认 121(约 5 秒)")
|
||
parser.add_argument("--frame-rate", type=float, default=24, help="帧率 1-60,默认 24")
|
||
parser.add_argument("--duration", type=float,
|
||
help="目标时长(秒),会按帧率换算 num_frames(覆盖 --num-frames)")
|
||
parser.add_argument("--negative-prompt", help="负向提示词,描述要避免的内容")
|
||
parser.add_argument("--seed", type=int, help="随机种子,用于结果复现")
|
||
parser.add_argument("--api-key", help="Agnes API Key(也可用 AGNES_API_KEY 环境变量)")
|
||
parser.add_argument("--save", help="下载视频并保存到本地路径")
|
||
parser.add_argument("--poll-interval", type=float, default=5, help="轮询间隔秒数,默认 5")
|
||
parser.add_argument("--max-wait", type=float, default=600, help="最大等待秒数,默认 600")
|
||
|
||
args = parser.parse_args()
|
||
|
||
api_key = args.api_key or os.environ.get("AGNES_API_KEY")
|
||
if not api_key:
|
||
print("ERROR: 缺少 API Key,请用 --api-key 或设置 AGNES_API_KEY 环境变量。")
|
||
sys.exit(1)
|
||
|
||
# 计算帧数
|
||
if args.duration:
|
||
num_frames = normalize_frames(args.duration * args.frame_rate)
|
||
else:
|
||
num_frames = normalize_frames(args.num_frames)
|
||
|
||
# 组装请求体
|
||
payload = {
|
||
"model": args.model,
|
||
"prompt": args.prompt,
|
||
"width": args.width,
|
||
"height": args.height,
|
||
"num_frames": num_frames,
|
||
"frame_rate": args.frame_rate,
|
||
}
|
||
if args.negative_prompt:
|
||
payload["negative_prompt"] = args.negative_prompt
|
||
if args.seed is not None:
|
||
payload["seed"] = args.seed
|
||
|
||
images = args.image or []
|
||
# 视频接口的参考图只支持可公网访问的 URL(不支持本地文件 / Base64)
|
||
for u in images:
|
||
if not (u.startswith("http://") or u.startswith("https://")):
|
||
print("ERROR: 视频生成的参考图必须是可公网访问的图片 URL(http/https),"
|
||
"不支持本地文件路径或 Base64。")
|
||
print(f" 问题输入: {u}")
|
||
print(" 建议:先用 generate_image.py 生成图片拿到其公网 URL,"
|
||
"或把本地图片上传到图床/对象存储后再用其 URL 传入。")
|
||
sys.exit(1)
|
||
|
||
if args.keyframes:
|
||
# 关键帧动画:extra_body.image + mode=keyframes
|
||
payload["extra_body"] = {"image": images, "mode": "keyframes"}
|
||
elif len(images) == 1:
|
||
# 图生视频:顶层 image
|
||
payload["image"] = images[0]
|
||
elif len(images) >= 2:
|
||
# 多图视频:extra_body.image
|
||
payload["extra_body"] = {"image": images}
|
||
|
||
# 1) 创建任务
|
||
log(f"创建视频任务({num_frames} 帧 @ {args.frame_rate}fps ≈ {num_frames/args.frame_rate:.1f}s)...")
|
||
try:
|
||
task = create_task(payload, api_key)
|
||
except requests.exceptions.RequestException as e:
|
||
print(f"ERROR: 创建任务失败: {e}")
|
||
if getattr(e, "response", None) is not None:
|
||
print(f"Response body: {e.response.text[:500]}")
|
||
sys.exit(1)
|
||
|
||
video_id = task.get("video_id")
|
||
task_id = task.get("task_id") or task.get("id")
|
||
if not video_id and not task_id:
|
||
print(f"ERROR: 创建任务响应缺少 video_id/task_id: {json.dumps(task)[:500]}")
|
||
sys.exit(1)
|
||
log(f"任务已创建 video_id={video_id} task_id={task_id} status={task.get('status')}")
|
||
|
||
# 2) 轮询结果
|
||
start = time.time()
|
||
video_url = None
|
||
while time.time() - start < args.max_wait:
|
||
time.sleep(args.poll_interval)
|
||
try:
|
||
data = query_result(video_id or task_id, api_key, model=args.model)
|
||
except Exception as e:
|
||
log(f" 查询出错(将重试): {e}")
|
||
continue
|
||
status = data.get("status")
|
||
progress = data.get("progress", 0)
|
||
log(f" 状态={status} 进度={progress}%")
|
||
if status == "completed":
|
||
video_url = data.get("remixed_from_video_id") # 文档:该字段为最终视频 URL
|
||
break
|
||
if status == "failed":
|
||
print(f"ERROR: 视频生成失败: {data.get('error')}")
|
||
sys.exit(1)
|
||
|
||
if not video_url:
|
||
print(f"ERROR: 等待超时({args.max_wait}s)或未返回视频 URL。")
|
||
sys.exit(1)
|
||
|
||
# 3) 输出结果
|
||
print(f"MEDIA_URL: {video_url}")
|
||
if args.save:
|
||
try:
|
||
log("下载视频中 ...")
|
||
download(video_url, args.save)
|
||
print(f"SAVED: {args.save}")
|
||
except Exception as e:
|
||
print(f"ERROR: 下载保存失败: {e}")
|
||
sys.exit(1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|