deep_agent支持 checkpoint
This commit is contained in:
parent
b93c40d5a5
commit
174a5e2059
@ -30,6 +30,8 @@ from langchain_core.language_models import BaseChatModel
|
||||
from langgraph.pregel import Pregel
|
||||
from deepagents_cli.shell import ShellMiddleware
|
||||
from deepagents_cli.agent_memory import AgentMemoryMiddleware
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.types import Checkpointer
|
||||
from deepagents_cli.skills import SkillsMiddleware
|
||||
from deepagents_cli.config import settings, get_default_coding_instructions
|
||||
import os
|
||||
@ -161,6 +163,30 @@ async def init_agent(config: AgentConfig):
|
||||
|
||||
checkpointer = None
|
||||
create_start = time.time()
|
||||
|
||||
# 构建中间件列表
|
||||
middleware = []
|
||||
# 首先添加 ToolUseCleanupMiddleware 来清理孤立的 tool_use
|
||||
middleware.append(ToolUseCleanupMiddleware())
|
||||
# 添加工具输出长度控制中间件
|
||||
tool_output_middleware = ToolOutputLengthMiddleware(
|
||||
max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
|
||||
truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
|
||||
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None,
|
||||
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [],
|
||||
preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
|
||||
preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
|
||||
)
|
||||
middleware.append(tool_output_middleware)
|
||||
|
||||
# 从连接池获取 checkpointer
|
||||
if config.session_id:
|
||||
from .checkpoint_manager import get_checkpointer_manager
|
||||
manager = get_checkpointer_manager()
|
||||
checkpointer = manager.checkpointer
|
||||
await prepare_checkpoint_message(config, checkpointer)
|
||||
|
||||
|
||||
if config.robot_type == "deep_agent":
|
||||
# 使用 DeepAgentX 创建 agent,自定义 workspace_root
|
||||
workspace_root = f"projects/robot/{config.bot_id}"
|
||||
@ -172,34 +198,16 @@ async def init_agent(config: AgentConfig):
|
||||
tools=mcp_tools,
|
||||
auto_approve=True,
|
||||
enable_memory=False,
|
||||
workspace_root=workspace_root
|
||||
workspace_root=workspace_root,
|
||||
middleware=middleware,
|
||||
checkpointer=checkpointer
|
||||
)
|
||||
else:
|
||||
# 构建中间件列表
|
||||
middleware = []
|
||||
# 首先添加 ToolUseCleanupMiddleware 来清理孤立的 tool_use
|
||||
middleware.append(ToolUseCleanupMiddleware())
|
||||
# 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware
|
||||
if config.enable_thinking:
|
||||
middleware.append(GuidelineMiddleware(llm_instance, config, system_prompt))
|
||||
|
||||
# 添加工具输出长度控制中间件
|
||||
tool_output_middleware = ToolOutputLengthMiddleware(
|
||||
max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
|
||||
truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
|
||||
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None,
|
||||
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [],
|
||||
preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
|
||||
preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
|
||||
)
|
||||
middleware.append(tool_output_middleware)
|
||||
|
||||
# 从连接池获取 checkpointer
|
||||
if config.session_id:
|
||||
from .checkpoint_manager import get_checkpointer_manager
|
||||
manager = get_checkpointer_manager()
|
||||
checkpointer = manager.checkpointer
|
||||
await prepare_checkpoint_message(config, checkpointer)
|
||||
summarization_middleware = SummarizationMiddleware(
|
||||
model=llm_instance,
|
||||
max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS,
|
||||
@ -317,7 +325,9 @@ def create_custom_cli_agent(
|
||||
enable_memory: bool = True,
|
||||
enable_skills: bool = True,
|
||||
enable_shell: bool = True,
|
||||
middleware: list[AgentMiddleware] = [],
|
||||
workspace_root: str | None = None,
|
||||
checkpointer: Checkpointer | None = None,
|
||||
) -> tuple[Pregel, CompositeBackend]:
|
||||
"""Create a CLI-configured agent with custom workspace_root for shell commands.
|
||||
|
||||
@ -358,7 +368,7 @@ def create_custom_cli_agent(
|
||||
agent_md.write_text(source_content)
|
||||
|
||||
# Build middleware stack based on enabled features
|
||||
agent_middleware = []
|
||||
agent_middleware = middleware
|
||||
|
||||
# CONDITIONAL SETUP: Local vs Remote Sandbox
|
||||
if sandbox is None:
|
||||
@ -453,6 +463,6 @@ def create_custom_cli_agent(
|
||||
backend=composite_backend,
|
||||
middleware=agent_middleware,
|
||||
interrupt_on=interrupt_on,
|
||||
checkpointer=InMemorySaver(),
|
||||
checkpointer=checkpointer,
|
||||
).with_config(config)
|
||||
return agent, composite_backend
|
||||
@ -17,7 +17,7 @@ from utils.fastapi_utils import (
|
||||
call_preamble_llm,
|
||||
create_stream_chunk
|
||||
)
|
||||
from langchain_core.messages import AIMessageChunk, ToolMessage, AIMessage
|
||||
from langchain_core.messages import AIMessageChunk, ToolMessage, AIMessage, HumanMessage
|
||||
from utils.settings import MAX_OUTPUT_TOKENS
|
||||
from agent.agent_config import AgentConfig
|
||||
from agent.deep_assistant import init_agent
|
||||
@ -201,7 +201,20 @@ async def create_agent_and_generate_response(
|
||||
agent, checkpointer = await init_agent(config)
|
||||
# 使用更新后的 messages
|
||||
agent_responses = await agent.ainvoke({"messages": config.messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
|
||||
append_messages = agent_responses["messages"][len(config.messages):]
|
||||
|
||||
# 从后往前找第一个 HumanMessage,之后的内容都给 append_messages
|
||||
all_messages = agent_responses["messages"]
|
||||
first_human_idx = None
|
||||
for i in range(len(all_messages) - 1, -1, -1):
|
||||
if isinstance(all_messages[i], HumanMessage):
|
||||
first_human_idx = i
|
||||
break
|
||||
|
||||
if first_human_idx is not None:
|
||||
append_messages = all_messages[first_human_idx + 1:]
|
||||
else:
|
||||
# 如果没找到 HumanMessage,取所有消息
|
||||
append_messages = all_messages
|
||||
response_text = ""
|
||||
for msg in append_messages:
|
||||
if isinstance(msg,AIMessage):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user