- Fix mem0 connection pool exhausted error with proper pooling - Convert memory operations to async tasks - Optimize docker-compose configuration - Add skill upload functionality - Reduce cache size for better performance - Update dependencies Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
322 lines
11 KiB
Python
322 lines
11 KiB
Python
"""
|
||
全局 PostgreSQL 连接池管理器
|
||
被 checkpoint、chat_history、mem0 共享使用
|
||
"""
|
||
import asyncio
|
||
import logging
|
||
from typing import Optional
|
||
|
||
from psycopg_pool import AsyncConnectionPool
|
||
from psycopg2 import pool as psycopg2_pool
|
||
|
||
from utils.settings import (
|
||
CHECKPOINT_DB_URL,
|
||
CHECKPOINT_POOL_SIZE,
|
||
MEM0_POOL_SIZE,
|
||
CHECKPOINT_CLEANUP_ENABLED,
|
||
CHECKPOINT_CLEANUP_INTERVAL_HOURS,
|
||
CHECKPOINT_CLEANUP_INACTIVE_DAYS,
|
||
)
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
|
||
class DBPoolManager:
|
||
"""
|
||
全局 PostgreSQL 连接池管理器
|
||
|
||
主要功能:
|
||
1. 使用 psycopg_pool.AsyncConnectionPool 管理异步连接
|
||
2. 使用 psycopg2.pool.SimpleConnectionPool 管理同步连接(供 Mem0 使用)
|
||
3. 被 CheckpointerManager、ChatHistoryManager、Mem0Manager 共享
|
||
4. 自动清理旧的 checkpoint 数据
|
||
5. 优雅关闭机制
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._pool: Optional[AsyncConnectionPool] = None
|
||
self._sync_pool: Optional[psycopg2_pool.SimpleConnectionPool] = None # 同步连接池
|
||
self._initialized = False
|
||
self._closed = False
|
||
# 清理调度任务
|
||
self._cleanup_task: Optional[asyncio.Task] = None
|
||
self._cleanup_stop_event = asyncio.Event()
|
||
|
||
async def initialize(self) -> None:
|
||
"""初始化连接池"""
|
||
if self._initialized:
|
||
return
|
||
|
||
logger.info(
|
||
f"Initializing PostgreSQL connection pool: "
|
||
f"URL={CHECKPOINT_DB_URL}, size={CHECKPOINT_POOL_SIZE}"
|
||
)
|
||
|
||
try:
|
||
# 1. 创建异步 psycopg 连接池
|
||
self._pool = AsyncConnectionPool(
|
||
CHECKPOINT_DB_URL,
|
||
min_size=1,
|
||
max_size=CHECKPOINT_POOL_SIZE,
|
||
open=False,
|
||
)
|
||
await self._pool.open()
|
||
|
||
# 2. 创建同步 psycopg2 连接池(供 Mem0 使用)
|
||
self._sync_pool = self._create_sync_pool(CHECKPOINT_DB_URL, MEM0_POOL_SIZE)
|
||
|
||
self._initialized = True
|
||
logger.info("PostgreSQL connection pool initialized successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize PostgreSQL connection pool: {e}")
|
||
raise
|
||
|
||
def _create_sync_pool(self, db_url: str, pool_size: int) -> psycopg2_pool.SimpleConnectionPool:
|
||
"""创建同步连接池(供 Mem0 使用)"""
|
||
# 解析连接 URL: postgresql://user:password@host:port/database
|
||
url_parts = db_url.replace("postgresql://", "").split("/")
|
||
conn_part = url_parts[0] if len(url_parts) > 1 else ""
|
||
dbname = url_parts[1] if len(url_parts) > 1 else "postgres"
|
||
|
||
if "@" in conn_part:
|
||
auth_part, host_part = conn_part.split("@")
|
||
user, password = auth_part.split(":") if ":" in auth_part else (auth_part, "")
|
||
else:
|
||
user = ""
|
||
password = ""
|
||
host_part = conn_part
|
||
|
||
if ":" in host_part:
|
||
host, port = host_part.split(":")
|
||
port = int(port)
|
||
else:
|
||
host = host_part
|
||
port = 5432
|
||
|
||
return psycopg2_pool.SimpleConnectionPool(
|
||
1, pool_size,
|
||
user=user,
|
||
password=password,
|
||
host=host,
|
||
port=port,
|
||
database=dbname
|
||
)
|
||
|
||
@property
|
||
def pool(self) -> AsyncConnectionPool:
|
||
"""获取异步连接池"""
|
||
if self._closed:
|
||
raise RuntimeError("DBPoolManager is closed")
|
||
|
||
if not self._initialized:
|
||
raise RuntimeError("DBPoolManager not initialized, call initialize() first")
|
||
|
||
return self._pool
|
||
|
||
@property
|
||
def sync_pool(self) -> psycopg2_pool.SimpleConnectionPool:
|
||
"""获取同步连接池(供 Mem0 使用)"""
|
||
if self._closed:
|
||
raise RuntimeError("DBPoolManager is closed")
|
||
|
||
if not self._initialized:
|
||
raise RuntimeError("DBPoolManager not initialized, call initialize() first")
|
||
|
||
if self._sync_pool is None:
|
||
raise RuntimeError("Sync pool not available")
|
||
|
||
return self._sync_pool
|
||
|
||
async def close(self) -> None:
|
||
"""关闭连接池"""
|
||
if self._closed:
|
||
return
|
||
|
||
# 停止清理任务
|
||
if self._cleanup_task is not None:
|
||
self._cleanup_stop_event.set()
|
||
try:
|
||
await asyncio.sleep(0.1)
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._cleanup_task = None
|
||
|
||
logger.info("Closing DBPoolManager...")
|
||
|
||
# 关闭异步连接池
|
||
if self._pool is not None:
|
||
await self._pool.close()
|
||
self._pool = None
|
||
|
||
# 关闭同步连接池
|
||
if self._sync_pool is not None:
|
||
self._sync_pool.closeall()
|
||
self._sync_pool = None
|
||
|
||
self._closed = True
|
||
self._initialized = False
|
||
logger.info("DBPoolManager closed")
|
||
|
||
# ============================================================
|
||
# Checkpoint 清理方法(基于 PostgreSQL SQL 查询)
|
||
# ============================================================
|
||
|
||
async def cleanup_old_checkpoints(self, inactive_days: int = None) -> dict:
|
||
"""
|
||
清理旧的 checkpoint 记录
|
||
|
||
删除 N 天未活动的 thread 的所有 checkpoint。
|
||
|
||
Args:
|
||
inactive_days: 删除 N 天未活动的 thread,默认使用配置值
|
||
|
||
Returns:
|
||
Dict: 清理统计信息
|
||
"""
|
||
if inactive_days is None:
|
||
inactive_days = CHECKPOINT_CLEANUP_INACTIVE_DAYS
|
||
|
||
logger.info(f"Starting checkpoint cleanup: removing threads inactive for {inactive_days} days")
|
||
|
||
if self._pool is None:
|
||
logger.warning("Connection pool not initialized, skipping cleanup")
|
||
return {"deleted": 0, "inactive_days": inactive_days}
|
||
|
||
try:
|
||
# 从池中获取连接执行清理
|
||
async with self._pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
# 查找不活跃的 thread
|
||
query_find_inactive = """
|
||
SELECT DISTINCT thread_id
|
||
FROM checkpoints
|
||
WHERE (checkpoint->>'ts')::timestamp < NOW() - INTERVAL '%s days'
|
||
"""
|
||
|
||
await cursor.execute(query_find_inactive, (inactive_days,))
|
||
inactive_threads = await cursor.fetchall()
|
||
inactive_thread_ids = [row[0] for row in inactive_threads]
|
||
|
||
if not inactive_thread_ids:
|
||
logger.info("No inactive threads found")
|
||
return {"deleted": 0, "threads_deleted": 0, "inactive_days": inactive_days}
|
||
|
||
# 删除不活跃 thread 的 checkpoints
|
||
placeholders = ','.join(['%s'] * len(inactive_thread_ids))
|
||
query_delete_checkpoints = f"""
|
||
DELETE FROM checkpoints
|
||
WHERE thread_id IN ({placeholders})
|
||
"""
|
||
|
||
await cursor.execute(query_delete_checkpoints, inactive_thread_ids)
|
||
deleted_checkpoints = cursor.rowcount
|
||
|
||
# 同时清理 checkpoint_writes
|
||
query_delete_writes = f"""
|
||
DELETE FROM checkpoint_writes
|
||
WHERE thread_id IN ({placeholders})
|
||
"""
|
||
await cursor.execute(query_delete_writes, inactive_thread_ids)
|
||
deleted_writes = cursor.rowcount
|
||
|
||
logger.info(
|
||
f"Checkpoint cleanup completed: "
|
||
f"deleted {deleted_checkpoints} checkpoints, {deleted_writes} writes "
|
||
f"from {len(inactive_thread_ids)} inactive threads"
|
||
)
|
||
|
||
return {
|
||
"deleted": deleted_checkpoints + deleted_writes,
|
||
"threads_deleted": len(inactive_thread_ids),
|
||
"inactive_days": inactive_days
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error during checkpoint cleanup: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return {"deleted": 0, "error": str(e), "inactive_days": inactive_days}
|
||
|
||
async def _cleanup_loop(self):
|
||
"""后台清理循环"""
|
||
interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600
|
||
|
||
logger.info(
|
||
f"Checkpoint cleanup scheduler started: "
|
||
f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, "
|
||
f"inactive_days={CHECKPOINT_CLEANUP_INACTIVE_DAYS}"
|
||
)
|
||
|
||
while not self._cleanup_stop_event.is_set():
|
||
try:
|
||
# 等待间隔时间或停止信号
|
||
try:
|
||
await asyncio.wait_for(
|
||
self._cleanup_stop_event.wait(),
|
||
timeout=interval_seconds
|
||
)
|
||
# 收到停止信号,退出循环
|
||
break
|
||
except asyncio.TimeoutError:
|
||
# 超时,执行清理任务
|
||
pass
|
||
|
||
if self._cleanup_stop_event.is_set():
|
||
break
|
||
|
||
# 执行清理
|
||
await self.cleanup_old_checkpoints()
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info("Cleanup task cancelled")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"Error in cleanup loop: {e}")
|
||
|
||
logger.info("Checkpoint cleanup scheduler stopped")
|
||
|
||
def start_cleanup_scheduler(self):
|
||
"""启动后台定时清理任务"""
|
||
if not CHECKPOINT_CLEANUP_ENABLED:
|
||
logger.info("Checkpoint cleanup is disabled")
|
||
return
|
||
|
||
if self._cleanup_task is not None and not self._cleanup_task.done():
|
||
logger.warning("Cleanup scheduler is already running")
|
||
return
|
||
|
||
try:
|
||
loop = asyncio.get_running_loop()
|
||
except RuntimeError:
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
|
||
self._cleanup_task = loop.create_task(self._cleanup_loop())
|
||
logger.info("Cleanup scheduler task created")
|
||
|
||
|
||
# 全局单例
|
||
_global_manager: Optional[DBPoolManager] = None
|
||
|
||
|
||
def get_db_pool_manager() -> DBPoolManager:
|
||
"""获取全局 DBPoolManager 单例"""
|
||
global _global_manager
|
||
if _global_manager is None:
|
||
_global_manager = DBPoolManager()
|
||
return _global_manager
|
||
|
||
|
||
async def init_global_db_pool() -> DBPoolManager:
|
||
"""初始化全局数据库连接池"""
|
||
manager = get_db_pool_manager()
|
||
await manager.initialize()
|
||
return manager
|
||
|
||
|
||
async def close_global_db_pool() -> None:
|
||
"""关闭全局数据库连接池"""
|
||
global _global_manager
|
||
if _global_manager is not None:
|
||
await _global_manager.close()
|