qwen_agent/agent/checkpoint_manager.py

187 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
全局 SQLite Checkpointer 管理器
解决高并发场景下的数据库锁定问题
"""
import asyncio
import logging
import os
from typing import Optional
import aiosqlite
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from utils.settings import (
CHECKPOINT_DB_PATH,
CHECKPOINT_WAL_MODE,
CHECKPOINT_BUSY_TIMEOUT,
CHECKPOINT_POOL_SIZE,
)
logger = logging.getLogger('app')
class CheckpointerManager:
"""
全局 Checkpointer 管理器,使用连接池复用 SQLite 连接
主要功能:
1. 全局单例连接管理,避免每次请求创建新连接
2. 预配置 WAL 模式和 busy_timeout
3. 连接池支持高并发访问
4. 优雅关闭机制
"""
def __init__(self):
self._pool: asyncio.Queue[AsyncSqliteSaver] = asyncio.Queue()
self._lock = asyncio.Lock()
self._initialized = False
self._closed = False
self._pool_size = CHECKPOINT_POOL_SIZE
self._db_path = CHECKPOINT_DB_PATH
async def initialize(self) -> None:
"""初始化连接池"""
if self._initialized:
return
async with self._lock:
if self._initialized:
return
logger.info(f"Initializing CheckpointerManager with pool_size={self._pool_size}")
# 确保目录存在
os.makedirs(os.path.dirname(self._db_path), exist_ok=True)
# 创建连接池
for i in range(self._pool_size):
try:
conn = await self._create_configured_connection()
checkpointer = AsyncSqliteSaver(conn=conn)
# 预先调用 setup 确保表结构已创建
await checkpointer.setup()
await self._pool.put(checkpointer)
logger.debug(f"Created checkpointer connection {i+1}/{self._pool_size}")
except Exception as e:
logger.error(f"Failed to create checkpointer connection {i+1}: {e}")
raise
self._initialized = True
logger.info("CheckpointerManager initialized successfully")
async def _create_configured_connection(self) -> aiosqlite.Connection:
"""
创建已配置的 SQLite 连接
配置包括:
1. WAL 模式 (Write-Ahead Logging) - 允许读写并发
2. busy_timeout - 等待锁定的最长时间
3. 其他优化参数
"""
conn = aiosqlite.connect(self._db_path)
# 等待连接建立
await conn.__aenter__()
# 设置 busy timeout必须在连接建立后设置
await conn.execute(f"PRAGMA busy_timeout = {CHECKPOINT_BUSY_TIMEOUT}")
# 如果启用 WAL 模式
if CHECKPOINT_WAL_MODE:
await conn.execute("PRAGMA journal_mode = WAL")
await conn.execute("PRAGMA synchronous = NORMAL")
# WAL 模式下的优化配置
await conn.execute("PRAGMA wal_autocheckpoint = 1000")
await conn.execute("PRAGMA cache_size = -64000") # 64MB 缓存
await conn.execute("PRAGMA temp_store = MEMORY")
await conn.commit()
return conn
async def acquire_for_agent(self) -> AsyncSqliteSaver:
"""
为 agent 获取 checkpointer
注意:此方法获取的 checkpointer 需要手动归还
使用 return_to_pool() 方法归还
Returns:
AsyncSqliteSaver 实例
"""
if not self._initialized:
raise RuntimeError("CheckpointerManager not initialized. Call initialize() first.")
checkpointer = await self._pool.get()
logger.debug(f"Acquired checkpointer from pool, remaining: {self._pool.qsize()}")
return checkpointer
async def return_to_pool(self, checkpointer: AsyncSqliteSaver) -> None:
"""
归还 checkpointer 到池
Args:
checkpointer: 要归还的 checkpointer 实例
"""
await self._pool.put(checkpointer)
logger.debug(f"Returned checkpointer to pool, remaining: {self._pool.qsize()}")
async def close(self) -> None:
"""关闭所有连接"""
if self._closed:
return
async with self._lock:
if self._closed:
return
logger.info("Closing CheckpointerManager...")
# 清空池并关闭所有连接
while not self._pool.empty():
try:
checkpointer = self._pool.get_nowait()
if checkpointer.conn:
await checkpointer.conn.close()
except asyncio.QueueEmpty:
break
self._closed = True
self._initialized = False
logger.info("CheckpointerManager closed")
def get_pool_stats(self) -> dict:
"""获取连接池状态统计"""
return {
"db_path": self._db_path,
"pool_size": self._pool_size,
"available_connections": self._pool.qsize(),
"initialized": self._initialized,
"closed": self._closed
}
# 全局单例
_global_manager: Optional[CheckpointerManager] = None
def get_checkpointer_manager() -> CheckpointerManager:
"""获取全局 CheckpointerManager 单例"""
global _global_manager
if _global_manager is None:
_global_manager = CheckpointerManager()
return _global_manager
async def init_global_checkpointer() -> None:
"""初始化全局 checkpointer 管理器"""
manager = get_checkpointer_manager()
await manager.initialize()
async def close_global_checkpointer() -> None:
"""关闭全局 checkpointer 管理器"""
global _global_manager
if _global_manager is not None:
await _global_manager.close()