115 lines
3.5 KiB
Python
115 lines
3.5 KiB
Python
"""
|
||
全局 PostgreSQL Checkpointer 管理器
|
||
使用共享的数据库连接池
|
||
"""
|
||
import logging
|
||
from typing import Optional
|
||
|
||
from psycopg_pool import AsyncConnectionPool
|
||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
|
||
class CheckpointerManager:
|
||
"""
|
||
全局 Checkpointer 管理器,使用共享的 PostgreSQL 连接池
|
||
|
||
主要功能:
|
||
1. 使用 DBPoolManager 的共享连接池
|
||
2. AsyncPostgresSaver 原生支持连接池,自动获取/释放连接
|
||
3. 无需手动归还,避免长请求占用连接
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._pool: Optional[AsyncConnectionPool] = None
|
||
self._initialized = False
|
||
self._closed = False
|
||
|
||
async def initialize(self, pool: AsyncConnectionPool) -> None:
|
||
"""初始化 checkpointer,使用外部传入的连接池
|
||
|
||
Args:
|
||
pool: AsyncConnectionPool 实例(来自 DBPoolManager)
|
||
"""
|
||
if self._initialized:
|
||
return
|
||
|
||
self._pool = pool
|
||
|
||
logger.info("Initializing PostgreSQL checkpointer (using shared connection pool)...")
|
||
|
||
try:
|
||
# 创建表结构(需要 autocommit 模式来执行 CREATE INDEX CONCURRENTLY)
|
||
async with self._pool.connection() as conn:
|
||
await conn.set_autocommit(True)
|
||
checkpointer = AsyncPostgresSaver(conn=conn)
|
||
await checkpointer.setup()
|
||
|
||
self._initialized = True
|
||
logger.info("PostgreSQL checkpointer initialized successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize PostgreSQL checkpointer: {e}")
|
||
raise
|
||
|
||
@property
|
||
def checkpointer(self) -> AsyncPostgresSaver:
|
||
"""获取全局 AsyncPostgresSaver 实例"""
|
||
if self._closed:
|
||
raise RuntimeError("CheckpointerManager is closed")
|
||
|
||
if not self._initialized:
|
||
raise RuntimeError("CheckpointerManager not initialized, call initialize() first")
|
||
|
||
return AsyncPostgresSaver(conn=self._pool)
|
||
|
||
@property
|
||
def pool(self) -> AsyncConnectionPool:
|
||
"""获取连接池(用于共享给其他管理器)"""
|
||
if self._closed:
|
||
raise RuntimeError("CheckpointerManager is closed")
|
||
|
||
if not self._initialized:
|
||
raise RuntimeError("CheckpointerManager not initialized, call initialize() first")
|
||
|
||
return self._pool
|
||
|
||
async def close(self) -> None:
|
||
"""关闭 Checkpointer(连接池由 DBPoolManager 管理,这里不关闭)"""
|
||
if self._closed:
|
||
return
|
||
|
||
logger.info("Closing CheckpointerManager...")
|
||
self._closed = True
|
||
self._initialized = False
|
||
logger.info("CheckpointerManager closed (pool managed by DBPoolManager)")
|
||
|
||
|
||
# 全局单例
|
||
_global_manager: Optional[CheckpointerManager] = None
|
||
|
||
|
||
def get_checkpointer_manager() -> CheckpointerManager:
|
||
"""获取全局 CheckpointerManager 单例"""
|
||
global _global_manager
|
||
if _global_manager is None:
|
||
_global_manager = CheckpointerManager()
|
||
return _global_manager
|
||
|
||
|
||
async def init_global_checkpointer(pool: AsyncConnectionPool) -> None:
|
||
"""初始化全局 checkpointer 管理器
|
||
|
||
Args:
|
||
pool: AsyncConnectionPool 实例(来自 DBPoolManager)
|
||
"""
|
||
manager = get_checkpointer_manager()
|
||
await manager.initialize(pool)
|
||
|
||
|
||
async def close_global_checkpointer() -> None:
|
||
"""关闭全局 checkpointer 管理器"""
|
||
global _global_manager
|
||
if _global_manager is not None:
|
||
await _global_manager.close()
|