""" 全局 SQLite Checkpointer 管理器 解决高并发场景下的数据库锁定问题 每个 session 使用独立的数据库文件,避免并发锁竞争 """ import asyncio import logging import os import time from typing import Optional, Dict, Any, Tuple import aiosqlite from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from utils.settings import ( CHECKPOINT_DB_PATH, CHECKPOINT_WAL_MODE, CHECKPOINT_BUSY_TIMEOUT, CHECKPOINT_CLEANUP_ENABLED, CHECKPOINT_CLEANUP_INTERVAL_HOURS, CHECKPOINT_CLEANUP_OLDER_THAN_DAYS, ) logger = logging.getLogger('app') # 每个 session 的连接池大小(单个 session 串行处理,1 个连接即可) POOL_SIZE_PER_SESSION = 1 class CheckpointerManager: """ 全局 Checkpointer 管理器,按 session_id 分离数据库文件 主要功能: 1. 每个 session_id 独立的数据库文件和连接池 2. 按需创建连接池,不用的 session 不占用资源 3. 预配置 WAL 模式和 busy_timeout 4. 基于文件修改时间的简单清理机制 5. 优雅关闭机制 """ def __init__(self): # 每个 (bot_id, session_id) 一个连接池 self._pools: Dict[Tuple[str, str], asyncio.Queue[AsyncSqliteSaver]] = {} # 每个 session 的初始化锁 self._locks: Dict[Tuple[str, str], asyncio.Lock] = {} # 全局锁,用于保护 pools 和 locks 字典的访问 self._global_lock = asyncio.Lock() self._closed = False # 清理调度任务 self._cleanup_task: Optional[asyncio.Task] = None self._cleanup_stop_event = asyncio.Event() def _get_db_path(self, bot_id: str, session_id: str) -> str: """获取指定 session 的数据库文件路径""" return os.path.join(CHECKPOINT_DB_PATH, bot_id, session_id, "checkpoints.db") def _get_pool_key(self, bot_id: str, session_id: str) -> Tuple[str, str]: """获取连接池的键""" return (bot_id, session_id) async def _initialize_session_pool(self, bot_id: str, session_id: str) -> None: """初始化指定 session 的连接池""" pool_key = self._get_pool_key(bot_id, session_id) if pool_key in self._pools: return logger.info(f"Initializing checkpointer pool for bot_id={bot_id}, session_id={session_id}") db_path = self._get_db_path(bot_id, session_id) os.makedirs(os.path.dirname(db_path), exist_ok=True) pool = asyncio.Queue() for i in range(POOL_SIZE_PER_SESSION): try: conn = await self._create_configured_connection(db_path) checkpointer = AsyncSqliteSaver(conn=conn) # 预先调用 setup 确保表结构已创建 await checkpointer.setup() await pool.put(checkpointer) logger.debug(f"Created checkpointer connection {i+1}/{POOL_SIZE_PER_SESSION} for session={session_id}") except Exception as e: logger.error(f"Failed to create checkpointer connection {i+1} for session={session_id}: {e}") raise self._pools[pool_key] = pool self._locks[pool_key] = asyncio.Lock() logger.info(f"Checkpointer pool initialized for bot_id={bot_id}, session_id={session_id}") async def _create_configured_connection(self, db_path: str) -> aiosqlite.Connection: """ 创建已配置的 SQLite 连接 配置包括: 1. WAL 模式 (Write-Ahead Logging) - 允许读写并发 2. busy_timeout - 等待锁定的最长时间 3. 其他优化参数 """ conn = aiosqlite.connect(db_path) # 等待连接建立 await conn.__aenter__() # 设置 busy timeout(必须在连接建立后设置) await conn.execute(f"PRAGMA busy_timeout = {CHECKPOINT_BUSY_TIMEOUT}") # 如果启用 WAL 模式 if CHECKPOINT_WAL_MODE: await conn.execute("PRAGMA journal_mode = WAL") await conn.execute("PRAGMA synchronous = NORMAL") # WAL 模式下的优化配置 await conn.execute("PRAGMA wal_autocheckpoint = 10000") # 增加到 10000 await conn.execute("PRAGMA cache_size = -64000") # 64MB 缓存 await conn.execute("PRAGMA temp_store = MEMORY") await conn.execute("PRAGMA journal_size_limit = 1048576") # 1MB await conn.commit() return conn async def initialize(self) -> None: """初始化管理器(不再预创建连接池,改为按需创建)""" logger.info("CheckpointerManager initialized (pools will be created on-demand)") async def acquire_for_agent(self, bot_id: str, session_id: str) -> AsyncSqliteSaver: """ 获取指定 session 的 checkpointer 注意:此方法获取的 checkpointer 需要手动归还 使用 return_to_pool() 方法归还 Args: bot_id: 机器人 ID session_id: 会话 ID Returns: AsyncSqliteSaver 实例 """ if self._closed: raise RuntimeError("CheckpointerManager is closed") pool_key = self._get_pool_key(bot_id, session_id) async with self._global_lock: if pool_key not in self._pools: await self._initialize_session_pool(bot_id, session_id) # 获取该 session 的锁,确保连接池操作线程安全 async with self._locks[pool_key]: checkpointer = await self._pools[pool_key].get() logger.debug(f"Acquired checkpointer for bot_id={bot_id}, session_id={session_id}, remaining: {self._pools[pool_key].qsize()}") return checkpointer async def return_to_pool(self, bot_id: str, session_id: str, checkpointer: AsyncSqliteSaver) -> None: """ 归还 checkpointer 到对应 session 的池 Args: bot_id: 机器人 ID session_id: 会话 ID checkpointer: 要归还的 checkpointer 实例 """ pool_key = self._get_pool_key(bot_id, session_id) if pool_key in self._pools: async with self._locks[pool_key]: await self._pools[pool_key].put(checkpointer) logger.debug(f"Returned checkpointer for bot_id={bot_id}, session_id={session_id}, remaining: {self._pools[pool_key].qsize()}") async def _close_session_pool(self, bot_id: str, session_id: str) -> None: """关闭指定 session 的连接池""" pool_key = self._get_pool_key(bot_id, session_id) if pool_key not in self._pools: return logger.info(f"Closing checkpointer pool for bot_id={bot_id}, session_id={session_id}") pool = self._pools[pool_key] while not pool.empty(): try: checkpointer = pool.get_nowait() if checkpointer.conn: await checkpointer.conn.close() except asyncio.QueueEmpty: break del self._pools[pool_key] if pool_key in self._locks: del self._locks[pool_key] logger.info(f"Checkpointer pool closed for bot_id={bot_id}, session_id={session_id}") async def close(self) -> None: """关闭所有连接池""" if self._closed: return # 停止清理任务 if self._cleanup_task is not None: self._cleanup_stop_event.set() try: self._cleanup_task.cancel() await asyncio.sleep(0.1) # 给任务一点时间清理 except asyncio.CancelledError: pass self._cleanup_task = None async with self._global_lock: if self._closed: return logger.info("Closing CheckpointerManager...") # 关闭所有 session 的连接池 pool_keys = list(self._pools.keys()) for bot_id, session_id in pool_keys: await self._close_session_pool(bot_id, session_id) self._closed = True logger.info("CheckpointerManager closed") def get_pool_stats(self) -> dict: """获取连接池状态统计""" return { "session_count": len(self._pools), "pools": { f"{bot_id}/{session_id}": { "available": pool.qsize(), "pool_size": POOL_SIZE_PER_SESSION } for (bot_id, session_id), pool in self._pools.items() }, "closed": self._closed } # ============================================================ # Checkpoint 清理方法(基于文件修改时间) # ============================================================ async def cleanup_old_dbs(self, older_than_days: int = None) -> Dict[str, Any]: """ 根据数据库文件的修改时间清理旧数据库文件 Args: older_than_days: 清理多少天前的数据,默认使用配置值 Returns: Dict: 清理统计信息 - deleted: 删除的 session 目录数量 - scanned: 扫描的 session 目录数量 - cutoff_time: 截止时间戳 """ if older_than_days is None: older_than_days = CHECKPOINT_CLEANUP_OLDER_THAN_DAYS cutoff_time = time.time() - older_than_days * 86400 logger.info(f"Starting checkpoint cleanup: removing db files not modified since {cutoff_time}") db_dir = CHECKPOINT_DB_PATH deleted_count = 0 scanned_count = 0 if not os.path.exists(db_dir): logger.info(f"Checkpoint directory does not exist: {db_dir}") return {"deleted": 0, "scanned": 0, "cutoff_time": cutoff_time} # 遍历 bot_id 目录 for bot_id in os.listdir(db_dir): bot_path = os.path.join(db_dir, bot_id) # 跳过非目录文件 if not os.path.isdir(bot_path): continue # 遍历 session_id 目录 for session_id in os.listdir(bot_path): session_path = os.path.join(bot_path, session_id) if not os.path.isdir(session_path): continue db_file = os.path.join(session_path, "checkpoints.db") if not os.path.exists(db_file): continue scanned_count += 1 mtime = os.path.getmtime(db_file) if mtime < cutoff_time: # 关闭该 session 的连接池(如果有) await self._close_session_pool(bot_id, session_id) # 删除整个 session 目录 try: import shutil shutil.rmtree(session_path) deleted_count += 1 logger.info(f"Deleted old checkpoint session: {bot_id}/{session_id}/ (last modified: {mtime})") except Exception as e: logger.warning(f"Failed to delete {session_path}: {e}") # 清理空的 bot_id 目录 for bot_id in os.listdir(db_dir): bot_path = os.path.join(db_dir, bot_id) if os.path.isdir(bot_path) and not os.listdir(bot_path): try: os.rmdir(bot_path) logger.debug(f"Removed empty bot directory: {bot_id}/") except Exception: pass result = { "deleted": deleted_count, "scanned": scanned_count, "cutoff_time": cutoff_time, "older_than_days": older_than_days } logger.info(f"Checkpoint cleanup completed: {result}") return result async def _cleanup_loop(self): """后台清理循环""" interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600 logger.info( f"Checkpoint cleanup scheduler started: " f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, " f"older_than={CHECKPOINT_CLEANUP_OLDER_THAN_DAYS} days" ) while not self._cleanup_stop_event.is_set(): try: # 等待间隔时间或停止信号 try: await asyncio.wait_for( self._cleanup_stop_event.wait(), timeout=interval_seconds ) # 收到停止信号,退出循环 break except asyncio.TimeoutError: # 超时,执行清理任务 pass if self._cleanup_stop_event.is_set(): break # 执行清理 await self.cleanup_old_dbs() except asyncio.CancelledError: logger.info("Cleanup task cancelled") break except Exception as e: logger.error(f"Error in cleanup loop: {e}") logger.info("Checkpoint cleanup scheduler stopped") def start_cleanup_scheduler(self): """启动后台定时清理任务""" if not CHECKPOINT_CLEANUP_ENABLED: logger.info("Checkpoint cleanup is disabled") return if self._cleanup_task is not None and not self._cleanup_task.done(): logger.warning("Cleanup scheduler is already running") return # 创建新的事件循环(如果需要在后台运行) try: loop = asyncio.get_running_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) self._cleanup_task = loop.create_task(self._cleanup_loop()) logger.info("Cleanup scheduler task created") # 全局单例 _global_manager: Optional[CheckpointerManager] = None def get_checkpointer_manager() -> CheckpointerManager: """获取全局 CheckpointerManager 单例""" global _global_manager if _global_manager is None: _global_manager = CheckpointerManager() return _global_manager async def init_global_checkpointer() -> None: """初始化全局 checkpointer 管理器""" manager = get_checkpointer_manager() await manager.initialize() async def close_global_checkpointer() -> None: """关闭全局 checkpointer 管理器""" global _global_manager if _global_manager is not None: await _global_manager.close()