- Fix mem0 connection pool exhausted error with proper pooling - Convert memory operations to async tasks - Optimize docker-compose configuration - Add skill upload functionality - Reduce cache size for better performance - Update dependencies Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
436 lines
14 KiB
Python
436 lines
14 KiB
Python
"""
|
||
Mem0 Agent 中间件
|
||
实现记忆召回和存储的 AgentMiddleware
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import threading
|
||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||
|
||
from langchain.agents.middleware import AgentMiddleware, AgentState, ModelRequest
|
||
from langgraph.runtime import Runtime
|
||
|
||
from .mem0_config import Mem0Config
|
||
from .mem0_manager import Mem0Manager, get_mem0_manager
|
||
|
||
logger = logging.getLogger("app")
|
||
|
||
# 避免循环导入
|
||
if TYPE_CHECKING:
|
||
from langchain_core.language_models import BaseChatModel
|
||
|
||
|
||
class Mem0Middleware(AgentMiddleware):
|
||
"""
|
||
Mem0 记忆中间件
|
||
|
||
功能:
|
||
1. before_agent: 召回相关记忆并注入到上下文
|
||
2. after_agent: 后台异步提取和存储新记忆
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
mem0_manager: Mem0Manager,
|
||
config: Mem0Config,
|
||
agent_config: "AgentConfig",
|
||
):
|
||
"""初始化 Mem0Middleware
|
||
|
||
Args:
|
||
mem0_manager: Mem0Manager 实例
|
||
config: Mem0Config 配置
|
||
agent_config: AgentConfig 实例,用于中间件间传递数据
|
||
"""
|
||
self.mem0_manager = mem0_manager
|
||
self.config = config
|
||
self.agent_config = agent_config
|
||
|
||
def _extract_user_query(self, state: AgentState) -> str:
|
||
"""从状态中提取用户查询(最后一条 HumanMessage)
|
||
|
||
Args:
|
||
state: Agent 状态
|
||
|
||
Returns:
|
||
用户查询文本
|
||
"""
|
||
from langchain_core.messages import HumanMessage
|
||
|
||
messages = state.get("messages", [])
|
||
if not messages:
|
||
return ""
|
||
|
||
# 查找最后一条 HumanMessage
|
||
for msg in reversed(messages):
|
||
if isinstance(msg, HumanMessage):
|
||
return str(msg.content) if msg.content else ""
|
||
|
||
return ""
|
||
|
||
def _extract_agent_response(self, state: AgentState) -> str:
|
||
"""从状态中提取 Agent 响应(最后一条 AIMessage)
|
||
|
||
Args:
|
||
state: Agent 状态
|
||
|
||
Returns:
|
||
Agent 响应文本
|
||
"""
|
||
from langchain_core.messages import AIMessage
|
||
|
||
messages = state.get("messages", [])
|
||
if not messages:
|
||
return ""
|
||
|
||
# 查找最后一条 AIMessage
|
||
for msg in reversed(messages):
|
||
if isinstance(msg, AIMessage):
|
||
return str(msg.content) if msg.content else ""
|
||
|
||
return ""
|
||
|
||
def _format_memories(self, memories: List[Dict[str, Any]]) -> str:
|
||
"""格式化记忆列表为文本
|
||
|
||
Args:
|
||
memories: 记忆列表
|
||
|
||
Returns:
|
||
格式化的记忆文本
|
||
"""
|
||
if not memories:
|
||
return ""
|
||
|
||
lines = []
|
||
for i, memory in enumerate(memories, 1):
|
||
content = memory.get("content", "")
|
||
fact_type = memory.get("fact_type", "fact")
|
||
lines.append(f"{i}. [{fact_type}] {content}")
|
||
|
||
return "\n".join(lines)
|
||
|
||
def before_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None:
|
||
"""Agent 执行前:召回相关记忆(同步版本)
|
||
|
||
Args:
|
||
state: Agent 状态
|
||
runtime: 运行时上下文
|
||
|
||
Returns:
|
||
更新后的状态或 None
|
||
"""
|
||
if not self.config.is_enabled():
|
||
return None
|
||
|
||
try:
|
||
# 提取用户查询
|
||
query = self._extract_user_query(state)
|
||
if not query:
|
||
return None
|
||
|
||
# 获取 attribution 参数
|
||
user_id, agent_id = self.config.get_attribution_tuple()
|
||
session_id = self.config.session_id or runtime.config.get("configurable", {}).get("thread_id", "default")
|
||
|
||
# 召回记忆(同步方式 - 在后台任务中执行)
|
||
memories = asyncio.run(self._recall_memories_async(query, user_id, agent_id, session_id))
|
||
|
||
if memories:
|
||
# 格式化记忆并拼接 memory_prompt
|
||
memory_text = self._format_memories(memories)
|
||
memory_prompt = self.config.get_memory_prompt([memory_text])
|
||
self.agent_config._mem0_context = memory_prompt
|
||
logger.info(f"Recalled {len(memories)} memories for context")
|
||
else:
|
||
self.agent_config._mem0_context = None
|
||
|
||
return state
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in Mem0Middleware.before_agent: {e}")
|
||
return None
|
||
|
||
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None:
|
||
"""Agent 执行前:召回相关记忆(异步版本)
|
||
|
||
Args:
|
||
state: Agent 状态
|
||
runtime: 运行时上下文
|
||
|
||
Returns:
|
||
更新后的状态或 None
|
||
"""
|
||
if not self.config.is_enabled():
|
||
return None
|
||
|
||
try:
|
||
# 提取用户查询
|
||
query = self._extract_user_query(state)
|
||
if not query:
|
||
logger.debug("No user query found, skipping memory recall")
|
||
return None
|
||
|
||
# 获取 attribution 参数
|
||
user_id, agent_id = self.config.get_attribution_tuple()
|
||
|
||
# 召回记忆(用户级别,跨会话)
|
||
memories = await self._recall_memories_async(query, user_id, agent_id)
|
||
|
||
if memories:
|
||
# 格式化记忆并拼接 memory_prompt
|
||
memory_text = self._format_memories(memories)
|
||
memory_prompt = self.config.get_memory_prompt([memory_text])
|
||
self.agent_config._mem0_context = memory_prompt
|
||
logger.info(f"Recalled {len(memories)} memories for context")
|
||
else:
|
||
self.agent_config._mem0_context = None
|
||
|
||
return state
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in Mem0Middleware.abefore_agent: {e}")
|
||
return None
|
||
|
||
async def _recall_memories_async(
|
||
self, query: str, user_id: str, agent_id: str
|
||
) -> List[Dict[str, Any]]:
|
||
"""异步召回记忆
|
||
|
||
Args:
|
||
query: 查询文本
|
||
user_id: 用户 ID
|
||
agent_id: Agent/Bot ID
|
||
|
||
Returns:
|
||
记忆列表
|
||
"""
|
||
return await self.mem0_manager.recall_memories(
|
||
query=query,
|
||
user_id=user_id,
|
||
agent_id=agent_id,
|
||
config=self.config,
|
||
)
|
||
|
||
def after_agent(self, state: AgentState, runtime: Runtime) -> None:
|
||
"""Agent 执行后:触发记忆增强(同步版本)
|
||
|
||
使用后台线程执行,避免阻塞主流程
|
||
|
||
Args:
|
||
state: Agent 状态
|
||
runtime: 运行时上下文
|
||
"""
|
||
if not self.config.is_enabled():
|
||
return
|
||
|
||
try:
|
||
# 在后台线程中执行,完全不阻塞主流程
|
||
thread = threading.Thread(
|
||
target=self._trigger_augmentation_sync,
|
||
args=(state, runtime),
|
||
daemon=True,
|
||
)
|
||
thread.start()
|
||
except Exception as e:
|
||
logger.error(f"Error in Mem0Middleware.after_agent: {e}")
|
||
|
||
async def aafter_agent(self, state: AgentState, runtime: Runtime) -> None:
|
||
"""Agent 执行后:触发记忆增强(异步版本)
|
||
|
||
使用后台线程执行,避免阻塞事件循环
|
||
|
||
Args:
|
||
state: Agent 状态
|
||
runtime: 运行时上下文
|
||
"""
|
||
if not self.config.is_enabled():
|
||
return
|
||
|
||
try:
|
||
# 在后台线程中执行,完全不阻塞事件循环
|
||
thread = threading.Thread(
|
||
target=self._trigger_augmentation_sync,
|
||
args=(state, runtime),
|
||
daemon=True,
|
||
)
|
||
thread.start()
|
||
except Exception as e:
|
||
logger.error(f"Error in Mem0Middleware.aafter_agent: {e}")
|
||
|
||
def _trigger_augmentation_sync(self, state: AgentState, runtime: Runtime) -> None:
|
||
"""触发记忆增强任务(同步版本,在线程中执行)
|
||
|
||
从对话中提取信息并存储到 Mem0(用户级别,跨会话)
|
||
|
||
Args:
|
||
state: Agent 状态
|
||
runtime: 运行时上下文
|
||
"""
|
||
try:
|
||
# 获取 attribution 参数
|
||
user_id, agent_id = self.config.get_attribution_tuple()
|
||
|
||
# 提取用户查询和 Agent 响应
|
||
user_query = self._extract_user_query(state)
|
||
agent_response = self._extract_agent_response(state)
|
||
|
||
# 将对话作为记忆存储(用户级别)
|
||
if user_query and agent_response:
|
||
conversation_text = f"User: {user_query}\nAssistant: {agent_response}"
|
||
|
||
# 在新的事件循环中运行异步代码(因为在线程中)
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
try:
|
||
loop.run_until_complete(
|
||
self.mem0_manager.add_memory(
|
||
text=conversation_text,
|
||
user_id=user_id,
|
||
agent_id=agent_id,
|
||
metadata={"type": "conversation"},
|
||
config=self.config,
|
||
)
|
||
)
|
||
logger.debug(f"Stored conversation as memory for user={user_id}, agent={agent_id}")
|
||
finally:
|
||
loop.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in _trigger_augmentation_sync: {e}")
|
||
|
||
async def _trigger_augmentation_async(self, state: AgentState, runtime: Runtime) -> None:
|
||
"""触发记忆增强任务
|
||
|
||
从对话中提取信息并存储到 Mem0(用户级别,跨会话)
|
||
|
||
Args:
|
||
state: Agent 状态
|
||
runtime: 运行时上下文
|
||
"""
|
||
try:
|
||
# 获取 attribution 参数
|
||
user_id, agent_id = self.config.get_attribution_tuple()
|
||
|
||
# 提取用户查询和 Agent 响应
|
||
user_query = self._extract_user_query(state)
|
||
agent_response = self._extract_agent_response(state)
|
||
|
||
# 将对话作为记忆存储(用户级别)
|
||
if user_query and agent_response:
|
||
conversation_text = f"User: {user_query}\nAssistant: {agent_response}"
|
||
|
||
await self.mem0_manager.add_memory(
|
||
text=conversation_text,
|
||
user_id=user_id,
|
||
agent_id=agent_id,
|
||
metadata={"type": "conversation"},
|
||
config=self.config,
|
||
)
|
||
|
||
logger.debug(f"Stored conversation as memory for user={user_id}, agent={agent_id}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in _trigger_augmentation_async: {e}")
|
||
|
||
def wrap_model_call(
|
||
self,
|
||
request: ModelRequest,
|
||
handler: Callable[[ModelRequest], Any],
|
||
) -> Any:
|
||
"""包装模型调用,注入记忆到系统提示词(同步版本)
|
||
|
||
Args:
|
||
request: 模型请求
|
||
handler: 原始处理器
|
||
|
||
Returns:
|
||
模型响应
|
||
"""
|
||
# 从 agent_config 获取已拼接好的记忆 prompt
|
||
memory_prompt = self.agent_config._mem0_context
|
||
if not memory_prompt:
|
||
return handler(request)
|
||
|
||
# 获取当前系统提示词
|
||
current_system_prompt = ""
|
||
if request.system_message:
|
||
current_system_prompt = request.system_message.content if hasattr(request.system_message, "content") else str(request.system_message)
|
||
|
||
# 修改系统提示词
|
||
new_system_prompt = current_system_prompt + memory_prompt
|
||
return handler(request.override(system_prompt=new_system_prompt))
|
||
|
||
async def awrap_model_call(
|
||
self,
|
||
request: ModelRequest,
|
||
handler: Callable[[ModelRequest], Any],
|
||
) -> Any:
|
||
"""包装模型调用,注入记忆到系统提示词(异步版本)
|
||
|
||
Args:
|
||
request: 模型请求
|
||
handler: 原始处理器
|
||
|
||
Returns:
|
||
模型响应
|
||
"""
|
||
# 从 agent_config 获取已拼接好的记忆 prompt
|
||
memory_prompt = self.agent_config._mem0_context
|
||
if not memory_prompt:
|
||
return await handler(request)
|
||
|
||
# 获取当前系统提示词
|
||
current_system_prompt = ""
|
||
if request.system_message:
|
||
current_system_prompt = request.system_message.content if hasattr(request.system_message, "content") else str(request.system_message)
|
||
|
||
# 修改系统提示词
|
||
new_system_prompt = current_system_prompt + memory_prompt
|
||
return await handler(request.override(system_prompt=new_system_prompt))
|
||
|
||
|
||
def create_mem0_middleware(
|
||
bot_id: str,
|
||
user_identifier: str,
|
||
session_id: str,
|
||
agent_config: "AgentConfig",
|
||
enabled: bool = True,
|
||
semantic_search_top_k: int = 20,
|
||
mem0_manager: Optional[Mem0Manager] = None,
|
||
llm_instance: Optional["BaseChatModel"] = None,
|
||
) -> Optional[Mem0Middleware]:
|
||
"""创建 Mem0Middleware 的工厂函数
|
||
|
||
Args:
|
||
bot_id: Bot ID
|
||
user_identifier: 用户标识
|
||
session_id: 会话 ID
|
||
agent_config: AgentConfig 实例,用于中间件间传递数据
|
||
enabled: 是否启用
|
||
semantic_search_top_k: 语义搜索返回数量
|
||
mem0_manager: Mem0Manager 实例(如果为 None,使用全局实例)
|
||
llm_instance: LangChain LLM 实例(用于 Mem0 的记忆提取和增强)
|
||
|
||
Returns:
|
||
Mem0Middleware 实例或 None
|
||
"""
|
||
if not enabled:
|
||
return None
|
||
|
||
# 获取或使用提供的 manager
|
||
manager = mem0_manager or get_mem0_manager()
|
||
|
||
# 创建配置
|
||
config = Mem0Config(
|
||
enabled=True,
|
||
user_id=user_identifier,
|
||
agent_id=bot_id,
|
||
session_id=session_id,
|
||
semantic_search_top_k=semantic_search_top_k,
|
||
llm_instance=llm_instance,
|
||
)
|
||
|
||
return Mem0Middleware(mem0_manager=manager, config=config, agent_config=agent_config)
|