新增checkpoint清理机制
This commit is contained in:
parent
bf11975183
commit
e117f1ee07
@ -5,7 +5,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||||
@ -15,6 +16,9 @@ from utils.settings import (
|
|||||||
CHECKPOINT_WAL_MODE,
|
CHECKPOINT_WAL_MODE,
|
||||||
CHECKPOINT_BUSY_TIMEOUT,
|
CHECKPOINT_BUSY_TIMEOUT,
|
||||||
CHECKPOINT_POOL_SIZE,
|
CHECKPOINT_POOL_SIZE,
|
||||||
|
CHECKPOINT_CLEANUP_ENABLED,
|
||||||
|
CHECKPOINT_CLEANUP_INTERVAL_HOURS,
|
||||||
|
CHECKPOINT_CLEANUP_OLDER_THAN_DAYS,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger('app')
|
logger = logging.getLogger('app')
|
||||||
@ -38,6 +42,9 @@ class CheckpointerManager:
|
|||||||
self._closed = False
|
self._closed = False
|
||||||
self._pool_size = CHECKPOINT_POOL_SIZE
|
self._pool_size = CHECKPOINT_POOL_SIZE
|
||||||
self._db_path = CHECKPOINT_DB_PATH
|
self._db_path = CHECKPOINT_DB_PATH
|
||||||
|
# 清理调度任务
|
||||||
|
self._cleanup_task: Optional[asyncio.Task] = None
|
||||||
|
self._cleanup_stop_event = asyncio.Event()
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""初始化连接池"""
|
"""初始化连接池"""
|
||||||
@ -131,6 +138,16 @@ class CheckpointerManager:
|
|||||||
if self._closed:
|
if self._closed:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 停止清理任务
|
||||||
|
if self._cleanup_task is not None:
|
||||||
|
self._cleanup_stop_event.set()
|
||||||
|
try:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
await asyncio.sleep(0.1) # 给任务一点时间清理
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._cleanup_task = None
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
if self._closed:
|
if self._closed:
|
||||||
return
|
return
|
||||||
@ -160,6 +177,181 @@ class CheckpointerManager:
|
|||||||
"closed": self._closed
|
"closed": self._closed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Checkpoint 清理方法
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
async def get_all_thread_ids(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
获取数据库中所有唯一的 thread_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 所有 thread_id 列表
|
||||||
|
"""
|
||||||
|
if not self._initialized:
|
||||||
|
return []
|
||||||
|
|
||||||
|
conn = aiosqlite.connect(self._db_path)
|
||||||
|
await conn.__aenter__()
|
||||||
|
|
||||||
|
try:
|
||||||
|
cursor = await conn.execute(
|
||||||
|
"SELECT DISTINCT thread_id FROM checkpoints"
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return [row[0] for row in rows]
|
||||||
|
finally:
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
async def get_thread_last_activity(self, thread_id: str) -> Optional[datetime]:
|
||||||
|
"""
|
||||||
|
获取指定 thread 的最后活动时间
|
||||||
|
|
||||||
|
通过查询该 thread 最新的 checkpoint 中的 ts 字段获取时间
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: 线程ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
datetime: 最后活动时间,如果找不到则返回 None
|
||||||
|
"""
|
||||||
|
if not self._initialized:
|
||||||
|
return None
|
||||||
|
|
||||||
|
checkpointer = await self.acquire_for_agent()
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = {"configurable": {"thread_id": thread_id}}
|
||||||
|
result = checkpointer.alist(config=config, limit=1)
|
||||||
|
|
||||||
|
last_checkpoint = None
|
||||||
|
async for item in result:
|
||||||
|
last_checkpoint = item
|
||||||
|
break
|
||||||
|
|
||||||
|
if last_checkpoint and last_checkpoint.checkpoint:
|
||||||
|
ts_str = last_checkpoint.checkpoint.get("ts")
|
||||||
|
if ts_str:
|
||||||
|
# 解析 ISO 格式时间戳
|
||||||
|
return datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error getting last activity for thread {thread_id}: {e}")
|
||||||
|
finally:
|
||||||
|
await self.return_to_pool(checkpointer)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def cleanup_old_threads(self, older_than_days: int = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
清理超过指定天数未活动的 thread
|
||||||
|
|
||||||
|
Args:
|
||||||
|
older_than_days: 清理多少天前的数据,默认使用配置值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 清理统计信息
|
||||||
|
- threads_deleted: 删除的 thread 数量
|
||||||
|
- threads_scanned: 扫描的 thread 总数
|
||||||
|
- cutoff_time: 截止时间
|
||||||
|
"""
|
||||||
|
if older_than_days is None:
|
||||||
|
older_than_days = CHECKPOINT_CLEANUP_OLDER_THAN_DAYS
|
||||||
|
|
||||||
|
# 使用带时区的时间,避免比较时出错
|
||||||
|
cutoff_time = datetime.now(timezone.utc) - timedelta(days=older_than_days)
|
||||||
|
logger.info(f"Starting checkpoint cleanup: removing threads inactive since {cutoff_time.isoformat()}")
|
||||||
|
|
||||||
|
all_thread_ids = await self.get_all_thread_ids()
|
||||||
|
threads_deleted = 0
|
||||||
|
threads_scanned = len(all_thread_ids)
|
||||||
|
|
||||||
|
checkpointer = await self.acquire_for_agent()
|
||||||
|
|
||||||
|
try:
|
||||||
|
for thread_id in all_thread_ids:
|
||||||
|
try:
|
||||||
|
last_activity = await self.get_thread_last_activity(thread_id)
|
||||||
|
|
||||||
|
if last_activity and last_activity < cutoff_time:
|
||||||
|
# 删除旧 thread
|
||||||
|
config = {"configurable": {"thread_id": thread_id}}
|
||||||
|
await checkpointer.adelete_thread(config)
|
||||||
|
threads_deleted += 1
|
||||||
|
logger.debug(f"Deleted old thread: {thread_id} (last activity: {last_activity.isoformat()})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error processing thread {thread_id}: {e}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await self.return_to_pool(checkpointer)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"threads_deleted": threads_deleted,
|
||||||
|
"threads_scanned": threads_scanned,
|
||||||
|
"cutoff_time": cutoff_time.isoformat(),
|
||||||
|
"older_than_days": older_than_days
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Checkpoint cleanup completed: {result}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _cleanup_loop(self):
|
||||||
|
"""后台清理循环"""
|
||||||
|
interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Checkpoint cleanup scheduler started: "
|
||||||
|
f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, "
|
||||||
|
f"older_than={CHECKPOINT_CLEANUP_OLDER_THAN_DAYS} days"
|
||||||
|
)
|
||||||
|
|
||||||
|
while not self._cleanup_stop_event.is_set():
|
||||||
|
try:
|
||||||
|
# 等待间隔时间或停止信号
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._cleanup_stop_event.wait(),
|
||||||
|
timeout=interval_seconds
|
||||||
|
)
|
||||||
|
# 收到停止信号,退出循环
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# 超时,执行清理任务
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self._cleanup_stop_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
# 执行清理
|
||||||
|
await self.cleanup_old_threads()
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Cleanup task cancelled")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in cleanup loop: {e}")
|
||||||
|
|
||||||
|
logger.info("Checkpoint cleanup scheduler stopped")
|
||||||
|
|
||||||
|
def start_cleanup_scheduler(self):
|
||||||
|
"""启动后台定时清理任务"""
|
||||||
|
if not CHECKPOINT_CLEANUP_ENABLED:
|
||||||
|
logger.info("Checkpoint cleanup is disabled")
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._cleanup_task is not None and not self._cleanup_task.done():
|
||||||
|
logger.warning("Cleanup scheduler is already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建新的事件循环(如果需要在后台运行)
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
self._cleanup_task = loop.create_task(self._cleanup_loop())
|
||||||
|
logger.info("Cleanup scheduler task created")
|
||||||
|
|
||||||
|
|
||||||
# 全局单例
|
# 全局单例
|
||||||
_global_manager: Optional[CheckpointerManager] = None
|
_global_manager: Optional[CheckpointerManager] = None
|
||||||
|
|||||||
@ -26,15 +26,33 @@ async def lifespan(app: FastAPI):
|
|||||||
"""FastAPI 应用生命周期管理"""
|
"""FastAPI 应用生命周期管理"""
|
||||||
# 启动时初始化
|
# 启动时初始化
|
||||||
logger.info("Starting up...")
|
logger.info("Starting up...")
|
||||||
from agent.checkpoint_manager import init_global_checkpointer
|
from agent.checkpoint_manager import (
|
||||||
|
init_global_checkpointer,
|
||||||
|
get_checkpointer_manager,
|
||||||
|
close_global_checkpointer
|
||||||
|
)
|
||||||
|
from utils.settings import CHECKPOINT_CLEANUP_ENABLED
|
||||||
|
|
||||||
await init_global_checkpointer()
|
await init_global_checkpointer()
|
||||||
logger.info("Global checkpointer initialized")
|
logger.info("Global checkpointer initialized")
|
||||||
|
|
||||||
|
# 启动 checkpoint 清理调度器
|
||||||
|
if CHECKPOINT_CLEANUP_ENABLED:
|
||||||
|
manager = get_checkpointer_manager()
|
||||||
|
# 启动时立即执行一次清理
|
||||||
|
try:
|
||||||
|
result = await manager.cleanup_old_threads()
|
||||||
|
logger.info(f"Startup cleanup completed: {result}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Startup cleanup failed (non-fatal): {e}")
|
||||||
|
# 启动定时清理调度器
|
||||||
|
manager.start_cleanup_scheduler()
|
||||||
|
logger.info("Checkpoint cleanup scheduler started")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# 关闭时清理
|
# 关闭时清理
|
||||||
logger.info("Shutting down...")
|
logger.info("Shutting down...")
|
||||||
from agent.checkpoint_manager import close_global_checkpointer
|
|
||||||
await close_global_checkpointer()
|
await close_global_checkpointer()
|
||||||
logger.info("Global checkpointer closed")
|
logger.info("Global checkpointer closed")
|
||||||
|
|
||||||
|
|||||||
@ -59,4 +59,16 @@ CHECKPOINT_BUSY_TIMEOUT = int(os.getenv("CHECKPOINT_BUSY_TIMEOUT", "10000"))
|
|||||||
# 同时可以持有的最大连接数
|
# 同时可以持有的最大连接数
|
||||||
CHECKPOINT_POOL_SIZE = int(os.getenv("CHECKPOINT_POOL_SIZE", "15"))
|
CHECKPOINT_POOL_SIZE = int(os.getenv("CHECKPOINT_POOL_SIZE", "15"))
|
||||||
|
|
||||||
|
# Checkpoint 自动清理配置
|
||||||
|
# 是否启用自动清理旧 session
|
||||||
|
CHECKPOINT_CLEANUP_ENABLED = os.getenv("CHECKPOINT_CLEANUP_ENABLED", "true") == "true"
|
||||||
|
|
||||||
|
# 清理间隔(小时)
|
||||||
|
# 每隔多少小时执行一次清理任务
|
||||||
|
CHECKPOINT_CLEANUP_INTERVAL_HOURS = int(os.getenv("CHECKPOINT_CLEANUP_INTERVAL_HOURS", "24"))
|
||||||
|
|
||||||
|
# 清理多少天前的数据
|
||||||
|
# 超过 N 天未活动的 thread 会被删除
|
||||||
|
CHECKPOINT_CLEANUP_OLDER_THAN_DAYS = int(os.getenv("CHECKPOINT_CLEANUP_OLDER_THAN_DAYS", "3"))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user