qwen_agent/agent/checkpoint_manager.py
2025-12-24 20:43:10 +08:00

277 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
全局 PostgreSQL Checkpointer 管理器
使用 psycopg_pool 连接池AsyncPostgresSaver 原生支持
"""
import asyncio
import logging
from typing import Optional
from psycopg_pool import AsyncConnectionPool
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from utils.settings import (
CHECKPOINT_DB_URL,
CHECKPOINT_POOL_SIZE,
CHECKPOINT_CLEANUP_ENABLED,
CHECKPOINT_CLEANUP_INTERVAL_HOURS,
CHECKPOINT_CLEANUP_INACTIVE_DAYS,
)
logger = logging.getLogger('app')
class CheckpointerManager:
"""
全局 Checkpointer 管理器,使用 PostgreSQL 连接池
主要功能:
1. 使用 psycopg_pool.AsyncConnectionPool 管理连接
2. AsyncPostgresSaver 原生支持连接池,自动获取/释放连接
3. 无需手动归还,避免长请求占用连接
4. 基于 SQL 查询的清理机制
5. 优雅关闭机制
"""
def __init__(self):
self._pool: Optional[AsyncConnectionPool] = None
self._initialized = False
self._closed = False
# 清理调度任务
self._cleanup_task: Optional[asyncio.Task] = None
self._cleanup_stop_event = asyncio.Event()
async def initialize(self) -> None:
"""初始化连接池"""
if self._initialized:
return
logger.info(
f"Initializing PostgreSQL checkpointer pool: "
f"URL={CHECKPOINT_DB_URL}, size={CHECKPOINT_POOL_SIZE}"
)
try:
# 创建 psycopg 连接池
self._pool = AsyncConnectionPool(
CHECKPOINT_DB_URL,
min_size=1,
max_size=CHECKPOINT_POOL_SIZE,
open=False,
)
# 打开连接池
await self._pool.open()
# 创建表结构(需要 autocommit 模式来执行 CREATE INDEX CONCURRENTLY
async with self._pool.connection() as conn:
await conn.set_autocommit(True)
checkpointer = AsyncPostgresSaver(conn=conn)
await checkpointer.setup()
self._initialized = True
logger.info("PostgreSQL checkpointer pool initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL checkpointer pool: {e}")
raise
@property
def checkpointer(self) -> AsyncPostgresSaver:
"""获取全局 AsyncPostgresSaver 实例"""
if self._closed:
raise RuntimeError("CheckpointerManager is closed")
if not self._initialized:
raise RuntimeError("CheckpointerManager not initialized, call initialize() first")
return AsyncPostgresSaver(conn=self._pool)
async def close(self) -> None:
"""关闭连接池"""
if self._closed:
return
# 停止清理任务
if self._cleanup_task is not None:
self._cleanup_stop_event.set()
try:
self._cleanup_task.cancel()
await asyncio.sleep(0.1)
except asyncio.CancelledError:
pass
self._cleanup_task = None
if self._closed:
return
logger.info("Closing CheckpointerManager...")
if self._pool is not None:
await self._pool.close()
self._pool = None
self._closed = True
self._initialized = False
logger.info("CheckpointerManager closed")
# ============================================================
# Checkpoint 清理方法(基于 PostgreSQL SQL 查询)
# ============================================================
async def cleanup_old_dbs(self, inactive_days: int = None) -> dict:
"""
清理旧的 checkpoint 记录
删除 N 天未活动的 thread 的所有 checkpoint。
Args:
inactive_days: 删除 N 天未活动的 thread默认使用配置值
Returns:
Dict: 清理统计信息
"""
if inactive_days is None:
inactive_days = CHECKPOINT_CLEANUP_INACTIVE_DAYS
logger.info(f"Starting checkpoint cleanup: removing threads inactive for {inactive_days} days")
if self._pool is None:
logger.warning("Connection pool not initialized, skipping cleanup")
return {"deleted": 0, "inactive_days": inactive_days}
try:
# 从池中获取连接执行清理
async with self._pool.connection() as conn:
async with conn.cursor() as cursor:
# 查找不活跃的 thread
query_find_inactive = """
SELECT DISTINCT thread_id
FROM checkpoints
WHERE (checkpoint->>'ts')::timestamp < NOW() - INTERVAL '%s days'
"""
await cursor.execute(query_find_inactive, (inactive_days,))
inactive_threads = await cursor.fetchall()
inactive_thread_ids = [row[0] for row in inactive_threads]
if not inactive_thread_ids:
logger.info("No inactive threads found")
return {"deleted": 0, "threads_deleted": 0, "inactive_days": inactive_days}
# 删除不活跃 thread 的 checkpoints
placeholders = ','.join(['%s'] * len(inactive_thread_ids))
query_delete_checkpoints = f"""
DELETE FROM checkpoints
WHERE thread_id IN ({placeholders})
"""
await cursor.execute(query_delete_checkpoints, inactive_thread_ids)
deleted_checkpoints = cursor.rowcount
# 同时清理 checkpoint_writes
query_delete_writes = f"""
DELETE FROM checkpoint_writes
WHERE thread_id IN ({placeholders})
"""
await cursor.execute(query_delete_writes, inactive_thread_ids)
deleted_writes = cursor.rowcount
logger.info(
f"Checkpoint cleanup completed: "
f"deleted {deleted_checkpoints} checkpoints, {deleted_writes} writes "
f"from {len(inactive_thread_ids)} inactive threads"
)
return {
"deleted": deleted_checkpoints + deleted_writes,
"threads_deleted": len(inactive_thread_ids),
"inactive_days": inactive_days
}
except Exception as e:
logger.error(f"Error during checkpoint cleanup: {e}")
import traceback
logger.error(traceback.format_exc())
return {"deleted": 0, "error": str(e), "inactive_days": inactive_days}
async def _cleanup_loop(self):
"""后台清理循环"""
interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600
logger.info(
f"Checkpoint cleanup scheduler started: "
f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, "
f"inactive_days={CHECKPOINT_CLEANUP_INACTIVE_DAYS}"
)
while not self._cleanup_stop_event.is_set():
try:
# 等待间隔时间或停止信号
try:
await asyncio.wait_for(
self._cleanup_stop_event.wait(),
timeout=interval_seconds
)
# 收到停止信号,退出循环
break
except asyncio.TimeoutError:
# 超时,执行清理任务
pass
if self._cleanup_stop_event.is_set():
break
# 执行清理
await self.cleanup_old_dbs()
except asyncio.CancelledError:
logger.info("Cleanup task cancelled")
break
except Exception as e:
logger.error(f"Error in cleanup loop: {e}")
logger.info("Checkpoint cleanup scheduler stopped")
def start_cleanup_scheduler(self):
"""启动后台定时清理任务"""
if not CHECKPOINT_CLEANUP_ENABLED:
logger.info("Checkpoint cleanup is disabled")
return
if self._cleanup_task is not None and not self._cleanup_task.done():
logger.warning("Cleanup scheduler is already running")
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._cleanup_task = loop.create_task(self._cleanup_loop())
logger.info("Cleanup scheduler task created")
# 全局单例
_global_manager: Optional[CheckpointerManager] = None
def get_checkpointer_manager() -> CheckpointerManager:
"""获取全局 CheckpointerManager 单例"""
global _global_manager
if _global_manager is None:
_global_manager = CheckpointerManager()
return _global_manager
async def init_global_checkpointer() -> None:
"""初始化全局 checkpointer 管理器"""
manager = get_checkpointer_manager()
await manager.initialize()
async def close_global_checkpointer() -> None:
"""关闭全局 checkpointer 管理器"""
global _global_manager
if _global_manager is not None:
await _global_manager.close()