379 lines
12 KiB
Python
379 lines
12 KiB
Python
"""
|
||
全局 SQLite Checkpointer 管理器
|
||
解决高并发场景下的数据库锁定问题
|
||
"""
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
from datetime import datetime, timedelta, timezone
|
||
from typing import Optional, List, Dict, Any
|
||
|
||
import aiosqlite
|
||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||
|
||
from utils.settings import (
|
||
CHECKPOINT_DB_PATH,
|
||
CHECKPOINT_WAL_MODE,
|
||
CHECKPOINT_BUSY_TIMEOUT,
|
||
CHECKPOINT_POOL_SIZE,
|
||
CHECKPOINT_CLEANUP_ENABLED,
|
||
CHECKPOINT_CLEANUP_INTERVAL_HOURS,
|
||
CHECKPOINT_CLEANUP_OLDER_THAN_DAYS,
|
||
)
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
|
||
class CheckpointerManager:
|
||
"""
|
||
全局 Checkpointer 管理器,使用连接池复用 SQLite 连接
|
||
|
||
主要功能:
|
||
1. 全局单例连接管理,避免每次请求创建新连接
|
||
2. 预配置 WAL 模式和 busy_timeout
|
||
3. 连接池支持高并发访问
|
||
4. 优雅关闭机制
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._pool: asyncio.Queue[AsyncSqliteSaver] = asyncio.Queue()
|
||
self._lock = asyncio.Lock()
|
||
self._initialized = False
|
||
self._closed = False
|
||
self._pool_size = CHECKPOINT_POOL_SIZE
|
||
self._db_path = CHECKPOINT_DB_PATH
|
||
# 清理调度任务
|
||
self._cleanup_task: Optional[asyncio.Task] = None
|
||
self._cleanup_stop_event = asyncio.Event()
|
||
|
||
async def initialize(self) -> None:
|
||
"""初始化连接池"""
|
||
if self._initialized:
|
||
return
|
||
|
||
async with self._lock:
|
||
if self._initialized:
|
||
return
|
||
|
||
logger.info(f"Initializing CheckpointerManager with pool_size={self._pool_size}")
|
||
|
||
# 确保目录存在
|
||
os.makedirs(os.path.dirname(self._db_path), exist_ok=True)
|
||
|
||
# 创建连接池
|
||
for i in range(self._pool_size):
|
||
try:
|
||
conn = await self._create_configured_connection()
|
||
checkpointer = AsyncSqliteSaver(conn=conn)
|
||
# 预先调用 setup 确保表结构已创建
|
||
await checkpointer.setup()
|
||
await self._pool.put(checkpointer)
|
||
logger.debug(f"Created checkpointer connection {i+1}/{self._pool_size}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to create checkpointer connection {i+1}: {e}")
|
||
raise
|
||
|
||
self._initialized = True
|
||
logger.info("CheckpointerManager initialized successfully")
|
||
|
||
async def _create_configured_connection(self) -> aiosqlite.Connection:
|
||
"""
|
||
创建已配置的 SQLite 连接
|
||
|
||
配置包括:
|
||
1. WAL 模式 (Write-Ahead Logging) - 允许读写并发
|
||
2. busy_timeout - 等待锁定的最长时间
|
||
3. 其他优化参数
|
||
"""
|
||
conn = aiosqlite.connect(self._db_path)
|
||
|
||
# 等待连接建立
|
||
await conn.__aenter__()
|
||
|
||
# 设置 busy timeout(必须在连接建立后设置)
|
||
await conn.execute(f"PRAGMA busy_timeout = {CHECKPOINT_BUSY_TIMEOUT}")
|
||
|
||
# 如果启用 WAL 模式
|
||
if CHECKPOINT_WAL_MODE:
|
||
await conn.execute("PRAGMA journal_mode = WAL")
|
||
await conn.execute("PRAGMA synchronous = NORMAL")
|
||
# WAL 模式下的优化配置
|
||
await conn.execute("PRAGMA wal_autocheckpoint = 1000")
|
||
await conn.execute("PRAGMA cache_size = -64000") # 64MB 缓存
|
||
await conn.execute("PRAGMA temp_store = MEMORY")
|
||
|
||
await conn.commit()
|
||
|
||
return conn
|
||
|
||
async def acquire_for_agent(self) -> AsyncSqliteSaver:
|
||
"""
|
||
为 agent 获取 checkpointer
|
||
|
||
注意:此方法获取的 checkpointer 需要手动归还
|
||
使用 return_to_pool() 方法归还
|
||
|
||
Returns:
|
||
AsyncSqliteSaver 实例
|
||
"""
|
||
if not self._initialized:
|
||
raise RuntimeError("CheckpointerManager not initialized. Call initialize() first.")
|
||
|
||
checkpointer = await self._pool.get()
|
||
logger.debug(f"Acquired checkpointer from pool, remaining: {self._pool.qsize()}")
|
||
return checkpointer
|
||
|
||
async def return_to_pool(self, checkpointer: AsyncSqliteSaver) -> None:
|
||
"""
|
||
归还 checkpointer 到池
|
||
|
||
Args:
|
||
checkpointer: 要归还的 checkpointer 实例
|
||
"""
|
||
await self._pool.put(checkpointer)
|
||
logger.debug(f"Returned checkpointer to pool, remaining: {self._pool.qsize()}")
|
||
|
||
async def close(self) -> None:
|
||
"""关闭所有连接"""
|
||
if self._closed:
|
||
return
|
||
|
||
# 停止清理任务
|
||
if self._cleanup_task is not None:
|
||
self._cleanup_stop_event.set()
|
||
try:
|
||
self._cleanup_task.cancel()
|
||
await asyncio.sleep(0.1) # 给任务一点时间清理
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._cleanup_task = None
|
||
|
||
async with self._lock:
|
||
if self._closed:
|
||
return
|
||
|
||
logger.info("Closing CheckpointerManager...")
|
||
|
||
# 清空池并关闭所有连接
|
||
while not self._pool.empty():
|
||
try:
|
||
checkpointer = self._pool.get_nowait()
|
||
if checkpointer.conn:
|
||
await checkpointer.conn.close()
|
||
except asyncio.QueueEmpty:
|
||
break
|
||
|
||
self._closed = True
|
||
self._initialized = False
|
||
logger.info("CheckpointerManager closed")
|
||
|
||
def get_pool_stats(self) -> dict:
|
||
"""获取连接池状态统计"""
|
||
return {
|
||
"db_path": self._db_path,
|
||
"pool_size": self._pool_size,
|
||
"available_connections": self._pool.qsize(),
|
||
"initialized": self._initialized,
|
||
"closed": self._closed
|
||
}
|
||
|
||
# ============================================================
|
||
# Checkpoint 清理方法
|
||
# ============================================================
|
||
|
||
async def get_all_thread_ids(self) -> List[str]:
|
||
"""
|
||
获取数据库中所有唯一的 thread_id
|
||
|
||
Returns:
|
||
List[str]: 所有 thread_id 列表
|
||
"""
|
||
if not self._initialized:
|
||
return []
|
||
|
||
conn = aiosqlite.connect(self._db_path)
|
||
await conn.__aenter__()
|
||
|
||
try:
|
||
cursor = await conn.execute(
|
||
"SELECT DISTINCT thread_id FROM checkpoints"
|
||
)
|
||
rows = await cursor.fetchall()
|
||
return [row[0] for row in rows]
|
||
finally:
|
||
await conn.close()
|
||
|
||
async def get_thread_last_activity(self, thread_id: str) -> Optional[datetime]:
|
||
"""
|
||
获取指定 thread 的最后活动时间
|
||
|
||
通过查询该 thread 最新的 checkpoint 中的 ts 字段获取时间
|
||
|
||
Args:
|
||
thread_id: 线程ID
|
||
|
||
Returns:
|
||
datetime: 最后活动时间,如果找不到则返回 None
|
||
"""
|
||
if not self._initialized:
|
||
return None
|
||
|
||
checkpointer = await self.acquire_for_agent()
|
||
|
||
try:
|
||
config = {"configurable": {"thread_id": thread_id}}
|
||
result = checkpointer.alist(config=config, limit=1)
|
||
|
||
last_checkpoint = None
|
||
async for item in result:
|
||
last_checkpoint = item
|
||
break
|
||
|
||
if last_checkpoint and last_checkpoint.checkpoint:
|
||
ts_str = last_checkpoint.checkpoint.get("ts")
|
||
if ts_str:
|
||
# 解析 ISO 格式时间戳
|
||
return datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
|
||
except Exception as e:
|
||
logger.warning(f"Error getting last activity for thread {thread_id}: {e}")
|
||
finally:
|
||
await self.return_to_pool(checkpointer)
|
||
|
||
return None
|
||
|
||
async def cleanup_old_threads(self, older_than_days: int = None) -> Dict[str, Any]:
|
||
"""
|
||
清理超过指定天数未活动的 thread
|
||
|
||
Args:
|
||
older_than_days: 清理多少天前的数据,默认使用配置值
|
||
|
||
Returns:
|
||
Dict: 清理统计信息
|
||
- threads_deleted: 删除的 thread 数量
|
||
- threads_scanned: 扫描的 thread 总数
|
||
- cutoff_time: 截止时间
|
||
"""
|
||
if older_than_days is None:
|
||
older_than_days = CHECKPOINT_CLEANUP_OLDER_THAN_DAYS
|
||
|
||
# 使用带时区的时间,避免比较时出错
|
||
cutoff_time = datetime.now(timezone.utc) - timedelta(days=older_than_days)
|
||
logger.info(f"Starting checkpoint cleanup: removing threads inactive since {cutoff_time.isoformat()}")
|
||
|
||
all_thread_ids = await self.get_all_thread_ids()
|
||
threads_deleted = 0
|
||
threads_scanned = len(all_thread_ids)
|
||
|
||
checkpointer = await self.acquire_for_agent()
|
||
|
||
try:
|
||
for thread_id in all_thread_ids:
|
||
try:
|
||
last_activity = await self.get_thread_last_activity(thread_id)
|
||
|
||
if last_activity and last_activity < cutoff_time:
|
||
# 删除旧 thread
|
||
config = {"configurable": {"thread_id": thread_id}}
|
||
await checkpointer.adelete_thread(config)
|
||
threads_deleted += 1
|
||
logger.debug(f"Deleted old thread: {thread_id} (last activity: {last_activity.isoformat()})")
|
||
except Exception as e:
|
||
logger.warning(f"Error processing thread {thread_id}: {e}")
|
||
|
||
finally:
|
||
await self.return_to_pool(checkpointer)
|
||
|
||
result = {
|
||
"threads_deleted": threads_deleted,
|
||
"threads_scanned": threads_scanned,
|
||
"cutoff_time": cutoff_time.isoformat(),
|
||
"older_than_days": older_than_days
|
||
}
|
||
|
||
logger.info(f"Checkpoint cleanup completed: {result}")
|
||
return result
|
||
|
||
async def _cleanup_loop(self):
|
||
"""后台清理循环"""
|
||
interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600
|
||
|
||
logger.info(
|
||
f"Checkpoint cleanup scheduler started: "
|
||
f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, "
|
||
f"older_than={CHECKPOINT_CLEANUP_OLDER_THAN_DAYS} days"
|
||
)
|
||
|
||
while not self._cleanup_stop_event.is_set():
|
||
try:
|
||
# 等待间隔时间或停止信号
|
||
try:
|
||
await asyncio.wait_for(
|
||
self._cleanup_stop_event.wait(),
|
||
timeout=interval_seconds
|
||
)
|
||
# 收到停止信号,退出循环
|
||
break
|
||
except asyncio.TimeoutError:
|
||
# 超时,执行清理任务
|
||
pass
|
||
|
||
if self._cleanup_stop_event.is_set():
|
||
break
|
||
|
||
# 执行清理
|
||
await self.cleanup_old_threads()
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info("Cleanup task cancelled")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"Error in cleanup loop: {e}")
|
||
|
||
logger.info("Checkpoint cleanup scheduler stopped")
|
||
|
||
def start_cleanup_scheduler(self):
|
||
"""启动后台定时清理任务"""
|
||
if not CHECKPOINT_CLEANUP_ENABLED:
|
||
logger.info("Checkpoint cleanup is disabled")
|
||
return
|
||
|
||
if self._cleanup_task is not None and not self._cleanup_task.done():
|
||
logger.warning("Cleanup scheduler is already running")
|
||
return
|
||
|
||
# 创建新的事件循环(如果需要在后台运行)
|
||
try:
|
||
loop = asyncio.get_running_loop()
|
||
except RuntimeError:
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
|
||
self._cleanup_task = loop.create_task(self._cleanup_loop())
|
||
logger.info("Cleanup scheduler task created")
|
||
|
||
|
||
# 全局单例
|
||
_global_manager: Optional[CheckpointerManager] = None
|
||
|
||
|
||
def get_checkpointer_manager() -> CheckpointerManager:
|
||
"""获取全局 CheckpointerManager 单例"""
|
||
global _global_manager
|
||
if _global_manager is None:
|
||
_global_manager = CheckpointerManager()
|
||
return _global_manager
|
||
|
||
|
||
async def init_global_checkpointer() -> None:
|
||
"""初始化全局 checkpointer 管理器"""
|
||
manager = get_checkpointer_manager()
|
||
await manager.initialize()
|
||
|
||
|
||
async def close_global_checkpointer() -> None:
|
||
"""关闭全局 checkpointer 管理器"""
|
||
global _global_manager
|
||
if _global_manager is not None:
|
||
await _global_manager.close()
|