deep_agent支持 checkpoint

This commit is contained in:
朱潮 2026-01-11 00:08:19 +08:00
parent b93c40d5a5
commit 174a5e2059
2 changed files with 48 additions and 25 deletions

View File

@ -30,6 +30,8 @@ from langchain_core.language_models import BaseChatModel
from langgraph.pregel import Pregel
from deepagents_cli.shell import ShellMiddleware
from deepagents_cli.agent_memory import AgentMemoryMiddleware
from langchain.agents.middleware import AgentMiddleware
from langgraph.types import Checkpointer
from deepagents_cli.skills import SkillsMiddleware
from deepagents_cli.config import settings, get_default_coding_instructions
import os
@ -161,6 +163,30 @@ async def init_agent(config: AgentConfig):
checkpointer = None
create_start = time.time()
# 构建中间件列表
middleware = []
# 首先添加 ToolUseCleanupMiddleware 来清理孤立的 tool_use
middleware.append(ToolUseCleanupMiddleware())
# 添加工具输出长度控制中间件
tool_output_middleware = ToolOutputLengthMiddleware(
max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None,
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [],
preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
)
middleware.append(tool_output_middleware)
# 从连接池获取 checkpointer
if config.session_id:
from .checkpoint_manager import get_checkpointer_manager
manager = get_checkpointer_manager()
checkpointer = manager.checkpointer
await prepare_checkpoint_message(config, checkpointer)
if config.robot_type == "deep_agent":
# 使用 DeepAgentX 创建 agent自定义 workspace_root
workspace_root = f"projects/robot/{config.bot_id}"
@ -172,34 +198,16 @@ async def init_agent(config: AgentConfig):
tools=mcp_tools,
auto_approve=True,
enable_memory=False,
workspace_root=workspace_root
workspace_root=workspace_root,
middleware=middleware,
checkpointer=checkpointer
)
else:
# 构建中间件列表
middleware = []
# 首先添加 ToolUseCleanupMiddleware 来清理孤立的 tool_use
middleware.append(ToolUseCleanupMiddleware())
# 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware
if config.enable_thinking:
middleware.append(GuidelineMiddleware(llm_instance, config, system_prompt))
# 添加工具输出长度控制中间件
tool_output_middleware = ToolOutputLengthMiddleware(
max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None,
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [],
preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
)
middleware.append(tool_output_middleware)
# 从连接池获取 checkpointer
if config.session_id:
from .checkpoint_manager import get_checkpointer_manager
manager = get_checkpointer_manager()
checkpointer = manager.checkpointer
await prepare_checkpoint_message(config, checkpointer)
summarization_middleware = SummarizationMiddleware(
model=llm_instance,
max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS,
@ -317,7 +325,9 @@ def create_custom_cli_agent(
enable_memory: bool = True,
enable_skills: bool = True,
enable_shell: bool = True,
middleware: list[AgentMiddleware] = [],
workspace_root: str | None = None,
checkpointer: Checkpointer | None = None,
) -> tuple[Pregel, CompositeBackend]:
"""Create a CLI-configured agent with custom workspace_root for shell commands.
@ -358,7 +368,7 @@ def create_custom_cli_agent(
agent_md.write_text(source_content)
# Build middleware stack based on enabled features
agent_middleware = []
agent_middleware = middleware
# CONDITIONAL SETUP: Local vs Remote Sandbox
if sandbox is None:
@ -453,6 +463,6 @@ def create_custom_cli_agent(
backend=composite_backend,
middleware=agent_middleware,
interrupt_on=interrupt_on,
checkpointer=InMemorySaver(),
checkpointer=checkpointer,
).with_config(config)
return agent, composite_backend

View File

@ -17,7 +17,7 @@ from utils.fastapi_utils import (
call_preamble_llm,
create_stream_chunk
)
from langchain_core.messages import AIMessageChunk, ToolMessage, AIMessage
from langchain_core.messages import AIMessageChunk, ToolMessage, AIMessage, HumanMessage
from utils.settings import MAX_OUTPUT_TOKENS
from agent.agent_config import AgentConfig
from agent.deep_assistant import init_agent
@ -201,7 +201,20 @@ async def create_agent_and_generate_response(
agent, checkpointer = await init_agent(config)
# 使用更新后的 messages
agent_responses = await agent.ainvoke({"messages": config.messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
append_messages = agent_responses["messages"][len(config.messages):]
# 从后往前找第一个 HumanMessage之后的内容都给 append_messages
all_messages = agent_responses["messages"]
first_human_idx = None
for i in range(len(all_messages) - 1, -1, -1):
if isinstance(all_messages[i], HumanMessage):
first_human_idx = i
break
if first_human_idx is not None:
append_messages = all_messages[first_human_idx + 1:]
else:
# 如果没找到 HumanMessage取所有消息
append_messages = all_messages
response_text = ""
for msg in append_messages:
if isinstance(msg,AIMessage):