qwen_agent/agent/db_pool_manager.py
朱潮 3dc119bca8 refactor(mem0): optimize connection pool and async memory handling
- Fix mem0 connection pool exhausted error with proper pooling
- Convert memory operations to async tasks
- Optimize docker-compose configuration
- Add skill upload functionality
- Reduce cache size for better performance
- Update dependencies

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 19:39:12 +08:00

322 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
全局 PostgreSQL 连接池管理器
被 checkpoint、chat_history、mem0 共享使用
"""
import asyncio
import logging
from typing import Optional
from psycopg_pool import AsyncConnectionPool
from psycopg2 import pool as psycopg2_pool
from utils.settings import (
CHECKPOINT_DB_URL,
CHECKPOINT_POOL_SIZE,
MEM0_POOL_SIZE,
CHECKPOINT_CLEANUP_ENABLED,
CHECKPOINT_CLEANUP_INTERVAL_HOURS,
CHECKPOINT_CLEANUP_INACTIVE_DAYS,
)
logger = logging.getLogger('app')
class DBPoolManager:
"""
全局 PostgreSQL 连接池管理器
主要功能:
1. 使用 psycopg_pool.AsyncConnectionPool 管理异步连接
2. 使用 psycopg2.pool.SimpleConnectionPool 管理同步连接(供 Mem0 使用)
3. 被 CheckpointerManager、ChatHistoryManager、Mem0Manager 共享
4. 自动清理旧的 checkpoint 数据
5. 优雅关闭机制
"""
def __init__(self):
self._pool: Optional[AsyncConnectionPool] = None
self._sync_pool: Optional[psycopg2_pool.SimpleConnectionPool] = None # 同步连接池
self._initialized = False
self._closed = False
# 清理调度任务
self._cleanup_task: Optional[asyncio.Task] = None
self._cleanup_stop_event = asyncio.Event()
async def initialize(self) -> None:
"""初始化连接池"""
if self._initialized:
return
logger.info(
f"Initializing PostgreSQL connection pool: "
f"URL={CHECKPOINT_DB_URL}, size={CHECKPOINT_POOL_SIZE}"
)
try:
# 1. 创建异步 psycopg 连接池
self._pool = AsyncConnectionPool(
CHECKPOINT_DB_URL,
min_size=1,
max_size=CHECKPOINT_POOL_SIZE,
open=False,
)
await self._pool.open()
# 2. 创建同步 psycopg2 连接池(供 Mem0 使用)
self._sync_pool = self._create_sync_pool(CHECKPOINT_DB_URL, MEM0_POOL_SIZE)
self._initialized = True
logger.info("PostgreSQL connection pool initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL connection pool: {e}")
raise
def _create_sync_pool(self, db_url: str, pool_size: int) -> psycopg2_pool.SimpleConnectionPool:
"""创建同步连接池(供 Mem0 使用)"""
# 解析连接 URL: postgresql://user:password@host:port/database
url_parts = db_url.replace("postgresql://", "").split("/")
conn_part = url_parts[0] if len(url_parts) > 1 else ""
dbname = url_parts[1] if len(url_parts) > 1 else "postgres"
if "@" in conn_part:
auth_part, host_part = conn_part.split("@")
user, password = auth_part.split(":") if ":" in auth_part else (auth_part, "")
else:
user = ""
password = ""
host_part = conn_part
if ":" in host_part:
host, port = host_part.split(":")
port = int(port)
else:
host = host_part
port = 5432
return psycopg2_pool.SimpleConnectionPool(
1, pool_size,
user=user,
password=password,
host=host,
port=port,
database=dbname
)
@property
def pool(self) -> AsyncConnectionPool:
"""获取异步连接池"""
if self._closed:
raise RuntimeError("DBPoolManager is closed")
if not self._initialized:
raise RuntimeError("DBPoolManager not initialized, call initialize() first")
return self._pool
@property
def sync_pool(self) -> psycopg2_pool.SimpleConnectionPool:
"""获取同步连接池(供 Mem0 使用)"""
if self._closed:
raise RuntimeError("DBPoolManager is closed")
if not self._initialized:
raise RuntimeError("DBPoolManager not initialized, call initialize() first")
if self._sync_pool is None:
raise RuntimeError("Sync pool not available")
return self._sync_pool
async def close(self) -> None:
"""关闭连接池"""
if self._closed:
return
# 停止清理任务
if self._cleanup_task is not None:
self._cleanup_stop_event.set()
try:
await asyncio.sleep(0.1)
except asyncio.CancelledError:
pass
self._cleanup_task = None
logger.info("Closing DBPoolManager...")
# 关闭异步连接池
if self._pool is not None:
await self._pool.close()
self._pool = None
# 关闭同步连接池
if self._sync_pool is not None:
self._sync_pool.closeall()
self._sync_pool = None
self._closed = True
self._initialized = False
logger.info("DBPoolManager closed")
# ============================================================
# Checkpoint 清理方法(基于 PostgreSQL SQL 查询)
# ============================================================
async def cleanup_old_checkpoints(self, inactive_days: int = None) -> dict:
"""
清理旧的 checkpoint 记录
删除 N 天未活动的 thread 的所有 checkpoint。
Args:
inactive_days: 删除 N 天未活动的 thread默认使用配置值
Returns:
Dict: 清理统计信息
"""
if inactive_days is None:
inactive_days = CHECKPOINT_CLEANUP_INACTIVE_DAYS
logger.info(f"Starting checkpoint cleanup: removing threads inactive for {inactive_days} days")
if self._pool is None:
logger.warning("Connection pool not initialized, skipping cleanup")
return {"deleted": 0, "inactive_days": inactive_days}
try:
# 从池中获取连接执行清理
async with self._pool.connection() as conn:
async with conn.cursor() as cursor:
# 查找不活跃的 thread
query_find_inactive = """
SELECT DISTINCT thread_id
FROM checkpoints
WHERE (checkpoint->>'ts')::timestamp < NOW() - INTERVAL '%s days'
"""
await cursor.execute(query_find_inactive, (inactive_days,))
inactive_threads = await cursor.fetchall()
inactive_thread_ids = [row[0] for row in inactive_threads]
if not inactive_thread_ids:
logger.info("No inactive threads found")
return {"deleted": 0, "threads_deleted": 0, "inactive_days": inactive_days}
# 删除不活跃 thread 的 checkpoints
placeholders = ','.join(['%s'] * len(inactive_thread_ids))
query_delete_checkpoints = f"""
DELETE FROM checkpoints
WHERE thread_id IN ({placeholders})
"""
await cursor.execute(query_delete_checkpoints, inactive_thread_ids)
deleted_checkpoints = cursor.rowcount
# 同时清理 checkpoint_writes
query_delete_writes = f"""
DELETE FROM checkpoint_writes
WHERE thread_id IN ({placeholders})
"""
await cursor.execute(query_delete_writes, inactive_thread_ids)
deleted_writes = cursor.rowcount
logger.info(
f"Checkpoint cleanup completed: "
f"deleted {deleted_checkpoints} checkpoints, {deleted_writes} writes "
f"from {len(inactive_thread_ids)} inactive threads"
)
return {
"deleted": deleted_checkpoints + deleted_writes,
"threads_deleted": len(inactive_thread_ids),
"inactive_days": inactive_days
}
except Exception as e:
logger.error(f"Error during checkpoint cleanup: {e}")
import traceback
logger.error(traceback.format_exc())
return {"deleted": 0, "error": str(e), "inactive_days": inactive_days}
async def _cleanup_loop(self):
"""后台清理循环"""
interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600
logger.info(
f"Checkpoint cleanup scheduler started: "
f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, "
f"inactive_days={CHECKPOINT_CLEANUP_INACTIVE_DAYS}"
)
while not self._cleanup_stop_event.is_set():
try:
# 等待间隔时间或停止信号
try:
await asyncio.wait_for(
self._cleanup_stop_event.wait(),
timeout=interval_seconds
)
# 收到停止信号,退出循环
break
except asyncio.TimeoutError:
# 超时,执行清理任务
pass
if self._cleanup_stop_event.is_set():
break
# 执行清理
await self.cleanup_old_checkpoints()
except asyncio.CancelledError:
logger.info("Cleanup task cancelled")
break
except Exception as e:
logger.error(f"Error in cleanup loop: {e}")
logger.info("Checkpoint cleanup scheduler stopped")
def start_cleanup_scheduler(self):
"""启动后台定时清理任务"""
if not CHECKPOINT_CLEANUP_ENABLED:
logger.info("Checkpoint cleanup is disabled")
return
if self._cleanup_task is not None and not self._cleanup_task.done():
logger.warning("Cleanup scheduler is already running")
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._cleanup_task = loop.create_task(self._cleanup_loop())
logger.info("Cleanup scheduler task created")
# 全局单例
_global_manager: Optional[DBPoolManager] = None
def get_db_pool_manager() -> DBPoolManager:
"""获取全局 DBPoolManager 单例"""
global _global_manager
if _global_manager is None:
_global_manager = DBPoolManager()
return _global_manager
async def init_global_db_pool() -> DBPoolManager:
"""初始化全局数据库连接池"""
manager = get_db_pool_manager()
await manager.initialize()
return manager
async def close_global_db_pool() -> None:
"""关闭全局数据库连接池"""
global _global_manager
if _global_manager is not None:
await _global_manager.close()