""" Global PostgreSQL connection pool manager. Shared by checkpoint, chat_history, and mem0. """ import asyncio import logging from typing import Optional from psycopg_pool import AsyncConnectionPool from psycopg2 import pool as psycopg2_pool from utils.settings import ( CHECKPOINT_DB_URL, CHECKPOINT_POOL_SIZE, MEM0_POOL_SIZE, CHECKPOINT_CLEANUP_ENABLED, CHECKPOINT_CLEANUP_INTERVAL_HOURS, CHECKPOINT_CLEANUP_INACTIVE_DAYS, ) logger = logging.getLogger('app') class DBPoolManager: """ Global PostgreSQL connection pool manager. Main responsibilities: 1. Manage async connections with psycopg_pool.AsyncConnectionPool 2. Manage sync connections with psycopg2.pool.SimpleConnectionPool for Mem0 3. Share pools across CheckpointerManager, ChatHistoryManager, and Mem0Manager 4. Automatically clean up old checkpoint data 5. Provide graceful shutdown behavior """ def __init__(self): self._pool: Optional[AsyncConnectionPool] = None self._sync_pool: Optional[psycopg2_pool.SimpleConnectionPool] = None # Synchronous connection pool self._initialized = False self._closed = False # Cleanup scheduler task self._cleanup_task: Optional[asyncio.Task] = None self._cleanup_stop_event = asyncio.Event() async def initialize(self) -> None: """Initialize the connection pools.""" if self._initialized: return logger.info( f"Initializing PostgreSQL connection pool: " f"URL={CHECKPOINT_DB_URL}, size={CHECKPOINT_POOL_SIZE}" ) try: # 1. Create the async psycopg connection pool. self._pool = AsyncConnectionPool( CHECKPOINT_DB_URL, min_size=1, max_size=CHECKPOINT_POOL_SIZE, open=False, ) await self._pool.open() # 2. Create the synchronous psycopg2 connection pool for Mem0. self._sync_pool = self._create_sync_pool(CHECKPOINT_DB_URL, MEM0_POOL_SIZE) self._initialized = True logger.info("PostgreSQL connection pool initialized successfully") except Exception as e: logger.error(f"Failed to initialize PostgreSQL connection pool: {e}") raise def _create_sync_pool(self, db_url: str, pool_size: int) -> psycopg2_pool.SimpleConnectionPool: """Create the synchronous connection pool used by Mem0.""" # Parse connection URL: postgresql://user:password@host:port/database url_parts = db_url.replace("postgresql://", "").split("/") conn_part = url_parts[0] if len(url_parts) > 1 else "" dbname = url_parts[1] if len(url_parts) > 1 else "postgres" if "@" in conn_part: auth_part, host_part = conn_part.split("@") user, password = auth_part.split(":") if ":" in auth_part else (auth_part, "") else: user = "" password = "" host_part = conn_part if ":" in host_part: host, port = host_part.split(":") port = int(port) else: host = host_part port = 5432 return psycopg2_pool.SimpleConnectionPool( 1, pool_size, user=user, password=password, host=host, port=port, database=dbname ) @property def pool(self) -> AsyncConnectionPool: """Get the async connection pool.""" if self._closed: raise RuntimeError("DBPoolManager is closed") if not self._initialized: raise RuntimeError("DBPoolManager not initialized, call initialize() first") return self._pool @property def sync_pool(self) -> psycopg2_pool.SimpleConnectionPool: """Get the synchronous connection pool used by Mem0.""" if self._closed: raise RuntimeError("DBPoolManager is closed") if not self._initialized: raise RuntimeError("DBPoolManager not initialized, call initialize() first") if self._sync_pool is None: raise RuntimeError("Sync pool not available") return self._sync_pool async def close(self) -> None: """Close the connection pools.""" if self._closed: return # Stop the cleanup task. if self._cleanup_task is not None: self._cleanup_stop_event.set() try: await asyncio.sleep(0.1) except asyncio.CancelledError: pass self._cleanup_task = None logger.info("Closing DBPoolManager...") # Close the async connection pool. if self._pool is not None: await self._pool.close() self._pool = None # Close the synchronous connection pool. if self._sync_pool is not None: self._sync_pool.closeall() self._sync_pool = None self._closed = True self._initialized = False logger.info("DBPoolManager closed") # ============================================================ # Checkpoint cleanup methods based on PostgreSQL SQL queries # ============================================================ async def cleanup_old_checkpoints(self, inactive_days: int = None) -> dict: """ Clean up old checkpoint records. Delete all checkpoints for threads that have been inactive for N days. Args: inactive_days: Delete threads inactive for N days. Defaults to the configured value. Returns: Dict: Cleanup statistics """ if inactive_days is None: inactive_days = CHECKPOINT_CLEANUP_INACTIVE_DAYS logger.info(f"Starting checkpoint cleanup: removing threads inactive for {inactive_days} days") if self._pool is None: logger.warning("Connection pool not initialized, skipping cleanup") return {"deleted": 0, "inactive_days": inactive_days} try: # Acquire a connection from the pool and run the cleanup. async with self._pool.connection() as conn: async with conn.cursor() as cursor: # Find inactive threads. query_find_inactive = """ SELECT DISTINCT thread_id FROM checkpoints WHERE (checkpoint->>'ts')::timestamp < NOW() - INTERVAL '%s days' """ await cursor.execute(query_find_inactive, (inactive_days,)) inactive_threads = await cursor.fetchall() inactive_thread_ids = [row[0] for row in inactive_threads] if not inactive_thread_ids: logger.info("No inactive threads found") return {"deleted": 0, "threads_deleted": 0, "inactive_days": inactive_days} # Delete checkpoints for inactive threads. placeholders = ','.join(['%s'] * len(inactive_thread_ids)) query_delete_checkpoints = f""" DELETE FROM checkpoints WHERE thread_id IN ({placeholders}) """ await cursor.execute(query_delete_checkpoints, inactive_thread_ids) deleted_checkpoints = cursor.rowcount # Also clean up checkpoint_writes. query_delete_writes = f""" DELETE FROM checkpoint_writes WHERE thread_id IN ({placeholders}) """ await cursor.execute(query_delete_writes, inactive_thread_ids) deleted_writes = cursor.rowcount logger.info( f"Checkpoint cleanup completed: " f"deleted {deleted_checkpoints} checkpoints, {deleted_writes} writes " f"from {len(inactive_thread_ids)} inactive threads" ) return { "deleted": deleted_checkpoints + deleted_writes, "threads_deleted": len(inactive_thread_ids), "inactive_days": inactive_days } except Exception as e: logger.error(f"Error during checkpoint cleanup: {e}") import traceback logger.error(traceback.format_exc()) return {"deleted": 0, "error": str(e), "inactive_days": inactive_days} async def _cleanup_loop(self): """Background cleanup loop.""" interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600 logger.info( f"Checkpoint cleanup scheduler started: " f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, " f"inactive_days={CHECKPOINT_CLEANUP_INACTIVE_DAYS}" ) while not self._cleanup_stop_event.is_set(): try: # Wait for the interval or a stop signal. try: await asyncio.wait_for( self._cleanup_stop_event.wait(), timeout=interval_seconds ) # Stop signal received, exit the loop. break except asyncio.TimeoutError: # Timeout reached, run the cleanup task. pass if self._cleanup_stop_event.is_set(): break # Execute cleanup. await self.cleanup_old_checkpoints() except asyncio.CancelledError: logger.info("Cleanup task cancelled") break except Exception as e: logger.error(f"Error in cleanup loop: {e}") logger.info("Checkpoint cleanup scheduler stopped") def start_cleanup_scheduler(self): """Start the background scheduled cleanup task.""" if not CHECKPOINT_CLEANUP_ENABLED: logger.info("Checkpoint cleanup is disabled") return if self._cleanup_task is not None and not self._cleanup_task.done(): logger.warning("Cleanup scheduler is already running") return try: loop = asyncio.get_running_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) self._cleanup_task = loop.create_task(self._cleanup_loop()) logger.info("Cleanup scheduler task created") # Global singleton _global_manager: Optional[DBPoolManager] = None def get_db_pool_manager() -> DBPoolManager: """Get the global DBPoolManager singleton.""" global _global_manager if _global_manager is None: _global_manager = DBPoolManager() return _global_manager async def init_global_db_pool() -> DBPoolManager: """Initialize the global database connection pool.""" manager = get_db_pool_manager() await manager.initialize() return manager async def close_global_db_pool() -> None: """Close the global database connection pool.""" global _global_manager if _global_manager is not None: await _global_manager.close()