qwen_agent/agent/checkpoint_manager.py
2026-01-19 23:39:04 +08:00

115 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
全局 PostgreSQL Checkpointer 管理器
使用共享的数据库连接池
"""
import logging
from typing import Optional
from psycopg_pool import AsyncConnectionPool
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
logger = logging.getLogger('app')
class CheckpointerManager:
"""
全局 Checkpointer 管理器,使用共享的 PostgreSQL 连接池
主要功能:
1. 使用 DBPoolManager 的共享连接池
2. AsyncPostgresSaver 原生支持连接池,自动获取/释放连接
3. 无需手动归还,避免长请求占用连接
"""
def __init__(self):
self._pool: Optional[AsyncConnectionPool] = None
self._initialized = False
self._closed = False
async def initialize(self, pool: AsyncConnectionPool) -> None:
"""初始化 checkpointer使用外部传入的连接池
Args:
pool: AsyncConnectionPool 实例(来自 DBPoolManager
"""
if self._initialized:
return
self._pool = pool
logger.info("Initializing PostgreSQL checkpointer (using shared connection pool)...")
try:
# 创建表结构(需要 autocommit 模式来执行 CREATE INDEX CONCURRENTLY
async with self._pool.connection() as conn:
await conn.set_autocommit(True)
checkpointer = AsyncPostgresSaver(conn=conn)
await checkpointer.setup()
self._initialized = True
logger.info("PostgreSQL checkpointer initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL checkpointer: {e}")
raise
@property
def checkpointer(self) -> AsyncPostgresSaver:
"""获取全局 AsyncPostgresSaver 实例"""
if self._closed:
raise RuntimeError("CheckpointerManager is closed")
if not self._initialized:
raise RuntimeError("CheckpointerManager not initialized, call initialize() first")
return AsyncPostgresSaver(conn=self._pool)
@property
def pool(self) -> AsyncConnectionPool:
"""获取连接池(用于共享给其他管理器)"""
if self._closed:
raise RuntimeError("CheckpointerManager is closed")
if not self._initialized:
raise RuntimeError("CheckpointerManager not initialized, call initialize() first")
return self._pool
async def close(self) -> None:
"""关闭 Checkpointer连接池由 DBPoolManager 管理,这里不关闭)"""
if self._closed:
return
logger.info("Closing CheckpointerManager...")
self._closed = True
self._initialized = False
logger.info("CheckpointerManager closed (pool managed by DBPoolManager)")
# 全局单例
_global_manager: Optional[CheckpointerManager] = None
def get_checkpointer_manager() -> CheckpointerManager:
"""获取全局 CheckpointerManager 单例"""
global _global_manager
if _global_manager is None:
_global_manager = CheckpointerManager()
return _global_manager
async def init_global_checkpointer(pool: AsyncConnectionPool) -> None:
"""初始化全局 checkpointer 管理器
Args:
pool: AsyncConnectionPool 实例(来自 DBPoolManager
"""
manager = get_checkpointer_manager()
await manager.initialize(pool)
async def close_global_checkpointer() -> None:
"""关闭全局 checkpointer 管理器"""
global _global_manager
if _global_manager is not None:
await _global_manager.close()