sqlite pool and change agent cache to tools cache
This commit is contained in:
parent
09a9c8be93
commit
d8dc973b95
@ -2,10 +2,12 @@
|
|||||||
基于内存的 Agent 缓存管理模块
|
基于内存的 Agent 缓存管理模块
|
||||||
使用 cachetools 库实现 TTLCache 和 LRUCache
|
使用 cachetools 库实现 TTLCache 和 LRUCache
|
||||||
"""
|
"""
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Optional, Dict, Tuple
|
from typing import Any, Optional, Dict, Tuple, List
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
@ -310,6 +312,48 @@ class AgentMemoryCacheManager:
|
|||||||
"""返回缓存中的项数"""
|
"""返回缓存中的项数"""
|
||||||
return len(self.cache)
|
return len(self.cache)
|
||||||
|
|
||||||
|
def get_mcp_tools(self, mcp_settings: dict) -> Optional[List]:
|
||||||
|
"""
|
||||||
|
获取缓存的 MCP tools
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_settings: MCP 配置字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存的 tools 列表或 None
|
||||||
|
"""
|
||||||
|
cache_key = self._get_mcp_cache_key(mcp_settings)
|
||||||
|
return self.get(cache_key)
|
||||||
|
|
||||||
|
def set_mcp_tools(self, mcp_settings: dict, tools: List, ttl: Optional[int] = None) -> bool:
|
||||||
|
"""
|
||||||
|
缓存 MCP tools
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_settings: MCP 配置字典
|
||||||
|
tools: 要缓存的 tools 列表
|
||||||
|
ttl: 过期时间(秒),如果为 None 则使用默认值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功设置缓存
|
||||||
|
"""
|
||||||
|
cache_key = self._get_mcp_cache_key(mcp_settings)
|
||||||
|
return self.set(cache_key, tools, ttl=ttl)
|
||||||
|
|
||||||
|
def _get_mcp_cache_key(self, mcp_settings: dict) -> str:
|
||||||
|
"""
|
||||||
|
根据 mcp_settings 生成缓存键
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_settings: MCP 配置字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存键字符串
|
||||||
|
"""
|
||||||
|
# 将 mcp_settings 转换为 JSON 字符串并生成哈希
|
||||||
|
settings_str = json.dumps(mcp_settings, sort_keys=True)
|
||||||
|
return f"mcp_tools:{hashlib.md5(settings_str.encode()).hexdigest()}"
|
||||||
|
|
||||||
|
|
||||||
# 全局缓存管理器实例
|
# 全局缓存管理器实例
|
||||||
_global_cache_manager: Optional[AgentMemoryCacheManager] = None
|
_global_cache_manager: Optional[AgentMemoryCacheManager] = None
|
||||||
|
|||||||
186
agent/checkpoint_manager.py
Normal file
186
agent/checkpoint_manager.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
"""
|
||||||
|
全局 SQLite Checkpointer 管理器
|
||||||
|
解决高并发场景下的数据库锁定问题
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||||
|
|
||||||
|
from utils.settings import (
|
||||||
|
CHECKPOINT_DB_PATH,
|
||||||
|
CHECKPOINT_WAL_MODE,
|
||||||
|
CHECKPOINT_BUSY_TIMEOUT,
|
||||||
|
CHECKPOINT_POOL_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger('app')
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointerManager:
|
||||||
|
"""
|
||||||
|
全局 Checkpointer 管理器,使用连接池复用 SQLite 连接
|
||||||
|
|
||||||
|
主要功能:
|
||||||
|
1. 全局单例连接管理,避免每次请求创建新连接
|
||||||
|
2. 预配置 WAL 模式和 busy_timeout
|
||||||
|
3. 连接池支持高并发访问
|
||||||
|
4. 优雅关闭机制
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._pool: asyncio.Queue[AsyncSqliteSaver] = asyncio.Queue()
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._initialized = False
|
||||||
|
self._closed = False
|
||||||
|
self._pool_size = CHECKPOINT_POOL_SIZE
|
||||||
|
self._db_path = CHECKPOINT_DB_PATH
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""初始化连接池"""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Initializing CheckpointerManager with pool_size={self._pool_size}")
|
||||||
|
|
||||||
|
# 确保目录存在
|
||||||
|
os.makedirs(os.path.dirname(self._db_path), exist_ok=True)
|
||||||
|
|
||||||
|
# 创建连接池
|
||||||
|
for i in range(self._pool_size):
|
||||||
|
try:
|
||||||
|
conn = await self._create_configured_connection()
|
||||||
|
checkpointer = AsyncSqliteSaver(conn=conn)
|
||||||
|
# 预先调用 setup 确保表结构已创建
|
||||||
|
await checkpointer.setup()
|
||||||
|
await self._pool.put(checkpointer)
|
||||||
|
logger.debug(f"Created checkpointer connection {i+1}/{self._pool_size}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create checkpointer connection {i+1}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
logger.info("CheckpointerManager initialized successfully")
|
||||||
|
|
||||||
|
async def _create_configured_connection(self) -> aiosqlite.Connection:
|
||||||
|
"""
|
||||||
|
创建已配置的 SQLite 连接
|
||||||
|
|
||||||
|
配置包括:
|
||||||
|
1. WAL 模式 (Write-Ahead Logging) - 允许读写并发
|
||||||
|
2. busy_timeout - 等待锁定的最长时间
|
||||||
|
3. 其他优化参数
|
||||||
|
"""
|
||||||
|
conn = aiosqlite.connect(self._db_path)
|
||||||
|
|
||||||
|
# 等待连接建立
|
||||||
|
await conn.__aenter__()
|
||||||
|
|
||||||
|
# 设置 busy timeout(必须在连接建立后设置)
|
||||||
|
await conn.execute(f"PRAGMA busy_timeout = {CHECKPOINT_BUSY_TIMEOUT}")
|
||||||
|
|
||||||
|
# 如果启用 WAL 模式
|
||||||
|
if CHECKPOINT_WAL_MODE:
|
||||||
|
await conn.execute("PRAGMA journal_mode = WAL")
|
||||||
|
await conn.execute("PRAGMA synchronous = NORMAL")
|
||||||
|
# WAL 模式下的优化配置
|
||||||
|
await conn.execute("PRAGMA wal_autocheckpoint = 1000")
|
||||||
|
await conn.execute("PRAGMA cache_size = -64000") # 64MB 缓存
|
||||||
|
await conn.execute("PRAGMA temp_store = MEMORY")
|
||||||
|
|
||||||
|
await conn.commit()
|
||||||
|
|
||||||
|
return conn
|
||||||
|
|
||||||
|
async def acquire_for_agent(self) -> AsyncSqliteSaver:
|
||||||
|
"""
|
||||||
|
为 agent 获取 checkpointer
|
||||||
|
|
||||||
|
注意:此方法获取的 checkpointer 需要手动归还
|
||||||
|
使用 return_to_pool() 方法归还
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncSqliteSaver 实例
|
||||||
|
"""
|
||||||
|
if not self._initialized:
|
||||||
|
raise RuntimeError("CheckpointerManager not initialized. Call initialize() first.")
|
||||||
|
|
||||||
|
checkpointer = await self._pool.get()
|
||||||
|
logger.debug(f"Acquired checkpointer from pool, remaining: {self._pool.qsize()}")
|
||||||
|
return checkpointer
|
||||||
|
|
||||||
|
async def return_to_pool(self, checkpointer: AsyncSqliteSaver) -> None:
|
||||||
|
"""
|
||||||
|
归还 checkpointer 到池
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpointer: 要归还的 checkpointer 实例
|
||||||
|
"""
|
||||||
|
await self._pool.put(checkpointer)
|
||||||
|
logger.debug(f"Returned checkpointer to pool, remaining: {self._pool.qsize()}")
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""关闭所有连接"""
|
||||||
|
if self._closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
if self._closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Closing CheckpointerManager...")
|
||||||
|
|
||||||
|
# 清空池并关闭所有连接
|
||||||
|
while not self._pool.empty():
|
||||||
|
try:
|
||||||
|
checkpointer = self._pool.get_nowait()
|
||||||
|
if checkpointer.conn:
|
||||||
|
await checkpointer.conn.close()
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
|
||||||
|
self._closed = True
|
||||||
|
self._initialized = False
|
||||||
|
logger.info("CheckpointerManager closed")
|
||||||
|
|
||||||
|
def get_pool_stats(self) -> dict:
|
||||||
|
"""获取连接池状态统计"""
|
||||||
|
return {
|
||||||
|
"db_path": self._db_path,
|
||||||
|
"pool_size": self._pool_size,
|
||||||
|
"available_connections": self._pool.qsize(),
|
||||||
|
"initialized": self._initialized,
|
||||||
|
"closed": self._closed
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_global_manager: Optional[CheckpointerManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_checkpointer_manager() -> CheckpointerManager:
|
||||||
|
"""获取全局 CheckpointerManager 单例"""
|
||||||
|
global _global_manager
|
||||||
|
if _global_manager is None:
|
||||||
|
_global_manager = CheckpointerManager()
|
||||||
|
return _global_manager
|
||||||
|
|
||||||
|
|
||||||
|
async def init_global_checkpointer() -> None:
|
||||||
|
"""初始化全局 checkpointer 管理器"""
|
||||||
|
manager = get_checkpointer_manager()
|
||||||
|
await manager.initialize()
|
||||||
|
|
||||||
|
|
||||||
|
async def close_global_checkpointer() -> None:
|
||||||
|
"""关闭全局 checkpointer 管理器"""
|
||||||
|
global _global_manager
|
||||||
|
if _global_manager is not None:
|
||||||
|
await _global_manager.close()
|
||||||
@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import copy
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
# from deepagents import create_deep_agent
|
# from deepagents import create_deep_agent
|
||||||
@ -41,14 +42,26 @@ def read_mcp_settings():
|
|||||||
|
|
||||||
|
|
||||||
async def get_tools_from_mcp(mcp):
|
async def get_tools_from_mcp(mcp):
|
||||||
"""从MCP配置中提取工具"""
|
"""从MCP配置中提取工具(带缓存)"""
|
||||||
|
start_time = time.time()
|
||||||
# 防御式处理:确保 mcp 是列表且长度大于 0,且包含 mcpServers
|
# 防御式处理:确保 mcp 是列表且长度大于 0,且包含 mcpServers
|
||||||
if not isinstance(mcp, list) or len(mcp) == 0 or "mcpServers" not in mcp[0]:
|
if not isinstance(mcp, list) or len(mcp) == 0 or "mcpServers" not in mcp[0]:
|
||||||
|
logger.info(f"get_tools_from_mcp: invalid mcp config, elapsed: {time.time() - start_time:.3f}s")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 修改 mcp[0]["mcpServers"] 列表,把 type 字段改成 transport
|
# 尝试从缓存获取
|
||||||
|
cache_manager = get_memory_cache_manager()
|
||||||
|
cached_tools = cache_manager.get_mcp_tools(mcp)
|
||||||
|
if cached_tools is not None:
|
||||||
|
logger.info(f"get_tools_from_mcp: cached {len(cached_tools)} tools, elapsed: {time.time() - start_time:.3f}s")
|
||||||
|
return cached_tools
|
||||||
|
|
||||||
|
# 深拷贝 mcp 配置,避免修改原始配置(影响缓存键)
|
||||||
|
mcp_copy = copy.deepcopy(mcp)
|
||||||
|
|
||||||
|
# 修改 mcp_copy[0]["mcpServers"] 列表,把 type 字段改成 transport
|
||||||
# 如果没有 transport,则根据是否存在 url 默认 transport 为 http 或 stdio
|
# 如果没有 transport,则根据是否存在 url 默认 transport 为 http 或 stdio
|
||||||
for cfg in mcp[0]["mcpServers"].values():
|
for cfg in mcp_copy[0]["mcpServers"].values():
|
||||||
if "type" in cfg:
|
if "type" in cfg:
|
||||||
cfg.pop("type")
|
cfg.pop("type")
|
||||||
if "transport" not in cfg:
|
if "transport" not in cfg:
|
||||||
@ -62,53 +75,49 @@ async def get_tools_from_mcp(mcp):
|
|||||||
if "sse_read_timeout" not in cfg:
|
if "sse_read_timeout" not in cfg:
|
||||||
cfg["sse_read_timeout"] = MCP_SSE_READ_TIMEOUT
|
cfg["sse_read_timeout"] = MCP_SSE_READ_TIMEOUT
|
||||||
|
|
||||||
# 确保 mcp[0]["mcpServers"] 是字典类型
|
# 确保 mcp_copy[0]["mcpServers"] 是字典类型
|
||||||
if not isinstance(mcp[0]["mcpServers"], dict):
|
if not isinstance(mcp_copy[0]["mcpServers"], dict):
|
||||||
|
logger.info(f"get_tools_from_mcp: mcpServers is not dict, elapsed: {time.time() - start_time:.3f}s")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mcp_client = MultiServerMCPClient(mcp[0]["mcpServers"])
|
mcp_client = MultiServerMCPClient(mcp_copy[0]["mcpServers"])
|
||||||
mcp_tools = await mcp_client.get_tools()
|
mcp_tools = await mcp_client.get_tools()
|
||||||
|
|
||||||
|
# 缓存结果
|
||||||
|
cache_manager.set_mcp_tools(mcp, mcp_tools)
|
||||||
|
|
||||||
|
logger.info(f"get_tools_from_mcp: loaded {len(mcp_tools)} tools, elapsed: {time.time() - start_time:.3f}s")
|
||||||
return mcp_tools
|
return mcp_tools
|
||||||
except Exception:
|
except Exception as e:
|
||||||
# 发生异常时返回空列表,避免上层调用报错
|
# 发生异常时返回空列表,避免上层调用报错
|
||||||
|
logger.info(f"get_tools_from_mcp: error {e}, elapsed: {time.time() - start_time:.3f}s")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def init_agent(config: AgentConfig):
|
async def init_agent(config: AgentConfig):
|
||||||
"""
|
"""
|
||||||
初始化 Agent,支持持久化内存和对话摘要
|
初始化 Agent,支持持久化内存和对话摘要
|
||||||
|
|
||||||
|
注意:不再缓存 agent,只缓存 mcp_tools
|
||||||
|
返回 (agent, checkpointer) 元组,调用后需要归还 checkpointer
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: AgentConfig 对象,包含所有初始化参数
|
config: AgentConfig 对象,包含所有初始化参数
|
||||||
mcp: MCP配置(如果为None则使用配置中的mcp_settings)
|
|
||||||
|
Returns:
|
||||||
|
(agent, checkpointer) 元组
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 初始化 checkpointer 和中间件
|
|
||||||
|
# 从连接池获取 checkpointer
|
||||||
checkpointer = None
|
checkpointer = None
|
||||||
if config.session_id:
|
if config.session_id:
|
||||||
os.makedirs("projects/memory", exist_ok=True)
|
from .checkpoint_manager import get_checkpointer_manager
|
||||||
conn = aiosqlite.connect("projects/memory/checkpoints.db")
|
manager = get_checkpointer_manager()
|
||||||
checkpointer = AsyncSqliteSaver(conn=conn)
|
checkpointer = await manager.acquire_for_agent()
|
||||||
await prepare_checkpoint_message(config, checkpointer)
|
await prepare_checkpoint_message(config, checkpointer)
|
||||||
# 获取缓存管理器
|
|
||||||
cache_manager = get_memory_cache_manager()
|
|
||||||
|
|
||||||
# 获取唯一的缓存键
|
|
||||||
cache_key = config.get_unique_cache_id()
|
|
||||||
|
|
||||||
# 如果有缓存键,检查缓存
|
|
||||||
if cache_key:
|
|
||||||
# 尝试从缓存中获取 agent
|
|
||||||
cached_agent = cache_manager.get(cache_key)
|
|
||||||
if cached_agent is not None:
|
|
||||||
logger.info(f"Using cached agent for session: {config.session_id}, cache_key: {cache_key}")
|
|
||||||
return cached_agent
|
|
||||||
else:
|
|
||||||
logger.info(f"Cache miss for session: {config.session_id}, cache_key: {cache_key}")
|
|
||||||
|
|
||||||
# 没有缓存或缓存已过期,创建新的 agent
|
|
||||||
logger.info(f"Creating new agent for session: {getattr(config, 'session_id', 'no-session')}")
|
|
||||||
|
|
||||||
|
# 加载配置
|
||||||
final_system_prompt = await load_system_prompt_async(
|
final_system_prompt = await load_system_prompt_async(
|
||||||
config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier
|
config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier
|
||||||
)
|
)
|
||||||
@ -123,10 +132,11 @@ async def init_agent(config: AgentConfig):
|
|||||||
config.system_prompt = mcp_settings
|
config.system_prompt = mcp_settings
|
||||||
config.mcp_settings = system_prompt
|
config.mcp_settings = system_prompt
|
||||||
|
|
||||||
|
# 获取 mcp_tools(缓存逻辑已内置到 get_tools_from_mcp 中)
|
||||||
mcp_tools = await get_tools_from_mcp(mcp_settings)
|
mcp_tools = await get_tools_from_mcp(mcp_settings)
|
||||||
|
|
||||||
# 检测或使用指定的提供商
|
# 检测或使用指定的提供商
|
||||||
model_provider,base_url = detect_provider(config.model_name, config.model_server)
|
model_provider, base_url = detect_provider(config.model_name, config.model_server)
|
||||||
# 构建模型参数
|
# 构建模型参数
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
"model": config.model_name,
|
"model": config.model_name,
|
||||||
@ -139,6 +149,10 @@ async def init_agent(config: AgentConfig):
|
|||||||
model_kwargs.update(config.generate_cfg)
|
model_kwargs.update(config.generate_cfg)
|
||||||
llm_instance = init_chat_model(**model_kwargs)
|
llm_instance = init_chat_model(**model_kwargs)
|
||||||
|
|
||||||
|
# 创建新的 agent(不再缓存)
|
||||||
|
logger.info(f"Creating new agent for session: {getattr(config, 'session_id', 'no-session')}")
|
||||||
|
|
||||||
|
create_start = time.time()
|
||||||
if config.robot_type == "deep_agent":
|
if config.robot_type == "deep_agent":
|
||||||
# 使用 DeepAgentX 创建 agent
|
# 使用 DeepAgentX 创建 agent
|
||||||
agent, composite_backend = create_cli_agent(
|
agent, composite_backend = create_cli_agent(
|
||||||
@ -160,8 +174,8 @@ async def init_agent(config: AgentConfig):
|
|||||||
tool_output_middleware = ToolOutputLengthMiddleware(
|
tool_output_middleware = ToolOutputLengthMiddleware(
|
||||||
max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
|
max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
|
||||||
truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
|
truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
|
||||||
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None, # 可配置特定工具
|
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None,
|
||||||
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [], # 排除的工具
|
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [],
|
||||||
preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
|
preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
|
||||||
preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
|
preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
|
||||||
)
|
)
|
||||||
@ -171,7 +185,7 @@ async def init_agent(config: AgentConfig):
|
|||||||
summarization_middleware = SummarizationMiddleware(
|
summarization_middleware = SummarizationMiddleware(
|
||||||
model=llm_instance,
|
model=llm_instance,
|
||||||
max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS,
|
max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS,
|
||||||
messages_to_keep=20, # 摘要后保留最近 20 条消息
|
messages_to_keep=20,
|
||||||
summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。"
|
summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。"
|
||||||
)
|
)
|
||||||
middleware.append(summarization_middleware)
|
middleware.append(summarization_middleware)
|
||||||
@ -181,13 +195,7 @@ async def init_agent(config: AgentConfig):
|
|||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
tools=mcp_tools,
|
tools=mcp_tools,
|
||||||
middleware=middleware,
|
middleware=middleware,
|
||||||
checkpointer=checkpointer # 传入 checkpointer 以启用持久化
|
checkpointer=checkpointer
|
||||||
)
|
)
|
||||||
|
logger.info(f"create {config.robot_type} elapsed: {time.time() - create_start:.3f}s")
|
||||||
# 如果有缓存键,将 agent 加入缓存
|
return agent, checkpointer
|
||||||
if cache_key:
|
|
||||||
# 使用 DiskCache 缓存管理器存储 agent
|
|
||||||
cache_manager.set(cache_key, agent)
|
|
||||||
logger.info(f"Cached agent for session: {config.session_id}, cache_key: {cache_key}")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
@ -4,6 +4,7 @@ import uuid
|
|||||||
import time
|
import time
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import sys
|
import sys
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
@ -19,7 +20,26 @@ from utils.log_util.logger import init_with_fastapi
|
|||||||
# Import route modules
|
# Import route modules
|
||||||
from routes import chat, files, projects, system
|
from routes import chat, files, projects, system
|
||||||
|
|
||||||
app = FastAPI(title="Database Assistant API", version="1.0.0")
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""FastAPI 应用生命周期管理"""
|
||||||
|
# 启动时初始化
|
||||||
|
logger.info("Starting up...")
|
||||||
|
from agent.checkpoint_manager import init_global_checkpointer
|
||||||
|
await init_global_checkpointer()
|
||||||
|
logger.info("Global checkpointer initialized")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# 关闭时清理
|
||||||
|
logger.info("Shutting down...")
|
||||||
|
from agent.checkpoint_manager import close_global_checkpointer
|
||||||
|
await close_global_checkpointer()
|
||||||
|
logger.info("Global checkpointer closed")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="Database Assistant API", version="1.0.0", lifespan=lifespan)
|
||||||
|
|
||||||
init_with_fastapi(app)
|
init_with_fastapi(app)
|
||||||
|
|
||||||
|
|||||||
130
routes/chat.py
130
routes/chat.py
@ -64,12 +64,13 @@ async def enhanced_generate_stream_response(
|
|||||||
|
|
||||||
# Agent 任务(准备 + 流式处理)
|
# Agent 任务(准备 + 流式处理)
|
||||||
async def agent_task():
|
async def agent_task():
|
||||||
|
checkpointer = None
|
||||||
try:
|
try:
|
||||||
# 开始流式处理
|
# 开始流式处理
|
||||||
logger.info(f"Starting agent stream response")
|
logger.info(f"Starting agent stream response")
|
||||||
chunk_id = 0
|
chunk_id = 0
|
||||||
message_tag = ""
|
message_tag = ""
|
||||||
agent = await init_agent(config)
|
agent, checkpointer = await init_agent(config)
|
||||||
async for msg, metadata in agent.astream({"messages": config.messages}, stream_mode="messages", config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS):
|
async for msg, metadata in agent.astream({"messages": config.messages}, stream_mode="messages", config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS):
|
||||||
new_content = ""
|
new_content = ""
|
||||||
|
|
||||||
@ -115,6 +116,11 @@ async def enhanced_generate_stream_response(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in agent task: {e}")
|
logger.error(f"Error in agent task: {e}")
|
||||||
await output_queue.put(("agent_done", None))
|
await output_queue.put(("agent_done", None))
|
||||||
|
finally:
|
||||||
|
if checkpointer:
|
||||||
|
from agent.checkpoint_manager import get_checkpointer_manager
|
||||||
|
manager = get_checkpointer_manager()
|
||||||
|
await manager.return_to_pool(checkpointer)
|
||||||
|
|
||||||
# 并发执行任务
|
# 并发执行任务
|
||||||
# 只有在 enable_thinking 为 True 时才执行 preamble 任务
|
# 只有在 enable_thinking 为 True 时才执行 preamble 任务
|
||||||
@ -203,41 +209,49 @@ async def create_agent_and_generate_response(
|
|||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = await init_agent(config)
|
agent, checkpointer = await init_agent(config)
|
||||||
# 使用更新后的 messages
|
try:
|
||||||
agent_responses = await agent.ainvoke({"messages": config.messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
|
# 使用更新后的 messages
|
||||||
append_messages = agent_responses["messages"][len(config.messages):]
|
agent_responses = await agent.ainvoke({"messages": config.messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
|
||||||
response_text = ""
|
append_messages = agent_responses["messages"][len(config.messages):]
|
||||||
for msg in append_messages:
|
response_text = ""
|
||||||
if isinstance(msg,AIMessage):
|
for msg in append_messages:
|
||||||
if len(msg.text)>0:
|
if isinstance(msg,AIMessage):
|
||||||
meta_message_tag = msg.additional_kwargs.get("message_tag", "ANSWER")
|
if len(msg.text)>0:
|
||||||
output_text = msg.text.replace("<think>","").replace("</think>","") if meta_message_tag == "THINK" else msg.text
|
meta_message_tag = msg.additional_kwargs.get("message_tag", "ANSWER")
|
||||||
response_text += f"[{meta_message_tag}]\n"+output_text+ "\n"
|
output_text = msg.text.replace("````","").replace("````","") if meta_message_tag == "THINK" else msg.text
|
||||||
if len(msg.tool_calls)>0:
|
response_text += f"[{meta_message_tag}]\n"+output_text+ "\n"
|
||||||
response_text += "".join([f"[TOOL_CALL] {tool['name']}\n{json.dumps(tool["args"]) if isinstance(tool["args"],dict) else tool["args"]}\n" for tool in msg.tool_calls])
|
if len(msg.tool_calls)>0:
|
||||||
elif isinstance(msg,ToolMessage) and config.tool_response:
|
response_text += "".join([f"[TOOL_CALL] {tool['name']}\n{json.dumps(tool["args"]) if isinstance(tool["args"],dict) else tool["args"]}\n" for tool in msg.tool_calls])
|
||||||
response_text += f"[TOOL_RESPONSE] {msg.name}\n{msg.text}\n"
|
elif isinstance(msg,ToolMessage) and config.tool_response:
|
||||||
|
response_text += f"[TOOL_RESPONSE] {msg.name}\n{msg.text}\n"
|
||||||
|
|
||||||
if len(response_text) > 0:
|
if len(response_text) > 0:
|
||||||
# 构造OpenAI格式的响应
|
# 构造OpenAI格式的响应
|
||||||
return ChatResponse(
|
result = ChatResponse(
|
||||||
choices=[{
|
choices=[{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": response_text
|
"content": response_text
|
||||||
},
|
},
|
||||||
"finish_reason": "stop"
|
"finish_reason": "stop"
|
||||||
}],
|
}],
|
||||||
usage={
|
usage={
|
||||||
"prompt_tokens": sum(len(msg.get("content", "")) for msg in config.messages),
|
"prompt_tokens": sum(len(msg.get("content", "")) for msg in config.messages),
|
||||||
"completion_tokens": len(response_text),
|
"completion_tokens": len(response_text),
|
||||||
"total_tokens": sum(len(msg.get("content", "")) for msg in config.messages) + len(response_text)
|
"total_tokens": sum(len(msg.get("content", "")) for msg in config.messages) + len(response_text)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=500, detail="No response from agent")
|
raise HTTPException(status_code=500, detail="No response from agent")
|
||||||
|
finally:
|
||||||
|
if checkpointer:
|
||||||
|
from agent.checkpoint_manager import get_checkpointer_manager
|
||||||
|
manager = get_checkpointer_manager()
|
||||||
|
await manager.return_to_pool(checkpointer)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.post("/api/v1/chat/completions")
|
@router.post("/api/v1/chat/completions")
|
||||||
@ -348,18 +362,27 @@ async def chat_warmup_v1(request: ChatRequest, authorization: Optional[str] = He
|
|||||||
# 创建 AgentConfig 对象
|
# 创建 AgentConfig 对象
|
||||||
config = AgentConfig.from_v1_request(request, api_key, project_dir, generate_cfg, messages)
|
config = AgentConfig.from_v1_request(request, api_key, project_dir, generate_cfg, messages)
|
||||||
|
|
||||||
# 预热:初始化agent(这会触发缓存)
|
# 预热 mcp_tools 缓存
|
||||||
logger.info(f"Warming up agent for bot_id: {bot_id}")
|
logger.info(f"Warming up mcp_tools for bot_id: {bot_id}")
|
||||||
agent = await init_agent(config)
|
from agent.deep_assistant import get_tools_from_mcp
|
||||||
|
from agent.prompt_loader import load_mcp_settings_async
|
||||||
|
|
||||||
# 获取缓存键
|
# 加载 mcp_settings
|
||||||
cache_key = config.get_unique_cache_id() if hasattr(config, 'get_unique_cache_id') else None
|
final_mcp_settings = await load_mcp_settings_async(
|
||||||
|
config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
|
||||||
|
)
|
||||||
|
mcp_settings = final_mcp_settings if final_mcp_settings else []
|
||||||
|
if not isinstance(mcp_settings, list) or len(mcp_settings) == 0:
|
||||||
|
mcp_settings = []
|
||||||
|
|
||||||
|
# 预热 mcp_tools(缓存逻辑已内置到 get_tools_from_mcp)
|
||||||
|
mcp_tools = await get_tools_from_mcp(mcp_settings)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "warmed_up",
|
"status": "warmed_up",
|
||||||
"bot_id": bot_id,
|
"bot_id": bot_id,
|
||||||
"cache_key": cache_key,
|
"mcp_tools_count": len(mcp_tools),
|
||||||
"message": "Agent has been initialized and cached successfully"
|
"message": "MCP tools have been cached successfully"
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -435,18 +458,27 @@ async def chat_warmup_v2(request: ChatRequestV2, authorization: Optional[str] =
|
|||||||
# 创建 AgentConfig 对象
|
# 创建 AgentConfig 对象
|
||||||
config = AgentConfig.from_v2_request(request, bot_config, project_dir, messages)
|
config = AgentConfig.from_v2_request(request, bot_config, project_dir, messages)
|
||||||
|
|
||||||
# 预热:初始化agent(这会触发缓存)
|
# 预热 mcp_tools 缓存
|
||||||
logger.info(f"Warming up agent for bot_id: {bot_id}")
|
logger.info(f"Warming up mcp_tools for bot_id: {bot_id}")
|
||||||
agent = await init_agent(config)
|
from agent.deep_assistant import get_tools_from_mcp
|
||||||
|
from agent.prompt_loader import load_mcp_settings_async
|
||||||
|
|
||||||
# 获取缓存键
|
# 加载 mcp_settings
|
||||||
cache_key = config.get_unique_cache_id() if hasattr(config, 'get_unique_cache_id') else None
|
final_mcp_settings = await load_mcp_settings_async(
|
||||||
|
config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
|
||||||
|
)
|
||||||
|
mcp_settings = final_mcp_settings if final_mcp_settings else []
|
||||||
|
if not isinstance(mcp_settings, list) or len(mcp_settings) == 0:
|
||||||
|
mcp_settings = []
|
||||||
|
|
||||||
|
# 预热 mcp_tools(缓存逻辑已内置到 get_tools_from_mcp)
|
||||||
|
mcp_tools = await get_tools_from_mcp(mcp_settings)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "warmed_up",
|
"status": "warmed_up",
|
||||||
"bot_id": bot_id,
|
"bot_id": bot_id,
|
||||||
"cache_key": cache_key,
|
"mcp_tools_count": len(mcp_tools),
|
||||||
"message": "Agent has been initialized and cached successfully"
|
"message": "MCP tools have been cached successfully"
|
||||||
}
|
}
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@ -45,9 +45,12 @@ class Formatter(logging.Formatter):
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
def format(self, record):
|
def format(self, record):
|
||||||
# 处理 trace_id
|
# 处理 trace_id - 在没有请求上下文时使用默认值
|
||||||
if not hasattr(record, "trace_id"):
|
if not hasattr(record, "trace_id"):
|
||||||
record.trace_id = getattr(g, "trace_id")
|
try:
|
||||||
|
record.trace_id = getattr(g, "trace_id")
|
||||||
|
except LookupError:
|
||||||
|
record.trace_id = "N/A"
|
||||||
# 处理 user_id
|
# 处理 user_id
|
||||||
# if not hasattr(record, "user_id"):
|
# if not hasattr(record, "user_id"):
|
||||||
# record.user_id = getattr(g, "user_id")
|
# record.user_id = getattr(g, "user_id")
|
||||||
|
|||||||
@ -37,4 +37,24 @@ DEFAULT_THINKING_ENABLE = os.getenv("DEFAULT_THINKING_ENABLE", "true") == "true"
|
|||||||
MCP_HTTP_TIMEOUT = int(os.getenv("MCP_HTTP_TIMEOUT", 60)) # HTTP 请求超时(秒)
|
MCP_HTTP_TIMEOUT = int(os.getenv("MCP_HTTP_TIMEOUT", 60)) # HTTP 请求超时(秒)
|
||||||
MCP_SSE_READ_TIMEOUT = int(os.getenv("MCP_SSE_READ_TIMEOUT", 300)) # SSE 读取超时(秒)
|
MCP_SSE_READ_TIMEOUT = int(os.getenv("MCP_SSE_READ_TIMEOUT", 300)) # SSE 读取超时(秒)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# SQLite Checkpoint Configuration
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# Checkpoint 数据库路径
|
||||||
|
CHECKPOINT_DB_PATH = os.getenv("CHECKPOINT_DB_PATH", "./projects/memory/checkpoints.db")
|
||||||
|
|
||||||
|
# 启用 WAL 模式 (Write-Ahead Logging)
|
||||||
|
# WAL 模式允许读写并发,大幅提升并发性能
|
||||||
|
CHECKPOINT_WAL_MODE = os.getenv("CHECKPOINT_WAL_MODE", "true") == "true"
|
||||||
|
|
||||||
|
# Busy Timeout (毫秒)
|
||||||
|
# 当数据库被锁定时,等待的最长时间(毫秒)
|
||||||
|
CHECKPOINT_BUSY_TIMEOUT = int(os.getenv("CHECKPOINT_BUSY_TIMEOUT", "10000"))
|
||||||
|
|
||||||
|
|
||||||
|
# 连接池大小
|
||||||
|
# 同时可以持有的最大连接数
|
||||||
|
CHECKPOINT_POOL_SIZE = int(os.getenv("CHECKPOINT_POOL_SIZE", "30"))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user