This commit is contained in:
朱潮 2025-12-24 11:05:10 +08:00
parent e117f1ee07
commit b86a8364e9
6 changed files with 196 additions and 169 deletions

View File

@ -1,12 +1,14 @@
""" """
全局 SQLite Checkpointer 管理器 全局 SQLite Checkpointer 管理器
解决高并发场景下的数据库锁定问题 解决高并发场景下的数据库锁定问题
每个 session 使用独立的数据库文件避免并发锁竞争
""" """
import asyncio import asyncio
import logging import logging
import os import os
from datetime import datetime, timedelta, timezone import time
from typing import Optional, List, Dict, Any from typing import Optional, Dict, Any, Tuple
import aiosqlite import aiosqlite
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
@ -15,7 +17,6 @@ from utils.settings import (
CHECKPOINT_DB_PATH, CHECKPOINT_DB_PATH,
CHECKPOINT_WAL_MODE, CHECKPOINT_WAL_MODE,
CHECKPOINT_BUSY_TIMEOUT, CHECKPOINT_BUSY_TIMEOUT,
CHECKPOINT_POOL_SIZE,
CHECKPOINT_CLEANUP_ENABLED, CHECKPOINT_CLEANUP_ENABLED,
CHECKPOINT_CLEANUP_INTERVAL_HOURS, CHECKPOINT_CLEANUP_INTERVAL_HOURS,
CHECKPOINT_CLEANUP_OLDER_THAN_DAYS, CHECKPOINT_CLEANUP_OLDER_THAN_DAYS,
@ -23,60 +24,71 @@ from utils.settings import (
logger = logging.getLogger('app') logger = logging.getLogger('app')
# 每个 session 的连接池大小(单个 session 串行处理1 个连接即可)
POOL_SIZE_PER_SESSION = 1
class CheckpointerManager: class CheckpointerManager:
""" """
全局 Checkpointer 管理器使用连接池复用 SQLite 连接 全局 Checkpointer 管理器session_id 分离数据库文件
主要功能 主要功能
1. 全局单例连接管理避免每次请求创建新连接 1. 每个 session_id 独立的数据库文件和连接池
2. 预配置 WAL 模式和 busy_timeout 2. 按需创建连接池不用的 session 不占用资源
3. 连接池支持高并发访问 3. 预配置 WAL 模式和 busy_timeout
4. 优雅关闭机制 4. 基于文件修改时间的简单清理机制
5. 优雅关闭机制
""" """
def __init__(self): def __init__(self):
self._pool: asyncio.Queue[AsyncSqliteSaver] = asyncio.Queue() # 每个 (bot_id, session_id) 一个连接池
self._lock = asyncio.Lock() self._pools: Dict[Tuple[str, str], asyncio.Queue[AsyncSqliteSaver]] = {}
self._initialized = False # 每个 session 的初始化锁
self._locks: Dict[Tuple[str, str], asyncio.Lock] = {}
# 全局锁,用于保护 pools 和 locks 字典的访问
self._global_lock = asyncio.Lock()
self._closed = False self._closed = False
self._pool_size = CHECKPOINT_POOL_SIZE
self._db_path = CHECKPOINT_DB_PATH
# 清理调度任务 # 清理调度任务
self._cleanup_task: Optional[asyncio.Task] = None self._cleanup_task: Optional[asyncio.Task] = None
self._cleanup_stop_event = asyncio.Event() self._cleanup_stop_event = asyncio.Event()
async def initialize(self) -> None: def _get_db_path(self, bot_id: str, session_id: str) -> str:
"""初始化连接池""" """获取指定 session 的数据库文件路径"""
if self._initialized: return os.path.join(CHECKPOINT_DB_PATH, bot_id, session_id, "checkpoints.db")
def _get_pool_key(self, bot_id: str, session_id: str) -> Tuple[str, str]:
"""获取连接池的键"""
return (bot_id, session_id)
async def _initialize_session_pool(self, bot_id: str, session_id: str) -> None:
"""初始化指定 session 的连接池"""
pool_key = self._get_pool_key(bot_id, session_id)
if pool_key in self._pools:
return return
async with self._lock: logger.info(f"Initializing checkpointer pool for bot_id={bot_id}, session_id={session_id}")
if self._initialized:
return
logger.info(f"Initializing CheckpointerManager with pool_size={self._pool_size}") db_path = self._get_db_path(bot_id, session_id)
os.makedirs(os.path.dirname(db_path), exist_ok=True)
# 确保目录存在 pool = asyncio.Queue()
os.makedirs(os.path.dirname(self._db_path), exist_ok=True) for i in range(POOL_SIZE_PER_SESSION):
try:
conn = await self._create_configured_connection(db_path)
checkpointer = AsyncSqliteSaver(conn=conn)
# 预先调用 setup 确保表结构已创建
await checkpointer.setup()
await pool.put(checkpointer)
logger.debug(f"Created checkpointer connection {i+1}/{POOL_SIZE_PER_SESSION} for session={session_id}")
except Exception as e:
logger.error(f"Failed to create checkpointer connection {i+1} for session={session_id}: {e}")
raise
# 创建连接池 self._pools[pool_key] = pool
for i in range(self._pool_size): self._locks[pool_key] = asyncio.Lock()
try: logger.info(f"Checkpointer pool initialized for bot_id={bot_id}, session_id={session_id}")
conn = await self._create_configured_connection()
checkpointer = AsyncSqliteSaver(conn=conn)
# 预先调用 setup 确保表结构已创建
await checkpointer.setup()
await self._pool.put(checkpointer)
logger.debug(f"Created checkpointer connection {i+1}/{self._pool_size}")
except Exception as e:
logger.error(f"Failed to create checkpointer connection {i+1}: {e}")
raise
self._initialized = True async def _create_configured_connection(self, db_path: str) -> aiosqlite.Connection:
logger.info("CheckpointerManager initialized successfully")
async def _create_configured_connection(self) -> aiosqlite.Connection:
""" """
创建已配置的 SQLite 连接 创建已配置的 SQLite 连接
@ -85,7 +97,7 @@ class CheckpointerManager:
2. busy_timeout - 等待锁定的最长时间 2. busy_timeout - 等待锁定的最长时间
3. 其他优化参数 3. 其他优化参数
""" """
conn = aiosqlite.connect(self._db_path) conn = aiosqlite.connect(db_path)
# 等待连接建立 # 等待连接建立
await conn.__aenter__() await conn.__aenter__()
@ -98,43 +110,87 @@ class CheckpointerManager:
await conn.execute("PRAGMA journal_mode = WAL") await conn.execute("PRAGMA journal_mode = WAL")
await conn.execute("PRAGMA synchronous = NORMAL") await conn.execute("PRAGMA synchronous = NORMAL")
# WAL 模式下的优化配置 # WAL 模式下的优化配置
await conn.execute("PRAGMA wal_autocheckpoint = 1000") await conn.execute("PRAGMA wal_autocheckpoint = 10000") # 增加到 10000
await conn.execute("PRAGMA cache_size = -64000") # 64MB 缓存 await conn.execute("PRAGMA cache_size = -64000") # 64MB 缓存
await conn.execute("PRAGMA temp_store = MEMORY") await conn.execute("PRAGMA temp_store = MEMORY")
await conn.execute("PRAGMA journal_size_limit = 1048576") # 1MB
await conn.commit() await conn.commit()
return conn return conn
async def acquire_for_agent(self) -> AsyncSqliteSaver: async def initialize(self) -> None:
"""初始化管理器(不再预创建连接池,改为按需创建)"""
logger.info("CheckpointerManager initialized (pools will be created on-demand)")
async def acquire_for_agent(self, bot_id: str, session_id: str) -> AsyncSqliteSaver:
""" """
agent 获取 checkpointer 获取指定 session checkpointer
注意此方法获取的 checkpointer 需要手动归还 注意此方法获取的 checkpointer 需要手动归还
使用 return_to_pool() 方法归还 使用 return_to_pool() 方法归还
Args:
bot_id: 机器人 ID
session_id: 会话 ID
Returns: Returns:
AsyncSqliteSaver 实例 AsyncSqliteSaver 实例
""" """
if not self._initialized: if self._closed:
raise RuntimeError("CheckpointerManager not initialized. Call initialize() first.") raise RuntimeError("CheckpointerManager is closed")
checkpointer = await self._pool.get() pool_key = self._get_pool_key(bot_id, session_id)
logger.debug(f"Acquired checkpointer from pool, remaining: {self._pool.qsize()}") async with self._global_lock:
return checkpointer if pool_key not in self._pools:
await self._initialize_session_pool(bot_id, session_id)
async def return_to_pool(self, checkpointer: AsyncSqliteSaver) -> None: # 获取该 session 的锁,确保连接池操作线程安全
async with self._locks[pool_key]:
checkpointer = await self._pools[pool_key].get()
logger.debug(f"Acquired checkpointer for bot_id={bot_id}, session_id={session_id}, remaining: {self._pools[pool_key].qsize()}")
return checkpointer
async def return_to_pool(self, bot_id: str, session_id: str, checkpointer: AsyncSqliteSaver) -> None:
""" """
归还 checkpointer 到池 归还 checkpointer 对应 session
Args: Args:
bot_id: 机器人 ID
session_id: 会话 ID
checkpointer: 要归还的 checkpointer 实例 checkpointer: 要归还的 checkpointer 实例
""" """
await self._pool.put(checkpointer) pool_key = self._get_pool_key(bot_id, session_id)
logger.debug(f"Returned checkpointer to pool, remaining: {self._pool.qsize()}") if pool_key in self._pools:
async with self._locks[pool_key]:
await self._pools[pool_key].put(checkpointer)
logger.debug(f"Returned checkpointer for bot_id={bot_id}, session_id={session_id}, remaining: {self._pools[pool_key].qsize()}")
async def _close_session_pool(self, bot_id: str, session_id: str) -> None:
"""关闭指定 session 的连接池"""
pool_key = self._get_pool_key(bot_id, session_id)
if pool_key not in self._pools:
return
logger.info(f"Closing checkpointer pool for bot_id={bot_id}, session_id={session_id}")
pool = self._pools[pool_key]
while not pool.empty():
try:
checkpointer = pool.get_nowait()
if checkpointer.conn:
await checkpointer.conn.close()
except asyncio.QueueEmpty:
break
del self._pools[pool_key]
if pool_key in self._locks:
del self._locks[pool_key]
logger.info(f"Checkpointer pool closed for bot_id={bot_id}, session_id={session_id}")
async def close(self) -> None: async def close(self) -> None:
"""关闭所有连接""" """关闭所有连接"""
if self._closed: if self._closed:
return return
@ -148,146 +204,112 @@ class CheckpointerManager:
pass pass
self._cleanup_task = None self._cleanup_task = None
async with self._lock: async with self._global_lock:
if self._closed: if self._closed:
return return
logger.info("Closing CheckpointerManager...") logger.info("Closing CheckpointerManager...")
# 清空池并关闭所有连接 # 关闭所有 session 的连接池
while not self._pool.empty(): pool_keys = list(self._pools.keys())
try: for bot_id, session_id in pool_keys:
checkpointer = self._pool.get_nowait() await self._close_session_pool(bot_id, session_id)
if checkpointer.conn:
await checkpointer.conn.close()
except asyncio.QueueEmpty:
break
self._closed = True self._closed = True
self._initialized = False
logger.info("CheckpointerManager closed") logger.info("CheckpointerManager closed")
def get_pool_stats(self) -> dict: def get_pool_stats(self) -> dict:
"""获取连接池状态统计""" """获取连接池状态统计"""
return { return {
"db_path": self._db_path, "session_count": len(self._pools),
"pool_size": self._pool_size, "pools": {
"available_connections": self._pool.qsize(), f"{bot_id}/{session_id}": {
"initialized": self._initialized, "available": pool.qsize(),
"pool_size": POOL_SIZE_PER_SESSION
}
for (bot_id, session_id), pool in self._pools.items()
},
"closed": self._closed "closed": self._closed
} }
# ============================================================ # ============================================================
# Checkpoint 清理方法 # Checkpoint 清理方法(基于文件修改时间)
# ============================================================ # ============================================================
async def get_all_thread_ids(self) -> List[str]: async def cleanup_old_dbs(self, older_than_days: int = None) -> Dict[str, Any]:
""" """
获取数据库中所有唯一的 thread_id 根据数据库文件的修改时间清理旧数据库文件
Returns:
List[str]: 所有 thread_id 列表
"""
if not self._initialized:
return []
conn = aiosqlite.connect(self._db_path)
await conn.__aenter__()
try:
cursor = await conn.execute(
"SELECT DISTINCT thread_id FROM checkpoints"
)
rows = await cursor.fetchall()
return [row[0] for row in rows]
finally:
await conn.close()
async def get_thread_last_activity(self, thread_id: str) -> Optional[datetime]:
"""
获取指定 thread 的最后活动时间
通过查询该 thread 最新的 checkpoint 中的 ts 字段获取时间
Args:
thread_id: 线程ID
Returns:
datetime: 最后活动时间如果找不到则返回 None
"""
if not self._initialized:
return None
checkpointer = await self.acquire_for_agent()
try:
config = {"configurable": {"thread_id": thread_id}}
result = checkpointer.alist(config=config, limit=1)
last_checkpoint = None
async for item in result:
last_checkpoint = item
break
if last_checkpoint and last_checkpoint.checkpoint:
ts_str = last_checkpoint.checkpoint.get("ts")
if ts_str:
# 解析 ISO 格式时间戳
return datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
except Exception as e:
logger.warning(f"Error getting last activity for thread {thread_id}: {e}")
finally:
await self.return_to_pool(checkpointer)
return None
async def cleanup_old_threads(self, older_than_days: int = None) -> Dict[str, Any]:
"""
清理超过指定天数未活动的 thread
Args: Args:
older_than_days: 清理多少天前的数据默认使用配置值 older_than_days: 清理多少天前的数据默认使用配置值
Returns: Returns:
Dict: 清理统计信息 Dict: 清理统计信息
- threads_deleted: 删除的 thread 数量 - deleted: 删除的 session 目录数量
- threads_scanned: 扫描的 thread 总数 - scanned: 扫描的 session 目录数量
- cutoff_time: 截止时间 - cutoff_time: 截止时间戳
""" """
if older_than_days is None: if older_than_days is None:
older_than_days = CHECKPOINT_CLEANUP_OLDER_THAN_DAYS older_than_days = CHECKPOINT_CLEANUP_OLDER_THAN_DAYS
# 使用带时区的时间,避免比较时出错 cutoff_time = time.time() - older_than_days * 86400
cutoff_time = datetime.now(timezone.utc) - timedelta(days=older_than_days) logger.info(f"Starting checkpoint cleanup: removing db files not modified since {cutoff_time}")
logger.info(f"Starting checkpoint cleanup: removing threads inactive since {cutoff_time.isoformat()}")
all_thread_ids = await self.get_all_thread_ids() db_dir = CHECKPOINT_DB_PATH
threads_deleted = 0 deleted_count = 0
threads_scanned = len(all_thread_ids) scanned_count = 0
checkpointer = await self.acquire_for_agent() if not os.path.exists(db_dir):
logger.info(f"Checkpoint directory does not exist: {db_dir}")
return {"deleted": 0, "scanned": 0, "cutoff_time": cutoff_time}
try: # 遍历 bot_id 目录
for thread_id in all_thread_ids: for bot_id in os.listdir(db_dir):
bot_path = os.path.join(db_dir, bot_id)
# 跳过非目录文件
if not os.path.isdir(bot_path):
continue
# 遍历 session_id 目录
for session_id in os.listdir(bot_path):
session_path = os.path.join(bot_path, session_id)
if not os.path.isdir(session_path):
continue
db_file = os.path.join(session_path, "checkpoints.db")
if not os.path.exists(db_file):
continue
scanned_count += 1
mtime = os.path.getmtime(db_file)
if mtime < cutoff_time:
# 关闭该 session 的连接池(如果有)
await self._close_session_pool(bot_id, session_id)
# 删除整个 session 目录
try:
import shutil
shutil.rmtree(session_path)
deleted_count += 1
logger.info(f"Deleted old checkpoint session: {bot_id}/{session_id}/ (last modified: {mtime})")
except Exception as e:
logger.warning(f"Failed to delete {session_path}: {e}")
# 清理空的 bot_id 目录
for bot_id in os.listdir(db_dir):
bot_path = os.path.join(db_dir, bot_id)
if os.path.isdir(bot_path) and not os.listdir(bot_path):
try: try:
last_activity = await self.get_thread_last_activity(thread_id) os.rmdir(bot_path)
logger.debug(f"Removed empty bot directory: {bot_id}/")
if last_activity and last_activity < cutoff_time: except Exception:
# 删除旧 thread pass
config = {"configurable": {"thread_id": thread_id}}
await checkpointer.adelete_thread(config)
threads_deleted += 1
logger.debug(f"Deleted old thread: {thread_id} (last activity: {last_activity.isoformat()})")
except Exception as e:
logger.warning(f"Error processing thread {thread_id}: {e}")
finally:
await self.return_to_pool(checkpointer)
result = { result = {
"threads_deleted": threads_deleted, "deleted": deleted_count,
"threads_scanned": threads_scanned, "scanned": scanned_count,
"cutoff_time": cutoff_time.isoformat(), "cutoff_time": cutoff_time,
"older_than_days": older_than_days "older_than_days": older_than_days
} }
@ -322,7 +344,7 @@ class CheckpointerManager:
break break
# 执行清理 # 执行清理
await self.cleanup_old_threads() await self.cleanup_old_dbs()
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Cleanup task cancelled") logger.info("Cleanup task cancelled")

View File

@ -180,7 +180,7 @@ async def init_agent(config: AgentConfig):
if config.session_id: if config.session_id:
from .checkpoint_manager import get_checkpointer_manager from .checkpoint_manager import get_checkpointer_manager
manager = get_checkpointer_manager() manager = get_checkpointer_manager()
checkpointer = await manager.acquire_for_agent() checkpointer = await manager.acquire_for_agent(config.bot_id, config.session_id)
await prepare_checkpoint_message(config, checkpointer) await prepare_checkpoint_message(config, checkpointer)
summarization_middleware = SummarizationMiddleware( summarization_middleware = SummarizationMiddleware(
model=llm_instance, model=llm_instance,

View File

@ -41,7 +41,7 @@ async def lifespan(app: FastAPI):
manager = get_checkpointer_manager() manager = get_checkpointer_manager()
# 启动时立即执行一次清理 # 启动时立即执行一次清理
try: try:
result = await manager.cleanup_old_threads() result = await manager.cleanup_old_dbs()
logger.info(f"Startup cleanup completed: {result}") logger.info(f"Startup cleanup completed: {result}")
except Exception as e: except Exception as e:
logger.warning(f"Startup cleanup failed (non-fatal): {e}") logger.warning(f"Startup cleanup failed (non-fatal): {e}")

View File

@ -120,7 +120,7 @@ async def enhanced_generate_stream_response(
if checkpointer: if checkpointer:
from agent.checkpoint_manager import get_checkpointer_manager from agent.checkpoint_manager import get_checkpointer_manager
manager = get_checkpointer_manager() manager = get_checkpointer_manager()
await manager.return_to_pool(checkpointer) await manager.return_to_pool(config.bot_id, config.session_id, checkpointer)
# 并发执行任务 # 并发执行任务
# 只有在 enable_thinking 为 True 时才执行 preamble 任务 # 只有在 enable_thinking 为 True 时才执行 preamble 任务
@ -249,7 +249,7 @@ async def create_agent_and_generate_response(
if checkpointer: if checkpointer:
from agent.checkpoint_manager import get_checkpointer_manager from agent.checkpoint_manager import get_checkpointer_manager
manager = get_checkpointer_manager() manager = get_checkpointer_manager()
await manager.return_to_pool(checkpointer) await manager.return_to_pool(config.bot_id, config.session_id, checkpointer)
return result return result

View File

@ -227,18 +227,21 @@ class ProcessManager:
env_vars = { env_vars = {
'TOKENIZERS_PARALLELISM': 'false', 'TOKENIZERS_PARALLELISM': 'false',
'TOOL_CACHE_MAX_SIZE': '10', 'TOOL_CACHE_MAX_SIZE': '10',
'CHECKPOINT_POOL_SIZE': '10',
} }
elif args.profile == "balanced": elif args.profile == "balanced":
env_vars = { env_vars = {
'TOKENIZERS_PARALLELISM': 'true', 'TOKENIZERS_PARALLELISM': 'true',
'TOKENIZERS_FAST': '1', 'TOKENIZERS_FAST': '1',
'TOOL_CACHE_MAX_SIZE': '20', 'TOOL_CACHE_MAX_SIZE': '20',
'CHECKPOINT_POOL_SIZE': '15',
} }
elif args.profile == "high_performance": elif args.profile == "high_performance":
env_vars = { env_vars = {
'TOKENIZERS_PARALLELISM': 'true', 'TOKENIZERS_PARALLELISM': 'true',
'TOKENIZERS_FAST': '1', 'TOKENIZERS_FAST': '1',
'TOOL_CACHE_MAX_SIZE': '50', 'TOOL_CACHE_MAX_SIZE': '30',
'CHECKPOINT_POOL_SIZE': '20',
} }
# 通用优化 # 通用优化

View File

@ -1,9 +1,16 @@
import os import os
# 必填参数
# API Settings
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
MASTERKEY = os.getenv("MASTERKEY", "master")
FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
# LLM Token Settings # LLM Token Settings
MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 262144)) MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 262144))
MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000)) MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000))
# 可选参数
# Summarization Settings # Summarization Settings
SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000 SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000
SUMMARIZATION_MESSAGES_TO_KEEP = int(os.getenv("SUMMARIZATION_MESSAGES_TO_KEEP", 20)) SUMMARIZATION_MESSAGES_TO_KEEP = int(os.getenv("SUMMARIZATION_MESSAGES_TO_KEEP", 20))
@ -13,11 +20,6 @@ TOOL_CACHE_MAX_SIZE = int(os.getenv("TOOL_CACHE_MAX_SIZE", 20))
TOOL_CACHE_TTL = int(os.getenv("TOOL_CACHE_TTL", 180)) TOOL_CACHE_TTL = int(os.getenv("TOOL_CACHE_TTL", 180))
TOOL_CACHE_AUTO_RENEW = os.getenv("TOOL_CACHE_AUTO_RENEW", "true") == "true" TOOL_CACHE_AUTO_RENEW = os.getenv("TOOL_CACHE_AUTO_RENEW", "true") == "true"
# API Settings
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
MASTERKEY = os.getenv("MASTERKEY", "master")
FASTAPI_URL = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
# Project Settings # Project Settings
PROJECT_DATA_DIR = os.getenv("PROJECT_DATA_DIR", "./projects/data") PROJECT_DATA_DIR = os.getenv("PROJECT_DATA_DIR", "./projects/data")
@ -44,7 +46,7 @@ MCP_SSE_READ_TIMEOUT = int(os.getenv("MCP_SSE_READ_TIMEOUT", 300)) # SSE 读取
# ============================================================ # ============================================================
# Checkpoint 数据库路径 # Checkpoint 数据库路径
CHECKPOINT_DB_PATH = os.getenv("CHECKPOINT_DB_PATH", "./projects/memory/checkpoints.db") CHECKPOINT_DB_PATH = os.getenv("CHECKPOINT_DB_PATH", "./projects/memory/")
# 启用 WAL 模式 (Write-Ahead Logging) # 启用 WAL 模式 (Write-Ahead Logging)
# WAL 模式允许读写并发,大幅提升并发性能 # WAL 模式允许读写并发,大幅提升并发性能