sqlite pool and change agent cache to tools cache
This commit is contained in:
parent
09a9c8be93
commit
d8dc973b95
@ -2,10 +2,12 @@
|
||||
基于内存的 Agent 缓存管理模块
|
||||
使用 cachetools 库实现 TTLCache 和 LRUCache
|
||||
"""
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
from typing import Any, Optional, Dict, Tuple
|
||||
from typing import Any, Optional, Dict, Tuple, List
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
@ -310,6 +312,48 @@ class AgentMemoryCacheManager:
|
||||
"""返回缓存中的项数"""
|
||||
return len(self.cache)
|
||||
|
||||
def get_mcp_tools(self, mcp_settings: dict) -> Optional[List]:
|
||||
"""
|
||||
获取缓存的 MCP tools
|
||||
|
||||
Args:
|
||||
mcp_settings: MCP 配置字典
|
||||
|
||||
Returns:
|
||||
缓存的 tools 列表或 None
|
||||
"""
|
||||
cache_key = self._get_mcp_cache_key(mcp_settings)
|
||||
return self.get(cache_key)
|
||||
|
||||
def set_mcp_tools(self, mcp_settings: dict, tools: List, ttl: Optional[int] = None) -> bool:
|
||||
"""
|
||||
缓存 MCP tools
|
||||
|
||||
Args:
|
||||
mcp_settings: MCP 配置字典
|
||||
tools: 要缓存的 tools 列表
|
||||
ttl: 过期时间(秒),如果为 None 则使用默认值
|
||||
|
||||
Returns:
|
||||
是否成功设置缓存
|
||||
"""
|
||||
cache_key = self._get_mcp_cache_key(mcp_settings)
|
||||
return self.set(cache_key, tools, ttl=ttl)
|
||||
|
||||
def _get_mcp_cache_key(self, mcp_settings: dict) -> str:
|
||||
"""
|
||||
根据 mcp_settings 生成缓存键
|
||||
|
||||
Args:
|
||||
mcp_settings: MCP 配置字典
|
||||
|
||||
Returns:
|
||||
缓存键字符串
|
||||
"""
|
||||
# 将 mcp_settings 转换为 JSON 字符串并生成哈希
|
||||
settings_str = json.dumps(mcp_settings, sort_keys=True)
|
||||
return f"mcp_tools:{hashlib.md5(settings_str.encode()).hexdigest()}"
|
||||
|
||||
|
||||
# 全局缓存管理器实例
|
||||
_global_cache_manager: Optional[AgentMemoryCacheManager] = None
|
||||
|
||||
186
agent/checkpoint_manager.py
Normal file
186
agent/checkpoint_manager.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""
|
||||
全局 SQLite Checkpointer 管理器
|
||||
解决高并发场景下的数据库锁定问题
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import aiosqlite
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
|
||||
from utils.settings import (
|
||||
CHECKPOINT_DB_PATH,
|
||||
CHECKPOINT_WAL_MODE,
|
||||
CHECKPOINT_BUSY_TIMEOUT,
|
||||
CHECKPOINT_POOL_SIZE,
|
||||
)
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
|
||||
class CheckpointerManager:
|
||||
"""
|
||||
全局 Checkpointer 管理器,使用连接池复用 SQLite 连接
|
||||
|
||||
主要功能:
|
||||
1. 全局单例连接管理,避免每次请求创建新连接
|
||||
2. 预配置 WAL 模式和 busy_timeout
|
||||
3. 连接池支持高并发访问
|
||||
4. 优雅关闭机制
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._pool: asyncio.Queue[AsyncSqliteSaver] = asyncio.Queue()
|
||||
self._lock = asyncio.Lock()
|
||||
self._initialized = False
|
||||
self._closed = False
|
||||
self._pool_size = CHECKPOINT_POOL_SIZE
|
||||
self._db_path = CHECKPOINT_DB_PATH
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化连接池"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info(f"Initializing CheckpointerManager with pool_size={self._pool_size}")
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(self._db_path), exist_ok=True)
|
||||
|
||||
# 创建连接池
|
||||
for i in range(self._pool_size):
|
||||
try:
|
||||
conn = await self._create_configured_connection()
|
||||
checkpointer = AsyncSqliteSaver(conn=conn)
|
||||
# 预先调用 setup 确保表结构已创建
|
||||
await checkpointer.setup()
|
||||
await self._pool.put(checkpointer)
|
||||
logger.debug(f"Created checkpointer connection {i+1}/{self._pool_size}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create checkpointer connection {i+1}: {e}")
|
||||
raise
|
||||
|
||||
self._initialized = True
|
||||
logger.info("CheckpointerManager initialized successfully")
|
||||
|
||||
async def _create_configured_connection(self) -> aiosqlite.Connection:
|
||||
"""
|
||||
创建已配置的 SQLite 连接
|
||||
|
||||
配置包括:
|
||||
1. WAL 模式 (Write-Ahead Logging) - 允许读写并发
|
||||
2. busy_timeout - 等待锁定的最长时间
|
||||
3. 其他优化参数
|
||||
"""
|
||||
conn = aiosqlite.connect(self._db_path)
|
||||
|
||||
# 等待连接建立
|
||||
await conn.__aenter__()
|
||||
|
||||
# 设置 busy timeout(必须在连接建立后设置)
|
||||
await conn.execute(f"PRAGMA busy_timeout = {CHECKPOINT_BUSY_TIMEOUT}")
|
||||
|
||||
# 如果启用 WAL 模式
|
||||
if CHECKPOINT_WAL_MODE:
|
||||
await conn.execute("PRAGMA journal_mode = WAL")
|
||||
await conn.execute("PRAGMA synchronous = NORMAL")
|
||||
# WAL 模式下的优化配置
|
||||
await conn.execute("PRAGMA wal_autocheckpoint = 1000")
|
||||
await conn.execute("PRAGMA cache_size = -64000") # 64MB 缓存
|
||||
await conn.execute("PRAGMA temp_store = MEMORY")
|
||||
|
||||
await conn.commit()
|
||||
|
||||
return conn
|
||||
|
||||
async def acquire_for_agent(self) -> AsyncSqliteSaver:
|
||||
"""
|
||||
为 agent 获取 checkpointer
|
||||
|
||||
注意:此方法获取的 checkpointer 需要手动归还
|
||||
使用 return_to_pool() 方法归还
|
||||
|
||||
Returns:
|
||||
AsyncSqliteSaver 实例
|
||||
"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("CheckpointerManager not initialized. Call initialize() first.")
|
||||
|
||||
checkpointer = await self._pool.get()
|
||||
logger.debug(f"Acquired checkpointer from pool, remaining: {self._pool.qsize()}")
|
||||
return checkpointer
|
||||
|
||||
async def return_to_pool(self, checkpointer: AsyncSqliteSaver) -> None:
|
||||
"""
|
||||
归还 checkpointer 到池
|
||||
|
||||
Args:
|
||||
checkpointer: 要归还的 checkpointer 实例
|
||||
"""
|
||||
await self._pool.put(checkpointer)
|
||||
logger.debug(f"Returned checkpointer to pool, remaining: {self._pool.qsize()}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭所有连接"""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
logger.info("Closing CheckpointerManager...")
|
||||
|
||||
# 清空池并关闭所有连接
|
||||
while not self._pool.empty():
|
||||
try:
|
||||
checkpointer = self._pool.get_nowait()
|
||||
if checkpointer.conn:
|
||||
await checkpointer.conn.close()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
self._closed = True
|
||||
self._initialized = False
|
||||
logger.info("CheckpointerManager closed")
|
||||
|
||||
def get_pool_stats(self) -> dict:
|
||||
"""获取连接池状态统计"""
|
||||
return {
|
||||
"db_path": self._db_path,
|
||||
"pool_size": self._pool_size,
|
||||
"available_connections": self._pool.qsize(),
|
||||
"initialized": self._initialized,
|
||||
"closed": self._closed
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
_global_manager: Optional[CheckpointerManager] = None
|
||||
|
||||
|
||||
def get_checkpointer_manager() -> CheckpointerManager:
|
||||
"""获取全局 CheckpointerManager 单例"""
|
||||
global _global_manager
|
||||
if _global_manager is None:
|
||||
_global_manager = CheckpointerManager()
|
||||
return _global_manager
|
||||
|
||||
|
||||
async def init_global_checkpointer() -> None:
|
||||
"""初始化全局 checkpointer 管理器"""
|
||||
manager = get_checkpointer_manager()
|
||||
await manager.initialize()
|
||||
|
||||
|
||||
async def close_global_checkpointer() -> None:
|
||||
"""关闭全局 checkpointer 管理器"""
|
||||
global _global_manager
|
||||
if _global_manager is not None:
|
||||
await _global_manager.close()
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import copy
|
||||
from typing import Any, Dict
|
||||
from langchain.chat_models import init_chat_model
|
||||
# from deepagents import create_deep_agent
|
||||
@ -41,14 +42,26 @@ def read_mcp_settings():
|
||||
|
||||
|
||||
async def get_tools_from_mcp(mcp):
|
||||
"""从MCP配置中提取工具"""
|
||||
"""从MCP配置中提取工具(带缓存)"""
|
||||
start_time = time.time()
|
||||
# 防御式处理:确保 mcp 是列表且长度大于 0,且包含 mcpServers
|
||||
if not isinstance(mcp, list) or len(mcp) == 0 or "mcpServers" not in mcp[0]:
|
||||
logger.info(f"get_tools_from_mcp: invalid mcp config, elapsed: {time.time() - start_time:.3f}s")
|
||||
return []
|
||||
|
||||
# 修改 mcp[0]["mcpServers"] 列表,把 type 字段改成 transport
|
||||
# 尝试从缓存获取
|
||||
cache_manager = get_memory_cache_manager()
|
||||
cached_tools = cache_manager.get_mcp_tools(mcp)
|
||||
if cached_tools is not None:
|
||||
logger.info(f"get_tools_from_mcp: cached {len(cached_tools)} tools, elapsed: {time.time() - start_time:.3f}s")
|
||||
return cached_tools
|
||||
|
||||
# 深拷贝 mcp 配置,避免修改原始配置(影响缓存键)
|
||||
mcp_copy = copy.deepcopy(mcp)
|
||||
|
||||
# 修改 mcp_copy[0]["mcpServers"] 列表,把 type 字段改成 transport
|
||||
# 如果没有 transport,则根据是否存在 url 默认 transport 为 http 或 stdio
|
||||
for cfg in mcp[0]["mcpServers"].values():
|
||||
for cfg in mcp_copy[0]["mcpServers"].values():
|
||||
if "type" in cfg:
|
||||
cfg.pop("type")
|
||||
if "transport" not in cfg:
|
||||
@ -62,53 +75,49 @@ async def get_tools_from_mcp(mcp):
|
||||
if "sse_read_timeout" not in cfg:
|
||||
cfg["sse_read_timeout"] = MCP_SSE_READ_TIMEOUT
|
||||
|
||||
# 确保 mcp[0]["mcpServers"] 是字典类型
|
||||
if not isinstance(mcp[0]["mcpServers"], dict):
|
||||
# 确保 mcp_copy[0]["mcpServers"] 是字典类型
|
||||
if not isinstance(mcp_copy[0]["mcpServers"], dict):
|
||||
logger.info(f"get_tools_from_mcp: mcpServers is not dict, elapsed: {time.time() - start_time:.3f}s")
|
||||
return []
|
||||
|
||||
try:
|
||||
mcp_client = MultiServerMCPClient(mcp[0]["mcpServers"])
|
||||
mcp_client = MultiServerMCPClient(mcp_copy[0]["mcpServers"])
|
||||
mcp_tools = await mcp_client.get_tools()
|
||||
|
||||
# 缓存结果
|
||||
cache_manager.set_mcp_tools(mcp, mcp_tools)
|
||||
|
||||
logger.info(f"get_tools_from_mcp: loaded {len(mcp_tools)} tools, elapsed: {time.time() - start_time:.3f}s")
|
||||
return mcp_tools
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# 发生异常时返回空列表,避免上层调用报错
|
||||
logger.info(f"get_tools_from_mcp: error {e}, elapsed: {time.time() - start_time:.3f}s")
|
||||
return []
|
||||
|
||||
async def init_agent(config: AgentConfig):
|
||||
"""
|
||||
初始化 Agent,支持持久化内存和对话摘要
|
||||
|
||||
注意:不再缓存 agent,只缓存 mcp_tools
|
||||
返回 (agent, checkpointer) 元组,调用后需要归还 checkpointer
|
||||
|
||||
Args:
|
||||
config: AgentConfig 对象,包含所有初始化参数
|
||||
mcp: MCP配置(如果为None则使用配置中的mcp_settings)
|
||||
|
||||
Returns:
|
||||
(agent, checkpointer) 元组
|
||||
"""
|
||||
|
||||
# 初始化 checkpointer 和中间件
|
||||
|
||||
# 从连接池获取 checkpointer
|
||||
checkpointer = None
|
||||
if config.session_id:
|
||||
os.makedirs("projects/memory", exist_ok=True)
|
||||
conn = aiosqlite.connect("projects/memory/checkpoints.db")
|
||||
checkpointer = AsyncSqliteSaver(conn=conn)
|
||||
from .checkpoint_manager import get_checkpointer_manager
|
||||
manager = get_checkpointer_manager()
|
||||
checkpointer = await manager.acquire_for_agent()
|
||||
await prepare_checkpoint_message(config, checkpointer)
|
||||
# 获取缓存管理器
|
||||
cache_manager = get_memory_cache_manager()
|
||||
|
||||
# 获取唯一的缓存键
|
||||
cache_key = config.get_unique_cache_id()
|
||||
|
||||
# 如果有缓存键,检查缓存
|
||||
if cache_key:
|
||||
# 尝试从缓存中获取 agent
|
||||
cached_agent = cache_manager.get(cache_key)
|
||||
if cached_agent is not None:
|
||||
logger.info(f"Using cached agent for session: {config.session_id}, cache_key: {cache_key}")
|
||||
return cached_agent
|
||||
else:
|
||||
logger.info(f"Cache miss for session: {config.session_id}, cache_key: {cache_key}")
|
||||
|
||||
# 没有缓存或缓存已过期,创建新的 agent
|
||||
logger.info(f"Creating new agent for session: {getattr(config, 'session_id', 'no-session')}")
|
||||
|
||||
# 加载配置
|
||||
final_system_prompt = await load_system_prompt_async(
|
||||
config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier
|
||||
)
|
||||
@ -123,10 +132,11 @@ async def init_agent(config: AgentConfig):
|
||||
config.system_prompt = mcp_settings
|
||||
config.mcp_settings = system_prompt
|
||||
|
||||
# 获取 mcp_tools(缓存逻辑已内置到 get_tools_from_mcp 中)
|
||||
mcp_tools = await get_tools_from_mcp(mcp_settings)
|
||||
|
||||
# 检测或使用指定的提供商
|
||||
model_provider,base_url = detect_provider(config.model_name, config.model_server)
|
||||
model_provider, base_url = detect_provider(config.model_name, config.model_server)
|
||||
# 构建模型参数
|
||||
model_kwargs = {
|
||||
"model": config.model_name,
|
||||
@ -139,6 +149,10 @@ async def init_agent(config: AgentConfig):
|
||||
model_kwargs.update(config.generate_cfg)
|
||||
llm_instance = init_chat_model(**model_kwargs)
|
||||
|
||||
# 创建新的 agent(不再缓存)
|
||||
logger.info(f"Creating new agent for session: {getattr(config, 'session_id', 'no-session')}")
|
||||
|
||||
create_start = time.time()
|
||||
if config.robot_type == "deep_agent":
|
||||
# 使用 DeepAgentX 创建 agent
|
||||
agent, composite_backend = create_cli_agent(
|
||||
@ -160,8 +174,8 @@ async def init_agent(config: AgentConfig):
|
||||
tool_output_middleware = ToolOutputLengthMiddleware(
|
||||
max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
|
||||
truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
|
||||
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None, # 可配置特定工具
|
||||
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [], # 排除的工具
|
||||
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None,
|
||||
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [],
|
||||
preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
|
||||
preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
|
||||
)
|
||||
@ -171,7 +185,7 @@ async def init_agent(config: AgentConfig):
|
||||
summarization_middleware = SummarizationMiddleware(
|
||||
model=llm_instance,
|
||||
max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS,
|
||||
messages_to_keep=20, # 摘要后保留最近 20 条消息
|
||||
messages_to_keep=20,
|
||||
summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。"
|
||||
)
|
||||
middleware.append(summarization_middleware)
|
||||
@ -181,13 +195,7 @@ async def init_agent(config: AgentConfig):
|
||||
system_prompt=system_prompt,
|
||||
tools=mcp_tools,
|
||||
middleware=middleware,
|
||||
checkpointer=checkpointer # 传入 checkpointer 以启用持久化
|
||||
checkpointer=checkpointer
|
||||
)
|
||||
|
||||
# 如果有缓存键,将 agent 加入缓存
|
||||
if cache_key:
|
||||
# 使用 DiskCache 缓存管理器存储 agent
|
||||
cache_manager.set(cache_key, agent)
|
||||
logger.info(f"Cached agent for session: {config.session_id}, cache_key: {cache_key}")
|
||||
|
||||
return agent
|
||||
logger.info(f"create {config.robot_type} elapsed: {time.time() - create_start:.3f}s")
|
||||
return agent, checkpointer
|
||||
@ -4,6 +4,7 @@ import uuid
|
||||
import time
|
||||
import multiprocessing
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
@ -19,7 +20,26 @@ from utils.log_util.logger import init_with_fastapi
|
||||
# Import route modules
|
||||
from routes import chat, files, projects, system
|
||||
|
||||
app = FastAPI(title="Database Assistant API", version="1.0.0")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""FastAPI 应用生命周期管理"""
|
||||
# 启动时初始化
|
||||
logger.info("Starting up...")
|
||||
from agent.checkpoint_manager import init_global_checkpointer
|
||||
await init_global_checkpointer()
|
||||
logger.info("Global checkpointer initialized")
|
||||
|
||||
yield
|
||||
|
||||
# 关闭时清理
|
||||
logger.info("Shutting down...")
|
||||
from agent.checkpoint_manager import close_global_checkpointer
|
||||
await close_global_checkpointer()
|
||||
logger.info("Global checkpointer closed")
|
||||
|
||||
|
||||
app = FastAPI(title="Database Assistant API", version="1.0.0", lifespan=lifespan)
|
||||
|
||||
init_with_fastapi(app)
|
||||
|
||||
|
||||
@ -64,12 +64,13 @@ async def enhanced_generate_stream_response(
|
||||
|
||||
# Agent 任务(准备 + 流式处理)
|
||||
async def agent_task():
|
||||
checkpointer = None
|
||||
try:
|
||||
# 开始流式处理
|
||||
logger.info(f"Starting agent stream response")
|
||||
chunk_id = 0
|
||||
message_tag = ""
|
||||
agent = await init_agent(config)
|
||||
agent, checkpointer = await init_agent(config)
|
||||
async for msg, metadata in agent.astream({"messages": config.messages}, stream_mode="messages", config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS):
|
||||
new_content = ""
|
||||
|
||||
@ -115,6 +116,11 @@ async def enhanced_generate_stream_response(
|
||||
except Exception as e:
|
||||
logger.error(f"Error in agent task: {e}")
|
||||
await output_queue.put(("agent_done", None))
|
||||
finally:
|
||||
if checkpointer:
|
||||
from agent.checkpoint_manager import get_checkpointer_manager
|
||||
manager = get_checkpointer_manager()
|
||||
await manager.return_to_pool(checkpointer)
|
||||
|
||||
# 并发执行任务
|
||||
# 只有在 enable_thinking 为 True 时才执行 preamble 任务
|
||||
@ -203,7 +209,8 @@ async def create_agent_and_generate_response(
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||||
)
|
||||
|
||||
agent = await init_agent(config)
|
||||
agent, checkpointer = await init_agent(config)
|
||||
try:
|
||||
# 使用更新后的 messages
|
||||
agent_responses = await agent.ainvoke({"messages": config.messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
|
||||
append_messages = agent_responses["messages"][len(config.messages):]
|
||||
@ -212,7 +219,7 @@ async def create_agent_and_generate_response(
|
||||
if isinstance(msg,AIMessage):
|
||||
if len(msg.text)>0:
|
||||
meta_message_tag = msg.additional_kwargs.get("message_tag", "ANSWER")
|
||||
output_text = msg.text.replace("<think>","").replace("</think>","") if meta_message_tag == "THINK" else msg.text
|
||||
output_text = msg.text.replace("````","").replace("````","") if meta_message_tag == "THINK" else msg.text
|
||||
response_text += f"[{meta_message_tag}]\n"+output_text+ "\n"
|
||||
if len(msg.tool_calls)>0:
|
||||
response_text += "".join([f"[TOOL_CALL] {tool['name']}\n{json.dumps(tool["args"]) if isinstance(tool["args"],dict) else tool["args"]}\n" for tool in msg.tool_calls])
|
||||
@ -221,7 +228,7 @@ async def create_agent_and_generate_response(
|
||||
|
||||
if len(response_text) > 0:
|
||||
# 构造OpenAI格式的响应
|
||||
return ChatResponse(
|
||||
result = ChatResponse(
|
||||
choices=[{
|
||||
"index": 0,
|
||||
"message": {
|
||||
@ -238,6 +245,13 @@ async def create_agent_and_generate_response(
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="No response from agent")
|
||||
finally:
|
||||
if checkpointer:
|
||||
from agent.checkpoint_manager import get_checkpointer_manager
|
||||
manager = get_checkpointer_manager()
|
||||
await manager.return_to_pool(checkpointer)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/api/v1/chat/completions")
|
||||
@ -348,18 +362,27 @@ async def chat_warmup_v1(request: ChatRequest, authorization: Optional[str] = He
|
||||
# 创建 AgentConfig 对象
|
||||
config = AgentConfig.from_v1_request(request, api_key, project_dir, generate_cfg, messages)
|
||||
|
||||
# 预热:初始化agent(这会触发缓存)
|
||||
logger.info(f"Warming up agent for bot_id: {bot_id}")
|
||||
agent = await init_agent(config)
|
||||
# 预热 mcp_tools 缓存
|
||||
logger.info(f"Warming up mcp_tools for bot_id: {bot_id}")
|
||||
from agent.deep_assistant import get_tools_from_mcp
|
||||
from agent.prompt_loader import load_mcp_settings_async
|
||||
|
||||
# 获取缓存键
|
||||
cache_key = config.get_unique_cache_id() if hasattr(config, 'get_unique_cache_id') else None
|
||||
# 加载 mcp_settings
|
||||
final_mcp_settings = await load_mcp_settings_async(
|
||||
config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
|
||||
)
|
||||
mcp_settings = final_mcp_settings if final_mcp_settings else []
|
||||
if not isinstance(mcp_settings, list) or len(mcp_settings) == 0:
|
||||
mcp_settings = []
|
||||
|
||||
# 预热 mcp_tools(缓存逻辑已内置到 get_tools_from_mcp)
|
||||
mcp_tools = await get_tools_from_mcp(mcp_settings)
|
||||
|
||||
return {
|
||||
"status": "warmed_up",
|
||||
"bot_id": bot_id,
|
||||
"cache_key": cache_key,
|
||||
"message": "Agent has been initialized and cached successfully"
|
||||
"mcp_tools_count": len(mcp_tools),
|
||||
"message": "MCP tools have been cached successfully"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@ -435,18 +458,27 @@ async def chat_warmup_v2(request: ChatRequestV2, authorization: Optional[str] =
|
||||
# 创建 AgentConfig 对象
|
||||
config = AgentConfig.from_v2_request(request, bot_config, project_dir, messages)
|
||||
|
||||
# 预热:初始化agent(这会触发缓存)
|
||||
logger.info(f"Warming up agent for bot_id: {bot_id}")
|
||||
agent = await init_agent(config)
|
||||
# 预热 mcp_tools 缓存
|
||||
logger.info(f"Warming up mcp_tools for bot_id: {bot_id}")
|
||||
from agent.deep_assistant import get_tools_from_mcp
|
||||
from agent.prompt_loader import load_mcp_settings_async
|
||||
|
||||
# 获取缓存键
|
||||
cache_key = config.get_unique_cache_id() if hasattr(config, 'get_unique_cache_id') else None
|
||||
# 加载 mcp_settings
|
||||
final_mcp_settings = await load_mcp_settings_async(
|
||||
config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
|
||||
)
|
||||
mcp_settings = final_mcp_settings if final_mcp_settings else []
|
||||
if not isinstance(mcp_settings, list) or len(mcp_settings) == 0:
|
||||
mcp_settings = []
|
||||
|
||||
# 预热 mcp_tools(缓存逻辑已内置到 get_tools_from_mcp)
|
||||
mcp_tools = await get_tools_from_mcp(mcp_settings)
|
||||
|
||||
return {
|
||||
"status": "warmed_up",
|
||||
"bot_id": bot_id,
|
||||
"cache_key": cache_key,
|
||||
"message": "Agent has been initialized and cached successfully"
|
||||
"mcp_tools_count": len(mcp_tools),
|
||||
"message": "MCP tools have been cached successfully"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@ -45,9 +45,12 @@ class Formatter(logging.Formatter):
|
||||
return s
|
||||
|
||||
def format(self, record):
|
||||
# 处理 trace_id
|
||||
# 处理 trace_id - 在没有请求上下文时使用默认值
|
||||
if not hasattr(record, "trace_id"):
|
||||
try:
|
||||
record.trace_id = getattr(g, "trace_id")
|
||||
except LookupError:
|
||||
record.trace_id = "N/A"
|
||||
# 处理 user_id
|
||||
# if not hasattr(record, "user_id"):
|
||||
# record.user_id = getattr(g, "user_id")
|
||||
|
||||
@ -37,4 +37,24 @@ DEFAULT_THINKING_ENABLE = os.getenv("DEFAULT_THINKING_ENABLE", "true") == "true"
|
||||
MCP_HTTP_TIMEOUT = int(os.getenv("MCP_HTTP_TIMEOUT", 60)) # HTTP 请求超时(秒)
|
||||
MCP_SSE_READ_TIMEOUT = int(os.getenv("MCP_SSE_READ_TIMEOUT", 300)) # SSE 读取超时(秒)
|
||||
|
||||
# ============================================================
|
||||
# SQLite Checkpoint Configuration
|
||||
# ============================================================
|
||||
|
||||
# Checkpoint 数据库路径
|
||||
CHECKPOINT_DB_PATH = os.getenv("CHECKPOINT_DB_PATH", "./projects/memory/checkpoints.db")
|
||||
|
||||
# 启用 WAL 模式 (Write-Ahead Logging)
|
||||
# WAL 模式允许读写并发,大幅提升并发性能
|
||||
CHECKPOINT_WAL_MODE = os.getenv("CHECKPOINT_WAL_MODE", "true") == "true"
|
||||
|
||||
# Busy Timeout (毫秒)
|
||||
# 当数据库被锁定时,等待的最长时间(毫秒)
|
||||
CHECKPOINT_BUSY_TIMEOUT = int(os.getenv("CHECKPOINT_BUSY_TIMEOUT", "10000"))
|
||||
|
||||
|
||||
# 连接池大小
|
||||
# 同时可以持有的最大连接数
|
||||
CHECKPOINT_POOL_SIZE = int(os.getenv("CHECKPOINT_POOL_SIZE", "30"))
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user