"""用于处理 LangGraph checkpoint 相关的工具函数""" import logging from typing import List, Dict, Any, Optional from langgraph.checkpoint.memory import MemorySaver logger = logging.getLogger('app') async def get_checkpoint_history(checkpointer: MemorySaver, thread_id: str) -> List: """ 从 checkpointer 获取指定 thread_id 的历史聊天记录 Args: checkpointer: MemorySaver 实例 thread_id: 线程ID(通常是 session_id) Returns: List[Dict]: 历史消息列表,如果没有历史记录或出错则返回空列表 """ if not checkpointer or not thread_id: logger.debug(f"No checkpointer or thread_id: checkpointer={bool(checkpointer)}, thread_id={thread_id}") return [] try: config = {"configurable": {"thread_id": thread_id}} checkpoint_tuple = await checkpointer.aget_tuple(config) if checkpoint_tuple is None or checkpoint_tuple.checkpoint is None: logger.debug(f"No checkpoint found for thread_id: {thread_id}") return [] # 从 checkpoint 中提取消息历史 checkpoint_data = checkpoint_tuple.checkpoint # LangGraph checkpoint 中的消息通常在 channel_values['messages'] 中 if "channel_values" not in checkpoint_data: logger.debug(f"No channel_values in checkpoint for thread_id: {thread_id}") return [] channel_values = checkpoint_data["channel_values"] if isinstance(channel_values, dict) and "messages" in channel_values: history_messages = channel_values["messages"] converted = history_messages logger.info(f"Loaded {len(converted)} messages from checkpoint for thread_id: {thread_id}") return converted elif isinstance(channel_values, list): # 有些情况下 channel_values 直接是消息列表 converted = channel_values logger.info(f"Loaded {len(converted)} messages from checkpoint for thread_id: {thread_id}") return converted else: logger.debug(f"Unexpected channel_values format: {type(channel_values)}") return [] except Exception as e: import traceback logger.error(f"Error getting checkpoint history for thread_id {thread_id}: {e}") logger.error(f"Full traceback: {traceback.format_exc()}") return [] async def prepare_checkpoint_message(config, checkpointer): """ 准备 checkpoint 相关的消息: 1. 获取并过滤历史记录(去除包含双引号/think的消息) 2. 根据是否有历史决定发送哪些消息 """ if not config.session_id or not checkpointer or len(config.messages) == 0: logger.debug("No session_id/checkpointer or empty messages, skipping checkpoint") return # 获取历史记录 history = await get_checkpoint_history(checkpointer, config.session_id) has_history = len(history) > 0 # 处理历史记录:过滤并保留最近20条 if has_history: filtered_history = [ h for h in history if getattr(h, "type", None) in ("human", "ai") and "" not in str(getattr(h, "content", "")).lower() ] logger.info(f"Filtered {len(filtered_history)} human/ai messages from history") config._session_history = filtered_history[-20:] # 处理要发送的消息:有历史只发最后一条用户消息,否则全发 if has_history: last_user_msg = next((m for m in reversed(config.messages) if m.get('role') == 'user'), None) if last_user_msg: config.messages = [last_user_msg] logger.info(f"Has history, sending last user message: {last_user_msg.get('content', '')[:50]}...") else: logger.info(f"No history, sending all {len(config.messages)} messages")