diff --git a/agent/checkpoint_utils.py b/agent/checkpoint_utils.py index a346070..d4ee1da 100644 --- a/agent/checkpoint_utils.py +++ b/agent/checkpoint_utils.py @@ -30,41 +30,13 @@ async def check_checkpoint_history(checkpointer: MemorySaver, thread_id: str) -> logger.debug(f"Checkpointer type: {type(checkpointer)}") logger.debug(f"Checkpointer dir: {[attr for attr in dir(checkpointer) if not attr.startswith('_')]}") - # 先尝试获取最新的 checkpoint - try: - latest_checkpoint = await checkpointer.aget_tuple(config) - logger.debug(f"aget_tuple result: {latest_checkpoint}") - - if latest_checkpoint is not None: - logger.info(f"Found latest checkpoint for thread_id: {thread_id}") - # 解构 checkpoint tuple - checkpoint_config, checkpoint, metadata = latest_checkpoint - logger.debug(f"Checkpoint metadata: {metadata}") - return True - except Exception as e: - logger.warning(f"aget_tuple failed: {e}") - - # 如果没有最新的,再列出所有 - logger.debug(f"No latest checkpoint for thread_id: {thread_id}, checking all checkpoints...") - try: - checkpoints = [] - async for c in checkpointer.alist(config): - checkpoints.append(c) - logger.debug(f"Found checkpoint: {c}") - - # 如果有至少一个 checkpoint,说明有历史记录 - has_history = len(checkpoints) > 0 - - if has_history: - logger.info(f"Found {len(checkpoints)} checkpoints in total for thread_id: {thread_id}") - else: - logger.info(f"No existing history for thread_id: {thread_id}") - - return has_history - except Exception as e: - logger.warning(f"alist failed: {e}") - return False + latest_checkpoint = await checkpointer.aget_tuple(config) + logger.debug(f"aget_tuple result: {latest_checkpoint}") + if latest_checkpoint is not None: + logger.info(f"Found latest checkpoint for thread_id: {thread_id}") + # 解构 checkpoint tuple + return True except Exception as e: import traceback logger.error(f"Error checking checkpoint history for thread_id {thread_id}: {e}") diff --git a/agent/deep_assistant.py b/agent/deep_assistant.py index 4373888..2a70cc3 100644 --- a/agent/deep_assistant.py +++ b/agent/deep_assistant.py @@ -11,6 +11,9 @@ from langchain_mcp_adapters.client import MultiServerMCPClient from langgraph.checkpoint.memory import MemorySaver from utils.fastapi_utils import detect_provider +# 全局 MemorySaver 实例 +_global_checkpointer = MemorySaver() + from .guideline_middleware import GuidelineMiddleware from .tool_output_length_middleware import ToolOutputLengthMiddleware from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH @@ -117,7 +120,7 @@ async def init_agent(config: AgentConfig): checkpointer = None if config.session_id: - checkpointer = MemorySaver() + checkpointer = _global_checkpointer summarization_middleware = SummarizationMiddleware( model=llm_instance, max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS,