401 lines
14 KiB
Python
401 lines
14 KiB
Python
"""
|
||
全局 SQLite Checkpointer 管理器
|
||
解决高并发场景下的数据库锁定问题
|
||
|
||
每个 session 使用独立的数据库文件,避免并发锁竞争
|
||
"""
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
import time
|
||
from typing import Optional, Dict, Any, Tuple
|
||
|
||
import aiosqlite
|
||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||
|
||
from utils.settings import (
|
||
CHECKPOINT_DB_PATH,
|
||
CHECKPOINT_WAL_MODE,
|
||
CHECKPOINT_BUSY_TIMEOUT,
|
||
CHECKPOINT_CLEANUP_ENABLED,
|
||
CHECKPOINT_CLEANUP_INTERVAL_HOURS,
|
||
CHECKPOINT_CLEANUP_OLDER_THAN_DAYS,
|
||
)
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
# 每个 session 的连接池大小(单个 session 串行处理,1 个连接即可)
|
||
POOL_SIZE_PER_SESSION = 1
|
||
|
||
|
||
class CheckpointerManager:
|
||
"""
|
||
全局 Checkpointer 管理器,按 session_id 分离数据库文件
|
||
|
||
主要功能:
|
||
1. 每个 session_id 独立的数据库文件和连接池
|
||
2. 按需创建连接池,不用的 session 不占用资源
|
||
3. 预配置 WAL 模式和 busy_timeout
|
||
4. 基于文件修改时间的简单清理机制
|
||
5. 优雅关闭机制
|
||
"""
|
||
|
||
def __init__(self):
|
||
# 每个 (bot_id, session_id) 一个连接池
|
||
self._pools: Dict[Tuple[str, str], asyncio.Queue[AsyncSqliteSaver]] = {}
|
||
# 每个 session 的初始化锁
|
||
self._locks: Dict[Tuple[str, str], asyncio.Lock] = {}
|
||
# 全局锁,用于保护 pools 和 locks 字典的访问
|
||
self._global_lock = asyncio.Lock()
|
||
self._closed = False
|
||
# 清理调度任务
|
||
self._cleanup_task: Optional[asyncio.Task] = None
|
||
self._cleanup_stop_event = asyncio.Event()
|
||
|
||
def _get_db_path(self, bot_id: str, session_id: str) -> str:
|
||
"""获取指定 session 的数据库文件路径"""
|
||
return os.path.join(CHECKPOINT_DB_PATH, bot_id, session_id, "checkpoints.db")
|
||
|
||
def _get_pool_key(self, bot_id: str, session_id: str) -> Tuple[str, str]:
|
||
"""获取连接池的键"""
|
||
return (bot_id, session_id)
|
||
|
||
async def _initialize_session_pool(self, bot_id: str, session_id: str) -> None:
|
||
"""初始化指定 session 的连接池"""
|
||
pool_key = self._get_pool_key(bot_id, session_id)
|
||
if pool_key in self._pools:
|
||
return
|
||
|
||
logger.info(f"Initializing checkpointer pool for bot_id={bot_id}, session_id={session_id}")
|
||
|
||
db_path = self._get_db_path(bot_id, session_id)
|
||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||
|
||
pool = asyncio.Queue()
|
||
for i in range(POOL_SIZE_PER_SESSION):
|
||
try:
|
||
conn = await self._create_configured_connection(db_path)
|
||
checkpointer = AsyncSqliteSaver(conn=conn)
|
||
# 预先调用 setup 确保表结构已创建
|
||
await checkpointer.setup()
|
||
await pool.put(checkpointer)
|
||
logger.debug(f"Created checkpointer connection {i+1}/{POOL_SIZE_PER_SESSION} for session={session_id}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to create checkpointer connection {i+1} for session={session_id}: {e}")
|
||
raise
|
||
|
||
self._pools[pool_key] = pool
|
||
self._locks[pool_key] = asyncio.Lock()
|
||
logger.info(f"Checkpointer pool initialized for bot_id={bot_id}, session_id={session_id}")
|
||
|
||
async def _create_configured_connection(self, db_path: str) -> aiosqlite.Connection:
|
||
"""
|
||
创建已配置的 SQLite 连接
|
||
|
||
配置包括:
|
||
1. WAL 模式 (Write-Ahead Logging) - 允许读写并发
|
||
2. busy_timeout - 等待锁定的最长时间
|
||
3. 其他优化参数
|
||
"""
|
||
conn = aiosqlite.connect(db_path)
|
||
|
||
# 等待连接建立
|
||
await conn.__aenter__()
|
||
|
||
# 设置 busy timeout(必须在连接建立后设置)
|
||
await conn.execute(f"PRAGMA busy_timeout = {CHECKPOINT_BUSY_TIMEOUT}")
|
||
|
||
# 如果启用 WAL 模式
|
||
if CHECKPOINT_WAL_MODE:
|
||
await conn.execute("PRAGMA journal_mode = WAL")
|
||
await conn.execute("PRAGMA synchronous = NORMAL")
|
||
# WAL 模式下的优化配置
|
||
await conn.execute("PRAGMA wal_autocheckpoint = 10000") # 增加到 10000
|
||
await conn.execute("PRAGMA cache_size = -64000") # 64MB 缓存
|
||
await conn.execute("PRAGMA temp_store = MEMORY")
|
||
await conn.execute("PRAGMA journal_size_limit = 1048576") # 1MB
|
||
|
||
await conn.commit()
|
||
|
||
return conn
|
||
|
||
async def initialize(self) -> None:
|
||
"""初始化管理器(不再预创建连接池,改为按需创建)"""
|
||
logger.info("CheckpointerManager initialized (pools will be created on-demand)")
|
||
|
||
async def acquire_for_agent(self, bot_id: str, session_id: str) -> AsyncSqliteSaver:
|
||
"""
|
||
获取指定 session 的 checkpointer
|
||
|
||
注意:此方法获取的 checkpointer 需要手动归还
|
||
使用 return_to_pool() 方法归还
|
||
|
||
Args:
|
||
bot_id: 机器人 ID
|
||
session_id: 会话 ID
|
||
|
||
Returns:
|
||
AsyncSqliteSaver 实例
|
||
"""
|
||
if self._closed:
|
||
raise RuntimeError("CheckpointerManager is closed")
|
||
|
||
pool_key = self._get_pool_key(bot_id, session_id)
|
||
async with self._global_lock:
|
||
if pool_key not in self._pools:
|
||
await self._initialize_session_pool(bot_id, session_id)
|
||
|
||
# 获取该 session 的锁,确保连接池操作线程安全
|
||
async with self._locks[pool_key]:
|
||
checkpointer = await self._pools[pool_key].get()
|
||
logger.debug(f"Acquired checkpointer for bot_id={bot_id}, session_id={session_id}, remaining: {self._pools[pool_key].qsize()}")
|
||
return checkpointer
|
||
|
||
async def return_to_pool(self, bot_id: str, session_id: str, checkpointer: AsyncSqliteSaver) -> None:
|
||
"""
|
||
归还 checkpointer 到对应 session 的池
|
||
|
||
Args:
|
||
bot_id: 机器人 ID
|
||
session_id: 会话 ID
|
||
checkpointer: 要归还的 checkpointer 实例
|
||
"""
|
||
pool_key = self._get_pool_key(bot_id, session_id)
|
||
if pool_key in self._pools:
|
||
async with self._locks[pool_key]:
|
||
await self._pools[pool_key].put(checkpointer)
|
||
logger.debug(f"Returned checkpointer for bot_id={bot_id}, session_id={session_id}, remaining: {self._pools[pool_key].qsize()}")
|
||
|
||
async def _close_session_pool(self, bot_id: str, session_id: str) -> None:
|
||
"""关闭指定 session 的连接池"""
|
||
pool_key = self._get_pool_key(bot_id, session_id)
|
||
if pool_key not in self._pools:
|
||
return
|
||
|
||
logger.info(f"Closing checkpointer pool for bot_id={bot_id}, session_id={session_id}")
|
||
|
||
pool = self._pools[pool_key]
|
||
while not pool.empty():
|
||
try:
|
||
checkpointer = pool.get_nowait()
|
||
if checkpointer.conn:
|
||
await checkpointer.conn.close()
|
||
except asyncio.QueueEmpty:
|
||
break
|
||
|
||
del self._pools[pool_key]
|
||
if pool_key in self._locks:
|
||
del self._locks[pool_key]
|
||
|
||
logger.info(f"Checkpointer pool closed for bot_id={bot_id}, session_id={session_id}")
|
||
|
||
async def close(self) -> None:
|
||
"""关闭所有连接池"""
|
||
if self._closed:
|
||
return
|
||
|
||
# 停止清理任务
|
||
if self._cleanup_task is not None:
|
||
self._cleanup_stop_event.set()
|
||
try:
|
||
self._cleanup_task.cancel()
|
||
await asyncio.sleep(0.1) # 给任务一点时间清理
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._cleanup_task = None
|
||
|
||
async with self._global_lock:
|
||
if self._closed:
|
||
return
|
||
|
||
logger.info("Closing CheckpointerManager...")
|
||
|
||
# 关闭所有 session 的连接池
|
||
pool_keys = list(self._pools.keys())
|
||
for bot_id, session_id in pool_keys:
|
||
await self._close_session_pool(bot_id, session_id)
|
||
|
||
self._closed = True
|
||
logger.info("CheckpointerManager closed")
|
||
|
||
def get_pool_stats(self) -> dict:
|
||
"""获取连接池状态统计"""
|
||
return {
|
||
"session_count": len(self._pools),
|
||
"pools": {
|
||
f"{bot_id}/{session_id}": {
|
||
"available": pool.qsize(),
|
||
"pool_size": POOL_SIZE_PER_SESSION
|
||
}
|
||
for (bot_id, session_id), pool in self._pools.items()
|
||
},
|
||
"closed": self._closed
|
||
}
|
||
|
||
# ============================================================
|
||
# Checkpoint 清理方法(基于文件修改时间)
|
||
# ============================================================
|
||
|
||
async def cleanup_old_dbs(self, older_than_days: int = None) -> Dict[str, Any]:
|
||
"""
|
||
根据数据库文件的修改时间清理旧数据库文件
|
||
|
||
Args:
|
||
older_than_days: 清理多少天前的数据,默认使用配置值
|
||
|
||
Returns:
|
||
Dict: 清理统计信息
|
||
- deleted: 删除的 session 目录数量
|
||
- scanned: 扫描的 session 目录数量
|
||
- cutoff_time: 截止时间戳
|
||
"""
|
||
if older_than_days is None:
|
||
older_than_days = CHECKPOINT_CLEANUP_OLDER_THAN_DAYS
|
||
|
||
cutoff_time = time.time() - older_than_days * 86400
|
||
logger.info(f"Starting checkpoint cleanup: removing db files not modified since {cutoff_time}")
|
||
|
||
db_dir = CHECKPOINT_DB_PATH
|
||
deleted_count = 0
|
||
scanned_count = 0
|
||
|
||
if not os.path.exists(db_dir):
|
||
logger.info(f"Checkpoint directory does not exist: {db_dir}")
|
||
return {"deleted": 0, "scanned": 0, "cutoff_time": cutoff_time}
|
||
|
||
# 遍历 bot_id 目录
|
||
for bot_id in os.listdir(db_dir):
|
||
bot_path = os.path.join(db_dir, bot_id)
|
||
# 跳过非目录文件
|
||
if not os.path.isdir(bot_path):
|
||
continue
|
||
|
||
# 遍历 session_id 目录
|
||
for session_id in os.listdir(bot_path):
|
||
session_path = os.path.join(bot_path, session_id)
|
||
if not os.path.isdir(session_path):
|
||
continue
|
||
|
||
db_file = os.path.join(session_path, "checkpoints.db")
|
||
if not os.path.exists(db_file):
|
||
continue
|
||
|
||
scanned_count += 1
|
||
mtime = os.path.getmtime(db_file)
|
||
|
||
if mtime < cutoff_time:
|
||
# 关闭该 session 的连接池(如果有)
|
||
await self._close_session_pool(bot_id, session_id)
|
||
|
||
# 删除整个 session 目录
|
||
try:
|
||
import shutil
|
||
shutil.rmtree(session_path)
|
||
deleted_count += 1
|
||
logger.info(f"Deleted old checkpoint session: {bot_id}/{session_id}/ (last modified: {mtime})")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to delete {session_path}: {e}")
|
||
|
||
# 清理空的 bot_id 目录
|
||
for bot_id in os.listdir(db_dir):
|
||
bot_path = os.path.join(db_dir, bot_id)
|
||
if os.path.isdir(bot_path) and not os.listdir(bot_path):
|
||
try:
|
||
os.rmdir(bot_path)
|
||
logger.debug(f"Removed empty bot directory: {bot_id}/")
|
||
except Exception:
|
||
pass
|
||
|
||
result = {
|
||
"deleted": deleted_count,
|
||
"scanned": scanned_count,
|
||
"cutoff_time": cutoff_time,
|
||
"older_than_days": older_than_days
|
||
}
|
||
|
||
logger.info(f"Checkpoint cleanup completed: {result}")
|
||
return result
|
||
|
||
async def _cleanup_loop(self):
|
||
"""后台清理循环"""
|
||
interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600
|
||
|
||
logger.info(
|
||
f"Checkpoint cleanup scheduler started: "
|
||
f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, "
|
||
f"older_than={CHECKPOINT_CLEANUP_OLDER_THAN_DAYS} days"
|
||
)
|
||
|
||
while not self._cleanup_stop_event.is_set():
|
||
try:
|
||
# 等待间隔时间或停止信号
|
||
try:
|
||
await asyncio.wait_for(
|
||
self._cleanup_stop_event.wait(),
|
||
timeout=interval_seconds
|
||
)
|
||
# 收到停止信号,退出循环
|
||
break
|
||
except asyncio.TimeoutError:
|
||
# 超时,执行清理任务
|
||
pass
|
||
|
||
if self._cleanup_stop_event.is_set():
|
||
break
|
||
|
||
# 执行清理
|
||
await self.cleanup_old_dbs()
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info("Cleanup task cancelled")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"Error in cleanup loop: {e}")
|
||
|
||
logger.info("Checkpoint cleanup scheduler stopped")
|
||
|
||
def start_cleanup_scheduler(self):
|
||
"""启动后台定时清理任务"""
|
||
if not CHECKPOINT_CLEANUP_ENABLED:
|
||
logger.info("Checkpoint cleanup is disabled")
|
||
return
|
||
|
||
if self._cleanup_task is not None and not self._cleanup_task.done():
|
||
logger.warning("Cleanup scheduler is already running")
|
||
return
|
||
|
||
# 创建新的事件循环(如果需要在后台运行)
|
||
try:
|
||
loop = asyncio.get_running_loop()
|
||
except RuntimeError:
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
|
||
self._cleanup_task = loop.create_task(self._cleanup_loop())
|
||
logger.info("Cleanup scheduler task created")
|
||
|
||
|
||
# 全局单例
|
||
_global_manager: Optional[CheckpointerManager] = None
|
||
|
||
|
||
def get_checkpointer_manager() -> CheckpointerManager:
|
||
"""获取全局 CheckpointerManager 单例"""
|
||
global _global_manager
|
||
if _global_manager is None:
|
||
_global_manager = CheckpointerManager()
|
||
return _global_manager
|
||
|
||
|
||
async def init_global_checkpointer() -> None:
|
||
"""初始化全局 checkpointer 管理器"""
|
||
manager = get_checkpointer_manager()
|
||
await manager.initialize()
|
||
|
||
|
||
async def close_global_checkpointer() -> None:
|
||
"""关闭全局 checkpointer 管理器"""
|
||
global _global_manager
|
||
if _global_manager is not None:
|
||
await _global_manager.close()
|