qwen_agent/agent/checkpoint_utils.py
朱潮 f694101747 refactor: migrate from Memori to Mem0 for long-term memory
Replace Memori with Mem0 for memory management:
- Delete memori_config.py, memori_manager.py, memori_middleware.py
- Add mem0_config.py, mem0_manager.py, mem0_middleware.py
- Update environment variables (MEMORI_* -> MEM0_*)
- Integrate Mem0 with LangGraph middleware
- Add sync connection pool for Mem0 in DBPoolManager
- Move checkpoint message prep to config creation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 21:15:30 +08:00

96 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""用于处理 LangGraph checkpoint 相关的工具函数"""
import logging
from typing import List, Dict, Any, Optional
from langgraph.checkpoint.memory import MemorySaver
logger = logging.getLogger('app')
async def get_checkpoint_history(checkpointer: MemorySaver, thread_id: str) -> List:
"""
从 checkpointer 获取指定 thread_id 的历史聊天记录
Args:
checkpointer: MemorySaver 实例
thread_id: 线程ID通常是 session_id
Returns:
List[Dict]: 历史消息列表,如果没有历史记录或出错则返回空列表
"""
if not checkpointer or not thread_id:
logger.debug(f"No checkpointer or thread_id: checkpointer={bool(checkpointer)}, thread_id={thread_id}")
return []
try:
config = {"configurable": {"thread_id": thread_id}}
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is None or checkpoint_tuple.checkpoint is None:
logger.debug(f"No checkpoint found for thread_id: {thread_id}")
return []
# 从 checkpoint 中提取消息历史
checkpoint_data = checkpoint_tuple.checkpoint
# LangGraph checkpoint 中的消息通常在 channel_values['messages'] 中
if "channel_values" not in checkpoint_data:
logger.debug(f"No channel_values in checkpoint for thread_id: {thread_id}")
return []
channel_values = checkpoint_data["channel_values"]
if isinstance(channel_values, dict) and "messages" in channel_values:
history_messages = channel_values["messages"]
converted = history_messages
logger.info(f"Loaded {len(converted)} messages from checkpoint for thread_id: {thread_id}")
return converted
elif isinstance(channel_values, list):
# 有些情况下 channel_values 直接是消息列表
converted = channel_values
logger.info(f"Loaded {len(converted)} messages from checkpoint for thread_id: {thread_id}")
return converted
else:
logger.debug(f"Unexpected channel_values format: {type(channel_values)}")
return []
except Exception as e:
import traceback
logger.error(f"Error getting checkpoint history for thread_id {thread_id}: {e}")
logger.error(f"Full traceback: {traceback.format_exc()}")
return []
async def prepare_checkpoint_message(config, checkpointer):
"""
准备 checkpoint 相关的消息:
1. 获取并过滤历史记录(去除包含双引号/think的消息
2. 根据是否有历史决定发送哪些消息
"""
if not config.session_id or not checkpointer or len(config.messages) == 0:
logger.debug("No session_id/checkpointer or empty messages, skipping checkpoint")
return
# 获取历史记录
history = await get_checkpoint_history(checkpointer, config.session_id)
has_history = len(history) > 0
# 处理历史记录过滤并保留最近20条
if has_history:
filtered_history = [
h for h in history
if getattr(h, "type", None) in ("human", "ai")
and "<think>" not in str(getattr(h, "content", "")).lower()
]
logger.info(f"Filtered {len(filtered_history)} human/ai messages from history")
config._session_history = filtered_history[-20:]
# 处理要发送的消息:有历史只发最后一条用户消息,否则全发
if has_history:
last_user_msg = next((m for m in reversed(config.messages) if m.get('role') == 'user'), None)
if last_user_msg:
config.messages = [last_user_msg]
logger.info(f"Has history, sending last user message: {last_user_msg.get('content', '')[:50]}...")
else:
logger.info(f"No history, sending all {len(config.messages)} messages")