qwen_agent/agent/db_pool_manager.py
2026-01-19 23:39:04 +08:00

266 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
全局 PostgreSQL 连接池管理器
被 checkpoint、chat_history 共享使用
"""
import asyncio
import logging
from typing import Optional
from psycopg_pool import AsyncConnectionPool
from utils.settings import (
CHECKPOINT_DB_URL,
CHECKPOINT_POOL_SIZE,
CHECKPOINT_CLEANUP_ENABLED,
CHECKPOINT_CLEANUP_INTERVAL_HOURS,
CHECKPOINT_CLEANUP_INACTIVE_DAYS,
)
logger = logging.getLogger('app')
class DBPoolManager:
"""
全局 PostgreSQL 连接池管理器
主要功能:
1. 使用 psycopg_pool.AsyncConnectionPool 管理连接
2. 被 CheckpointerManager、ChatHistoryManager 共享
3. 自动清理旧的 checkpoint 数据
4. 优雅关闭机制
"""
def __init__(self):
self._pool: Optional[AsyncConnectionPool] = None
self._initialized = False
self._closed = False
# 清理调度任务
self._cleanup_task: Optional[asyncio.Task] = None
self._cleanup_stop_event = asyncio.Event()
async def initialize(self) -> None:
"""初始化连接池"""
if self._initialized:
return
logger.info(
f"Initializing PostgreSQL connection pool: "
f"URL={CHECKPOINT_DB_URL}, size={CHECKPOINT_POOL_SIZE}"
)
try:
# 创建 psycopg 连接池
self._pool = AsyncConnectionPool(
CHECKPOINT_DB_URL,
min_size=1,
max_size=CHECKPOINT_POOL_SIZE,
open=False,
)
# 打开连接池
await self._pool.open()
self._initialized = True
logger.info("PostgreSQL connection pool initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL connection pool: {e}")
raise
@property
def pool(self) -> AsyncConnectionPool:
"""获取连接池"""
if self._closed:
raise RuntimeError("DBPoolManager is closed")
if not self._initialized:
raise RuntimeError("DBPoolManager not initialized, call initialize() first")
return self._pool
async def close(self) -> None:
"""关闭连接池"""
if self._closed:
return
# 停止清理任务
if self._cleanup_task is not None:
self._cleanup_stop_event.set()
try:
await asyncio.sleep(0.1)
except asyncio.CancelledError:
pass
self._cleanup_task = None
logger.info("Closing DBPoolManager...")
if self._pool is not None:
await self._pool.close()
self._pool = None
self._closed = True
self._initialized = False
logger.info("DBPoolManager closed")
# ============================================================
# Checkpoint 清理方法(基于 PostgreSQL SQL 查询)
# ============================================================
async def cleanup_old_checkpoints(self, inactive_days: int = None) -> dict:
"""
清理旧的 checkpoint 记录
删除 N 天未活动的 thread 的所有 checkpoint。
Args:
inactive_days: 删除 N 天未活动的 thread默认使用配置值
Returns:
Dict: 清理统计信息
"""
if inactive_days is None:
inactive_days = CHECKPOINT_CLEANUP_INACTIVE_DAYS
logger.info(f"Starting checkpoint cleanup: removing threads inactive for {inactive_days} days")
if self._pool is None:
logger.warning("Connection pool not initialized, skipping cleanup")
return {"deleted": 0, "inactive_days": inactive_days}
try:
# 从池中获取连接执行清理
async with self._pool.connection() as conn:
async with conn.cursor() as cursor:
# 查找不活跃的 thread
query_find_inactive = """
SELECT DISTINCT thread_id
FROM checkpoints
WHERE (checkpoint->>'ts')::timestamp < NOW() - INTERVAL '%s days'
"""
await cursor.execute(query_find_inactive, (inactive_days,))
inactive_threads = await cursor.fetchall()
inactive_thread_ids = [row[0] for row in inactive_threads]
if not inactive_thread_ids:
logger.info("No inactive threads found")
return {"deleted": 0, "threads_deleted": 0, "inactive_days": inactive_days}
# 删除不活跃 thread 的 checkpoints
placeholders = ','.join(['%s'] * len(inactive_thread_ids))
query_delete_checkpoints = f"""
DELETE FROM checkpoints
WHERE thread_id IN ({placeholders})
"""
await cursor.execute(query_delete_checkpoints, inactive_thread_ids)
deleted_checkpoints = cursor.rowcount
# 同时清理 checkpoint_writes
query_delete_writes = f"""
DELETE FROM checkpoint_writes
WHERE thread_id IN ({placeholders})
"""
await cursor.execute(query_delete_writes, inactive_thread_ids)
deleted_writes = cursor.rowcount
logger.info(
f"Checkpoint cleanup completed: "
f"deleted {deleted_checkpoints} checkpoints, {deleted_writes} writes "
f"from {len(inactive_thread_ids)} inactive threads"
)
return {
"deleted": deleted_checkpoints + deleted_writes,
"threads_deleted": len(inactive_thread_ids),
"inactive_days": inactive_days
}
except Exception as e:
logger.error(f"Error during checkpoint cleanup: {e}")
import traceback
logger.error(traceback.format_exc())
return {"deleted": 0, "error": str(e), "inactive_days": inactive_days}
async def _cleanup_loop(self):
"""后台清理循环"""
interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600
logger.info(
f"Checkpoint cleanup scheduler started: "
f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, "
f"inactive_days={CHECKPOINT_CLEANUP_INACTIVE_DAYS}"
)
while not self._cleanup_stop_event.is_set():
try:
# 等待间隔时间或停止信号
try:
await asyncio.wait_for(
self._cleanup_stop_event.wait(),
timeout=interval_seconds
)
# 收到停止信号,退出循环
break
except asyncio.TimeoutError:
# 超时,执行清理任务
pass
if self._cleanup_stop_event.is_set():
break
# 执行清理
await self.cleanup_old_checkpoints()
except asyncio.CancelledError:
logger.info("Cleanup task cancelled")
break
except Exception as e:
logger.error(f"Error in cleanup loop: {e}")
logger.info("Checkpoint cleanup scheduler stopped")
def start_cleanup_scheduler(self):
"""启动后台定时清理任务"""
if not CHECKPOINT_CLEANUP_ENABLED:
logger.info("Checkpoint cleanup is disabled")
return
if self._cleanup_task is not None and not self._cleanup_task.done():
logger.warning("Cleanup scheduler is already running")
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._cleanup_task = loop.create_task(self._cleanup_loop())
logger.info("Cleanup scheduler task created")
# 全局单例
_global_manager: Optional[DBPoolManager] = None
def get_db_pool_manager() -> DBPoolManager:
"""获取全局 DBPoolManager 单例"""
global _global_manager
if _global_manager is None:
_global_manager = DBPoolManager()
return _global_manager
async def init_global_db_pool() -> DBPoolManager:
"""初始化全局数据库连接池"""
manager = get_db_pool_manager()
await manager.initialize()
return manager
async def close_global_db_pool() -> None:
"""关闭全局数据库连接池"""
global _global_manager
if _global_manager is not None:
await _global_manager.close()