""" 全局 SQLite Checkpointer 管理器 解决高并发场景下的数据库锁定问题 """ import asyncio import logging import os from datetime import datetime, timedelta, timezone from typing import Optional, List, Dict, Any import aiosqlite from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from utils.settings import ( CHECKPOINT_DB_PATH, CHECKPOINT_WAL_MODE, CHECKPOINT_BUSY_TIMEOUT, CHECKPOINT_POOL_SIZE, CHECKPOINT_CLEANUP_ENABLED, CHECKPOINT_CLEANUP_INTERVAL_HOURS, CHECKPOINT_CLEANUP_OLDER_THAN_DAYS, ) logger = logging.getLogger('app') class CheckpointerManager: """ 全局 Checkpointer 管理器,使用连接池复用 SQLite 连接 主要功能: 1. 全局单例连接管理,避免每次请求创建新连接 2. 预配置 WAL 模式和 busy_timeout 3. 连接池支持高并发访问 4. 优雅关闭机制 """ def __init__(self): self._pool: asyncio.Queue[AsyncSqliteSaver] = asyncio.Queue() self._lock = asyncio.Lock() self._initialized = False self._closed = False self._pool_size = CHECKPOINT_POOL_SIZE self._db_path = CHECKPOINT_DB_PATH # 清理调度任务 self._cleanup_task: Optional[asyncio.Task] = None self._cleanup_stop_event = asyncio.Event() async def initialize(self) -> None: """初始化连接池""" if self._initialized: return async with self._lock: if self._initialized: return logger.info(f"Initializing CheckpointerManager with pool_size={self._pool_size}") # 确保目录存在 os.makedirs(os.path.dirname(self._db_path), exist_ok=True) # 创建连接池 for i in range(self._pool_size): try: conn = await self._create_configured_connection() checkpointer = AsyncSqliteSaver(conn=conn) # 预先调用 setup 确保表结构已创建 await checkpointer.setup() await self._pool.put(checkpointer) logger.debug(f"Created checkpointer connection {i+1}/{self._pool_size}") except Exception as e: logger.error(f"Failed to create checkpointer connection {i+1}: {e}") raise self._initialized = True logger.info("CheckpointerManager initialized successfully") async def _create_configured_connection(self) -> aiosqlite.Connection: """ 创建已配置的 SQLite 连接 配置包括: 1. WAL 模式 (Write-Ahead Logging) - 允许读写并发 2. busy_timeout - 等待锁定的最长时间 3. 其他优化参数 """ conn = aiosqlite.connect(self._db_path) # 等待连接建立 await conn.__aenter__() # 设置 busy timeout(必须在连接建立后设置) await conn.execute(f"PRAGMA busy_timeout = {CHECKPOINT_BUSY_TIMEOUT}") # 如果启用 WAL 模式 if CHECKPOINT_WAL_MODE: await conn.execute("PRAGMA journal_mode = WAL") await conn.execute("PRAGMA synchronous = NORMAL") # WAL 模式下的优化配置 await conn.execute("PRAGMA wal_autocheckpoint = 1000") await conn.execute("PRAGMA cache_size = -64000") # 64MB 缓存 await conn.execute("PRAGMA temp_store = MEMORY") await conn.commit() return conn async def acquire_for_agent(self) -> AsyncSqliteSaver: """ 为 agent 获取 checkpointer 注意:此方法获取的 checkpointer 需要手动归还 使用 return_to_pool() 方法归还 Returns: AsyncSqliteSaver 实例 """ if not self._initialized: raise RuntimeError("CheckpointerManager not initialized. Call initialize() first.") checkpointer = await self._pool.get() logger.debug(f"Acquired checkpointer from pool, remaining: {self._pool.qsize()}") return checkpointer async def return_to_pool(self, checkpointer: AsyncSqliteSaver) -> None: """ 归还 checkpointer 到池 Args: checkpointer: 要归还的 checkpointer 实例 """ await self._pool.put(checkpointer) logger.debug(f"Returned checkpointer to pool, remaining: {self._pool.qsize()}") async def close(self) -> None: """关闭所有连接""" if self._closed: return # 停止清理任务 if self._cleanup_task is not None: self._cleanup_stop_event.set() try: self._cleanup_task.cancel() await asyncio.sleep(0.1) # 给任务一点时间清理 except asyncio.CancelledError: pass self._cleanup_task = None async with self._lock: if self._closed: return logger.info("Closing CheckpointerManager...") # 清空池并关闭所有连接 while not self._pool.empty(): try: checkpointer = self._pool.get_nowait() if checkpointer.conn: await checkpointer.conn.close() except asyncio.QueueEmpty: break self._closed = True self._initialized = False logger.info("CheckpointerManager closed") def get_pool_stats(self) -> dict: """获取连接池状态统计""" return { "db_path": self._db_path, "pool_size": self._pool_size, "available_connections": self._pool.qsize(), "initialized": self._initialized, "closed": self._closed } # ============================================================ # Checkpoint 清理方法 # ============================================================ async def get_all_thread_ids(self) -> List[str]: """ 获取数据库中所有唯一的 thread_id Returns: List[str]: 所有 thread_id 列表 """ if not self._initialized: return [] conn = aiosqlite.connect(self._db_path) await conn.__aenter__() try: cursor = await conn.execute( "SELECT DISTINCT thread_id FROM checkpoints" ) rows = await cursor.fetchall() return [row[0] for row in rows] finally: await conn.close() async def get_thread_last_activity(self, thread_id: str) -> Optional[datetime]: """ 获取指定 thread 的最后活动时间 通过查询该 thread 最新的 checkpoint 中的 ts 字段获取时间 Args: thread_id: 线程ID Returns: datetime: 最后活动时间,如果找不到则返回 None """ if not self._initialized: return None checkpointer = await self.acquire_for_agent() try: config = {"configurable": {"thread_id": thread_id}} result = checkpointer.alist(config=config, limit=1) last_checkpoint = None async for item in result: last_checkpoint = item break if last_checkpoint and last_checkpoint.checkpoint: ts_str = last_checkpoint.checkpoint.get("ts") if ts_str: # 解析 ISO 格式时间戳 return datetime.fromisoformat(ts_str.replace("Z", "+00:00")) except Exception as e: logger.warning(f"Error getting last activity for thread {thread_id}: {e}") finally: await self.return_to_pool(checkpointer) return None async def cleanup_old_threads(self, older_than_days: int = None) -> Dict[str, Any]: """ 清理超过指定天数未活动的 thread Args: older_than_days: 清理多少天前的数据,默认使用配置值 Returns: Dict: 清理统计信息 - threads_deleted: 删除的 thread 数量 - threads_scanned: 扫描的 thread 总数 - cutoff_time: 截止时间 """ if older_than_days is None: older_than_days = CHECKPOINT_CLEANUP_OLDER_THAN_DAYS # 使用带时区的时间,避免比较时出错 cutoff_time = datetime.now(timezone.utc) - timedelta(days=older_than_days) logger.info(f"Starting checkpoint cleanup: removing threads inactive since {cutoff_time.isoformat()}") all_thread_ids = await self.get_all_thread_ids() threads_deleted = 0 threads_scanned = len(all_thread_ids) checkpointer = await self.acquire_for_agent() try: for thread_id in all_thread_ids: try: last_activity = await self.get_thread_last_activity(thread_id) if last_activity and last_activity < cutoff_time: # 删除旧 thread config = {"configurable": {"thread_id": thread_id}} await checkpointer.adelete_thread(config) threads_deleted += 1 logger.debug(f"Deleted old thread: {thread_id} (last activity: {last_activity.isoformat()})") except Exception as e: logger.warning(f"Error processing thread {thread_id}: {e}") finally: await self.return_to_pool(checkpointer) result = { "threads_deleted": threads_deleted, "threads_scanned": threads_scanned, "cutoff_time": cutoff_time.isoformat(), "older_than_days": older_than_days } logger.info(f"Checkpoint cleanup completed: {result}") return result async def _cleanup_loop(self): """后台清理循环""" interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600 logger.info( f"Checkpoint cleanup scheduler started: " f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, " f"older_than={CHECKPOINT_CLEANUP_OLDER_THAN_DAYS} days" ) while not self._cleanup_stop_event.is_set(): try: # 等待间隔时间或停止信号 try: await asyncio.wait_for( self._cleanup_stop_event.wait(), timeout=interval_seconds ) # 收到停止信号,退出循环 break except asyncio.TimeoutError: # 超时,执行清理任务 pass if self._cleanup_stop_event.is_set(): break # 执行清理 await self.cleanup_old_threads() except asyncio.CancelledError: logger.info("Cleanup task cancelled") break except Exception as e: logger.error(f"Error in cleanup loop: {e}") logger.info("Checkpoint cleanup scheduler stopped") def start_cleanup_scheduler(self): """启动后台定时清理任务""" if not CHECKPOINT_CLEANUP_ENABLED: logger.info("Checkpoint cleanup is disabled") return if self._cleanup_task is not None and not self._cleanup_task.done(): logger.warning("Cleanup scheduler is already running") return # 创建新的事件循环(如果需要在后台运行) try: loop = asyncio.get_running_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) self._cleanup_task = loop.create_task(self._cleanup_loop()) logger.info("Cleanup scheduler task created") # 全局单例 _global_manager: Optional[CheckpointerManager] = None def get_checkpointer_manager() -> CheckpointerManager: """获取全局 CheckpointerManager 单例""" global _global_manager if _global_manager is None: _global_manager = CheckpointerManager() return _global_manager async def init_global_checkpointer() -> None: """初始化全局 checkpointer 管理器""" manager = get_checkpointer_manager() await manager.initialize() async def close_global_checkpointer() -> None: """关闭全局 checkpointer 管理器""" global _global_manager if _global_manager is not None: await _global_manager.close()