""" 全局 PostgreSQL Checkpointer 管理器 使用 psycopg_pool 连接池,AsyncPostgresSaver 原生支持 """ import asyncio import logging from typing import Optional from psycopg_pool import AsyncConnectionPool from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from utils.settings import ( CHECKPOINT_DB_URL, CHECKPOINT_POOL_SIZE, CHECKPOINT_CLEANUP_ENABLED, CHECKPOINT_CLEANUP_INTERVAL_HOURS, CHECKPOINT_CLEANUP_INACTIVE_DAYS, ) logger = logging.getLogger('app') class CheckpointerManager: """ 全局 Checkpointer 管理器,使用 PostgreSQL 连接池 主要功能: 1. 使用 psycopg_pool.AsyncConnectionPool 管理连接 2. AsyncPostgresSaver 原生支持连接池,自动获取/释放连接 3. 无需手动归还,避免长请求占用连接 4. 基于 SQL 查询的清理机制 5. 优雅关闭机制 """ def __init__(self): self._pool: Optional[AsyncConnectionPool] = None self._initialized = False self._closed = False # 清理调度任务 self._cleanup_task: Optional[asyncio.Task] = None self._cleanup_stop_event = asyncio.Event() async def initialize(self) -> None: """初始化连接池""" if self._initialized: return logger.info( f"Initializing PostgreSQL checkpointer pool: " f"URL={CHECKPOINT_DB_URL}, size={CHECKPOINT_POOL_SIZE}" ) try: # 创建 psycopg 连接池 self._pool = AsyncConnectionPool( CHECKPOINT_DB_URL, min_size=1, max_size=CHECKPOINT_POOL_SIZE, open=False, ) # 打开连接池 await self._pool.open() # 创建表结构(需要 autocommit 模式来执行 CREATE INDEX CONCURRENTLY) async with self._pool.connection() as conn: await conn.set_autocommit(True) checkpointer = AsyncPostgresSaver(conn=conn) await checkpointer.setup() self._initialized = True logger.info("PostgreSQL checkpointer pool initialized successfully") except Exception as e: logger.error(f"Failed to initialize PostgreSQL checkpointer pool: {e}") raise @property def checkpointer(self) -> AsyncPostgresSaver: """获取全局 AsyncPostgresSaver 实例""" if self._closed: raise RuntimeError("CheckpointerManager is closed") if not self._initialized: raise RuntimeError("CheckpointerManager not initialized, call initialize() first") return AsyncPostgresSaver(conn=self._pool) async def close(self) -> None: """关闭连接池""" if self._closed: return # 停止清理任务 if self._cleanup_task is not None: self._cleanup_stop_event.set() try: self._cleanup_task.cancel() await asyncio.sleep(0.1) except asyncio.CancelledError: pass self._cleanup_task = None if self._closed: return logger.info("Closing CheckpointerManager...") if self._pool is not None: await self._pool.close() self._pool = None self._closed = True self._initialized = False logger.info("CheckpointerManager closed") # ============================================================ # Checkpoint 清理方法(基于 PostgreSQL SQL 查询) # ============================================================ async def cleanup_old_dbs(self, inactive_days: int = None) -> dict: """ 清理旧的 checkpoint 记录 删除 N 天未活动的 thread 的所有 checkpoint。 Args: inactive_days: 删除 N 天未活动的 thread,默认使用配置值 Returns: Dict: 清理统计信息 """ if inactive_days is None: inactive_days = CHECKPOINT_CLEANUP_INACTIVE_DAYS logger.info(f"Starting checkpoint cleanup: removing threads inactive for {inactive_days} days") if self._pool is None: logger.warning("Connection pool not initialized, skipping cleanup") return {"deleted": 0, "inactive_days": inactive_days} try: # 从池中获取连接执行清理 async with self._pool.connection() as conn: async with conn.cursor() as cursor: # 查找不活跃的 thread query_find_inactive = """ SELECT DISTINCT thread_id FROM checkpoints WHERE (checkpoint->>'ts')::timestamp < NOW() - INTERVAL '%s days' """ await cursor.execute(query_find_inactive, (inactive_days,)) inactive_threads = await cursor.fetchall() inactive_thread_ids = [row[0] for row in inactive_threads] if not inactive_thread_ids: logger.info("No inactive threads found") return {"deleted": 0, "threads_deleted": 0, "inactive_days": inactive_days} # 删除不活跃 thread 的 checkpoints placeholders = ','.join(['%s'] * len(inactive_thread_ids)) query_delete_checkpoints = f""" DELETE FROM checkpoints WHERE thread_id IN ({placeholders}) """ await cursor.execute(query_delete_checkpoints, inactive_thread_ids) deleted_checkpoints = cursor.rowcount # 同时清理 checkpoint_writes query_delete_writes = f""" DELETE FROM checkpoint_writes WHERE thread_id IN ({placeholders}) """ await cursor.execute(query_delete_writes, inactive_thread_ids) deleted_writes = cursor.rowcount logger.info( f"Checkpoint cleanup completed: " f"deleted {deleted_checkpoints} checkpoints, {deleted_writes} writes " f"from {len(inactive_thread_ids)} inactive threads" ) return { "deleted": deleted_checkpoints + deleted_writes, "threads_deleted": len(inactive_thread_ids), "inactive_days": inactive_days } except Exception as e: logger.error(f"Error during checkpoint cleanup: {e}") import traceback logger.error(traceback.format_exc()) return {"deleted": 0, "error": str(e), "inactive_days": inactive_days} async def _cleanup_loop(self): """后台清理循环""" interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600 logger.info( f"Checkpoint cleanup scheduler started: " f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, " f"inactive_days={CHECKPOINT_CLEANUP_INACTIVE_DAYS}" ) while not self._cleanup_stop_event.is_set(): try: # 等待间隔时间或停止信号 try: await asyncio.wait_for( self._cleanup_stop_event.wait(), timeout=interval_seconds ) # 收到停止信号,退出循环 break except asyncio.TimeoutError: # 超时,执行清理任务 pass if self._cleanup_stop_event.is_set(): break # 执行清理 await self.cleanup_old_dbs() except asyncio.CancelledError: logger.info("Cleanup task cancelled") break except Exception as e: logger.error(f"Error in cleanup loop: {e}") logger.info("Checkpoint cleanup scheduler stopped") def start_cleanup_scheduler(self): """启动后台定时清理任务""" if not CHECKPOINT_CLEANUP_ENABLED: logger.info("Checkpoint cleanup is disabled") return if self._cleanup_task is not None and not self._cleanup_task.done(): logger.warning("Cleanup scheduler is already running") return try: loop = asyncio.get_running_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) self._cleanup_task = loop.create_task(self._cleanup_loop()) logger.info("Cleanup scheduler task created") # 全局单例 _global_manager: Optional[CheckpointerManager] = None def get_checkpointer_manager() -> CheckpointerManager: """获取全局 CheckpointerManager 单例""" global _global_manager if _global_manager is None: _global_manager = CheckpointerManager() return _global_manager async def init_global_checkpointer() -> None: """初始化全局 checkpointer 管理器""" manager = get_checkpointer_manager() await manager.initialize() async def close_global_checkpointer() -> None: """关闭全局 checkpointer 管理器""" global _global_manager if _global_manager is not None: await _global_manager.close()