deep_agent支持 checkpoint
This commit is contained in:
parent
b93c40d5a5
commit
174a5e2059
@ -30,6 +30,8 @@ from langchain_core.language_models import BaseChatModel
|
|||||||
from langgraph.pregel import Pregel
|
from langgraph.pregel import Pregel
|
||||||
from deepagents_cli.shell import ShellMiddleware
|
from deepagents_cli.shell import ShellMiddleware
|
||||||
from deepagents_cli.agent_memory import AgentMemoryMiddleware
|
from deepagents_cli.agent_memory import AgentMemoryMiddleware
|
||||||
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
from langgraph.types import Checkpointer
|
||||||
from deepagents_cli.skills import SkillsMiddleware
|
from deepagents_cli.skills import SkillsMiddleware
|
||||||
from deepagents_cli.config import settings, get_default_coding_instructions
|
from deepagents_cli.config import settings, get_default_coding_instructions
|
||||||
import os
|
import os
|
||||||
@ -161,6 +163,30 @@ async def init_agent(config: AgentConfig):
|
|||||||
|
|
||||||
checkpointer = None
|
checkpointer = None
|
||||||
create_start = time.time()
|
create_start = time.time()
|
||||||
|
|
||||||
|
# 构建中间件列表
|
||||||
|
middleware = []
|
||||||
|
# 首先添加 ToolUseCleanupMiddleware 来清理孤立的 tool_use
|
||||||
|
middleware.append(ToolUseCleanupMiddleware())
|
||||||
|
# 添加工具输出长度控制中间件
|
||||||
|
tool_output_middleware = ToolOutputLengthMiddleware(
|
||||||
|
max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
|
||||||
|
truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
|
||||||
|
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None,
|
||||||
|
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [],
|
||||||
|
preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
|
||||||
|
preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
|
||||||
|
)
|
||||||
|
middleware.append(tool_output_middleware)
|
||||||
|
|
||||||
|
# 从连接池获取 checkpointer
|
||||||
|
if config.session_id:
|
||||||
|
from .checkpoint_manager import get_checkpointer_manager
|
||||||
|
manager = get_checkpointer_manager()
|
||||||
|
checkpointer = manager.checkpointer
|
||||||
|
await prepare_checkpoint_message(config, checkpointer)
|
||||||
|
|
||||||
|
|
||||||
if config.robot_type == "deep_agent":
|
if config.robot_type == "deep_agent":
|
||||||
# 使用 DeepAgentX 创建 agent,自定义 workspace_root
|
# 使用 DeepAgentX 创建 agent,自定义 workspace_root
|
||||||
workspace_root = f"projects/robot/{config.bot_id}"
|
workspace_root = f"projects/robot/{config.bot_id}"
|
||||||
@ -172,34 +198,16 @@ async def init_agent(config: AgentConfig):
|
|||||||
tools=mcp_tools,
|
tools=mcp_tools,
|
||||||
auto_approve=True,
|
auto_approve=True,
|
||||||
enable_memory=False,
|
enable_memory=False,
|
||||||
workspace_root=workspace_root
|
workspace_root=workspace_root,
|
||||||
|
middleware=middleware,
|
||||||
|
checkpointer=checkpointer
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 构建中间件列表
|
|
||||||
middleware = []
|
|
||||||
# 首先添加 ToolUseCleanupMiddleware 来清理孤立的 tool_use
|
|
||||||
middleware.append(ToolUseCleanupMiddleware())
|
|
||||||
# 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware
|
# 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware
|
||||||
if config.enable_thinking:
|
if config.enable_thinking:
|
||||||
middleware.append(GuidelineMiddleware(llm_instance, config, system_prompt))
|
middleware.append(GuidelineMiddleware(llm_instance, config, system_prompt))
|
||||||
|
|
||||||
# 添加工具输出长度控制中间件
|
|
||||||
tool_output_middleware = ToolOutputLengthMiddleware(
|
|
||||||
max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
|
|
||||||
truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
|
|
||||||
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None,
|
|
||||||
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [],
|
|
||||||
preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
|
|
||||||
preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
|
|
||||||
)
|
|
||||||
middleware.append(tool_output_middleware)
|
|
||||||
|
|
||||||
# 从连接池获取 checkpointer
|
|
||||||
if config.session_id:
|
if config.session_id:
|
||||||
from .checkpoint_manager import get_checkpointer_manager
|
|
||||||
manager = get_checkpointer_manager()
|
|
||||||
checkpointer = manager.checkpointer
|
|
||||||
await prepare_checkpoint_message(config, checkpointer)
|
|
||||||
summarization_middleware = SummarizationMiddleware(
|
summarization_middleware = SummarizationMiddleware(
|
||||||
model=llm_instance,
|
model=llm_instance,
|
||||||
max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS,
|
max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS,
|
||||||
@ -317,7 +325,9 @@ def create_custom_cli_agent(
|
|||||||
enable_memory: bool = True,
|
enable_memory: bool = True,
|
||||||
enable_skills: bool = True,
|
enable_skills: bool = True,
|
||||||
enable_shell: bool = True,
|
enable_shell: bool = True,
|
||||||
|
middleware: list[AgentMiddleware] = [],
|
||||||
workspace_root: str | None = None,
|
workspace_root: str | None = None,
|
||||||
|
checkpointer: Checkpointer | None = None,
|
||||||
) -> tuple[Pregel, CompositeBackend]:
|
) -> tuple[Pregel, CompositeBackend]:
|
||||||
"""Create a CLI-configured agent with custom workspace_root for shell commands.
|
"""Create a CLI-configured agent with custom workspace_root for shell commands.
|
||||||
|
|
||||||
@ -358,7 +368,7 @@ def create_custom_cli_agent(
|
|||||||
agent_md.write_text(source_content)
|
agent_md.write_text(source_content)
|
||||||
|
|
||||||
# Build middleware stack based on enabled features
|
# Build middleware stack based on enabled features
|
||||||
agent_middleware = []
|
agent_middleware = middleware
|
||||||
|
|
||||||
# CONDITIONAL SETUP: Local vs Remote Sandbox
|
# CONDITIONAL SETUP: Local vs Remote Sandbox
|
||||||
if sandbox is None:
|
if sandbox is None:
|
||||||
@ -453,6 +463,6 @@ def create_custom_cli_agent(
|
|||||||
backend=composite_backend,
|
backend=composite_backend,
|
||||||
middleware=agent_middleware,
|
middleware=agent_middleware,
|
||||||
interrupt_on=interrupt_on,
|
interrupt_on=interrupt_on,
|
||||||
checkpointer=InMemorySaver(),
|
checkpointer=checkpointer,
|
||||||
).with_config(config)
|
).with_config(config)
|
||||||
return agent, composite_backend
|
return agent, composite_backend
|
||||||
@ -17,7 +17,7 @@ from utils.fastapi_utils import (
|
|||||||
call_preamble_llm,
|
call_preamble_llm,
|
||||||
create_stream_chunk
|
create_stream_chunk
|
||||||
)
|
)
|
||||||
from langchain_core.messages import AIMessageChunk, ToolMessage, AIMessage
|
from langchain_core.messages import AIMessageChunk, ToolMessage, AIMessage, HumanMessage
|
||||||
from utils.settings import MAX_OUTPUT_TOKENS
|
from utils.settings import MAX_OUTPUT_TOKENS
|
||||||
from agent.agent_config import AgentConfig
|
from agent.agent_config import AgentConfig
|
||||||
from agent.deep_assistant import init_agent
|
from agent.deep_assistant import init_agent
|
||||||
@ -201,7 +201,20 @@ async def create_agent_and_generate_response(
|
|||||||
agent, checkpointer = await init_agent(config)
|
agent, checkpointer = await init_agent(config)
|
||||||
# 使用更新后的 messages
|
# 使用更新后的 messages
|
||||||
agent_responses = await agent.ainvoke({"messages": config.messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
|
agent_responses = await agent.ainvoke({"messages": config.messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
|
||||||
append_messages = agent_responses["messages"][len(config.messages):]
|
|
||||||
|
# 从后往前找第一个 HumanMessage,之后的内容都给 append_messages
|
||||||
|
all_messages = agent_responses["messages"]
|
||||||
|
first_human_idx = None
|
||||||
|
for i in range(len(all_messages) - 1, -1, -1):
|
||||||
|
if isinstance(all_messages[i], HumanMessage):
|
||||||
|
first_human_idx = i
|
||||||
|
break
|
||||||
|
|
||||||
|
if first_human_idx is not None:
|
||||||
|
append_messages = all_messages[first_human_idx + 1:]
|
||||||
|
else:
|
||||||
|
# 如果没找到 HumanMessage,取所有消息
|
||||||
|
append_messages = all_messages
|
||||||
response_text = ""
|
response_text = ""
|
||||||
for msg in append_messages:
|
for msg in append_messages:
|
||||||
if isinstance(msg,AIMessage):
|
if isinstance(msg,AIMessage):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user