db_pool
This commit is contained in:
parent
05744cb9f4
commit
af63c54778
@ -1,5 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
聊天历史记录管理器
|
聊天历史记录管理器
|
||||||
|
使用共享的数据库连接池
|
||||||
直接保存完整的原始聊天消息到数据库,不受 checkpoint summary 影响
|
直接保存完整的原始聊天消息到数据库,不受 checkpoint summary 影响
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -27,8 +28,7 @@ class ChatHistoryManager:
|
|||||||
"""
|
"""
|
||||||
聊天历史管理器
|
聊天历史管理器
|
||||||
|
|
||||||
使用独立的数据库表存储完整的聊天历史记录
|
使用共享的 PostgreSQL 连接池存储完整的聊天历史记录
|
||||||
复用 checkpoint_manager 的 PostgreSQL 连接池
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pool: AsyncConnectionPool):
|
def __init__(self, pool: AsyncConnectionPool):
|
||||||
@ -36,7 +36,7 @@ class ChatHistoryManager:
|
|||||||
初始化聊天历史管理器
|
初始化聊天历史管理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pool: PostgreSQL 连接池
|
pool: PostgreSQL 连接池(来自 DBPoolManager)
|
||||||
"""
|
"""
|
||||||
self._pool = pool
|
self._pool = pool
|
||||||
|
|
||||||
@ -250,15 +250,20 @@ _global_manager: Optional['ChatHistoryManagerWithPool'] = None
|
|||||||
class ChatHistoryManagerWithPool:
|
class ChatHistoryManagerWithPool:
|
||||||
"""
|
"""
|
||||||
带连接池的聊天历史管理器单例
|
带连接池的聊天历史管理器单例
|
||||||
复用 checkpoint_manager 的连接池
|
使用共享的 PostgreSQL 连接池
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._pool: Optional[AsyncConnectionPool] = None
|
self._pool: Optional[AsyncConnectionPool] = None
|
||||||
self._manager: Optional[ChatHistoryManager] = None
|
self._manager: Optional[ChatHistoryManager] = None
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
async def initialize(self, pool: AsyncConnectionPool) -> None:
|
async def initialize(self, pool: AsyncConnectionPool) -> None:
|
||||||
"""初始化管理器"""
|
"""初始化管理器,使用外部传入的连接池
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool: AsyncConnectionPool 实例(来自 DBPoolManager)
|
||||||
|
"""
|
||||||
if self._initialized:
|
if self._initialized:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -266,15 +271,28 @@ class ChatHistoryManagerWithPool:
|
|||||||
self._manager = ChatHistoryManager(pool)
|
self._manager = ChatHistoryManager(pool)
|
||||||
await self._manager.create_table()
|
await self._manager.create_table()
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
logger.info("ChatHistoryManager initialized successfully")
|
logger.info("ChatHistoryManager initialized successfully (using shared pool)")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def manager(self) -> ChatHistoryManager:
|
def manager(self) -> ChatHistoryManager:
|
||||||
"""获取 ChatHistoryManager 实例"""
|
"""获取 ChatHistoryManager 实例"""
|
||||||
|
if self._closed:
|
||||||
|
raise RuntimeError("ChatHistoryManager is closed")
|
||||||
|
|
||||||
if not self._initialized or not self._manager:
|
if not self._initialized or not self._manager:
|
||||||
raise RuntimeError("ChatHistoryManager not initialized")
|
raise RuntimeError("ChatHistoryManager not initialized")
|
||||||
return self._manager
|
return self._manager
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""关闭管理器(连接池由 DBPoolManager 管理,这里不关闭)"""
|
||||||
|
if self._closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Closing ChatHistoryManager...")
|
||||||
|
self._closed = True
|
||||||
|
self._initialized = False
|
||||||
|
logger.info("ChatHistoryManager closed (pool managed by DBPoolManager)")
|
||||||
|
|
||||||
|
|
||||||
def get_chat_history_manager() -> ChatHistoryManagerWithPool:
|
def get_chat_history_manager() -> ChatHistoryManagerWithPool:
|
||||||
"""获取全局 ChatHistoryManager 单例"""
|
"""获取全局 ChatHistoryManager 单例"""
|
||||||
@ -285,6 +303,17 @@ def get_chat_history_manager() -> ChatHistoryManagerWithPool:
|
|||||||
|
|
||||||
|
|
||||||
async def init_chat_history_manager(pool: AsyncConnectionPool) -> None:
|
async def init_chat_history_manager(pool: AsyncConnectionPool) -> None:
|
||||||
"""初始化全局聊天历史管理器"""
|
"""初始化全局聊天历史管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool: AsyncConnectionPool 实例(来自 DBPoolManager)
|
||||||
|
"""
|
||||||
manager = get_chat_history_manager()
|
manager = get_chat_history_manager()
|
||||||
await manager.initialize(pool)
|
await manager.initialize(pool)
|
||||||
|
|
||||||
|
|
||||||
|
async def close_chat_history_manager() -> None:
|
||||||
|
"""关闭全局聊天历史管理器"""
|
||||||
|
global _global_manager
|
||||||
|
if _global_manager is not None:
|
||||||
|
await _global_manager.close()
|
||||||
|
|||||||
@ -1,81 +1,55 @@
|
|||||||
"""
|
"""
|
||||||
全局 PostgreSQL Checkpointer 管理器
|
全局 PostgreSQL Checkpointer 管理器
|
||||||
使用 psycopg_pool 连接池,AsyncPostgresSaver 原生支持
|
使用共享的数据库连接池
|
||||||
"""
|
"""
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from psycopg_pool import AsyncConnectionPool
|
from psycopg_pool import AsyncConnectionPool
|
||||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
|
|
||||||
from utils.settings import (
|
|
||||||
CHECKPOINT_DB_URL,
|
|
||||||
CHECKPOINT_POOL_SIZE,
|
|
||||||
CHECKPOINT_CLEANUP_ENABLED,
|
|
||||||
CHECKPOINT_CLEANUP_INTERVAL_HOURS,
|
|
||||||
CHECKPOINT_CLEANUP_INACTIVE_DAYS,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger('app')
|
logger = logging.getLogger('app')
|
||||||
|
|
||||||
|
|
||||||
class CheckpointerManager:
|
class CheckpointerManager:
|
||||||
"""
|
"""
|
||||||
全局 Checkpointer 管理器,使用 PostgreSQL 连接池
|
全局 Checkpointer 管理器,使用共享的 PostgreSQL 连接池
|
||||||
|
|
||||||
主要功能:
|
主要功能:
|
||||||
1. 使用 psycopg_pool.AsyncConnectionPool 管理连接
|
1. 使用 DBPoolManager 的共享连接池
|
||||||
2. AsyncPostgresSaver 原生支持连接池,自动获取/释放连接
|
2. AsyncPostgresSaver 原生支持连接池,自动获取/释放连接
|
||||||
3. 无需手动归还,避免长请求占用连接
|
3. 无需手动归还,避免长请求占用连接
|
||||||
4. 基于 SQL 查询的清理机制
|
|
||||||
5. 优雅关闭机制
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._pool: Optional[AsyncConnectionPool] = None
|
self._pool: Optional[AsyncConnectionPool] = None
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
self._closed = False
|
self._closed = False
|
||||||
# 清理调度任务
|
|
||||||
self._cleanup_task: Optional[asyncio.Task] = None
|
|
||||||
self._cleanup_stop_event = asyncio.Event()
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self, pool: AsyncConnectionPool) -> None:
|
||||||
"""初始化连接池"""
|
"""初始化 checkpointer,使用外部传入的连接池
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool: AsyncConnectionPool 实例(来自 DBPoolManager)
|
||||||
|
"""
|
||||||
if self._initialized:
|
if self._initialized:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
self._pool = pool
|
||||||
f"Initializing PostgreSQL checkpointer pool: "
|
|
||||||
f"URL={CHECKPOINT_DB_URL}, size={CHECKPOINT_POOL_SIZE}"
|
logger.info("Initializing PostgreSQL checkpointer (using shared connection pool)...")
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 创建 psycopg 连接池
|
|
||||||
self._pool = AsyncConnectionPool(
|
|
||||||
CHECKPOINT_DB_URL,
|
|
||||||
min_size=1,
|
|
||||||
max_size=CHECKPOINT_POOL_SIZE,
|
|
||||||
open=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 打开连接池
|
|
||||||
await self._pool.open()
|
|
||||||
|
|
||||||
# 创建表结构(需要 autocommit 模式来执行 CREATE INDEX CONCURRENTLY)
|
# 创建表结构(需要 autocommit 模式来执行 CREATE INDEX CONCURRENTLY)
|
||||||
async with self._pool.connection() as conn:
|
async with self._pool.connection() as conn:
|
||||||
await conn.set_autocommit(True)
|
await conn.set_autocommit(True)
|
||||||
checkpointer = AsyncPostgresSaver(conn=conn)
|
checkpointer = AsyncPostgresSaver(conn=conn)
|
||||||
await checkpointer.setup()
|
await checkpointer.setup()
|
||||||
|
|
||||||
# 初始化 ChatHistoryManager(复用同一个连接池)
|
|
||||||
from .chat_history_manager import init_chat_history_manager
|
|
||||||
await init_chat_history_manager(self._pool)
|
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
logger.info("PostgreSQL checkpointer pool initialized successfully")
|
logger.info("PostgreSQL checkpointer initialized successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize PostgreSQL checkpointer pool: {e}")
|
logger.error(f"Failed to initialize PostgreSQL checkpointer: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -89,170 +63,26 @@ class CheckpointerManager:
|
|||||||
|
|
||||||
return AsyncPostgresSaver(conn=self._pool)
|
return AsyncPostgresSaver(conn=self._pool)
|
||||||
|
|
||||||
async def close(self) -> None:
|
@property
|
||||||
"""关闭连接池"""
|
def pool(self) -> AsyncConnectionPool:
|
||||||
|
"""获取连接池(用于共享给其他管理器)"""
|
||||||
if self._closed:
|
if self._closed:
|
||||||
return
|
raise RuntimeError("CheckpointerManager is closed")
|
||||||
|
|
||||||
# 停止清理任务
|
if not self._initialized:
|
||||||
if self._cleanup_task is not None:
|
raise RuntimeError("CheckpointerManager not initialized, call initialize() first")
|
||||||
self._cleanup_stop_event.set()
|
|
||||||
try:
|
|
||||||
self._cleanup_task.cancel()
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
self._cleanup_task = None
|
|
||||||
|
|
||||||
|
return self._pool
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""关闭 Checkpointer(连接池由 DBPoolManager 管理,这里不关闭)"""
|
||||||
if self._closed:
|
if self._closed:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Closing CheckpointerManager...")
|
logger.info("Closing CheckpointerManager...")
|
||||||
|
|
||||||
if self._pool is not None:
|
|
||||||
await self._pool.close()
|
|
||||||
self._pool = None
|
|
||||||
|
|
||||||
self._closed = True
|
self._closed = True
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
logger.info("CheckpointerManager closed")
|
logger.info("CheckpointerManager closed (pool managed by DBPoolManager)")
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Checkpoint 清理方法(基于 PostgreSQL SQL 查询)
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
async def cleanup_old_dbs(self, inactive_days: int = None) -> dict:
|
|
||||||
"""
|
|
||||||
清理旧的 checkpoint 记录
|
|
||||||
|
|
||||||
删除 N 天未活动的 thread 的所有 checkpoint。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inactive_days: 删除 N 天未活动的 thread,默认使用配置值
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: 清理统计信息
|
|
||||||
"""
|
|
||||||
if inactive_days is None:
|
|
||||||
inactive_days = CHECKPOINT_CLEANUP_INACTIVE_DAYS
|
|
||||||
|
|
||||||
logger.info(f"Starting checkpoint cleanup: removing threads inactive for {inactive_days} days")
|
|
||||||
|
|
||||||
if self._pool is None:
|
|
||||||
logger.warning("Connection pool not initialized, skipping cleanup")
|
|
||||||
return {"deleted": 0, "inactive_days": inactive_days}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 从池中获取连接执行清理
|
|
||||||
async with self._pool.connection() as conn:
|
|
||||||
async with conn.cursor() as cursor:
|
|
||||||
# 查找不活跃的 thread
|
|
||||||
query_find_inactive = """
|
|
||||||
SELECT DISTINCT thread_id
|
|
||||||
FROM checkpoints
|
|
||||||
WHERE (checkpoint->>'ts')::timestamp < NOW() - INTERVAL '%s days'
|
|
||||||
"""
|
|
||||||
|
|
||||||
await cursor.execute(query_find_inactive, (inactive_days,))
|
|
||||||
inactive_threads = await cursor.fetchall()
|
|
||||||
inactive_thread_ids = [row[0] for row in inactive_threads]
|
|
||||||
|
|
||||||
if not inactive_thread_ids:
|
|
||||||
logger.info("No inactive threads found")
|
|
||||||
return {"deleted": 0, "threads_deleted": 0, "inactive_days": inactive_days}
|
|
||||||
|
|
||||||
# 删除不活跃 thread 的 checkpoints
|
|
||||||
placeholders = ','.join(['%s'] * len(inactive_thread_ids))
|
|
||||||
query_delete_checkpoints = f"""
|
|
||||||
DELETE FROM checkpoints
|
|
||||||
WHERE thread_id IN ({placeholders})
|
|
||||||
"""
|
|
||||||
|
|
||||||
await cursor.execute(query_delete_checkpoints, inactive_thread_ids)
|
|
||||||
deleted_checkpoints = cursor.rowcount
|
|
||||||
|
|
||||||
# 同时清理 checkpoint_writes
|
|
||||||
query_delete_writes = f"""
|
|
||||||
DELETE FROM checkpoint_writes
|
|
||||||
WHERE thread_id IN ({placeholders})
|
|
||||||
"""
|
|
||||||
await cursor.execute(query_delete_writes, inactive_thread_ids)
|
|
||||||
deleted_writes = cursor.rowcount
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Checkpoint cleanup completed: "
|
|
||||||
f"deleted {deleted_checkpoints} checkpoints, {deleted_writes} writes "
|
|
||||||
f"from {len(inactive_thread_ids)} inactive threads"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"deleted": deleted_checkpoints + deleted_writes,
|
|
||||||
"threads_deleted": len(inactive_thread_ids),
|
|
||||||
"inactive_days": inactive_days
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during checkpoint cleanup: {e}")
|
|
||||||
import traceback
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return {"deleted": 0, "error": str(e), "inactive_days": inactive_days}
|
|
||||||
|
|
||||||
async def _cleanup_loop(self):
|
|
||||||
"""后台清理循环"""
|
|
||||||
interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Checkpoint cleanup scheduler started: "
|
|
||||||
f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, "
|
|
||||||
f"inactive_days={CHECKPOINT_CLEANUP_INACTIVE_DAYS}"
|
|
||||||
)
|
|
||||||
|
|
||||||
while not self._cleanup_stop_event.is_set():
|
|
||||||
try:
|
|
||||||
# 等待间隔时间或停止信号
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(
|
|
||||||
self._cleanup_stop_event.wait(),
|
|
||||||
timeout=interval_seconds
|
|
||||||
)
|
|
||||||
# 收到停止信号,退出循环
|
|
||||||
break
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
# 超时,执行清理任务
|
|
||||||
pass
|
|
||||||
|
|
||||||
if self._cleanup_stop_event.is_set():
|
|
||||||
break
|
|
||||||
|
|
||||||
# 执行清理
|
|
||||||
await self.cleanup_old_dbs()
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info("Cleanup task cancelled")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in cleanup loop: {e}")
|
|
||||||
|
|
||||||
logger.info("Checkpoint cleanup scheduler stopped")
|
|
||||||
|
|
||||||
def start_cleanup_scheduler(self):
|
|
||||||
"""启动后台定时清理任务"""
|
|
||||||
if not CHECKPOINT_CLEANUP_ENABLED:
|
|
||||||
logger.info("Checkpoint cleanup is disabled")
|
|
||||||
return
|
|
||||||
|
|
||||||
if self._cleanup_task is not None and not self._cleanup_task.done():
|
|
||||||
logger.warning("Cleanup scheduler is already running")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
self._cleanup_task = loop.create_task(self._cleanup_loop())
|
|
||||||
logger.info("Cleanup scheduler task created")
|
|
||||||
|
|
||||||
|
|
||||||
# 全局单例
|
# 全局单例
|
||||||
@ -267,10 +97,14 @@ def get_checkpointer_manager() -> CheckpointerManager:
|
|||||||
return _global_manager
|
return _global_manager
|
||||||
|
|
||||||
|
|
||||||
async def init_global_checkpointer() -> None:
|
async def init_global_checkpointer(pool: AsyncConnectionPool) -> None:
|
||||||
"""初始化全局 checkpointer 管理器"""
|
"""初始化全局 checkpointer 管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool: AsyncConnectionPool 实例(来自 DBPoolManager)
|
||||||
|
"""
|
||||||
manager = get_checkpointer_manager()
|
manager = get_checkpointer_manager()
|
||||||
await manager.initialize()
|
await manager.initialize(pool)
|
||||||
|
|
||||||
|
|
||||||
async def close_global_checkpointer() -> None:
|
async def close_global_checkpointer() -> None:
|
||||||
|
|||||||
265
agent/db_pool_manager.py
Normal file
265
agent/db_pool_manager.py
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
"""
|
||||||
|
全局 PostgreSQL 连接池管理器
|
||||||
|
被 checkpoint、chat_history 共享使用
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
from utils.settings import (
|
||||||
|
CHECKPOINT_DB_URL,
|
||||||
|
CHECKPOINT_POOL_SIZE,
|
||||||
|
CHECKPOINT_CLEANUP_ENABLED,
|
||||||
|
CHECKPOINT_CLEANUP_INTERVAL_HOURS,
|
||||||
|
CHECKPOINT_CLEANUP_INACTIVE_DAYS,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger('app')
|
||||||
|
|
||||||
|
|
||||||
|
class DBPoolManager:
|
||||||
|
"""
|
||||||
|
全局 PostgreSQL 连接池管理器
|
||||||
|
|
||||||
|
主要功能:
|
||||||
|
1. 使用 psycopg_pool.AsyncConnectionPool 管理连接
|
||||||
|
2. 被 CheckpointerManager、ChatHistoryManager 共享
|
||||||
|
3. 自动清理旧的 checkpoint 数据
|
||||||
|
4. 优雅关闭机制
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._pool: Optional[AsyncConnectionPool] = None
|
||||||
|
self._initialized = False
|
||||||
|
self._closed = False
|
||||||
|
# 清理调度任务
|
||||||
|
self._cleanup_task: Optional[asyncio.Task] = None
|
||||||
|
self._cleanup_stop_event = asyncio.Event()
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""初始化连接池"""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Initializing PostgreSQL connection pool: "
|
||||||
|
f"URL={CHECKPOINT_DB_URL}, size={CHECKPOINT_POOL_SIZE}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 创建 psycopg 连接池
|
||||||
|
self._pool = AsyncConnectionPool(
|
||||||
|
CHECKPOINT_DB_URL,
|
||||||
|
min_size=1,
|
||||||
|
max_size=CHECKPOINT_POOL_SIZE,
|
||||||
|
open=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 打开连接池
|
||||||
|
await self._pool.open()
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
logger.info("PostgreSQL connection pool initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize PostgreSQL connection pool: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pool(self) -> AsyncConnectionPool:
|
||||||
|
"""获取连接池"""
|
||||||
|
if self._closed:
|
||||||
|
raise RuntimeError("DBPoolManager is closed")
|
||||||
|
|
||||||
|
if not self._initialized:
|
||||||
|
raise RuntimeError("DBPoolManager not initialized, call initialize() first")
|
||||||
|
|
||||||
|
return self._pool
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""关闭连接池"""
|
||||||
|
if self._closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 停止清理任务
|
||||||
|
if self._cleanup_task is not None:
|
||||||
|
self._cleanup_stop_event.set()
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._cleanup_task = None
|
||||||
|
|
||||||
|
logger.info("Closing DBPoolManager...")
|
||||||
|
|
||||||
|
if self._pool is not None:
|
||||||
|
await self._pool.close()
|
||||||
|
self._pool = None
|
||||||
|
|
||||||
|
self._closed = True
|
||||||
|
self._initialized = False
|
||||||
|
logger.info("DBPoolManager closed")
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Checkpoint 清理方法(基于 PostgreSQL SQL 查询)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
async def cleanup_old_checkpoints(self, inactive_days: int = None) -> dict:
|
||||||
|
"""
|
||||||
|
清理旧的 checkpoint 记录
|
||||||
|
|
||||||
|
删除 N 天未活动的 thread 的所有 checkpoint。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inactive_days: 删除 N 天未活动的 thread,默认使用配置值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 清理统计信息
|
||||||
|
"""
|
||||||
|
if inactive_days is None:
|
||||||
|
inactive_days = CHECKPOINT_CLEANUP_INACTIVE_DAYS
|
||||||
|
|
||||||
|
logger.info(f"Starting checkpoint cleanup: removing threads inactive for {inactive_days} days")
|
||||||
|
|
||||||
|
if self._pool is None:
|
||||||
|
logger.warning("Connection pool not initialized, skipping cleanup")
|
||||||
|
return {"deleted": 0, "inactive_days": inactive_days}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 从池中获取连接执行清理
|
||||||
|
async with self._pool.connection() as conn:
|
||||||
|
async with conn.cursor() as cursor:
|
||||||
|
# 查找不活跃的 thread
|
||||||
|
query_find_inactive = """
|
||||||
|
SELECT DISTINCT thread_id
|
||||||
|
FROM checkpoints
|
||||||
|
WHERE (checkpoint->>'ts')::timestamp < NOW() - INTERVAL '%s days'
|
||||||
|
"""
|
||||||
|
|
||||||
|
await cursor.execute(query_find_inactive, (inactive_days,))
|
||||||
|
inactive_threads = await cursor.fetchall()
|
||||||
|
inactive_thread_ids = [row[0] for row in inactive_threads]
|
||||||
|
|
||||||
|
if not inactive_thread_ids:
|
||||||
|
logger.info("No inactive threads found")
|
||||||
|
return {"deleted": 0, "threads_deleted": 0, "inactive_days": inactive_days}
|
||||||
|
|
||||||
|
# 删除不活跃 thread 的 checkpoints
|
||||||
|
placeholders = ','.join(['%s'] * len(inactive_thread_ids))
|
||||||
|
query_delete_checkpoints = f"""
|
||||||
|
DELETE FROM checkpoints
|
||||||
|
WHERE thread_id IN ({placeholders})
|
||||||
|
"""
|
||||||
|
|
||||||
|
await cursor.execute(query_delete_checkpoints, inactive_thread_ids)
|
||||||
|
deleted_checkpoints = cursor.rowcount
|
||||||
|
|
||||||
|
# 同时清理 checkpoint_writes
|
||||||
|
query_delete_writes = f"""
|
||||||
|
DELETE FROM checkpoint_writes
|
||||||
|
WHERE thread_id IN ({placeholders})
|
||||||
|
"""
|
||||||
|
await cursor.execute(query_delete_writes, inactive_thread_ids)
|
||||||
|
deleted_writes = cursor.rowcount
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Checkpoint cleanup completed: "
|
||||||
|
f"deleted {deleted_checkpoints} checkpoints, {deleted_writes} writes "
|
||||||
|
f"from {len(inactive_thread_ids)} inactive threads"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"deleted": deleted_checkpoints + deleted_writes,
|
||||||
|
"threads_deleted": len(inactive_thread_ids),
|
||||||
|
"inactive_days": inactive_days
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during checkpoint cleanup: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return {"deleted": 0, "error": str(e), "inactive_days": inactive_days}
|
||||||
|
|
||||||
|
async def _cleanup_loop(self):
|
||||||
|
"""后台清理循环"""
|
||||||
|
interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Checkpoint cleanup scheduler started: "
|
||||||
|
f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, "
|
||||||
|
f"inactive_days={CHECKPOINT_CLEANUP_INACTIVE_DAYS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
while not self._cleanup_stop_event.is_set():
|
||||||
|
try:
|
||||||
|
# 等待间隔时间或停止信号
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._cleanup_stop_event.wait(),
|
||||||
|
timeout=interval_seconds
|
||||||
|
)
|
||||||
|
# 收到停止信号,退出循环
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# 超时,执行清理任务
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self._cleanup_stop_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
# 执行清理
|
||||||
|
await self.cleanup_old_checkpoints()
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Cleanup task cancelled")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in cleanup loop: {e}")
|
||||||
|
|
||||||
|
logger.info("Checkpoint cleanup scheduler stopped")
|
||||||
|
|
||||||
|
def start_cleanup_scheduler(self):
|
||||||
|
"""启动后台定时清理任务"""
|
||||||
|
if not CHECKPOINT_CLEANUP_ENABLED:
|
||||||
|
logger.info("Checkpoint cleanup is disabled")
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._cleanup_task is not None and not self._cleanup_task.done():
|
||||||
|
logger.warning("Cleanup scheduler is already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
self._cleanup_task = loop.create_task(self._cleanup_loop())
|
||||||
|
logger.info("Cleanup scheduler task created")
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_global_manager: Optional[DBPoolManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_pool_manager() -> DBPoolManager:
|
||||||
|
"""获取全局 DBPoolManager 单例"""
|
||||||
|
global _global_manager
|
||||||
|
if _global_manager is None:
|
||||||
|
_global_manager = DBPoolManager()
|
||||||
|
return _global_manager
|
||||||
|
|
||||||
|
|
||||||
|
async def init_global_db_pool() -> DBPoolManager:
|
||||||
|
"""初始化全局数据库连接池"""
|
||||||
|
manager = get_db_pool_manager()
|
||||||
|
await manager.initialize()
|
||||||
|
return manager
|
||||||
|
|
||||||
|
|
||||||
|
async def close_global_db_pool() -> None:
|
||||||
|
"""关闭全局数据库连接池"""
|
||||||
|
global _global_manager
|
||||||
|
if _global_manager is not None:
|
||||||
|
await _global_manager.close()
|
||||||
@ -23,6 +23,7 @@ from agent.agent_config import AgentConfig
|
|||||||
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||||||
from agent.agent_memory_cache import get_memory_cache_manager
|
from agent.agent_memory_cache import get_memory_cache_manager
|
||||||
from .checkpoint_utils import prepare_checkpoint_message
|
from .checkpoint_utils import prepare_checkpoint_message
|
||||||
|
from .checkpoint_manager import get_checkpointer_manager
|
||||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
@ -181,10 +182,13 @@ async def init_agent(config: AgentConfig):
|
|||||||
|
|
||||||
# 从连接池获取 checkpointer
|
# 从连接池获取 checkpointer
|
||||||
if config.session_id:
|
if config.session_id:
|
||||||
from .checkpoint_manager import get_checkpointer_manager
|
try:
|
||||||
manager = get_checkpointer_manager()
|
manager = get_checkpointer_manager()
|
||||||
checkpointer = manager.checkpointer
|
checkpointer = manager.checkpointer
|
||||||
await prepare_checkpoint_message(config, checkpointer)
|
if checkpointer:
|
||||||
|
await prepare_checkpoint_message(config, checkpointer)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load checkpointer: {e}")
|
||||||
|
|
||||||
|
|
||||||
if config.robot_type == "deep_agent":
|
if config.robot_type == "deep_agent":
|
||||||
@ -328,6 +332,7 @@ def create_custom_cli_agent(
|
|||||||
middleware: list[AgentMiddleware] = [],
|
middleware: list[AgentMiddleware] = [],
|
||||||
workspace_root: str | None = None,
|
workspace_root: str | None = None,
|
||||||
checkpointer: Checkpointer | None = None,
|
checkpointer: Checkpointer | None = None,
|
||||||
|
store: BaseStore | None = None,
|
||||||
) -> tuple[Pregel, CompositeBackend]:
|
) -> tuple[Pregel, CompositeBackend]:
|
||||||
"""Create a CLI-configured agent with custom workspace_root for shell commands.
|
"""Create a CLI-configured agent with custom workspace_root for shell commands.
|
||||||
|
|
||||||
@ -350,6 +355,8 @@ def create_custom_cli_agent(
|
|||||||
enable_skills: Enable SkillsMiddleware for custom agent skills
|
enable_skills: Enable SkillsMiddleware for custom agent skills
|
||||||
enable_shell: Enable ShellMiddleware for local shell execution (only in local mode)
|
enable_shell: Enable ShellMiddleware for local shell execution (only in local mode)
|
||||||
workspace_root: Working directory for shell commands. If None, uses Path.cwd().
|
workspace_root: Working directory for shell commands. If None, uses Path.cwd().
|
||||||
|
checkpointer: Optional checkpointer for persisting conversation state
|
||||||
|
store: Optional BaseStore for persisting user preferences and agent memory
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
2-tuple of (agent_graph, composite_backend)
|
2-tuple of (agent_graph, composite_backend)
|
||||||
@ -464,5 +471,6 @@ def create_custom_cli_agent(
|
|||||||
middleware=agent_middleware,
|
middleware=agent_middleware,
|
||||||
interrupt_on=interrupt_on,
|
interrupt_on=interrupt_on,
|
||||||
checkpointer=checkpointer,
|
checkpointer=checkpointer,
|
||||||
|
store=store,
|
||||||
).with_config(config)
|
).with_config(config)
|
||||||
return agent, composite_backend
|
return agent, composite_backend
|
||||||
|
|||||||
@ -15,6 +15,8 @@ import logging
|
|||||||
|
|
||||||
from utils.log_util.logger import init_with_fastapi
|
from utils.log_util.logger import init_with_fastapi
|
||||||
|
|
||||||
|
# Initialize logger
|
||||||
|
logger = logging.getLogger('app')
|
||||||
|
|
||||||
# Import route modules
|
# Import route modules
|
||||||
from routes import chat, files, projects, system, skill_manager
|
from routes import chat, files, projects, system, skill_manager
|
||||||
@ -25,44 +27,61 @@ async def lifespan(app: FastAPI):
|
|||||||
"""FastAPI 应用生命周期管理"""
|
"""FastAPI 应用生命周期管理"""
|
||||||
# 启动时初始化
|
# 启动时初始化
|
||||||
logger.info("Starting up...")
|
logger.info("Starting up...")
|
||||||
|
from agent.db_pool_manager import (
|
||||||
|
init_global_db_pool,
|
||||||
|
get_db_pool_manager,
|
||||||
|
close_global_db_pool
|
||||||
|
)
|
||||||
from agent.checkpoint_manager import (
|
from agent.checkpoint_manager import (
|
||||||
init_global_checkpointer,
|
init_global_checkpointer,
|
||||||
get_checkpointer_manager,
|
|
||||||
close_global_checkpointer
|
close_global_checkpointer
|
||||||
)
|
)
|
||||||
|
from agent.chat_history_manager import (
|
||||||
|
init_chat_history_manager,
|
||||||
|
close_chat_history_manager
|
||||||
|
)
|
||||||
from utils.settings import CHECKPOINT_CLEANUP_ENABLED
|
from utils.settings import CHECKPOINT_CLEANUP_ENABLED
|
||||||
|
|
||||||
await init_global_checkpointer()
|
# 1. 初始化共享的数据库连接池
|
||||||
|
db_pool_manager = await init_global_db_pool()
|
||||||
|
logger.info("Global DB pool initialized")
|
||||||
|
|
||||||
|
# 2. 初始化 checkpoint (使用共享连接池)
|
||||||
|
await init_global_checkpointer(db_pool_manager.pool)
|
||||||
logger.info("Global checkpointer initialized")
|
logger.info("Global checkpointer initialized")
|
||||||
|
|
||||||
# 启动 checkpoint 清理调度器
|
# 3. 初始化 chat_history (使用共享连接池)
|
||||||
|
await init_chat_history_manager(db_pool_manager.pool)
|
||||||
|
logger.info("Chat history manager initialized")
|
||||||
|
|
||||||
|
# 4. 启动 checkpoint 清理调度器
|
||||||
if CHECKPOINT_CLEANUP_ENABLED:
|
if CHECKPOINT_CLEANUP_ENABLED:
|
||||||
manager = get_checkpointer_manager()
|
|
||||||
# 启动时立即执行一次清理
|
# 启动时立即执行一次清理
|
||||||
try:
|
try:
|
||||||
result = await manager.cleanup_old_dbs()
|
result = await db_pool_manager.cleanup_old_checkpoints()
|
||||||
logger.info(f"Startup cleanup completed: {result}")
|
logger.info(f"Startup cleanup completed: {result}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Startup cleanup failed (non-fatal): {e}")
|
logger.warning(f"Startup cleanup failed (non-fatal): {e}")
|
||||||
# 启动定时清理调度器
|
# 启动定时清理调度器
|
||||||
manager.start_cleanup_scheduler()
|
db_pool_manager.start_cleanup_scheduler()
|
||||||
logger.info("Checkpoint cleanup scheduler started")
|
logger.info("Checkpoint cleanup scheduler started")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# 关闭时清理
|
# 关闭时清理(按相反顺序)
|
||||||
logger.info("Shutting down...")
|
logger.info("Shutting down...")
|
||||||
|
await close_chat_history_manager()
|
||||||
|
logger.info("Chat history manager closed")
|
||||||
await close_global_checkpointer()
|
await close_global_checkpointer()
|
||||||
logger.info("Global checkpointer closed")
|
logger.info("Global checkpointer closed")
|
||||||
|
await close_global_db_pool()
|
||||||
|
logger.info("Global DB pool closed")
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Database Assistant API", version="1.0.0", lifespan=lifespan)
|
app = FastAPI(title="Database Assistant API", version="1.0.0", lifespan=lifespan)
|
||||||
|
|
||||||
init_with_fastapi(app)
|
init_with_fastapi(app)
|
||||||
|
|
||||||
logger = logging.getLogger('app')
|
|
||||||
|
|
||||||
|
|
||||||
# 挂载public文件夹为静态文件服务
|
# 挂载public文件夹为静态文件服务
|
||||||
app.mount("/public", StaticFiles(directory="public"), name="static")
|
app.mount("/public", StaticFiles(directory="public"), name="static")
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user