qwen_agent/agent/mem0_middleware.py
朱潮 f694101747 refactor: migrate from Memori to Mem0 for long-term memory
Replace Memori with Mem0 for memory management:
- Delete memori_config.py, memori_manager.py, memori_middleware.py
- Add mem0_config.py, mem0_manager.py, mem0_middleware.py
- Update environment variables (MEMORI_* -> MEM0_*)
- Integrate Mem0 with LangGraph middleware
- Add sync connection pool for Mem0 in DBPoolManager
- Move checkpoint message prep to config creation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 21:15:30 +08:00

382 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Mem0 Agent 中间件
实现记忆召回和存储的 AgentMiddleware
"""
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from langchain.agents.middleware import AgentMiddleware, AgentState, ModelRequest
from langgraph.runtime import Runtime
from .mem0_config import Mem0Config
from .mem0_manager import Mem0Manager, get_mem0_manager
logger = logging.getLogger("app")
# 避免循环导入
if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
class Mem0Middleware(AgentMiddleware):
"""
Mem0 记忆中间件
功能:
1. before_agent: 召回相关记忆并注入到上下文
2. after_agent: 后台异步提取和存储新记忆
"""
def __init__(
self,
mem0_manager: Mem0Manager,
config: Mem0Config,
agent_config: "AgentConfig",
):
"""初始化 Mem0Middleware
Args:
mem0_manager: Mem0Manager 实例
config: Mem0Config 配置
agent_config: AgentConfig 实例,用于中间件间传递数据
"""
self.mem0_manager = mem0_manager
self.config = config
self.agent_config = agent_config
def _extract_user_query(self, state: AgentState) -> str:
"""从状态中提取用户查询(最后一条 HumanMessage
Args:
state: Agent 状态
Returns:
用户查询文本
"""
from langchain_core.messages import HumanMessage
messages = state.get("messages", [])
if not messages:
return ""
# 查找最后一条 HumanMessage
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
return str(msg.content) if msg.content else ""
return ""
def _extract_agent_response(self, state: AgentState) -> str:
"""从状态中提取 Agent 响应(最后一条 AIMessage
Args:
state: Agent 状态
Returns:
Agent 响应文本
"""
from langchain_core.messages import AIMessage
messages = state.get("messages", [])
if not messages:
return ""
# 查找最后一条 AIMessage
for msg in reversed(messages):
if isinstance(msg, AIMessage):
return str(msg.content) if msg.content else ""
return ""
def _format_memories(self, memories: List[Dict[str, Any]]) -> str:
"""格式化记忆列表为文本
Args:
memories: 记忆列表
Returns:
格式化的记忆文本
"""
if not memories:
return ""
lines = []
for i, memory in enumerate(memories, 1):
content = memory.get("content", "")
fact_type = memory.get("fact_type", "fact")
lines.append(f"{i}. [{fact_type}] {content}")
return "\n".join(lines)
def before_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None:
"""Agent 执行前:召回相关记忆(同步版本)
Args:
state: Agent 状态
runtime: 运行时上下文
Returns:
更新后的状态或 None
"""
if not self.config.is_enabled():
return None
try:
import asyncio
# 提取用户查询
query = self._extract_user_query(state)
if not query:
return None
# 获取 attribution 参数
user_id, agent_id = self.config.get_attribution_tuple()
session_id = self.config.session_id or runtime.config.get("configurable", {}).get("thread_id", "default")
# 召回记忆(同步方式 - 在后台任务中执行)
memories = asyncio.run(self._recall_memories_async(query, user_id, agent_id, session_id))
if memories:
# 格式化记忆并拼接 memory_prompt
memory_text = self._format_memories(memories)
memory_prompt = self.config.get_memory_prompt([memory_text])
self.agent_config._mem0_context = memory_prompt
logger.info(f"Recalled {len(memories)} memories for context")
else:
self.agent_config._mem0_context = None
return state
except Exception as e:
logger.error(f"Error in Mem0Middleware.before_agent: {e}")
return None
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None:
"""Agent 执行前:召回相关记忆(异步版本)
Args:
state: Agent 状态
runtime: 运行时上下文
Returns:
更新后的状态或 None
"""
if not self.config.is_enabled():
return None
try:
# 提取用户查询
query = self._extract_user_query(state)
if not query:
logger.debug("No user query found, skipping memory recall")
return None
# 获取 attribution 参数
user_id, agent_id = self.config.get_attribution_tuple()
# 召回记忆(用户级别,跨会话)
memories = await self._recall_memories_async(query, user_id, agent_id)
if memories:
# 格式化记忆并拼接 memory_prompt
memory_text = self._format_memories(memories)
memory_prompt = self.config.get_memory_prompt([memory_text])
self.agent_config._mem0_context = memory_prompt
logger.info(f"Recalled {len(memories)} memories for context")
else:
self.agent_config._mem0_context = None
return state
except Exception as e:
logger.error(f"Error in Mem0Middleware.abefore_agent: {e}")
return None
async def _recall_memories_async(
self, query: str, user_id: str, agent_id: str
) -> List[Dict[str, Any]]:
"""异步召回记忆
Args:
query: 查询文本
user_id: 用户 ID
agent_id: Agent/Bot ID
Returns:
记忆列表
"""
return await self.mem0_manager.recall_memories(
query=query,
user_id=user_id,
agent_id=agent_id,
config=self.config,
)
def after_agent(self, state: AgentState, runtime: Runtime) -> None:
"""Agent 执行后:触发记忆增强(同步版本)
Args:
state: Agent 状态
runtime: 运行时上下文
"""
if not self.config.is_enabled():
return
try:
import asyncio
# 触发后台增强任务
asyncio.create_task(self._trigger_augmentation_async(state, runtime))
except Exception as e:
logger.error(f"Error in Mem0Middleware.after_agent: {e}")
async def aafter_agent(self, state: AgentState, runtime: Runtime) -> None:
"""Agent 执行后:触发记忆增强(异步版本)
Args:
state: Agent 状态
runtime: 运行时上下文
"""
if not self.config.is_enabled():
return
try:
await self._trigger_augmentation_async(state, runtime)
except Exception as e:
logger.error(f"Error in Mem0Middleware.aafter_agent: {e}")
async def _trigger_augmentation_async(self, state: AgentState, runtime: Runtime) -> None:
"""触发记忆增强任务
从对话中提取信息并存储到 Mem0用户级别跨会话
Args:
state: Agent 状态
runtime: 运行时上下文
"""
try:
# 获取 attribution 参数
user_id, agent_id = self.config.get_attribution_tuple()
# 提取用户查询和 Agent 响应
user_query = self._extract_user_query(state)
agent_response = self._extract_agent_response(state)
# 将对话作为记忆存储(用户级别)
if user_query and agent_response:
conversation_text = f"User: {user_query}\nAssistant: {agent_response}"
await self.mem0_manager.add_memory(
text=conversation_text,
user_id=user_id,
agent_id=agent_id,
metadata={"type": "conversation"},
config=self.config,
)
logger.debug(f"Stored conversation as memory for user={user_id}, agent={agent_id}")
except Exception as e:
logger.error(f"Error in _trigger_augmentation_async: {e}")
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Any],
) -> Any:
"""包装模型调用,注入记忆到系统提示词(同步版本)
Args:
request: 模型请求
handler: 原始处理器
Returns:
模型响应
"""
# 从 agent_config 获取已拼接好的记忆 prompt
memory_prompt = self.agent_config._mem0_context
if not memory_prompt:
return handler(request)
# 获取当前系统提示词
current_system_prompt = ""
if request.system_message:
current_system_prompt = request.system_message.content if hasattr(request.system_message, "content") else str(request.system_message)
# 修改系统提示词
new_system_prompt = current_system_prompt + memory_prompt
return handler(request.override(system_prompt=new_system_prompt))
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Any],
) -> Any:
"""包装模型调用,注入记忆到系统提示词(异步版本)
Args:
request: 模型请求
handler: 原始处理器
Returns:
模型响应
"""
# 从 agent_config 获取已拼接好的记忆 prompt
memory_prompt = self.agent_config._mem0_context
if not memory_prompt:
return await handler(request)
# 获取当前系统提示词
current_system_prompt = ""
if request.system_message:
current_system_prompt = request.system_message.content if hasattr(request.system_message, "content") else str(request.system_message)
# 修改系统提示词
new_system_prompt = current_system_prompt + memory_prompt
return await handler(request.override(system_prompt=new_system_prompt))
def create_mem0_middleware(
bot_id: str,
user_identifier: str,
session_id: str,
agent_config: "AgentConfig",
enabled: bool = True,
semantic_search_top_k: int = 20,
mem0_manager: Optional[Mem0Manager] = None,
llm_instance: Optional["BaseChatModel"] = None,
) -> Optional[Mem0Middleware]:
"""创建 Mem0Middleware 的工厂函数
Args:
bot_id: Bot ID
user_identifier: 用户标识
session_id: 会话 ID
agent_config: AgentConfig 实例,用于中间件间传递数据
enabled: 是否启用
semantic_search_top_k: 语义搜索返回数量
mem0_manager: Mem0Manager 实例(如果为 None使用全局实例
llm_instance: LangChain LLM 实例(用于 Mem0 的记忆提取和增强)
Returns:
Mem0Middleware 实例或 None
"""
if not enabled:
return None
# 获取或使用提供的 manager
manager = mem0_manager or get_mem0_manager()
# 创建配置
config = Mem0Config(
enabled=True,
user_id=user_identifier,
agent_id=bot_id,
session_id=session_id,
semantic_search_top_k=semantic_search_top_k,
llm_instance=llm_instance,
)
return Mem0Middleware(mem0_manager=manager, config=config, agent_config=agent_config)