Convert all Chinese comments, docstrings, logger/print output, HTTPException detail messages, and API response messages to English across the entire codebase. Functional zh/ja localized strings (e.g. prompt templates, timezone display names, date formats) are preserved as-is. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
460 lines
15 KiB
Python
460 lines
15 KiB
Python
"""
|
|
Mem0 agent middleware.
|
|
Implements memory recall and storage for AgentMiddleware.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import threading
|
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
|
|
|
from langchain.agents.middleware import AgentMiddleware, AgentState, ModelRequest
|
|
from langgraph.runtime import Runtime
|
|
|
|
from .mem0_config import Mem0Config
|
|
from .mem0_manager import Mem0Manager, get_mem0_manager
|
|
|
|
logger = logging.getLogger("app")
|
|
|
|
# Avoid circular imports
|
|
if TYPE_CHECKING:
|
|
from langchain_core.language_models import BaseChatModel
|
|
|
|
|
|
class Mem0Middleware(AgentMiddleware):
|
|
"""
|
|
Mem0 memory middleware.
|
|
|
|
Features:
|
|
1. before_agent: recall relevant memories and inject them into context
|
|
2. after_agent: asynchronously extract and store new memories in the background
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
mem0_manager: Mem0Manager,
|
|
config: Mem0Config,
|
|
agent_config: "AgentConfig",
|
|
):
|
|
"""Initialize Mem0Middleware.
|
|
|
|
Args:
|
|
mem0_manager: Mem0Manager instance
|
|
config: Mem0Config configuration
|
|
agent_config: AgentConfig instance used to pass data between middlewares
|
|
"""
|
|
self.mem0_manager = mem0_manager
|
|
self.config = config
|
|
self.agent_config = agent_config
|
|
|
|
def _extract_user_query(self, state: AgentState) -> str:
|
|
"""Extract the user query from state, using the last HumanMessage.
|
|
|
|
Args:
|
|
state: Agent state
|
|
|
|
Returns:
|
|
User query text
|
|
"""
|
|
from langchain_core.messages import HumanMessage
|
|
|
|
messages = state.get("messages", [])
|
|
if not messages:
|
|
return ""
|
|
|
|
# Find the last HumanMessage
|
|
for msg in reversed(messages):
|
|
if isinstance(msg, HumanMessage):
|
|
return str(msg.content) if msg.content else ""
|
|
|
|
return ""
|
|
|
|
def _extract_agent_response(self, state: AgentState) -> str:
|
|
"""Extract the agent response from state, using the last AIMessage.
|
|
|
|
Args:
|
|
state: Agent state
|
|
|
|
Returns:
|
|
Agent response text
|
|
"""
|
|
from langchain_core.messages import AIMessage
|
|
|
|
messages = state.get("messages", [])
|
|
if not messages:
|
|
return ""
|
|
|
|
# Find the last AIMessage
|
|
for msg in reversed(messages):
|
|
if isinstance(msg, AIMessage):
|
|
return str(msg.content) if msg.content else ""
|
|
|
|
return ""
|
|
|
|
def _format_memories(self, memories: List[Dict[str, Any]]) -> str:
|
|
"""Format a list of memories as text.
|
|
|
|
Args:
|
|
memories: List of memories
|
|
|
|
Returns:
|
|
Formatted memory text
|
|
"""
|
|
if not memories:
|
|
return ""
|
|
|
|
lines = []
|
|
for i, memory in enumerate(memories, 1):
|
|
content = memory.get("content", "")
|
|
fact_type = memory.get("fact_type", "fact")
|
|
lines.append(f"{i}. [{fact_type}] {content}")
|
|
|
|
return "\n".join(lines)
|
|
|
|
def before_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None:
|
|
"""Recall relevant memories before agent execution, synchronous version.
|
|
|
|
Args:
|
|
state: Agent state
|
|
runtime: Runtime context
|
|
|
|
Returns:
|
|
Updated state or None
|
|
"""
|
|
if not self.config.is_enabled():
|
|
return None
|
|
|
|
try:
|
|
# Extract the user query
|
|
query = self._extract_user_query(state)
|
|
if not query:
|
|
return None
|
|
|
|
# Get attribution parameters
|
|
user_id, agent_id = self.config.get_attribution_tuple()
|
|
session_id = self.config.session_id or runtime.config.get("configurable", {}).get("thread_id", "default")
|
|
|
|
# Recall memories synchronously by running the async method
|
|
memories = asyncio.run(self._recall_memories_async(query, user_id, agent_id, session_id))
|
|
|
|
if memories:
|
|
# Format memories and append them to the memory prompt
|
|
memory_text = self._format_memories(memories)
|
|
memory_prompt = self.config.get_memory_prompt([memory_text])
|
|
self.agent_config._mem0_context = memory_prompt
|
|
logger.info(f"Recalled {len(memories)} memories for context")
|
|
else:
|
|
self.agent_config._mem0_context = None
|
|
|
|
return state
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in Mem0Middleware.before_agent: {e}")
|
|
return None
|
|
|
|
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None:
|
|
"""Recall relevant memories before agent execution, asynchronous version.
|
|
|
|
Args:
|
|
state: Agent state
|
|
runtime: Runtime context
|
|
|
|
Returns:
|
|
Updated state or None
|
|
"""
|
|
if not self.config.is_enabled():
|
|
return None
|
|
|
|
try:
|
|
# Extract the user query
|
|
query = self._extract_user_query(state)
|
|
if not query:
|
|
logger.debug("No user query found, skipping memory recall")
|
|
return None
|
|
|
|
# Get attribution parameters
|
|
user_id, agent_id = self.config.get_attribution_tuple()
|
|
|
|
# Recall user-level memories across sessions
|
|
memories = await self._recall_memories_async(query, user_id, agent_id)
|
|
|
|
if memories:
|
|
# Format memories and append them to the memory prompt
|
|
memory_text = self._format_memories(memories)
|
|
memory_prompt = self.config.get_memory_prompt([memory_text])
|
|
self.agent_config._mem0_context = memory_prompt
|
|
logger.info(f"Recalled {len(memories)} memories for context")
|
|
else:
|
|
self.agent_config._mem0_context = None
|
|
|
|
return state
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in Mem0Middleware.abefore_agent: {e}")
|
|
return None
|
|
|
|
async def _recall_memories_async(
|
|
self, query: str, user_id: str, agent_id: str
|
|
) -> List[Dict[str, Any]]:
|
|
"""Recall memories asynchronously.
|
|
|
|
Args:
|
|
query: Query text
|
|
user_id: User ID
|
|
agent_id: Agent/Bot ID
|
|
|
|
Returns:
|
|
List of memories
|
|
"""
|
|
return await self.mem0_manager.recall_memories(
|
|
query=query,
|
|
user_id=user_id,
|
|
agent_id=agent_id,
|
|
config=self.config,
|
|
)
|
|
|
|
def after_agent(self, state: AgentState, runtime: Runtime) -> None:
|
|
"""Trigger memory augmentation after agent execution, synchronous version.
|
|
|
|
Runs in a background thread to avoid blocking the main flow.
|
|
|
|
Args:
|
|
state: Agent state
|
|
runtime: Runtime context
|
|
"""
|
|
if not self.config.is_enabled():
|
|
return
|
|
|
|
try:
|
|
# Run in a background thread so the main flow is never blocked
|
|
thread = threading.Thread(
|
|
target=self._trigger_augmentation_sync,
|
|
args=(state, runtime),
|
|
daemon=True,
|
|
)
|
|
thread.start()
|
|
except Exception as e:
|
|
logger.error(f"Error in Mem0Middleware.after_agent: {e}")
|
|
|
|
async def aafter_agent(self, state: AgentState, runtime: Runtime) -> None:
|
|
"""Trigger memory augmentation after agent execution, asynchronous version.
|
|
|
|
Runs in a background thread to avoid blocking the event loop.
|
|
|
|
Args:
|
|
state: Agent state
|
|
runtime: Runtime context
|
|
"""
|
|
if not self.config.is_enabled():
|
|
return
|
|
|
|
try:
|
|
# Run in a background thread so the event loop is never blocked
|
|
thread = threading.Thread(
|
|
target=self._trigger_augmentation_sync,
|
|
args=(state, runtime),
|
|
daemon=True,
|
|
)
|
|
thread.start()
|
|
except Exception as e:
|
|
logger.error(f"Error in Mem0Middleware.aafter_agent: {e}")
|
|
|
|
def _trigger_augmentation_sync(self, state: AgentState, runtime: Runtime) -> None:
|
|
"""Trigger the memory augmentation task, synchronous version executed in a thread.
|
|
|
|
Extracts information from the conversation and stores it in Mem0
|
|
at the user level across sessions.
|
|
|
|
Args:
|
|
state: Agent state
|
|
runtime: Runtime context
|
|
"""
|
|
try:
|
|
# Get attribution parameters
|
|
user_id, agent_id = self.config.get_attribution_tuple()
|
|
|
|
# Extract the user query and agent response
|
|
user_query = self._extract_user_query(state)
|
|
agent_response = self._extract_agent_response(state)
|
|
|
|
# Store the conversation as user-level memory
|
|
if user_query and agent_response:
|
|
conversation_text = f"User: {user_query}\nAssistant: {agent_response}"
|
|
|
|
# Run async code in a new event loop because this runs in a thread
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
loop.run_until_complete(
|
|
self.mem0_manager.add_memory(
|
|
text=conversation_text,
|
|
user_id=user_id,
|
|
agent_id=agent_id,
|
|
metadata={"type": "conversation"},
|
|
config=self.config,
|
|
)
|
|
)
|
|
logger.debug(f"Stored conversation as memory for user={user_id}, agent={agent_id}")
|
|
finally:
|
|
loop.close()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in _trigger_augmentation_sync: {e}")
|
|
|
|
async def _trigger_augmentation_async(self, state: AgentState, runtime: Runtime) -> None:
|
|
"""Trigger the memory augmentation task.
|
|
|
|
Extracts information from the conversation and stores it in Mem0
|
|
at the user level across sessions.
|
|
|
|
Args:
|
|
state: Agent state
|
|
runtime: Runtime context
|
|
"""
|
|
try:
|
|
# Get attribution parameters
|
|
user_id, agent_id = self.config.get_attribution_tuple()
|
|
|
|
# Extract the user query and agent response
|
|
user_query = self._extract_user_query(state)
|
|
agent_response = self._extract_agent_response(state)
|
|
|
|
# Store the conversation as user-level memory
|
|
if user_query and agent_response:
|
|
conversation_text = f"User: {user_query}\nAssistant: {agent_response}"
|
|
|
|
await self.mem0_manager.add_memory(
|
|
text=conversation_text,
|
|
user_id=user_id,
|
|
agent_id=agent_id,
|
|
metadata={"type": "conversation"},
|
|
config=self.config,
|
|
)
|
|
|
|
logger.debug(f"Stored conversation as memory for user={user_id}, agent={agent_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in _trigger_augmentation_async: {e}")
|
|
|
|
def wrap_model_call(
|
|
self,
|
|
request: ModelRequest,
|
|
handler: Callable[[ModelRequest], Any],
|
|
) -> Any:
|
|
"""Wrap model calls and inject memories into the system prompt, synchronous version.
|
|
|
|
Args:
|
|
request: Model request
|
|
handler: Original handler
|
|
|
|
Returns:
|
|
Model response
|
|
"""
|
|
# Get the assembled memory prompt from agent_config
|
|
memory_prompt = self.agent_config._mem0_context
|
|
if not memory_prompt:
|
|
return handler(request)
|
|
|
|
# Get the current system prompt
|
|
current_system_prompt = ""
|
|
if request.system_message:
|
|
content = request.system_message.content if hasattr(request.system_message, "content") else str(request.system_message)
|
|
# content may be a list or a string; make sure it is converted to str
|
|
if isinstance(content, list):
|
|
current_system_prompt = "\n".join(str(item) for item in content)
|
|
else:
|
|
current_system_prompt = str(content) if content else ""
|
|
|
|
# Make sure memory_prompt is also a string
|
|
if isinstance(memory_prompt, list):
|
|
memory_prompt = "\n".join(str(item) for item in memory_prompt)
|
|
else:
|
|
memory_prompt = str(memory_prompt) if memory_prompt else ""
|
|
|
|
# Update the system prompt
|
|
new_system_prompt = current_system_prompt + memory_prompt
|
|
return handler(request.override(system_prompt=new_system_prompt))
|
|
|
|
async def awrap_model_call(
|
|
self,
|
|
request: ModelRequest,
|
|
handler: Callable[[ModelRequest], Any],
|
|
) -> Any:
|
|
"""Wrap model calls and inject memories into the system prompt, asynchronous version.
|
|
|
|
Args:
|
|
request: Model request
|
|
handler: Original handler
|
|
|
|
Returns:
|
|
Model response
|
|
"""
|
|
# Get the assembled memory prompt from agent_config
|
|
memory_prompt = self.agent_config._mem0_context
|
|
if not memory_prompt:
|
|
return await handler(request)
|
|
|
|
# Get the current system prompt
|
|
current_system_prompt = ""
|
|
if request.system_message:
|
|
content = request.system_message.content if hasattr(request.system_message, "content") else str(request.system_message)
|
|
# content may be a list or a string; make sure it is converted to str
|
|
if isinstance(content, list):
|
|
current_system_prompt = "\n".join(str(item) for item in content)
|
|
else:
|
|
current_system_prompt = str(content) if content else ""
|
|
|
|
# Make sure memory_prompt is also a string
|
|
if isinstance(memory_prompt, list):
|
|
memory_prompt = "\n".join(str(item) for item in memory_prompt)
|
|
else:
|
|
memory_prompt = str(memory_prompt) if memory_prompt else ""
|
|
|
|
# Update the system prompt
|
|
new_system_prompt = current_system_prompt + memory_prompt
|
|
return await handler(request.override(system_prompt=new_system_prompt))
|
|
|
|
|
|
def create_mem0_middleware(
|
|
bot_id: str,
|
|
user_identifier: str,
|
|
session_id: str,
|
|
agent_config: "AgentConfig",
|
|
enabled: bool = True,
|
|
semantic_search_top_k: int = 20,
|
|
mem0_manager: Optional[Mem0Manager] = None,
|
|
llm_instance: Optional["BaseChatModel"] = None,
|
|
) -> Optional[Mem0Middleware]:
|
|
"""Factory function for creating Mem0Middleware.
|
|
|
|
Args:
|
|
bot_id: Bot ID
|
|
user_identifier: User identifier
|
|
session_id: Session ID
|
|
agent_config: AgentConfig instance used to pass data between middlewares
|
|
enabled: Whether the middleware is enabled
|
|
semantic_search_top_k: Number of semantic search results to return
|
|
mem0_manager: Mem0Manager instance; uses the global instance if None
|
|
llm_instance: LangChain LLM instance used for Mem0 memory extraction and augmentation
|
|
|
|
Returns:
|
|
Mem0Middleware instance or None
|
|
"""
|
|
if not enabled:
|
|
return None
|
|
|
|
# Use the provided manager or get the default one
|
|
manager = mem0_manager or get_mem0_manager()
|
|
|
|
# Create configuration
|
|
config = Mem0Config(
|
|
enabled=True,
|
|
user_id=user_identifier,
|
|
agent_id=bot_id,
|
|
session_id=session_id,
|
|
semantic_search_top_k=semantic_search_top_k,
|
|
llm_instance=llm_instance,
|
|
)
|
|
|
|
return Mem0Middleware(mem0_manager=manager, config=config, agent_config=agent_config)
|