98 lines
3.2 KiB
Python
98 lines
3.2 KiB
Python
"""用于处理 LangGraph checkpoint 相关的工具函数"""
|
||
|
||
import logging
|
||
from typing import List, Dict, Any, Optional
|
||
from langgraph.checkpoint.memory import MemorySaver
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
|
||
async def check_checkpoint_history(checkpointer: MemorySaver, thread_id: str) -> bool:
|
||
"""
|
||
检查指定的 thread_id 在 checkpointer 中是否已有历史记录
|
||
|
||
Args:
|
||
checkpointer: MemorySaver 实例
|
||
thread_id: 线程ID(通常是 session_id)
|
||
|
||
Returns:
|
||
bool: True 表示有历史记录,False 表示没有
|
||
"""
|
||
if not checkpointer or not thread_id:
|
||
logger.debug(f"No checkpointer or thread_id: checkpointer={bool(checkpointer)}, thread_id={thread_id}")
|
||
return False
|
||
|
||
try:
|
||
# 获取配置
|
||
config = {"configurable": {"thread_id": thread_id}}
|
||
|
||
# 调试信息:检查 checkpointer 类型
|
||
logger.debug(f"Checkpointer type: {type(checkpointer)}")
|
||
logger.debug(f"Checkpointer dir: {[attr for attr in dir(checkpointer) if not attr.startswith('_')]}")
|
||
|
||
latest_checkpoint = await checkpointer.aget_tuple(config)
|
||
logger.debug(f"aget_tuple result: {latest_checkpoint}")
|
||
|
||
if latest_checkpoint is not None:
|
||
logger.info(f"Found latest checkpoint for thread_id: {thread_id}")
|
||
# 解构 checkpoint tuple
|
||
return True
|
||
except Exception as e:
|
||
import traceback
|
||
logger.error(f"Error checking checkpoint history for thread_id {thread_id}: {e}")
|
||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||
# 出错时保守处理,返回 False
|
||
return False
|
||
|
||
|
||
def prepare_messages_for_agent(
|
||
messages: List[Dict[str, Any]],
|
||
has_history: bool
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
根据是否有历史记录来准备要发送给 agent 的消息
|
||
|
||
Args:
|
||
messages: 完整的消息列表
|
||
has_history: 是否已有历史记录
|
||
|
||
Returns:
|
||
List[Dict]: 要发送给 agent 的消息列表
|
||
"""
|
||
if not messages:
|
||
return []
|
||
|
||
# 如果有历史记录,只发送最后一条用户消息
|
||
if has_history:
|
||
# 找到最后一条用户消息
|
||
for msg in reversed(messages):
|
||
if msg.get('role') == 'user':
|
||
logger.info(f"Has history, sending only last user message: {msg.get('content', '')[:50]}...")
|
||
return [msg]
|
||
|
||
# 如果没有用户消息(理论上不应该发生),返回空列表
|
||
logger.warning("No user message found in messages")
|
||
return messages
|
||
|
||
# 如果没有历史记录,发送所有消息
|
||
logger.info(f"No history, sending all {len(messages)} messages")
|
||
return messages
|
||
|
||
|
||
def update_agent_config_for_checkpoint(
|
||
config_messages: List[Dict[str, Any]],
|
||
has_history: bool
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
更新 AgentConfig 中的 messages,根据是否有历史记录决定发送哪些消息
|
||
|
||
这个函数可以在调用 agent 之前使用,避免重复处理消息历史
|
||
|
||
Args:
|
||
config_messages: AgentConfig 中的原始消息列表
|
||
has_history: 是否已有历史记录
|
||
|
||
Returns:
|
||
List[Dict]: 更新后的消息列表
|
||
"""
|
||
return prepare_messages_for_agent(config_messages, has_history) |