""" 全局 PostgreSQL Checkpointer 管理器 使用共享的数据库连接池 """ import logging from typing import Optional from psycopg_pool import AsyncConnectionPool from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver logger = logging.getLogger('app') class CheckpointerManager: """ 全局 Checkpointer 管理器,使用共享的 PostgreSQL 连接池 主要功能: 1. 使用 DBPoolManager 的共享连接池 2. AsyncPostgresSaver 原生支持连接池,自动获取/释放连接 3. 无需手动归还,避免长请求占用连接 """ def __init__(self): self._pool: Optional[AsyncConnectionPool] = None self._initialized = False self._closed = False async def initialize(self, pool: AsyncConnectionPool) -> None: """初始化 checkpointer,使用外部传入的连接池 Args: pool: AsyncConnectionPool 实例(来自 DBPoolManager) """ if self._initialized: return self._pool = pool logger.info("Initializing PostgreSQL checkpointer (using shared connection pool)...") try: # 创建表结构(需要 autocommit 模式来执行 CREATE INDEX CONCURRENTLY) async with self._pool.connection() as conn: await conn.set_autocommit(True) checkpointer = AsyncPostgresSaver(conn=conn) await checkpointer.setup() self._initialized = True logger.info("PostgreSQL checkpointer initialized successfully") except Exception as e: logger.error(f"Failed to initialize PostgreSQL checkpointer: {e}") raise @property def checkpointer(self) -> AsyncPostgresSaver: """获取全局 AsyncPostgresSaver 实例""" if self._closed: raise RuntimeError("CheckpointerManager is closed") if not self._initialized: raise RuntimeError("CheckpointerManager not initialized, call initialize() first") return AsyncPostgresSaver(conn=self._pool) @property def pool(self) -> AsyncConnectionPool: """获取连接池(用于共享给其他管理器)""" if self._closed: raise RuntimeError("CheckpointerManager is closed") if not self._initialized: raise RuntimeError("CheckpointerManager not initialized, call initialize() first") return self._pool async def close(self) -> None: """关闭 Checkpointer(连接池由 DBPoolManager 管理,这里不关闭)""" if self._closed: return logger.info("Closing CheckpointerManager...") self._closed = True self._initialized = False logger.info("CheckpointerManager closed (pool managed by DBPoolManager)") # 全局单例 _global_manager: Optional[CheckpointerManager] = None def get_checkpointer_manager() -> CheckpointerManager: """获取全局 CheckpointerManager 单例""" global _global_manager if _global_manager is None: _global_manager = CheckpointerManager() return _global_manager async def init_global_checkpointer(pool: AsyncConnectionPool) -> None: """初始化全局 checkpointer 管理器 Args: pool: AsyncConnectionPool 实例(来自 DBPoolManager) """ manager = get_checkpointer_manager() await manager.initialize(pool) async def close_global_checkpointer() -> None: """关闭全局 checkpointer 管理器""" global _global_manager if _global_manager is not None: await _global_manager.close()