diff --git a/agent/checkpoint_manager.py b/agent/checkpoint_manager.py index 3c5a796..d8afc34 100644 --- a/agent/checkpoint_manager.py +++ b/agent/checkpoint_manager.py @@ -5,7 +5,8 @@ import asyncio import logging import os -from typing import Optional +from datetime import datetime, timedelta, timezone +from typing import Optional, List, Dict, Any import aiosqlite from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver @@ -15,6 +16,9 @@ from utils.settings import ( CHECKPOINT_WAL_MODE, CHECKPOINT_BUSY_TIMEOUT, CHECKPOINT_POOL_SIZE, + CHECKPOINT_CLEANUP_ENABLED, + CHECKPOINT_CLEANUP_INTERVAL_HOURS, + CHECKPOINT_CLEANUP_OLDER_THAN_DAYS, ) logger = logging.getLogger('app') @@ -38,6 +42,9 @@ class CheckpointerManager: self._closed = False self._pool_size = CHECKPOINT_POOL_SIZE self._db_path = CHECKPOINT_DB_PATH + # 清理调度任务 + self._cleanup_task: Optional[asyncio.Task] = None + self._cleanup_stop_event = asyncio.Event() async def initialize(self) -> None: """初始化连接池""" @@ -131,6 +138,16 @@ class CheckpointerManager: if self._closed: return + # 停止清理任务 + if self._cleanup_task is not None: + self._cleanup_stop_event.set() + try: + self._cleanup_task.cancel() + await asyncio.sleep(0.1) # 给任务一点时间清理 + except asyncio.CancelledError: + pass + self._cleanup_task = None + async with self._lock: if self._closed: return @@ -160,6 +177,181 @@ class CheckpointerManager: "closed": self._closed } + # ============================================================ + # Checkpoint 清理方法 + # ============================================================ + + async def get_all_thread_ids(self) -> List[str]: + """ + 获取数据库中所有唯一的 thread_id + + Returns: + List[str]: 所有 thread_id 列表 + """ + if not self._initialized: + return [] + + conn = aiosqlite.connect(self._db_path) + await conn.__aenter__() + + try: + cursor = await conn.execute( + "SELECT DISTINCT thread_id FROM checkpoints" + ) + rows = await cursor.fetchall() + return [row[0] for row in rows] + finally: + await conn.close() + + async def get_thread_last_activity(self, thread_id: str) -> Optional[datetime]: + """ + 获取指定 thread 的最后活动时间 + + 通过查询该 thread 最新的 checkpoint 中的 ts 字段获取时间 + + Args: + thread_id: 线程ID + + Returns: + datetime: 最后活动时间,如果找不到则返回 None + """ + if not self._initialized: + return None + + checkpointer = await self.acquire_for_agent() + + try: + config = {"configurable": {"thread_id": thread_id}} + result = checkpointer.alist(config=config, limit=1) + + last_checkpoint = None + async for item in result: + last_checkpoint = item + break + + if last_checkpoint and last_checkpoint.checkpoint: + ts_str = last_checkpoint.checkpoint.get("ts") + if ts_str: + # 解析 ISO 格式时间戳 + return datetime.fromisoformat(ts_str.replace("Z", "+00:00")) + except Exception as e: + logger.warning(f"Error getting last activity for thread {thread_id}: {e}") + finally: + await self.return_to_pool(checkpointer) + + return None + + async def cleanup_old_threads(self, older_than_days: int = None) -> Dict[str, Any]: + """ + 清理超过指定天数未活动的 thread + + Args: + older_than_days: 清理多少天前的数据,默认使用配置值 + + Returns: + Dict: 清理统计信息 + - threads_deleted: 删除的 thread 数量 + - threads_scanned: 扫描的 thread 总数 + - cutoff_time: 截止时间 + """ + if older_than_days is None: + older_than_days = CHECKPOINT_CLEANUP_OLDER_THAN_DAYS + + # 使用带时区的时间,避免比较时出错 + cutoff_time = datetime.now(timezone.utc) - timedelta(days=older_than_days) + logger.info(f"Starting checkpoint cleanup: removing threads inactive since {cutoff_time.isoformat()}") + + all_thread_ids = await self.get_all_thread_ids() + threads_deleted = 0 + threads_scanned = len(all_thread_ids) + + checkpointer = await self.acquire_for_agent() + + try: + for thread_id in all_thread_ids: + try: + last_activity = await self.get_thread_last_activity(thread_id) + + if last_activity and last_activity < cutoff_time: + # 删除旧 thread + config = {"configurable": {"thread_id": thread_id}} + await checkpointer.adelete_thread(config) + threads_deleted += 1 + logger.debug(f"Deleted old thread: {thread_id} (last activity: {last_activity.isoformat()})") + except Exception as e: + logger.warning(f"Error processing thread {thread_id}: {e}") + + finally: + await self.return_to_pool(checkpointer) + + result = { + "threads_deleted": threads_deleted, + "threads_scanned": threads_scanned, + "cutoff_time": cutoff_time.isoformat(), + "older_than_days": older_than_days + } + + logger.info(f"Checkpoint cleanup completed: {result}") + return result + + async def _cleanup_loop(self): + """后台清理循环""" + interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600 + + logger.info( + f"Checkpoint cleanup scheduler started: " + f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, " + f"older_than={CHECKPOINT_CLEANUP_OLDER_THAN_DAYS} days" + ) + + while not self._cleanup_stop_event.is_set(): + try: + # 等待间隔时间或停止信号 + try: + await asyncio.wait_for( + self._cleanup_stop_event.wait(), + timeout=interval_seconds + ) + # 收到停止信号,退出循环 + break + except asyncio.TimeoutError: + # 超时,执行清理任务 + pass + + if self._cleanup_stop_event.is_set(): + break + + # 执行清理 + await self.cleanup_old_threads() + + except asyncio.CancelledError: + logger.info("Cleanup task cancelled") + break + except Exception as e: + logger.error(f"Error in cleanup loop: {e}") + + logger.info("Checkpoint cleanup scheduler stopped") + + def start_cleanup_scheduler(self): + """启动后台定时清理任务""" + if not CHECKPOINT_CLEANUP_ENABLED: + logger.info("Checkpoint cleanup is disabled") + return + + if self._cleanup_task is not None and not self._cleanup_task.done(): + logger.warning("Cleanup scheduler is already running") + return + + # 创建新的事件循环(如果需要在后台运行) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + self._cleanup_task = loop.create_task(self._cleanup_loop()) + logger.info("Cleanup scheduler task created") + # 全局单例 _global_manager: Optional[CheckpointerManager] = None diff --git a/fastapi_app.py b/fastapi_app.py index f5302fa..79def96 100644 --- a/fastapi_app.py +++ b/fastapi_app.py @@ -26,15 +26,33 @@ async def lifespan(app: FastAPI): """FastAPI 应用生命周期管理""" # 启动时初始化 logger.info("Starting up...") - from agent.checkpoint_manager import init_global_checkpointer + from agent.checkpoint_manager import ( + init_global_checkpointer, + get_checkpointer_manager, + close_global_checkpointer + ) + from utils.settings import CHECKPOINT_CLEANUP_ENABLED + await init_global_checkpointer() logger.info("Global checkpointer initialized") + # 启动 checkpoint 清理调度器 + if CHECKPOINT_CLEANUP_ENABLED: + manager = get_checkpointer_manager() + # 启动时立即执行一次清理 + try: + result = await manager.cleanup_old_threads() + logger.info(f"Startup cleanup completed: {result}") + except Exception as e: + logger.warning(f"Startup cleanup failed (non-fatal): {e}") + # 启动定时清理调度器 + manager.start_cleanup_scheduler() + logger.info("Checkpoint cleanup scheduler started") + yield # 关闭时清理 logger.info("Shutting down...") - from agent.checkpoint_manager import close_global_checkpointer await close_global_checkpointer() logger.info("Global checkpointer closed") diff --git a/utils/settings.py b/utils/settings.py index e2340ae..abd4076 100644 --- a/utils/settings.py +++ b/utils/settings.py @@ -59,4 +59,16 @@ CHECKPOINT_BUSY_TIMEOUT = int(os.getenv("CHECKPOINT_BUSY_TIMEOUT", "10000")) # 同时可以持有的最大连接数 CHECKPOINT_POOL_SIZE = int(os.getenv("CHECKPOINT_POOL_SIZE", "15")) +# Checkpoint 自动清理配置 +# 是否启用自动清理旧 session +CHECKPOINT_CLEANUP_ENABLED = os.getenv("CHECKPOINT_CLEANUP_ENABLED", "true") == "true" + +# 清理间隔(小时) +# 每隔多少小时执行一次清理任务 +CHECKPOINT_CLEANUP_INTERVAL_HOURS = int(os.getenv("CHECKPOINT_CLEANUP_INTERVAL_HOURS", "24")) + +# 清理多少天前的数据 +# 超过 N 天未活动的 thread 会被删除 +CHECKPOINT_CLEANUP_OLDER_THAN_DAYS = int(os.getenv("CHECKPOINT_CLEANUP_OLDER_THAN_DAYS", "3")) +