Replace Memori with Mem0 for memory management: - Delete memori_config.py, memori_manager.py, memori_middleware.py - Add mem0_config.py, mem0_manager.py, mem0_middleware.py - Update environment variables (MEMORI_* -> MEM0_*) - Integrate Mem0 with LangGraph middleware - Add sync connection pool for Mem0 in DBPoolManager - Move checkpoint message prep to config creation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
96 lines
3.8 KiB
Python
96 lines
3.8 KiB
Python
"""用于处理 LangGraph checkpoint 相关的工具函数"""
|
||
|
||
import logging
|
||
from typing import List, Dict, Any, Optional
|
||
from langgraph.checkpoint.memory import MemorySaver
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
|
||
async def get_checkpoint_history(checkpointer: MemorySaver, thread_id: str) -> List:
|
||
"""
|
||
从 checkpointer 获取指定 thread_id 的历史聊天记录
|
||
|
||
Args:
|
||
checkpointer: MemorySaver 实例
|
||
thread_id: 线程ID(通常是 session_id)
|
||
|
||
Returns:
|
||
List[Dict]: 历史消息列表,如果没有历史记录或出错则返回空列表
|
||
"""
|
||
if not checkpointer or not thread_id:
|
||
logger.debug(f"No checkpointer or thread_id: checkpointer={bool(checkpointer)}, thread_id={thread_id}")
|
||
return []
|
||
|
||
try:
|
||
config = {"configurable": {"thread_id": thread_id}}
|
||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||
|
||
if checkpoint_tuple is None or checkpoint_tuple.checkpoint is None:
|
||
logger.debug(f"No checkpoint found for thread_id: {thread_id}")
|
||
return []
|
||
|
||
# 从 checkpoint 中提取消息历史
|
||
checkpoint_data = checkpoint_tuple.checkpoint
|
||
|
||
# LangGraph checkpoint 中的消息通常在 channel_values['messages'] 中
|
||
if "channel_values" not in checkpoint_data:
|
||
logger.debug(f"No channel_values in checkpoint for thread_id: {thread_id}")
|
||
return []
|
||
|
||
channel_values = checkpoint_data["channel_values"]
|
||
|
||
if isinstance(channel_values, dict) and "messages" in channel_values:
|
||
history_messages = channel_values["messages"]
|
||
converted = history_messages
|
||
logger.info(f"Loaded {len(converted)} messages from checkpoint for thread_id: {thread_id}")
|
||
return converted
|
||
elif isinstance(channel_values, list):
|
||
# 有些情况下 channel_values 直接是消息列表
|
||
converted = channel_values
|
||
logger.info(f"Loaded {len(converted)} messages from checkpoint for thread_id: {thread_id}")
|
||
return converted
|
||
else:
|
||
logger.debug(f"Unexpected channel_values format: {type(channel_values)}")
|
||
return []
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
logger.error(f"Error getting checkpoint history for thread_id {thread_id}: {e}")
|
||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||
return []
|
||
|
||
|
||
async def prepare_checkpoint_message(config, checkpointer):
|
||
"""
|
||
准备 checkpoint 相关的消息:
|
||
1. 获取并过滤历史记录(去除包含双引号/think的消息)
|
||
2. 根据是否有历史决定发送哪些消息
|
||
"""
|
||
if not config.session_id or not checkpointer or len(config.messages) == 0:
|
||
logger.debug("No session_id/checkpointer or empty messages, skipping checkpoint")
|
||
return
|
||
|
||
# 获取历史记录
|
||
history = await get_checkpoint_history(checkpointer, config.session_id)
|
||
has_history = len(history) > 0
|
||
|
||
# 处理历史记录:过滤并保留最近20条
|
||
if has_history:
|
||
filtered_history = [
|
||
h for h in history
|
||
if getattr(h, "type", None) in ("human", "ai")
|
||
and "<think>" not in str(getattr(h, "content", "")).lower()
|
||
]
|
||
logger.info(f"Filtered {len(filtered_history)} human/ai messages from history")
|
||
config._session_history = filtered_history[-20:]
|
||
|
||
# 处理要发送的消息:有历史只发最后一条用户消息,否则全发
|
||
if has_history:
|
||
last_user_msg = next((m for m in reversed(config.messages) if m.get('role') == 'user'), None)
|
||
if last_user_msg:
|
||
config.messages = [last_user_msg]
|
||
logger.info(f"Has history, sending last user message: {last_user_msg.get('content', '')[:50]}...")
|
||
else:
|
||
logger.info(f"No history, sending all {len(config.messages)} messages")
|