deep_agent支持 checkpoint

This commit is contained in:
朱潮 2026-01-11 00:08:19 +08:00
parent b93c40d5a5
commit 174a5e2059
2 changed files with 48 additions and 25 deletions

View File

@ -30,6 +30,8 @@ from langchain_core.language_models import BaseChatModel
from langgraph.pregel import Pregel from langgraph.pregel import Pregel
from deepagents_cli.shell import ShellMiddleware from deepagents_cli.shell import ShellMiddleware
from deepagents_cli.agent_memory import AgentMemoryMiddleware from deepagents_cli.agent_memory import AgentMemoryMiddleware
from langchain.agents.middleware import AgentMiddleware
from langgraph.types import Checkpointer
from deepagents_cli.skills import SkillsMiddleware from deepagents_cli.skills import SkillsMiddleware
from deepagents_cli.config import settings, get_default_coding_instructions from deepagents_cli.config import settings, get_default_coding_instructions
import os import os
@ -161,6 +163,30 @@ async def init_agent(config: AgentConfig):
checkpointer = None checkpointer = None
create_start = time.time() create_start = time.time()
# 构建中间件列表
middleware = []
# 首先添加 ToolUseCleanupMiddleware 来清理孤立的 tool_use
middleware.append(ToolUseCleanupMiddleware())
# 添加工具输出长度控制中间件
tool_output_middleware = ToolOutputLengthMiddleware(
max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None,
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [],
preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
)
middleware.append(tool_output_middleware)
# 从连接池获取 checkpointer
if config.session_id:
from .checkpoint_manager import get_checkpointer_manager
manager = get_checkpointer_manager()
checkpointer = manager.checkpointer
await prepare_checkpoint_message(config, checkpointer)
if config.robot_type == "deep_agent": if config.robot_type == "deep_agent":
# 使用 DeepAgentX 创建 agent自定义 workspace_root # 使用 DeepAgentX 创建 agent自定义 workspace_root
workspace_root = f"projects/robot/{config.bot_id}" workspace_root = f"projects/robot/{config.bot_id}"
@ -172,34 +198,16 @@ async def init_agent(config: AgentConfig):
tools=mcp_tools, tools=mcp_tools,
auto_approve=True, auto_approve=True,
enable_memory=False, enable_memory=False,
workspace_root=workspace_root workspace_root=workspace_root,
middleware=middleware,
checkpointer=checkpointer
) )
else: else:
# 构建中间件列表
middleware = []
# 首先添加 ToolUseCleanupMiddleware 来清理孤立的 tool_use
middleware.append(ToolUseCleanupMiddleware())
# 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware # 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware
if config.enable_thinking: if config.enable_thinking:
middleware.append(GuidelineMiddleware(llm_instance, config, system_prompt)) middleware.append(GuidelineMiddleware(llm_instance, config, system_prompt))
# 添加工具输出长度控制中间件
tool_output_middleware = ToolOutputLengthMiddleware(
max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None,
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [],
preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
)
middleware.append(tool_output_middleware)
# 从连接池获取 checkpointer
if config.session_id: if config.session_id:
from .checkpoint_manager import get_checkpointer_manager
manager = get_checkpointer_manager()
checkpointer = manager.checkpointer
await prepare_checkpoint_message(config, checkpointer)
summarization_middleware = SummarizationMiddleware( summarization_middleware = SummarizationMiddleware(
model=llm_instance, model=llm_instance,
max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS, max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS,
@ -317,7 +325,9 @@ def create_custom_cli_agent(
enable_memory: bool = True, enable_memory: bool = True,
enable_skills: bool = True, enable_skills: bool = True,
enable_shell: bool = True, enable_shell: bool = True,
middleware: list[AgentMiddleware] = [],
workspace_root: str | None = None, workspace_root: str | None = None,
checkpointer: Checkpointer | None = None,
) -> tuple[Pregel, CompositeBackend]: ) -> tuple[Pregel, CompositeBackend]:
"""Create a CLI-configured agent with custom workspace_root for shell commands. """Create a CLI-configured agent with custom workspace_root for shell commands.
@ -358,7 +368,7 @@ def create_custom_cli_agent(
agent_md.write_text(source_content) agent_md.write_text(source_content)
# Build middleware stack based on enabled features # Build middleware stack based on enabled features
agent_middleware = [] agent_middleware = middleware
# CONDITIONAL SETUP: Local vs Remote Sandbox # CONDITIONAL SETUP: Local vs Remote Sandbox
if sandbox is None: if sandbox is None:
@ -453,6 +463,6 @@ def create_custom_cli_agent(
backend=composite_backend, backend=composite_backend,
middleware=agent_middleware, middleware=agent_middleware,
interrupt_on=interrupt_on, interrupt_on=interrupt_on,
checkpointer=InMemorySaver(), checkpointer=checkpointer,
).with_config(config) ).with_config(config)
return agent, composite_backend return agent, composite_backend

View File

@ -17,7 +17,7 @@ from utils.fastapi_utils import (
call_preamble_llm, call_preamble_llm,
create_stream_chunk create_stream_chunk
) )
from langchain_core.messages import AIMessageChunk, ToolMessage, AIMessage from langchain_core.messages import AIMessageChunk, ToolMessage, AIMessage, HumanMessage
from utils.settings import MAX_OUTPUT_TOKENS from utils.settings import MAX_OUTPUT_TOKENS
from agent.agent_config import AgentConfig from agent.agent_config import AgentConfig
from agent.deep_assistant import init_agent from agent.deep_assistant import init_agent
@ -201,7 +201,20 @@ async def create_agent_and_generate_response(
agent, checkpointer = await init_agent(config) agent, checkpointer = await init_agent(config)
# 使用更新后的 messages # 使用更新后的 messages
agent_responses = await agent.ainvoke({"messages": config.messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS) agent_responses = await agent.ainvoke({"messages": config.messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
append_messages = agent_responses["messages"][len(config.messages):]
# 从后往前找第一个 HumanMessage之后的内容都给 append_messages
all_messages = agent_responses["messages"]
first_human_idx = None
for i in range(len(all_messages) - 1, -1, -1):
if isinstance(all_messages[i], HumanMessage):
first_human_idx = i
break
if first_human_idx is not None:
append_messages = all_messages[first_human_idx + 1:]
else:
# 如果没找到 HumanMessage取所有消息
append_messages = all_messages
response_text = "" response_text = ""
for msg in append_messages: for msg in append_messages:
if isinstance(msg,AIMessage): if isinstance(msg,AIMessage):