db_pool
This commit is contained in:
parent
05744cb9f4
commit
af63c54778
@ -1,5 +1,6 @@
|
||||
"""
|
||||
聊天历史记录管理器
|
||||
使用共享的数据库连接池
|
||||
直接保存完整的原始聊天消息到数据库,不受 checkpoint summary 影响
|
||||
"""
|
||||
|
||||
@ -27,8 +28,7 @@ class ChatHistoryManager:
|
||||
"""
|
||||
聊天历史管理器
|
||||
|
||||
使用独立的数据库表存储完整的聊天历史记录
|
||||
复用 checkpoint_manager 的 PostgreSQL 连接池
|
||||
使用共享的 PostgreSQL 连接池存储完整的聊天历史记录
|
||||
"""
|
||||
|
||||
def __init__(self, pool: AsyncConnectionPool):
|
||||
@ -36,7 +36,7 @@ class ChatHistoryManager:
|
||||
初始化聊天历史管理器
|
||||
|
||||
Args:
|
||||
pool: PostgreSQL 连接池
|
||||
pool: PostgreSQL 连接池(来自 DBPoolManager)
|
||||
"""
|
||||
self._pool = pool
|
||||
|
||||
@ -250,15 +250,20 @@ _global_manager: Optional['ChatHistoryManagerWithPool'] = None
|
||||
class ChatHistoryManagerWithPool:
|
||||
"""
|
||||
带连接池的聊天历史管理器单例
|
||||
复用 checkpoint_manager 的连接池
|
||||
使用共享的 PostgreSQL 连接池
|
||||
"""
|
||||
def __init__(self):
|
||||
self._pool: Optional[AsyncConnectionPool] = None
|
||||
self._manager: Optional[ChatHistoryManager] = None
|
||||
self._initialized = False
|
||||
self._closed = False
|
||||
|
||||
async def initialize(self, pool: AsyncConnectionPool) -> None:
|
||||
"""初始化管理器"""
|
||||
"""初始化管理器,使用外部传入的连接池
|
||||
|
||||
Args:
|
||||
pool: AsyncConnectionPool 实例(来自 DBPoolManager)
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
@ -266,15 +271,28 @@ class ChatHistoryManagerWithPool:
|
||||
self._manager = ChatHistoryManager(pool)
|
||||
await self._manager.create_table()
|
||||
self._initialized = True
|
||||
logger.info("ChatHistoryManager initialized successfully")
|
||||
logger.info("ChatHistoryManager initialized successfully (using shared pool)")
|
||||
|
||||
@property
|
||||
def manager(self) -> ChatHistoryManager:
|
||||
"""获取 ChatHistoryManager 实例"""
|
||||
if self._closed:
|
||||
raise RuntimeError("ChatHistoryManager is closed")
|
||||
|
||||
if not self._initialized or not self._manager:
|
||||
raise RuntimeError("ChatHistoryManager not initialized")
|
||||
return self._manager
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭管理器(连接池由 DBPoolManager 管理,这里不关闭)"""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
logger.info("Closing ChatHistoryManager...")
|
||||
self._closed = True
|
||||
self._initialized = False
|
||||
logger.info("ChatHistoryManager closed (pool managed by DBPoolManager)")
|
||||
|
||||
|
||||
def get_chat_history_manager() -> ChatHistoryManagerWithPool:
|
||||
"""获取全局 ChatHistoryManager 单例"""
|
||||
@ -285,6 +303,17 @@ def get_chat_history_manager() -> ChatHistoryManagerWithPool:
|
||||
|
||||
|
||||
async def init_chat_history_manager(pool: AsyncConnectionPool) -> None:
|
||||
"""初始化全局聊天历史管理器"""
|
||||
"""初始化全局聊天历史管理器
|
||||
|
||||
Args:
|
||||
pool: AsyncConnectionPool 实例(来自 DBPoolManager)
|
||||
"""
|
||||
manager = get_chat_history_manager()
|
||||
await manager.initialize(pool)
|
||||
|
||||
|
||||
async def close_chat_history_manager() -> None:
|
||||
"""关闭全局聊天历史管理器"""
|
||||
global _global_manager
|
||||
if _global_manager is not None:
|
||||
await _global_manager.close()
|
||||
|
||||
@ -1,81 +1,55 @@
|
||||
"""
|
||||
全局 PostgreSQL Checkpointer 管理器
|
||||
使用 psycopg_pool 连接池,AsyncPostgresSaver 原生支持
|
||||
使用共享的数据库连接池
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
|
||||
from utils.settings import (
|
||||
CHECKPOINT_DB_URL,
|
||||
CHECKPOINT_POOL_SIZE,
|
||||
CHECKPOINT_CLEANUP_ENABLED,
|
||||
CHECKPOINT_CLEANUP_INTERVAL_HOURS,
|
||||
CHECKPOINT_CLEANUP_INACTIVE_DAYS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
|
||||
class CheckpointerManager:
|
||||
"""
|
||||
全局 Checkpointer 管理器,使用 PostgreSQL 连接池
|
||||
全局 Checkpointer 管理器,使用共享的 PostgreSQL 连接池
|
||||
|
||||
主要功能:
|
||||
1. 使用 psycopg_pool.AsyncConnectionPool 管理连接
|
||||
1. 使用 DBPoolManager 的共享连接池
|
||||
2. AsyncPostgresSaver 原生支持连接池,自动获取/释放连接
|
||||
3. 无需手动归还,避免长请求占用连接
|
||||
4. 基于 SQL 查询的清理机制
|
||||
5. 优雅关闭机制
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._pool: Optional[AsyncConnectionPool] = None
|
||||
self._initialized = False
|
||||
self._closed = False
|
||||
# 清理调度任务
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._cleanup_stop_event = asyncio.Event()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化连接池"""
|
||||
async def initialize(self, pool: AsyncConnectionPool) -> None:
|
||||
"""初始化 checkpointer,使用外部传入的连接池
|
||||
|
||||
Args:
|
||||
pool: AsyncConnectionPool 实例(来自 DBPoolManager)
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Initializing PostgreSQL checkpointer pool: "
|
||||
f"URL={CHECKPOINT_DB_URL}, size={CHECKPOINT_POOL_SIZE}"
|
||||
)
|
||||
self._pool = pool
|
||||
|
||||
logger.info("Initializing PostgreSQL checkpointer (using shared connection pool)...")
|
||||
|
||||
try:
|
||||
# 创建 psycopg 连接池
|
||||
self._pool = AsyncConnectionPool(
|
||||
CHECKPOINT_DB_URL,
|
||||
min_size=1,
|
||||
max_size=CHECKPOINT_POOL_SIZE,
|
||||
open=False,
|
||||
)
|
||||
|
||||
# 打开连接池
|
||||
await self._pool.open()
|
||||
|
||||
# 创建表结构(需要 autocommit 模式来执行 CREATE INDEX CONCURRENTLY)
|
||||
async with self._pool.connection() as conn:
|
||||
await conn.set_autocommit(True)
|
||||
checkpointer = AsyncPostgresSaver(conn=conn)
|
||||
await checkpointer.setup()
|
||||
|
||||
# 初始化 ChatHistoryManager(复用同一个连接池)
|
||||
from .chat_history_manager import init_chat_history_manager
|
||||
await init_chat_history_manager(self._pool)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("PostgreSQL checkpointer pool initialized successfully")
|
||||
logger.info("PostgreSQL checkpointer initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize PostgreSQL checkpointer pool: {e}")
|
||||
logger.error(f"Failed to initialize PostgreSQL checkpointer: {e}")
|
||||
raise
|
||||
|
||||
@property
|
||||
@ -89,170 +63,26 @@ class CheckpointerManager:
|
||||
|
||||
return AsyncPostgresSaver(conn=self._pool)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭连接池"""
|
||||
@property
|
||||
def pool(self) -> AsyncConnectionPool:
|
||||
"""获取连接池(用于共享给其他管理器)"""
|
||||
if self._closed:
|
||||
return
|
||||
raise RuntimeError("CheckpointerManager is closed")
|
||||
|
||||
# 停止清理任务
|
||||
if self._cleanup_task is not None:
|
||||
self._cleanup_stop_event.set()
|
||||
try:
|
||||
self._cleanup_task.cancel()
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
if not self._initialized:
|
||||
raise RuntimeError("CheckpointerManager not initialized, call initialize() first")
|
||||
|
||||
return self._pool
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 Checkpointer(连接池由 DBPoolManager 管理,这里不关闭)"""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
logger.info("Closing CheckpointerManager...")
|
||||
|
||||
if self._pool is not None:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
|
||||
self._closed = True
|
||||
self._initialized = False
|
||||
logger.info("CheckpointerManager closed")
|
||||
|
||||
# ============================================================
|
||||
# Checkpoint 清理方法(基于 PostgreSQL SQL 查询)
|
||||
# ============================================================
|
||||
|
||||
async def cleanup_old_dbs(self, inactive_days: int = None) -> dict:
|
||||
"""
|
||||
清理旧的 checkpoint 记录
|
||||
|
||||
删除 N 天未活动的 thread 的所有 checkpoint。
|
||||
|
||||
Args:
|
||||
inactive_days: 删除 N 天未活动的 thread,默认使用配置值
|
||||
|
||||
Returns:
|
||||
Dict: 清理统计信息
|
||||
"""
|
||||
if inactive_days is None:
|
||||
inactive_days = CHECKPOINT_CLEANUP_INACTIVE_DAYS
|
||||
|
||||
logger.info(f"Starting checkpoint cleanup: removing threads inactive for {inactive_days} days")
|
||||
|
||||
if self._pool is None:
|
||||
logger.warning("Connection pool not initialized, skipping cleanup")
|
||||
return {"deleted": 0, "inactive_days": inactive_days}
|
||||
|
||||
try:
|
||||
# 从池中获取连接执行清理
|
||||
async with self._pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# 查找不活跃的 thread
|
||||
query_find_inactive = """
|
||||
SELECT DISTINCT thread_id
|
||||
FROM checkpoints
|
||||
WHERE (checkpoint->>'ts')::timestamp < NOW() - INTERVAL '%s days'
|
||||
"""
|
||||
|
||||
await cursor.execute(query_find_inactive, (inactive_days,))
|
||||
inactive_threads = await cursor.fetchall()
|
||||
inactive_thread_ids = [row[0] for row in inactive_threads]
|
||||
|
||||
if not inactive_thread_ids:
|
||||
logger.info("No inactive threads found")
|
||||
return {"deleted": 0, "threads_deleted": 0, "inactive_days": inactive_days}
|
||||
|
||||
# 删除不活跃 thread 的 checkpoints
|
||||
placeholders = ','.join(['%s'] * len(inactive_thread_ids))
|
||||
query_delete_checkpoints = f"""
|
||||
DELETE FROM checkpoints
|
||||
WHERE thread_id IN ({placeholders})
|
||||
"""
|
||||
|
||||
await cursor.execute(query_delete_checkpoints, inactive_thread_ids)
|
||||
deleted_checkpoints = cursor.rowcount
|
||||
|
||||
# 同时清理 checkpoint_writes
|
||||
query_delete_writes = f"""
|
||||
DELETE FROM checkpoint_writes
|
||||
WHERE thread_id IN ({placeholders})
|
||||
"""
|
||||
await cursor.execute(query_delete_writes, inactive_thread_ids)
|
||||
deleted_writes = cursor.rowcount
|
||||
|
||||
logger.info(
|
||||
f"Checkpoint cleanup completed: "
|
||||
f"deleted {deleted_checkpoints} checkpoints, {deleted_writes} writes "
|
||||
f"from {len(inactive_thread_ids)} inactive threads"
|
||||
)
|
||||
|
||||
return {
|
||||
"deleted": deleted_checkpoints + deleted_writes,
|
||||
"threads_deleted": len(inactive_thread_ids),
|
||||
"inactive_days": inactive_days
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during checkpoint cleanup: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return {"deleted": 0, "error": str(e), "inactive_days": inactive_days}
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""后台清理循环"""
|
||||
interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600
|
||||
|
||||
logger.info(
|
||||
f"Checkpoint cleanup scheduler started: "
|
||||
f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, "
|
||||
f"inactive_days={CHECKPOINT_CLEANUP_INACTIVE_DAYS}"
|
||||
)
|
||||
|
||||
while not self._cleanup_stop_event.is_set():
|
||||
try:
|
||||
# 等待间隔时间或停止信号
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._cleanup_stop_event.wait(),
|
||||
timeout=interval_seconds
|
||||
)
|
||||
# 收到停止信号,退出循环
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# 超时,执行清理任务
|
||||
pass
|
||||
|
||||
if self._cleanup_stop_event.is_set():
|
||||
break
|
||||
|
||||
# 执行清理
|
||||
await self.cleanup_old_dbs()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Cleanup task cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup loop: {e}")
|
||||
|
||||
logger.info("Checkpoint cleanup scheduler stopped")
|
||||
|
||||
def start_cleanup_scheduler(self):
|
||||
"""启动后台定时清理任务"""
|
||||
if not CHECKPOINT_CLEANUP_ENABLED:
|
||||
logger.info("Checkpoint cleanup is disabled")
|
||||
return
|
||||
|
||||
if self._cleanup_task is not None and not self._cleanup_task.done():
|
||||
logger.warning("Cleanup scheduler is already running")
|
||||
return
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
self._cleanup_task = loop.create_task(self._cleanup_loop())
|
||||
logger.info("Cleanup scheduler task created")
|
||||
logger.info("CheckpointerManager closed (pool managed by DBPoolManager)")
|
||||
|
||||
|
||||
# 全局单例
|
||||
@ -267,10 +97,14 @@ def get_checkpointer_manager() -> CheckpointerManager:
|
||||
return _global_manager
|
||||
|
||||
|
||||
async def init_global_checkpointer() -> None:
|
||||
"""初始化全局 checkpointer 管理器"""
|
||||
async def init_global_checkpointer(pool: AsyncConnectionPool) -> None:
|
||||
"""初始化全局 checkpointer 管理器
|
||||
|
||||
Args:
|
||||
pool: AsyncConnectionPool 实例(来自 DBPoolManager)
|
||||
"""
|
||||
manager = get_checkpointer_manager()
|
||||
await manager.initialize()
|
||||
await manager.initialize(pool)
|
||||
|
||||
|
||||
async def close_global_checkpointer() -> None:
|
||||
|
||||
265
agent/db_pool_manager.py
Normal file
265
agent/db_pool_manager.py
Normal file
@ -0,0 +1,265 @@
|
||||
"""
|
||||
全局 PostgreSQL 连接池管理器
|
||||
被 checkpoint、chat_history 共享使用
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
from utils.settings import (
|
||||
CHECKPOINT_DB_URL,
|
||||
CHECKPOINT_POOL_SIZE,
|
||||
CHECKPOINT_CLEANUP_ENABLED,
|
||||
CHECKPOINT_CLEANUP_INTERVAL_HOURS,
|
||||
CHECKPOINT_CLEANUP_INACTIVE_DAYS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
|
||||
class DBPoolManager:
|
||||
"""
|
||||
全局 PostgreSQL 连接池管理器
|
||||
|
||||
主要功能:
|
||||
1. 使用 psycopg_pool.AsyncConnectionPool 管理连接
|
||||
2. 被 CheckpointerManager、ChatHistoryManager 共享
|
||||
3. 自动清理旧的 checkpoint 数据
|
||||
4. 优雅关闭机制
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._pool: Optional[AsyncConnectionPool] = None
|
||||
self._initialized = False
|
||||
self._closed = False
|
||||
# 清理调度任务
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._cleanup_stop_event = asyncio.Event()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化连接池"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Initializing PostgreSQL connection pool: "
|
||||
f"URL={CHECKPOINT_DB_URL}, size={CHECKPOINT_POOL_SIZE}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 创建 psycopg 连接池
|
||||
self._pool = AsyncConnectionPool(
|
||||
CHECKPOINT_DB_URL,
|
||||
min_size=1,
|
||||
max_size=CHECKPOINT_POOL_SIZE,
|
||||
open=False,
|
||||
)
|
||||
|
||||
# 打开连接池
|
||||
await self._pool.open()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("PostgreSQL connection pool initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize PostgreSQL connection pool: {e}")
|
||||
raise
|
||||
|
||||
@property
|
||||
def pool(self) -> AsyncConnectionPool:
|
||||
"""获取连接池"""
|
||||
if self._closed:
|
||||
raise RuntimeError("DBPoolManager is closed")
|
||||
|
||||
if not self._initialized:
|
||||
raise RuntimeError("DBPoolManager not initialized, call initialize() first")
|
||||
|
||||
return self._pool
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭连接池"""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
# 停止清理任务
|
||||
if self._cleanup_task is not None:
|
||||
self._cleanup_stop_event.set()
|
||||
try:
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
|
||||
logger.info("Closing DBPoolManager...")
|
||||
|
||||
if self._pool is not None:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
|
||||
self._closed = True
|
||||
self._initialized = False
|
||||
logger.info("DBPoolManager closed")
|
||||
|
||||
# ============================================================
|
||||
# Checkpoint 清理方法(基于 PostgreSQL SQL 查询)
|
||||
# ============================================================
|
||||
|
||||
async def cleanup_old_checkpoints(self, inactive_days: int = None) -> dict:
|
||||
"""
|
||||
清理旧的 checkpoint 记录
|
||||
|
||||
删除 N 天未活动的 thread 的所有 checkpoint。
|
||||
|
||||
Args:
|
||||
inactive_days: 删除 N 天未活动的 thread,默认使用配置值
|
||||
|
||||
Returns:
|
||||
Dict: 清理统计信息
|
||||
"""
|
||||
if inactive_days is None:
|
||||
inactive_days = CHECKPOINT_CLEANUP_INACTIVE_DAYS
|
||||
|
||||
logger.info(f"Starting checkpoint cleanup: removing threads inactive for {inactive_days} days")
|
||||
|
||||
if self._pool is None:
|
||||
logger.warning("Connection pool not initialized, skipping cleanup")
|
||||
return {"deleted": 0, "inactive_days": inactive_days}
|
||||
|
||||
try:
|
||||
# 从池中获取连接执行清理
|
||||
async with self._pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# 查找不活跃的 thread
|
||||
query_find_inactive = """
|
||||
SELECT DISTINCT thread_id
|
||||
FROM checkpoints
|
||||
WHERE (checkpoint->>'ts')::timestamp < NOW() - INTERVAL '%s days'
|
||||
"""
|
||||
|
||||
await cursor.execute(query_find_inactive, (inactive_days,))
|
||||
inactive_threads = await cursor.fetchall()
|
||||
inactive_thread_ids = [row[0] for row in inactive_threads]
|
||||
|
||||
if not inactive_thread_ids:
|
||||
logger.info("No inactive threads found")
|
||||
return {"deleted": 0, "threads_deleted": 0, "inactive_days": inactive_days}
|
||||
|
||||
# 删除不活跃 thread 的 checkpoints
|
||||
placeholders = ','.join(['%s'] * len(inactive_thread_ids))
|
||||
query_delete_checkpoints = f"""
|
||||
DELETE FROM checkpoints
|
||||
WHERE thread_id IN ({placeholders})
|
||||
"""
|
||||
|
||||
await cursor.execute(query_delete_checkpoints, inactive_thread_ids)
|
||||
deleted_checkpoints = cursor.rowcount
|
||||
|
||||
# 同时清理 checkpoint_writes
|
||||
query_delete_writes = f"""
|
||||
DELETE FROM checkpoint_writes
|
||||
WHERE thread_id IN ({placeholders})
|
||||
"""
|
||||
await cursor.execute(query_delete_writes, inactive_thread_ids)
|
||||
deleted_writes = cursor.rowcount
|
||||
|
||||
logger.info(
|
||||
f"Checkpoint cleanup completed: "
|
||||
f"deleted {deleted_checkpoints} checkpoints, {deleted_writes} writes "
|
||||
f"from {len(inactive_thread_ids)} inactive threads"
|
||||
)
|
||||
|
||||
return {
|
||||
"deleted": deleted_checkpoints + deleted_writes,
|
||||
"threads_deleted": len(inactive_thread_ids),
|
||||
"inactive_days": inactive_days
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during checkpoint cleanup: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return {"deleted": 0, "error": str(e), "inactive_days": inactive_days}
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""后台清理循环"""
|
||||
interval_seconds = CHECKPOINT_CLEANUP_INTERVAL_HOURS * 3600
|
||||
|
||||
logger.info(
|
||||
f"Checkpoint cleanup scheduler started: "
|
||||
f"interval={CHECKPOINT_CLEANUP_INTERVAL_HOURS}h, "
|
||||
f"inactive_days={CHECKPOINT_CLEANUP_INACTIVE_DAYS}"
|
||||
)
|
||||
|
||||
while not self._cleanup_stop_event.is_set():
|
||||
try:
|
||||
# 等待间隔时间或停止信号
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._cleanup_stop_event.wait(),
|
||||
timeout=interval_seconds
|
||||
)
|
||||
# 收到停止信号,退出循环
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# 超时,执行清理任务
|
||||
pass
|
||||
|
||||
if self._cleanup_stop_event.is_set():
|
||||
break
|
||||
|
||||
# 执行清理
|
||||
await self.cleanup_old_checkpoints()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Cleanup task cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup loop: {e}")
|
||||
|
||||
logger.info("Checkpoint cleanup scheduler stopped")
|
||||
|
||||
def start_cleanup_scheduler(self):
|
||||
"""启动后台定时清理任务"""
|
||||
if not CHECKPOINT_CLEANUP_ENABLED:
|
||||
logger.info("Checkpoint cleanup is disabled")
|
||||
return
|
||||
|
||||
if self._cleanup_task is not None and not self._cleanup_task.done():
|
||||
logger.warning("Cleanup scheduler is already running")
|
||||
return
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
self._cleanup_task = loop.create_task(self._cleanup_loop())
|
||||
logger.info("Cleanup scheduler task created")
|
||||
|
||||
|
||||
# 全局单例
|
||||
_global_manager: Optional[DBPoolManager] = None
|
||||
|
||||
|
||||
def get_db_pool_manager() -> DBPoolManager:
|
||||
"""获取全局 DBPoolManager 单例"""
|
||||
global _global_manager
|
||||
if _global_manager is None:
|
||||
_global_manager = DBPoolManager()
|
||||
return _global_manager
|
||||
|
||||
|
||||
async def init_global_db_pool() -> DBPoolManager:
|
||||
"""初始化全局数据库连接池"""
|
||||
manager = get_db_pool_manager()
|
||||
await manager.initialize()
|
||||
return manager
|
||||
|
||||
|
||||
async def close_global_db_pool() -> None:
|
||||
"""关闭全局数据库连接池"""
|
||||
global _global_manager
|
||||
if _global_manager is not None:
|
||||
await _global_manager.close()
|
||||
@ -23,6 +23,7 @@ from agent.agent_config import AgentConfig
|
||||
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||||
from agent.agent_memory_cache import get_memory_cache_manager
|
||||
from .checkpoint_utils import prepare_checkpoint_message
|
||||
from .checkpoint_manager import get_checkpointer_manager
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langchain.tools import BaseTool
|
||||
@ -181,10 +182,13 @@ async def init_agent(config: AgentConfig):
|
||||
|
||||
# 从连接池获取 checkpointer
|
||||
if config.session_id:
|
||||
from .checkpoint_manager import get_checkpointer_manager
|
||||
manager = get_checkpointer_manager()
|
||||
checkpointer = manager.checkpointer
|
||||
await prepare_checkpoint_message(config, checkpointer)
|
||||
try:
|
||||
manager = get_checkpointer_manager()
|
||||
checkpointer = manager.checkpointer
|
||||
if checkpointer:
|
||||
await prepare_checkpoint_message(config, checkpointer)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load checkpointer: {e}")
|
||||
|
||||
|
||||
if config.robot_type == "deep_agent":
|
||||
@ -328,6 +332,7 @@ def create_custom_cli_agent(
|
||||
middleware: list[AgentMiddleware] = [],
|
||||
workspace_root: str | None = None,
|
||||
checkpointer: Checkpointer | None = None,
|
||||
store: BaseStore | None = None,
|
||||
) -> tuple[Pregel, CompositeBackend]:
|
||||
"""Create a CLI-configured agent with custom workspace_root for shell commands.
|
||||
|
||||
@ -350,6 +355,8 @@ def create_custom_cli_agent(
|
||||
enable_skills: Enable SkillsMiddleware for custom agent skills
|
||||
enable_shell: Enable ShellMiddleware for local shell execution (only in local mode)
|
||||
workspace_root: Working directory for shell commands. If None, uses Path.cwd().
|
||||
checkpointer: Optional checkpointer for persisting conversation state
|
||||
store: Optional BaseStore for persisting user preferences and agent memory
|
||||
|
||||
Returns:
|
||||
2-tuple of (agent_graph, composite_backend)
|
||||
@ -464,5 +471,6 @@ def create_custom_cli_agent(
|
||||
middleware=agent_middleware,
|
||||
interrupt_on=interrupt_on,
|
||||
checkpointer=checkpointer,
|
||||
store=store,
|
||||
).with_config(config)
|
||||
return agent, composite_backend
|
||||
return agent, composite_backend
|
||||
|
||||
@ -15,6 +15,8 @@ import logging
|
||||
|
||||
from utils.log_util.logger import init_with_fastapi
|
||||
|
||||
# Initialize logger
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
# Import route modules
|
||||
from routes import chat, files, projects, system, skill_manager
|
||||
@ -25,44 +27,61 @@ async def lifespan(app: FastAPI):
|
||||
"""FastAPI 应用生命周期管理"""
|
||||
# 启动时初始化
|
||||
logger.info("Starting up...")
|
||||
from agent.db_pool_manager import (
|
||||
init_global_db_pool,
|
||||
get_db_pool_manager,
|
||||
close_global_db_pool
|
||||
)
|
||||
from agent.checkpoint_manager import (
|
||||
init_global_checkpointer,
|
||||
get_checkpointer_manager,
|
||||
close_global_checkpointer
|
||||
)
|
||||
from agent.chat_history_manager import (
|
||||
init_chat_history_manager,
|
||||
close_chat_history_manager
|
||||
)
|
||||
from utils.settings import CHECKPOINT_CLEANUP_ENABLED
|
||||
|
||||
await init_global_checkpointer()
|
||||
# 1. 初始化共享的数据库连接池
|
||||
db_pool_manager = await init_global_db_pool()
|
||||
logger.info("Global DB pool initialized")
|
||||
|
||||
# 2. 初始化 checkpoint (使用共享连接池)
|
||||
await init_global_checkpointer(db_pool_manager.pool)
|
||||
logger.info("Global checkpointer initialized")
|
||||
|
||||
# 启动 checkpoint 清理调度器
|
||||
# 3. 初始化 chat_history (使用共享连接池)
|
||||
await init_chat_history_manager(db_pool_manager.pool)
|
||||
logger.info("Chat history manager initialized")
|
||||
|
||||
# 4. 启动 checkpoint 清理调度器
|
||||
if CHECKPOINT_CLEANUP_ENABLED:
|
||||
manager = get_checkpointer_manager()
|
||||
# 启动时立即执行一次清理
|
||||
try:
|
||||
result = await manager.cleanup_old_dbs()
|
||||
result = await db_pool_manager.cleanup_old_checkpoints()
|
||||
logger.info(f"Startup cleanup completed: {result}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Startup cleanup failed (non-fatal): {e}")
|
||||
# 启动定时清理调度器
|
||||
manager.start_cleanup_scheduler()
|
||||
db_pool_manager.start_cleanup_scheduler()
|
||||
logger.info("Checkpoint cleanup scheduler started")
|
||||
|
||||
yield
|
||||
|
||||
# 关闭时清理
|
||||
# 关闭时清理(按相反顺序)
|
||||
logger.info("Shutting down...")
|
||||
await close_chat_history_manager()
|
||||
logger.info("Chat history manager closed")
|
||||
await close_global_checkpointer()
|
||||
logger.info("Global checkpointer closed")
|
||||
await close_global_db_pool()
|
||||
logger.info("Global DB pool closed")
|
||||
|
||||
|
||||
app = FastAPI(title="Database Assistant API", version="1.0.0", lifespan=lifespan)
|
||||
|
||||
init_with_fastapi(app)
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
|
||||
# 挂载public文件夹为静态文件服务
|
||||
app.mount("/public", StaticFiles(directory="public"), name="static")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user