327 lines
12 KiB
Python
327 lines
12 KiB
Python
"""
|
||
全局定时任务调度器
|
||
|
||
扫描所有 projects/robot/{bot_id}/users/{user_id}/tasks.yaml 文件,
|
||
找到到期的任务并调用 create_agent_and_generate_response 执行。
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import time
|
||
import yaml
|
||
import aiohttp
|
||
import json
|
||
from datetime import datetime, timezone, timedelta
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
|
||
class ScheduleExecutor:
|
||
"""定时任务调度器,以 asyncio 后台任务运行"""
|
||
|
||
def __init__(self, scan_interval: int = 60, max_concurrent: int = 5):
|
||
self._scan_interval = scan_interval
|
||
self._max_concurrent = max_concurrent
|
||
self._task: Optional[asyncio.Task] = None
|
||
self._stop_event = asyncio.Event()
|
||
self._executing_tasks: set = set() # 正在执行的任务 ID,防重复
|
||
self._semaphore: Optional[asyncio.Semaphore] = None
|
||
|
||
def start(self):
|
||
"""启动调度器"""
|
||
if self._task is not None and not self._task.done():
|
||
logger.warning("Schedule executor is already running")
|
||
return
|
||
|
||
self._stop_event.clear()
|
||
self._semaphore = asyncio.Semaphore(self._max_concurrent)
|
||
self._task = asyncio.create_task(self._scan_loop())
|
||
logger.info(
|
||
f"Schedule executor started: interval={self._scan_interval}s, "
|
||
f"max_concurrent={self._max_concurrent}"
|
||
)
|
||
|
||
async def stop(self):
|
||
"""停止调度器"""
|
||
self._stop_event.set()
|
||
if self._task and not self._task.done():
|
||
self._task.cancel()
|
||
try:
|
||
await self._task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
logger.info("Schedule executor stopped")
|
||
|
||
async def _scan_loop(self):
|
||
"""主扫描循环"""
|
||
while not self._stop_event.is_set():
|
||
try:
|
||
await self._scan_and_execute()
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"Schedule scan error: {e}")
|
||
|
||
# 等待下一次扫描或停止信号
|
||
try:
|
||
await asyncio.wait_for(
|
||
self._stop_event.wait(),
|
||
timeout=self._scan_interval
|
||
)
|
||
break # 收到停止信号
|
||
except asyncio.TimeoutError:
|
||
pass # 超时继续下一轮扫描
|
||
|
||
async def _scan_and_execute(self):
|
||
"""扫描所有 tasks.yaml,找到到期任务并触发执行"""
|
||
now = datetime.now(timezone.utc)
|
||
robot_dir = Path("projects/robot")
|
||
|
||
if not robot_dir.exists():
|
||
return
|
||
|
||
tasks_files = list(robot_dir.glob("*/users/*/tasks.yaml"))
|
||
if not tasks_files:
|
||
return
|
||
|
||
for tasks_file in tasks_files:
|
||
try:
|
||
with open(tasks_file, 'r', encoding='utf-8') as f:
|
||
data = yaml.safe_load(f)
|
||
|
||
if not data or not data.get("tasks"):
|
||
continue
|
||
|
||
# 从路径提取 bot_id 和 user_id
|
||
parts = tasks_file.parts
|
||
# 路径格式: .../projects/robot/{bot_id}/users/{user_id}/tasks.yaml
|
||
bot_id = parts[-4]
|
||
user_id = parts[-2]
|
||
|
||
for task in data["tasks"]:
|
||
if task.get("status") != "active":
|
||
continue
|
||
if task["id"] in self._executing_tasks:
|
||
continue
|
||
|
||
next_run_str = task.get("next_run_at")
|
||
if not next_run_str:
|
||
continue
|
||
|
||
try:
|
||
next_run = datetime.fromisoformat(next_run_str)
|
||
if next_run.tzinfo is None:
|
||
next_run = next_run.replace(tzinfo=timezone.utc)
|
||
except (ValueError, TypeError):
|
||
logger.warning(f"Invalid next_run_at for task {task['id']}: {next_run_str}")
|
||
continue
|
||
|
||
if next_run <= now:
|
||
# 到期,触发执行
|
||
asyncio.create_task(
|
||
self._execute_task(bot_id, user_id, task, tasks_file)
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error reading {tasks_file}: {e}")
|
||
|
||
async def _execute_task(self, bot_id: str, user_id: str, task: dict, tasks_file: Path):
|
||
"""执行单个到期任务"""
|
||
task_id = task["id"]
|
||
self._executing_tasks.add(task_id)
|
||
start_time = time.time()
|
||
|
||
try:
|
||
async with self._semaphore:
|
||
logger.info(f"Executing scheduled task: {task_id} ({task.get('name', '')}) for bot={bot_id} user={user_id}")
|
||
|
||
# 调用 agent
|
||
response_text = await self._call_agent_v3(bot_id, user_id, task)
|
||
|
||
# 写入日志
|
||
duration_ms = int((time.time() - start_time) * 1000)
|
||
self._write_log(bot_id, user_id, task, response_text, "success", duration_ms)
|
||
|
||
# 更新 tasks.yaml
|
||
self._update_task_after_execution(task_id, tasks_file)
|
||
|
||
logger.info(f"Task {task_id} completed in {duration_ms}ms")
|
||
|
||
except Exception as e:
|
||
duration_ms = int((time.time() - start_time) * 1000)
|
||
logger.error(f"Task {task_id} execution failed: {e}")
|
||
self._write_log(bot_id, user_id, task, f"ERROR: {e}", "error", duration_ms)
|
||
# 即使失败也更新 next_run_at,避免无限重试
|
||
self._update_task_after_execution(task_id, tasks_file)
|
||
finally:
|
||
self._executing_tasks.discard(task_id)
|
||
|
||
async def _call_agent_v2(self, bot_id: str, user_id: str, task: dict) -> str:
|
||
"""通过 HTTP 调用 /api/v2/chat/completions 接口"""
|
||
from utils.fastapi_utils import generate_v2_auth_token
|
||
|
||
url = f"http://127.0.0.1:8001/api/v2/chat/completions"
|
||
auth_token = generate_v2_auth_token(bot_id)
|
||
|
||
payload = {
|
||
"messages": [{"role": "user", "content": task["message"]}],
|
||
"stream": False,
|
||
"bot_id": bot_id,
|
||
"tool_response": False,
|
||
"session_id": f"schedule_{task['id']}",
|
||
"user_identifier": user_id,
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {auth_token}",
|
||
}
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(url, json=payload, headers=headers, timeout=aiohttp.ClientTimeout(total=300)) as resp:
|
||
if resp.status != 200:
|
||
body = await resp.text()
|
||
raise RuntimeError(f"API returned {resp.status}: {body}")
|
||
data = await resp.json()
|
||
|
||
return data["choices"][0]["message"]["content"]
|
||
|
||
async def _call_agent_v3(self, bot_id: str, user_id: str, task: dict) -> str:
|
||
"""通过 HTTP 调用 /api/v3/chat/completions 接口(从数据库读取配置)"""
|
||
url = "http://127.0.0.1:8001/api/v3/chat/completions"
|
||
|
||
payload = {
|
||
"messages": [{"role": "user", "content": task["message"]}],
|
||
"stream": False,
|
||
"bot_id": bot_id,
|
||
"session_id": f"schedule_{task['id']}",
|
||
"user_identifier": user_id,
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(url, json=payload, headers=headers, timeout=aiohttp.ClientTimeout(total=300)) as resp:
|
||
if resp.status != 200:
|
||
body = await resp.text()
|
||
raise RuntimeError(f"API returned {resp.status}: {body}")
|
||
data = await resp.json()
|
||
|
||
return data["choices"][0]["message"]["content"]
|
||
|
||
def _update_task_after_execution(self, task_id: str, tasks_file: Path):
|
||
"""执行后更新 tasks.yaml"""
|
||
try:
|
||
with open(tasks_file, 'r', encoding='utf-8') as f:
|
||
data = yaml.safe_load(f)
|
||
|
||
if not data or not data.get("tasks"):
|
||
return
|
||
|
||
now_utc = datetime.now(timezone.utc).isoformat()
|
||
|
||
for task in data["tasks"]:
|
||
if task["id"] != task_id:
|
||
continue
|
||
|
||
task["last_executed_at"] = now_utc
|
||
task["execution_count"] = task.get("execution_count", 0) + 1
|
||
|
||
if task["type"] == "once":
|
||
task["status"] = "done"
|
||
task["next_run_at"] = None
|
||
elif task["type"] == "cron" and task.get("schedule"):
|
||
# 计算下次执行时间
|
||
task["next_run_at"] = self._compute_next_run(
|
||
task["schedule"],
|
||
task.get("timezone", "UTC")
|
||
)
|
||
break
|
||
|
||
with open(tasks_file, 'w', encoding='utf-8') as f:
|
||
yaml.dump(data, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to update task {task_id}: {e}")
|
||
|
||
def _compute_next_run(self, schedule: str, tz: str) -> str:
|
||
"""计算 cron 任务的下次执行 UTC 时间"""
|
||
from croniter import croniter
|
||
|
||
# 时区偏移映射
|
||
tz_offsets = {
|
||
'Asia/Shanghai': 8,
|
||
'Asia/Tokyo': 9,
|
||
'UTC': 0,
|
||
'America/New_York': -5,
|
||
'America/Los_Angeles': -8,
|
||
'Europe/London': 0,
|
||
'Europe/Berlin': 1,
|
||
}
|
||
|
||
offset_hours = tz_offsets.get(tz, 0)
|
||
offset = timedelta(hours=offset_hours)
|
||
|
||
now_utc = datetime.now(timezone.utc)
|
||
now_local = (now_utc + offset).replace(tzinfo=None)
|
||
|
||
cron = croniter(schedule, now_local)
|
||
next_local = cron.get_next(datetime)
|
||
|
||
next_utc = next_local - offset
|
||
return next_utc.replace(tzinfo=timezone.utc).isoformat()
|
||
|
||
def _write_log(self, bot_id: str, user_id: str, task: dict,
|
||
response: str, status: str, duration_ms: int):
|
||
"""写入执行日志"""
|
||
logs_dir = Path("projects/robot") / bot_id / "users" / user_id / "task_logs"
|
||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||
log_file = logs_dir / "execution.log"
|
||
|
||
log_entry = {
|
||
"task_id": task["id"],
|
||
"task_name": task.get("name", ""),
|
||
"executed_at": datetime.now(timezone.utc).isoformat(),
|
||
"status": status,
|
||
"response": response[:2000] if response else "", # 截断过长响应
|
||
"duration_ms": duration_ms,
|
||
}
|
||
|
||
# 追加写入 YAML 列表
|
||
existing_logs = []
|
||
if log_file.exists():
|
||
try:
|
||
with open(log_file, 'r', encoding='utf-8') as f:
|
||
existing_logs = yaml.safe_load(f) or []
|
||
except Exception:
|
||
existing_logs = []
|
||
|
||
existing_logs.append(log_entry)
|
||
|
||
# 保留最近 100 条日志
|
||
if len(existing_logs) > 100:
|
||
existing_logs = existing_logs[-100:]
|
||
|
||
with open(log_file, 'w', encoding='utf-8') as f:
|
||
yaml.dump(existing_logs, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
|
||
|
||
|
||
# 全局单例
|
||
_executor: Optional[ScheduleExecutor] = None
|
||
|
||
|
||
def get_schedule_executor() -> ScheduleExecutor:
|
||
"""获取全局调度器实例"""
|
||
global _executor
|
||
if _executor is None:
|
||
from utils.settings import SCHEDULE_SCAN_INTERVAL, SCHEDULE_MAX_CONCURRENT
|
||
_executor = ScheduleExecutor(
|
||
scan_interval=SCHEDULE_SCAN_INTERVAL,
|
||
max_concurrent=SCHEDULE_MAX_CONCURRENT,
|
||
)
|
||
return _executor
|