""" 全局 PostgreSQL 连接池管理器 被 checkpoint、chat_history、mem0 共享使用 """ import asyncio import logging from typing import Optional from psycopg_pool import AsyncConnectionPool from psycopg2 import pool as psycopg2_pool from utils.settings import ( CHECKPOINT_DB_URL, CHECKPOINT_POOL_SIZE, MEM0_POOL_SIZE, CHECKPOINT_CLEANUP_ENABLED, CHECKPOINT_CLEANUP_INTERVAL_HOURS, CHECKPOINT_CLEANUP_INACTIVE_DAYS, ) logger = logging.getLogger('app') class DBPoolManager: """ 全局 PostgreSQL 连接池管理器 主要功能: 1. 使用 psycopg_pool.AsyncConnectionPool 管理异步连接 2. 使用 psycopg2.pool.SimpleConnectionPool 管理同步连接(供 Mem0 使用) 3. 被 CheckpointerManager、ChatHistoryManager、Mem0Manager 共享 4. 自动清理旧的 checkpoint 数据 5. 优雅关闭机制 """ def __init__(self): self._pool: Optional[AsyncConnectionPool] = None self._sync_pool: Optional[psycopg2_pool.SimpleConnectionPool] = None # 同步连接池 self._initialized = False self._closed = False # 清理调度任务 self._cleanup_task: Optional[asyncio.Task] = None self._cleanup_stop_event = asyncio.Event() async def initialize(self) -> None: """初始化连接池""" if self._initialized: return logger.info( f"Initializing PostgreSQL connection pool: " f"URL={CHECKPOINT_DB_URL}, size={CHECKPOINT_POOL_SIZE}" ) try: # 1. 创建异步 psycopg 连接池 self._pool = AsyncConnectionPool( CHECKPOINT_DB_URL, min_size=1, max_size=CHECKPOINT_POOL_SIZE, open=False, ) await self._pool.open() # 2. 创建同步 psycopg2 连接池(供 Mem0 使用) self._sync_pool = self._create_sync_pool(CHECKPOINT_DB_URL, MEM0_POOL_SIZE) self._initialized = True logger.info("PostgreSQL connection pool initialized successfully") except Exception as e: logger.error(f"Failed to initialize PostgreSQL connection pool: {e}") raise def _create_sync_pool(self, db_url: str, pool_size: int) -> psycopg2_pool.SimpleConnectionPool: """创建同步连接池(供 Mem0 使用)""" # 解析连接 URL: postgresql://user:password@host:port/database url_parts = db_url.replace("postgresql://", "").split("/") conn_part = url_parts[0] if len(url_parts) > 1 else "" dbname = url_parts[1] if len(url_parts) > 1 else "postgres" if "@" in conn_part: auth_part, host_part = conn_part.split("@") user, password = auth_part.split(":") if ":" in auth_part else (auth_part, "") else: user = "" password = "" host_part = conn_part if ":" in host_part: host, port = host_part.split(":") port = int(port) else: host = host_part port = 5432 return psycopg2_pool.SimpleConnectionPool( 1, pool_size, user=user, password=password, host=host, port=port, database=dbname ) @property def pool(self) -> AsyncConnectionPool: """获取异步连接池""" if self._closed: raise RuntimeError("DBPoolManager is closed") if not self._initialized: raise RuntimeError("DBPoolManager not initialized, call initialize() first") return self._pool @property def sync_pool(self) -> psycopg2_pool.SimpleConnectionPool: """获取同步连接池(供 Mem0 使用)""" if self._closed: raise RuntimeError("DBPoolManager is closed") if not self._initialized: raise RuntimeError("DBPoolManager not initialized, call initialize() first") if self._sync_pool is None: raise RuntimeError("Sync pool not available") return self._sync_pool async def close(self) -> None: """关闭连接池""" if self._closed: return # 停止清理任务 if self._cleanup_task is not None: self._cleanup_stop_event.set() try: await asyncio.sleep(0.1) except asyncio.CancelledError: pass self._cleanup_task = None logger.info("Closing DBPoolManager...") # 关闭异步连接池 if self._pool is not None: await self._pool.close() self._pool = None # 关闭同步连接池 if self._sync_pool is not None: self._sync_pool.closeall() self._sync_pool = None self._closed = True self._initialized = False logger.info("DBPoolManager closed") # ============================================================ # Checkpoint 清理方法(基于 PostgreSQL SQL 查询) # ============================================================ async def cleanup_old_checkpoints(self, inactive_days: int = None) -> dict: """ 清理旧的 checkpoint 记录 删除 N 天未活动的 thread 的所有 checkpoint。 Args: inactive_days: 删除 N 天未活动的 thread,默认使用配置值 Returns: Dict: 清理统计信息 """ if inactive_days is None: inactive_days = CHECKPOINT_CLEANUP_INACTIVE_DAYS logger.info(f"Starting checkpoint cleanup: removing threads inactive for {inactive_days} days") if self._pool is None: logger.warning("Connection pool not initialized, skipping cleanup") return {"deleted": 0, "inactive_days": inactive_days} try: # 从池中获取连接执行清理 async with self._pool.connection() as conn: async with conn.cursor() as cursor: # 查找不活跃的 thread query_find_inactive = """ SELECT DISTINCT thread_id FROM checkpoints WHERE (checkpoint->>'ts')::timestamp < NOW() - INTERVAL '%s days' """ await cursor.execute(query_find_inactive, (inactive_days,)) inactive_threads = await cursor.fetchall() inactive_thread_ids = [row[0] for row in inactive_threads] if not inactive_thread_ids: logger.info("No inactive threads found") return {"deleted": 0, "threads_deleted": 0, "inactive_days": inactive_days} # 删除不活跃 thread 的 checkpoints placeholders = ','.join(['%s'] * len(inactive_thread_ids)) query_delete_checkpoints = f""" DELETE FROM checkpoints WHERE thread_id IN ({placeholders}) """ await cursor.execute(query_delete_checkpoints, inactive_thread_ids) deleted_checkpoints = cursor.rowcount # 同时清理 checkpoint_writes query_delete_writes = f""" DELETE FROM checkpoint_writes WHERE thread_id IN ({placeholders}) """ await cursor.execute(query_delete_writes, inactive_thread_ids) deleted_writes = cursor.rowcount logger.info( f"Checkpoint cleanup completed: " f"deleted {deleted_checkpoints} checkpoints, {deleted_writes} writes " f"from {len(inactive_thread_ids)} inactive threads" ) return { "deleted": deleted_checkpoints + deleted_writes, "threads_deleted": len(inactive_thread_ids), "inactive_days": inactive_days } except Exception as e: logger.error(f"Error during checkpoint cleanup: {e}") import traceback logger.error(traceback.format_exc()) return {"deleted": 0, "error": str(e), "inactive_days": inactive_days} async def _cleanup_loop(self): """后台清理循环""" interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600 logger.info( f"Checkpoint cleanup scheduler started: " f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, " f"inactive_days={CHECKPOINT_CLEANUP_INACTIVE_DAYS}" ) while not self._cleanup_stop_event.is_set(): try: # 等待间隔时间或停止信号 try: await asyncio.wait_for( self._cleanup_stop_event.wait(), timeout=interval_seconds ) # 收到停止信号,退出循环 break except asyncio.TimeoutError: # 超时,执行清理任务 pass if self._cleanup_stop_event.is_set(): break # 执行清理 await self.cleanup_old_checkpoints() except asyncio.CancelledError: logger.info("Cleanup task cancelled") break except Exception as e: logger.error(f"Error in cleanup loop: {e}") logger.info("Checkpoint cleanup scheduler stopped") def start_cleanup_scheduler(self): """启动后台定时清理任务""" if not CHECKPOINT_CLEANUP_ENABLED: logger.info("Checkpoint cleanup is disabled") return if self._cleanup_task is not None and not self._cleanup_task.done(): logger.warning("Cleanup scheduler is already running") return try: loop = asyncio.get_running_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) self._cleanup_task = loop.create_task(self._cleanup_loop()) logger.info("Cleanup scheduler task created") # 全局单例 _global_manager: Optional[DBPoolManager] = None def get_db_pool_manager() -> DBPoolManager: """获取全局 DBPoolManager 单例""" global _global_manager if _global_manager is None: _global_manager = DBPoolManager() return _global_manager async def init_global_db_pool() -> DBPoolManager: """初始化全局数据库连接池""" manager = get_db_pool_manager() await manager.initialize() return manager async def close_global_db_pool() -> None: """关闭全局数据库连接池""" global _global_manager if _global_manager is not None: await _global_manager.close()