qwen_agent/agent/checkpoint_utils.py
2025-12-18 00:38:04 +08:00

108 lines
3.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""用于处理 LangGraph checkpoint 相关的工具函数"""
import logging
from typing import List, Dict, Any, Optional
from langgraph.checkpoint.memory import MemorySaver
logger = logging.getLogger('app')
async def check_checkpoint_history(checkpointer: MemorySaver, thread_id: str) -> bool:
"""
检查指定的 thread_id 在 checkpointer 中是否已有历史记录
Args:
checkpointer: MemorySaver 实例
thread_id: 线程ID通常是 session_id
Returns:
bool: True 表示有历史记录False 表示没有
"""
if not checkpointer or not thread_id:
logger.debug(f"No checkpointer or thread_id: checkpointer={bool(checkpointer)}, thread_id={thread_id}")
return False
try:
# 获取配置
config = {"configurable": {"thread_id": thread_id}}
# 调试信息:检查 checkpointer 类型
logger.debug(f"Checkpointer type: {type(checkpointer)}")
logger.debug(f"Checkpointer dir: {[attr for attr in dir(checkpointer) if not attr.startswith('_')]}")
latest_checkpoint = await checkpointer.aget_tuple(config)
logger.debug(f"aget_tuple result: {latest_checkpoint}")
if latest_checkpoint is not None:
logger.info(f"Found latest checkpoint for thread_id: {thread_id}")
# 解构 checkpoint tuple
return True
except Exception as e:
import traceback
logger.error(f"Error checking checkpoint history for thread_id {thread_id}: {e}")
logger.error(f"Full traceback: {traceback.format_exc()}")
# 出错时保守处理,返回 False
return False
def prepare_messages_for_agent(
messages: List[Dict[str, Any]],
has_history: bool
) -> List[Dict[str, Any]]:
"""
根据是否有历史记录来准备要发送给 agent 的消息
Args:
messages: 完整的消息列表
has_history: 是否已有历史记录
Returns:
List[Dict]: 要发送给 agent 的消息列表
"""
if not messages:
return []
# 如果有历史记录,只发送最后一条用户消息
if has_history:
# 找到最后一条用户消息
for msg in reversed(messages):
if msg.get('role') == 'user':
logger.info(f"Has history, sending only last user message: {msg.get('content', '')[:50]}...")
return [msg]
# 如果没有用户消息(理论上不应该发生),返回空列表
logger.warning("No user message found in messages")
return messages
# 如果没有历史记录,发送所有消息
logger.info(f"No history, sending all {len(messages)} messages")
return messages
def update_agent_config_for_checkpoint(
config_messages: List[Dict[str, Any]],
has_history: bool
) -> List[Dict[str, Any]]:
"""
更新 AgentConfig 中的 messages根据是否有历史记录决定发送哪些消息
这个函数可以在调用 agent 之前使用,避免重复处理消息历史
Args:
config_messages: AgentConfig 中的原始消息列表
has_history: 是否已有历史记录
Returns:
List[Dict]: 更新后的消息列表
"""
return prepare_messages_for_agent(config_messages, has_history)
async def prepare_checkpoint_message(config,checkpointer):
# 如果有 checkpointer检查是否有历史记录
if config.session_id and checkpointer and len(config.messages) > 0:
has_history = await check_checkpoint_history(checkpointer, config.session_id)
config.messages = prepare_messages_for_agent(config.messages, has_history)
logger.info(f"Session {config.session_id}: has_history={has_history}, sending {len(config.messages)} messages")
else:
logger.debug(f"No session_id provided, skipping checkpoint check")