266 lines
8.7 KiB
Python
266 lines
8.7 KiB
Python
"""
|
||
全局 PostgreSQL 连接池管理器
|
||
被 checkpoint、chat_history 共享使用
|
||
"""
|
||
import asyncio
|
||
import logging
|
||
from typing import Optional
|
||
|
||
from psycopg_pool import AsyncConnectionPool
|
||
|
||
from utils.settings import (
|
||
CHECKPOINT_DB_URL,
|
||
CHECKPOINT_POOL_SIZE,
|
||
CHECKPOINT_CLEANUP_ENABLED,
|
||
CHECKPOINT_CLEANUP_INTERVAL_HOURS,
|
||
CHECKPOINT_CLEANUP_INACTIVE_DAYS,
|
||
)
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
|
||
class DBPoolManager:
|
||
"""
|
||
全局 PostgreSQL 连接池管理器
|
||
|
||
主要功能:
|
||
1. 使用 psycopg_pool.AsyncConnectionPool 管理连接
|
||
2. 被 CheckpointerManager、ChatHistoryManager 共享
|
||
3. 自动清理旧的 checkpoint 数据
|
||
4. 优雅关闭机制
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._pool: Optional[AsyncConnectionPool] = None
|
||
self._initialized = False
|
||
self._closed = False
|
||
# 清理调度任务
|
||
self._cleanup_task: Optional[asyncio.Task] = None
|
||
self._cleanup_stop_event = asyncio.Event()
|
||
|
||
async def initialize(self) -> None:
|
||
"""初始化连接池"""
|
||
if self._initialized:
|
||
return
|
||
|
||
logger.info(
|
||
f"Initializing PostgreSQL connection pool: "
|
||
f"URL={CHECKPOINT_DB_URL}, size={CHECKPOINT_POOL_SIZE}"
|
||
)
|
||
|
||
try:
|
||
# 创建 psycopg 连接池
|
||
self._pool = AsyncConnectionPool(
|
||
CHECKPOINT_DB_URL,
|
||
min_size=1,
|
||
max_size=CHECKPOINT_POOL_SIZE,
|
||
open=False,
|
||
)
|
||
|
||
# 打开连接池
|
||
await self._pool.open()
|
||
|
||
self._initialized = True
|
||
logger.info("PostgreSQL connection pool initialized successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize PostgreSQL connection pool: {e}")
|
||
raise
|
||
|
||
@property
|
||
def pool(self) -> AsyncConnectionPool:
|
||
"""获取连接池"""
|
||
if self._closed:
|
||
raise RuntimeError("DBPoolManager is closed")
|
||
|
||
if not self._initialized:
|
||
raise RuntimeError("DBPoolManager not initialized, call initialize() first")
|
||
|
||
return self._pool
|
||
|
||
async def close(self) -> None:
|
||
"""关闭连接池"""
|
||
if self._closed:
|
||
return
|
||
|
||
# 停止清理任务
|
||
if self._cleanup_task is not None:
|
||
self._cleanup_stop_event.set()
|
||
try:
|
||
await asyncio.sleep(0.1)
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._cleanup_task = None
|
||
|
||
logger.info("Closing DBPoolManager...")
|
||
|
||
if self._pool is not None:
|
||
await self._pool.close()
|
||
self._pool = None
|
||
|
||
self._closed = True
|
||
self._initialized = False
|
||
logger.info("DBPoolManager closed")
|
||
|
||
# ============================================================
|
||
# Checkpoint 清理方法(基于 PostgreSQL SQL 查询)
|
||
# ============================================================
|
||
|
||
async def cleanup_old_checkpoints(self, inactive_days: int = None) -> dict:
|
||
"""
|
||
清理旧的 checkpoint 记录
|
||
|
||
删除 N 天未活动的 thread 的所有 checkpoint。
|
||
|
||
Args:
|
||
inactive_days: 删除 N 天未活动的 thread,默认使用配置值
|
||
|
||
Returns:
|
||
Dict: 清理统计信息
|
||
"""
|
||
if inactive_days is None:
|
||
inactive_days = CHECKPOINT_CLEANUP_INACTIVE_DAYS
|
||
|
||
logger.info(f"Starting checkpoint cleanup: removing threads inactive for {inactive_days} days")
|
||
|
||
if self._pool is None:
|
||
logger.warning("Connection pool not initialized, skipping cleanup")
|
||
return {"deleted": 0, "inactive_days": inactive_days}
|
||
|
||
try:
|
||
# 从池中获取连接执行清理
|
||
async with self._pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 查找不活跃的 thread
|
||
query_find_inactive = """
|
||
SELECT DISTINCT thread_id
|
||
FROM checkpoints
|
||
WHERE (checkpoint->>'ts')::timestamp < NOW() - INTERVAL '%s days'
|
||
"""
|
||
|
||
await cursor.execute(query_find_inactive, (inactive_days,))
|
||
inactive_threads = await cursor.fetchall()
|
||
inactive_thread_ids = [row[0] for row in inactive_threads]
|
||
|
||
if not inactive_thread_ids:
|
||
logger.info("No inactive threads found")
|
||
return {"deleted": 0, "threads_deleted": 0, "inactive_days": inactive_days}
|
||
|
||
# 删除不活跃 thread 的 checkpoints
|
||
placeholders = ','.join(['%s'] * len(inactive_thread_ids))
|
||
query_delete_checkpoints = f"""
|
||
DELETE FROM checkpoints
|
||
WHERE thread_id IN ({placeholders})
|
||
"""
|
||
|
||
await cursor.execute(query_delete_checkpoints, inactive_thread_ids)
|
||
deleted_checkpoints = cursor.rowcount
|
||
|
||
# 同时清理 checkpoint_writes
|
||
query_delete_writes = f"""
|
||
DELETE FROM checkpoint_writes
|
||
WHERE thread_id IN ({placeholders})
|
||
"""
|
||
await cursor.execute(query_delete_writes, inactive_thread_ids)
|
||
deleted_writes = cursor.rowcount
|
||
|
||
logger.info(
|
||
f"Checkpoint cleanup completed: "
|
||
f"deleted {deleted_checkpoints} checkpoints, {deleted_writes} writes "
|
||
f"from {len(inactive_thread_ids)} inactive threads"
|
||
)
|
||
|
||
return {
|
||
"deleted": deleted_checkpoints + deleted_writes,
|
||
"threads_deleted": len(inactive_thread_ids),
|
||
"inactive_days": inactive_days
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error during checkpoint cleanup: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return {"deleted": 0, "error": str(e), "inactive_days": inactive_days}
|
||
|
||
async def _cleanup_loop(self):
|
||
"""后台清理循环"""
|
||
interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600
|
||
|
||
logger.info(
|
||
f"Checkpoint cleanup scheduler started: "
|
||
f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, "
|
||
f"inactive_days={CHECKPOINT_CLEANUP_INACTIVE_DAYS}"
|
||
)
|
||
|
||
while not self._cleanup_stop_event.is_set():
|
||
try:
|
||
# 等待间隔时间或停止信号
|
||
try:
|
||
await asyncio.wait_for(
|
||
self._cleanup_stop_event.wait(),
|
||
timeout=interval_seconds
|
||
)
|
||
# 收到停止信号,退出循环
|
||
break
|
||
except asyncio.TimeoutError:
|
||
# 超时,执行清理任务
|
||
pass
|
||
|
||
if self._cleanup_stop_event.is_set():
|
||
break
|
||
|
||
# 执行清理
|
||
await self.cleanup_old_checkpoints()
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info("Cleanup task cancelled")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"Error in cleanup loop: {e}")
|
||
|
||
logger.info("Checkpoint cleanup scheduler stopped")
|
||
|
||
def start_cleanup_scheduler(self):
|
||
"""启动后台定时清理任务"""
|
||
if not CHECKPOINT_CLEANUP_ENABLED:
|
||
logger.info("Checkpoint cleanup is disabled")
|
||
return
|
||
|
||
if self._cleanup_task is not None and not self._cleanup_task.done():
|
||
logger.warning("Cleanup scheduler is already running")
|
||
return
|
||
|
||
try:
|
||
loop = asyncio.get_running_loop()
|
||
except RuntimeError:
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
|
||
self._cleanup_task = loop.create_task(self._cleanup_loop())
|
||
logger.info("Cleanup scheduler task created")
|
||
|
||
|
||
# 全局单例
|
||
_global_manager: Optional[DBPoolManager] = None
|
||
|
||
|
||
def get_db_pool_manager() -> DBPoolManager:
|
||
"""获取全局 DBPoolManager 单例"""
|
||
global _global_manager
|
||
if _global_manager is None:
|
||
_global_manager = DBPoolManager()
|
||
return _global_manager
|
||
|
||
|
||
async def init_global_db_pool() -> DBPoolManager:
|
||
"""初始化全局数据库连接池"""
|
||
manager = get_db_pool_manager()
|
||
await manager.initialize()
|
||
return manager
|
||
|
||
|
||
async def close_global_db_pool() -> None:
|
||
"""关闭全局数据库连接池"""
|
||
global _global_manager
|
||
if _global_manager is not None:
|
||
await _global_manager.close()
|