diff --git a/agent/checkpoint_manager.py b/agent/checkpoint_manager.py index d8afc34..4962d45 100644 --- a/agent/checkpoint_manager.py +++ b/agent/checkpoint_manager.py @@ -1,12 +1,14 @@ """ 全局 SQLite Checkpointer 管理器 解决高并发场景下的数据库锁定问题 + +每个 session 使用独立的数据库文件,避免并发锁竞争 """ import asyncio import logging import os -from datetime import datetime, timedelta, timezone -from typing import Optional, List, Dict, Any +import time +from typing import Optional, Dict, Any, Tuple import aiosqlite from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver @@ -15,7 +17,6 @@ from utils.settings import ( CHECKPOINT_DB_PATH, CHECKPOINT_WAL_MODE, CHECKPOINT_BUSY_TIMEOUT, - CHECKPOINT_POOL_SIZE, CHECKPOINT_CLEANUP_ENABLED, CHECKPOINT_CLEANUP_INTERVAL_HOURS, CHECKPOINT_CLEANUP_OLDER_THAN_DAYS, @@ -23,60 +24,71 @@ from utils.settings import ( logger = logging.getLogger('app') +# 每个 session 的连接池大小(单个 session 串行处理,1 个连接即可) +POOL_SIZE_PER_SESSION = 1 + class CheckpointerManager: """ - 全局 Checkpointer 管理器,使用连接池复用 SQLite 连接 + 全局 Checkpointer 管理器,按 session_id 分离数据库文件 主要功能: - 1. 全局单例连接管理,避免每次请求创建新连接 - 2. 预配置 WAL 模式和 busy_timeout - 3. 连接池支持高并发访问 - 4. 优雅关闭机制 + 1. 每个 session_id 独立的数据库文件和连接池 + 2. 按需创建连接池,不用的 session 不占用资源 + 3. 预配置 WAL 模式和 busy_timeout + 4. 基于文件修改时间的简单清理机制 + 5. 优雅关闭机制 """ def __init__(self): - self._pool: asyncio.Queue[AsyncSqliteSaver] = asyncio.Queue() - self._lock = asyncio.Lock() - self._initialized = False + # 每个 (bot_id, session_id) 一个连接池 + self._pools: Dict[Tuple[str, str], asyncio.Queue[AsyncSqliteSaver]] = {} + # 每个 session 的初始化锁 + self._locks: Dict[Tuple[str, str], asyncio.Lock] = {} + # 全局锁,用于保护 pools 和 locks 字典的访问 + self._global_lock = asyncio.Lock() self._closed = False - self._pool_size = CHECKPOINT_POOL_SIZE - self._db_path = CHECKPOINT_DB_PATH # 清理调度任务 self._cleanup_task: Optional[asyncio.Task] = None self._cleanup_stop_event = asyncio.Event() - async def initialize(self) -> None: - """初始化连接池""" - if self._initialized: + def _get_db_path(self, bot_id: str, session_id: str) -> str: + """获取指定 session 的数据库文件路径""" + return os.path.join(CHECKPOINT_DB_PATH, bot_id, session_id, "checkpoints.db") + + def _get_pool_key(self, bot_id: str, session_id: str) -> Tuple[str, str]: + """获取连接池的键""" + return (bot_id, session_id) + + async def _initialize_session_pool(self, bot_id: str, session_id: str) -> None: + """初始化指定 session 的连接池""" + pool_key = self._get_pool_key(bot_id, session_id) + if pool_key in self._pools: return - async with self._lock: - if self._initialized: - return + logger.info(f"Initializing checkpointer pool for bot_id={bot_id}, session_id={session_id}") - logger.info(f"Initializing CheckpointerManager with pool_size={self._pool_size}") + db_path = self._get_db_path(bot_id, session_id) + os.makedirs(os.path.dirname(db_path), exist_ok=True) - # 确保目录存在 - os.makedirs(os.path.dirname(self._db_path), exist_ok=True) + pool = asyncio.Queue() + for i in range(POOL_SIZE_PER_SESSION): + try: + conn = await self._create_configured_connection(db_path) + checkpointer = AsyncSqliteSaver(conn=conn) + # 预先调用 setup 确保表结构已创建 + await checkpointer.setup() + await pool.put(checkpointer) + logger.debug(f"Created checkpointer connection {i+1}/{POOL_SIZE_PER_SESSION} for session={session_id}") + except Exception as e: + logger.error(f"Failed to create checkpointer connection {i+1} for session={session_id}: {e}") + raise - # 创建连接池 - for i in range(self._pool_size): - try: - conn = await self._create_configured_connection() - checkpointer = AsyncSqliteSaver(conn=conn) - # 预先调用 setup 确保表结构已创建 - await checkpointer.setup() - await self._pool.put(checkpointer) - logger.debug(f"Created checkpointer connection {i+1}/{self._pool_size}") - except Exception as e: - logger.error(f"Failed to create checkpointer connection {i+1}: {e}") - raise + self._pools[pool_key] = pool + self._locks[pool_key] = asyncio.Lock() + logger.info(f"Checkpointer pool initialized for bot_id={bot_id}, session_id={session_id}") - self._initialized = True - logger.info("CheckpointerManager initialized successfully") - - async def _create_configured_connection(self) -> aiosqlite.Connection: + async def _create_configured_connection(self, db_path: str) -> aiosqlite.Connection: """ 创建已配置的 SQLite 连接 @@ -85,7 +97,7 @@ class CheckpointerManager: 2. busy_timeout - 等待锁定的最长时间 3. 其他优化参数 """ - conn = aiosqlite.connect(self._db_path) + conn = aiosqlite.connect(db_path) # 等待连接建立 await conn.__aenter__() @@ -98,43 +110,87 @@ class CheckpointerManager: await conn.execute("PRAGMA journal_mode = WAL") await conn.execute("PRAGMA synchronous = NORMAL") # WAL 模式下的优化配置 - await conn.execute("PRAGMA wal_autocheckpoint = 1000") + await conn.execute("PRAGMA wal_autocheckpoint = 10000") # 增加到 10000 await conn.execute("PRAGMA cache_size = -64000") # 64MB 缓存 await conn.execute("PRAGMA temp_store = MEMORY") + await conn.execute("PRAGMA journal_size_limit = 1048576") # 1MB await conn.commit() return conn - async def acquire_for_agent(self) -> AsyncSqliteSaver: + async def initialize(self) -> None: + """初始化管理器(不再预创建连接池,改为按需创建)""" + logger.info("CheckpointerManager initialized (pools will be created on-demand)") + + async def acquire_for_agent(self, bot_id: str, session_id: str) -> AsyncSqliteSaver: """ - 为 agent 获取 checkpointer + 获取指定 session 的 checkpointer 注意:此方法获取的 checkpointer 需要手动归还 使用 return_to_pool() 方法归还 + Args: + bot_id: 机器人 ID + session_id: 会话 ID + Returns: AsyncSqliteSaver 实例 """ - if not self._initialized: - raise RuntimeError("CheckpointerManager not initialized. Call initialize() first.") + if self._closed: + raise RuntimeError("CheckpointerManager is closed") - checkpointer = await self._pool.get() - logger.debug(f"Acquired checkpointer from pool, remaining: {self._pool.qsize()}") - return checkpointer + pool_key = self._get_pool_key(bot_id, session_id) + async with self._global_lock: + if pool_key not in self._pools: + await self._initialize_session_pool(bot_id, session_id) - async def return_to_pool(self, checkpointer: AsyncSqliteSaver) -> None: + # 获取该 session 的锁,确保连接池操作线程安全 + async with self._locks[pool_key]: + checkpointer = await self._pools[pool_key].get() + logger.debug(f"Acquired checkpointer for bot_id={bot_id}, session_id={session_id}, remaining: {self._pools[pool_key].qsize()}") + return checkpointer + + async def return_to_pool(self, bot_id: str, session_id: str, checkpointer: AsyncSqliteSaver) -> None: """ - 归还 checkpointer 到池 + 归还 checkpointer 到对应 session 的池 Args: + bot_id: 机器人 ID + session_id: 会话 ID checkpointer: 要归还的 checkpointer 实例 """ - await self._pool.put(checkpointer) - logger.debug(f"Returned checkpointer to pool, remaining: {self._pool.qsize()}") + pool_key = self._get_pool_key(bot_id, session_id) + if pool_key in self._pools: + async with self._locks[pool_key]: + await self._pools[pool_key].put(checkpointer) + logger.debug(f"Returned checkpointer for bot_id={bot_id}, session_id={session_id}, remaining: {self._pools[pool_key].qsize()}") + + async def _close_session_pool(self, bot_id: str, session_id: str) -> None: + """关闭指定 session 的连接池""" + pool_key = self._get_pool_key(bot_id, session_id) + if pool_key not in self._pools: + return + + logger.info(f"Closing checkpointer pool for bot_id={bot_id}, session_id={session_id}") + + pool = self._pools[pool_key] + while not pool.empty(): + try: + checkpointer = pool.get_nowait() + if checkpointer.conn: + await checkpointer.conn.close() + except asyncio.QueueEmpty: + break + + del self._pools[pool_key] + if pool_key in self._locks: + del self._locks[pool_key] + + logger.info(f"Checkpointer pool closed for bot_id={bot_id}, session_id={session_id}") async def close(self) -> None: - """关闭所有连接""" + """关闭所有连接池""" if self._closed: return @@ -148,146 +204,112 @@ class CheckpointerManager: pass self._cleanup_task = None - async with self._lock: + async with self._global_lock: if self._closed: return logger.info("Closing CheckpointerManager...") - # 清空池并关闭所有连接 - while not self._pool.empty(): - try: - checkpointer = self._pool.get_nowait() - if checkpointer.conn: - await checkpointer.conn.close() - except asyncio.QueueEmpty: - break + # 关闭所有 session 的连接池 + pool_keys = list(self._pools.keys()) + for bot_id, session_id in pool_keys: + await self._close_session_pool(bot_id, session_id) self._closed = True - self._initialized = False logger.info("CheckpointerManager closed") def get_pool_stats(self) -> dict: """获取连接池状态统计""" return { - "db_path": self._db_path, - "pool_size": self._pool_size, - "available_connections": self._pool.qsize(), - "initialized": self._initialized, + "session_count": len(self._pools), + "pools": { + f"{bot_id}/{session_id}": { + "available": pool.qsize(), + "pool_size": POOL_SIZE_PER_SESSION + } + for (bot_id, session_id), pool in self._pools.items() + }, "closed": self._closed } # ============================================================ - # Checkpoint 清理方法 + # Checkpoint 清理方法(基于文件修改时间) # ============================================================ - async def get_all_thread_ids(self) -> List[str]: + async def cleanup_old_dbs(self, older_than_days: int = None) -> Dict[str, Any]: """ - 获取数据库中所有唯一的 thread_id - - Returns: - List[str]: 所有 thread_id 列表 - """ - if not self._initialized: - return [] - - conn = aiosqlite.connect(self._db_path) - await conn.__aenter__() - - try: - cursor = await conn.execute( - "SELECT DISTINCT thread_id FROM checkpoints" - ) - rows = await cursor.fetchall() - return [row[0] for row in rows] - finally: - await conn.close() - - async def get_thread_last_activity(self, thread_id: str) -> Optional[datetime]: - """ - 获取指定 thread 的最后活动时间 - - 通过查询该 thread 最新的 checkpoint 中的 ts 字段获取时间 - - Args: - thread_id: 线程ID - - Returns: - datetime: 最后活动时间,如果找不到则返回 None - """ - if not self._initialized: - return None - - checkpointer = await self.acquire_for_agent() - - try: - config = {"configurable": {"thread_id": thread_id}} - result = checkpointer.alist(config=config, limit=1) - - last_checkpoint = None - async for item in result: - last_checkpoint = item - break - - if last_checkpoint and last_checkpoint.checkpoint: - ts_str = last_checkpoint.checkpoint.get("ts") - if ts_str: - # 解析 ISO 格式时间戳 - return datetime.fromisoformat(ts_str.replace("Z", "+00:00")) - except Exception as e: - logger.warning(f"Error getting last activity for thread {thread_id}: {e}") - finally: - await self.return_to_pool(checkpointer) - - return None - - async def cleanup_old_threads(self, older_than_days: int = None) -> Dict[str, Any]: - """ - 清理超过指定天数未活动的 thread + 根据数据库文件的修改时间清理旧数据库文件 Args: older_than_days: 清理多少天前的数据,默认使用配置值 Returns: Dict: 清理统计信息 - - threads_deleted: 删除的 thread 数量 - - threads_scanned: 扫描的 thread 总数 - - cutoff_time: 截止时间 + - deleted: 删除的 session 目录数量 + - scanned: 扫描的 session 目录数量 + - cutoff_time: 截止时间戳 """ if older_than_days is None: older_than_days = CHECKPOINT_CLEANUP_OLDER_THAN_DAYS - # 使用带时区的时间,避免比较时出错 - cutoff_time = datetime.now(timezone.utc) - timedelta(days=older_than_days) - logger.info(f"Starting checkpoint cleanup: removing threads inactive since {cutoff_time.isoformat()}") + cutoff_time = time.time() - older_than_days * 86400 + logger.info(f"Starting checkpoint cleanup: removing db files not modified since {cutoff_time}") - all_thread_ids = await self.get_all_thread_ids() - threads_deleted = 0 - threads_scanned = len(all_thread_ids) + db_dir = CHECKPOINT_DB_PATH + deleted_count = 0 + scanned_count = 0 - checkpointer = await self.acquire_for_agent() + if not os.path.exists(db_dir): + logger.info(f"Checkpoint directory does not exist: {db_dir}") + return {"deleted": 0, "scanned": 0, "cutoff_time": cutoff_time} - try: - for thread_id in all_thread_ids: + # 遍历 bot_id 目录 + for bot_id in os.listdir(db_dir): + bot_path = os.path.join(db_dir, bot_id) + # 跳过非目录文件 + if not os.path.isdir(bot_path): + continue + + # 遍历 session_id 目录 + for session_id in os.listdir(bot_path): + session_path = os.path.join(bot_path, session_id) + if not os.path.isdir(session_path): + continue + + db_file = os.path.join(session_path, "checkpoints.db") + if not os.path.exists(db_file): + continue + + scanned_count += 1 + mtime = os.path.getmtime(db_file) + + if mtime < cutoff_time: + # 关闭该 session 的连接池(如果有) + await self._close_session_pool(bot_id, session_id) + + # 删除整个 session 目录 + try: + import shutil + shutil.rmtree(session_path) + deleted_count += 1 + logger.info(f"Deleted old checkpoint session: {bot_id}/{session_id}/ (last modified: {mtime})") + except Exception as e: + logger.warning(f"Failed to delete {session_path}: {e}") + + # 清理空的 bot_id 目录 + for bot_id in os.listdir(db_dir): + bot_path = os.path.join(db_dir, bot_id) + if os.path.isdir(bot_path) and not os.listdir(bot_path): try: - last_activity = await self.get_thread_last_activity(thread_id) - - if last_activity and last_activity < cutoff_time: - # 删除旧 thread - config = {"configurable": {"thread_id": thread_id}} - await checkpointer.adelete_thread(config) - threads_deleted += 1 - logger.debug(f"Deleted old thread: {thread_id} (last activity: {last_activity.isoformat()})") - except Exception as e: - logger.warning(f"Error processing thread {thread_id}: {e}") - - finally: - await self.return_to_pool(checkpointer) + os.rmdir(bot_path) + logger.debug(f"Removed empty bot directory: {bot_id}/") + except Exception: + pass result = { - "threads_deleted": threads_deleted, - "threads_scanned": threads_scanned, - "cutoff_time": cutoff_time.isoformat(), + "deleted": deleted_count, + "scanned": scanned_count, + "cutoff_time": cutoff_time, "older_than_days": older_than_days } @@ -322,7 +344,7 @@ class CheckpointerManager: break # 执行清理 - await self.cleanup_old_threads() + await self.cleanup_old_dbs() except asyncio.CancelledError: logger.info("Cleanup task cancelled") diff --git a/agent/deep_assistant.py b/agent/deep_assistant.py index fff0602..b0943d6 100644 --- a/agent/deep_assistant.py +++ b/agent/deep_assistant.py @@ -180,7 +180,7 @@ async def init_agent(config: AgentConfig): if config.session_id: from .checkpoint_manager import get_checkpointer_manager manager = get_checkpointer_manager() - checkpointer = await manager.acquire_for_agent() + checkpointer = await manager.acquire_for_agent(config.bot_id, config.session_id) await prepare_checkpoint_message(config, checkpointer) summarization_middleware = SummarizationMiddleware( model=llm_instance, diff --git a/fastapi_app.py b/fastapi_app.py index 79def96..ced584f 100644 --- a/fastapi_app.py +++ b/fastapi_app.py @@ -41,7 +41,7 @@ async def lifespan(app: FastAPI): manager = get_checkpointer_manager() # 启动时立即执行一次清理 try: - result = await manager.cleanup_old_threads() + result = await manager.cleanup_old_dbs() logger.info(f"Startup cleanup completed: {result}") except Exception as e: logger.warning(f"Startup cleanup failed (non-fatal): {e}") diff --git a/routes/chat.py b/routes/chat.py index 72f5d5a..5e82605 100644 --- a/routes/chat.py +++ b/routes/chat.py @@ -120,7 +120,7 @@ async def enhanced_generate_stream_response( if checkpointer: from agent.checkpoint_manager import get_checkpointer_manager manager = get_checkpointer_manager() - await manager.return_to_pool(checkpointer) + await manager.return_to_pool(config.bot_id, config.session_id, checkpointer) # 并发执行任务 # 只有在 enable_thinking 为 True 时才执行 preamble 任务 @@ -249,7 +249,7 @@ async def create_agent_and_generate_response( if checkpointer: from agent.checkpoint_manager import get_checkpointer_manager manager = get_checkpointer_manager() - await manager.return_to_pool(checkpointer) + await manager.return_to_pool(config.bot_id, config.session_id, checkpointer) return result diff --git a/start_unified.py b/start_unified.py index 21b4047..270b94b 100755 --- a/start_unified.py +++ b/start_unified.py @@ -227,18 +227,21 @@ class ProcessManager: env_vars = { 'TOKENIZERS_PARALLELISM': 'false', 'TOOL_CACHE_MAX_SIZE': '10', + 'CHECKPOINT_POOL_SIZE': '10', } elif args.profile == "balanced": env_vars = { 'TOKENIZERS_PARALLELISM': 'true', 'TOKENIZERS_FAST': '1', 'TOOL_CACHE_MAX_SIZE': '20', + 'CHECKPOINT_POOL_SIZE': '15', } elif args.profile == "high_performance": env_vars = { 'TOKENIZERS_PARALLELISM': 'true', 'TOKENIZERS_FAST': '1', - 'TOOL_CACHE_MAX_SIZE': '50', + 'TOOL_CACHE_MAX_SIZE': '30', + 'CHECKPOINT_POOL_SIZE': '20', } # 通用优化 diff --git a/utils/settings.py b/utils/settings.py index abd4076..75e708c 100644 --- a/utils/settings.py +++ b/utils/settings.py @@ -1,9 +1,16 @@ import os +# 必填参数 +# API Settings +BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai") +MASTERKEY = os.getenv("MASTERKEY", "master") +FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') + # LLM Token Settings MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 262144)) MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000)) +# 可选参数 # Summarization Settings SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000 SUMMARIZATION_MESSAGES_TO_KEEP = int(os.getenv("SUMMARIZATION_MESSAGES_TO_KEEP", 20)) @@ -13,11 +20,6 @@ TOOL_CACHE_MAX_SIZE = int(os.getenv("TOOL_CACHE_MAX_SIZE", 20)) TOOL_CACHE_TTL = int(os.getenv("TOOL_CACHE_TTL", 180)) TOOL_CACHE_AUTO_RENEW = os.getenv("TOOL_CACHE_AUTO_RENEW", "true") == "true" -# API Settings -BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai") -MASTERKEY = os.getenv("MASTERKEY", "master") -FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') - # Project Settings PROJECT_DATA_DIR = os.getenv("PROJECT_DATA_DIR", "./projects/data") @@ -44,7 +46,7 @@ MCP_SSE_READ_TIMEOUT = int(os.getenv("MCP_SSE_READ_TIMEOUT", 300)) # SSE 读取 # ============================================================ # Checkpoint 数据库路径 -CHECKPOINT_DB_PATH = os.getenv("CHECKPOINT_DB_PATH", "./projects/memory/checkpoints.db") +CHECKPOINT_DB_PATH = os.getenv("CHECKPOINT_DB_PATH", "./projects/memory/") # 启用 WAL 模式 (Write-Ahead Logging) # WAL 模式允许读写并发,大幅提升并发性能