187 lines
5.7 KiB
Python
187 lines
5.7 KiB
Python
"""
|
||
全局 SQLite Checkpointer 管理器
|
||
解决高并发场景下的数据库锁定问题
|
||
"""
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
from typing import Optional
|
||
|
||
import aiosqlite
|
||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||
|
||
from utils.settings import (
|
||
CHECKPOINT_DB_PATH,
|
||
CHECKPOINT_WAL_MODE,
|
||
CHECKPOINT_BUSY_TIMEOUT,
|
||
CHECKPOINT_POOL_SIZE,
|
||
)
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
|
||
class CheckpointerManager:
|
||
"""
|
||
全局 Checkpointer 管理器,使用连接池复用 SQLite 连接
|
||
|
||
主要功能:
|
||
1. 全局单例连接管理,避免每次请求创建新连接
|
||
2. 预配置 WAL 模式和 busy_timeout
|
||
3. 连接池支持高并发访问
|
||
4. 优雅关闭机制
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._pool: asyncio.Queue[AsyncSqliteSaver] = asyncio.Queue()
|
||
self._lock = asyncio.Lock()
|
||
self._initialized = False
|
||
self._closed = False
|
||
self._pool_size = CHECKPOINT_POOL_SIZE
|
||
self._db_path = CHECKPOINT_DB_PATH
|
||
|
||
async def initialize(self) -> None:
|
||
"""初始化连接池"""
|
||
if self._initialized:
|
||
return
|
||
|
||
async with self._lock:
|
||
if self._initialized:
|
||
return
|
||
|
||
logger.info(f"Initializing CheckpointerManager with pool_size={self._pool_size}")
|
||
|
||
# 确保目录存在
|
||
os.makedirs(os.path.dirname(self._db_path), exist_ok=True)
|
||
|
||
# 创建连接池
|
||
for i in range(self._pool_size):
|
||
try:
|
||
conn = await self._create_configured_connection()
|
||
checkpointer = AsyncSqliteSaver(conn=conn)
|
||
# 预先调用 setup 确保表结构已创建
|
||
await checkpointer.setup()
|
||
await self._pool.put(checkpointer)
|
||
logger.debug(f"Created checkpointer connection {i+1}/{self._pool_size}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to create checkpointer connection {i+1}: {e}")
|
||
raise
|
||
|
||
self._initialized = True
|
||
logger.info("CheckpointerManager initialized successfully")
|
||
|
||
async def _create_configured_connection(self) -> aiosqlite.Connection:
|
||
"""
|
||
创建已配置的 SQLite 连接
|
||
|
||
配置包括:
|
||
1. WAL 模式 (Write-Ahead Logging) - 允许读写并发
|
||
2. busy_timeout - 等待锁定的最长时间
|
||
3. 其他优化参数
|
||
"""
|
||
conn = aiosqlite.connect(self._db_path)
|
||
|
||
# 等待连接建立
|
||
await conn.__aenter__()
|
||
|
||
# 设置 busy timeout(必须在连接建立后设置)
|
||
await conn.execute(f"PRAGMA busy_timeout = {CHECKPOINT_BUSY_TIMEOUT}")
|
||
|
||
# 如果启用 WAL 模式
|
||
if CHECKPOINT_WAL_MODE:
|
||
await conn.execute("PRAGMA journal_mode = WAL")
|
||
await conn.execute("PRAGMA synchronous = NORMAL")
|
||
# WAL 模式下的优化配置
|
||
await conn.execute("PRAGMA wal_autocheckpoint = 1000")
|
||
await conn.execute("PRAGMA cache_size = -64000") # 64MB 缓存
|
||
await conn.execute("PRAGMA temp_store = MEMORY")
|
||
|
||
await conn.commit()
|
||
|
||
return conn
|
||
|
||
async def acquire_for_agent(self) -> AsyncSqliteSaver:
|
||
"""
|
||
为 agent 获取 checkpointer
|
||
|
||
注意:此方法获取的 checkpointer 需要手动归还
|
||
使用 return_to_pool() 方法归还
|
||
|
||
Returns:
|
||
AsyncSqliteSaver 实例
|
||
"""
|
||
if not self._initialized:
|
||
raise RuntimeError("CheckpointerManager not initialized. Call initialize() first.")
|
||
|
||
checkpointer = await self._pool.get()
|
||
logger.debug(f"Acquired checkpointer from pool, remaining: {self._pool.qsize()}")
|
||
return checkpointer
|
||
|
||
async def return_to_pool(self, checkpointer: AsyncSqliteSaver) -> None:
|
||
"""
|
||
归还 checkpointer 到池
|
||
|
||
Args:
|
||
checkpointer: 要归还的 checkpointer 实例
|
||
"""
|
||
await self._pool.put(checkpointer)
|
||
logger.debug(f"Returned checkpointer to pool, remaining: {self._pool.qsize()}")
|
||
|
||
async def close(self) -> None:
|
||
"""关闭所有连接"""
|
||
if self._closed:
|
||
return
|
||
|
||
async with self._lock:
|
||
if self._closed:
|
||
return
|
||
|
||
logger.info("Closing CheckpointerManager...")
|
||
|
||
# 清空池并关闭所有连接
|
||
while not self._pool.empty():
|
||
try:
|
||
checkpointer = self._pool.get_nowait()
|
||
if checkpointer.conn:
|
||
await checkpointer.conn.close()
|
||
except asyncio.QueueEmpty:
|
||
break
|
||
|
||
self._closed = True
|
||
self._initialized = False
|
||
logger.info("CheckpointerManager closed")
|
||
|
||
def get_pool_stats(self) -> dict:
|
||
"""获取连接池状态统计"""
|
||
return {
|
||
"db_path": self._db_path,
|
||
"pool_size": self._pool_size,
|
||
"available_connections": self._pool.qsize(),
|
||
"initialized": self._initialized,
|
||
"closed": self._closed
|
||
}
|
||
|
||
|
||
# 全局单例
|
||
_global_manager: Optional[CheckpointerManager] = None
|
||
|
||
|
||
def get_checkpointer_manager() -> CheckpointerManager:
|
||
"""获取全局 CheckpointerManager 单例"""
|
||
global _global_manager
|
||
if _global_manager is None:
|
||
_global_manager = CheckpointerManager()
|
||
return _global_manager
|
||
|
||
|
||
async def init_global_checkpointer() -> None:
|
||
"""初始化全局 checkpointer 管理器"""
|
||
manager = get_checkpointer_manager()
|
||
await manager.initialize()
|
||
|
||
|
||
async def close_global_checkpointer() -> None:
|
||
"""关闭全局 checkpointer 管理器"""
|
||
global _global_manager
|
||
if _global_manager is not None:
|
||
await _global_manager.close()
|