qwen_agent/agent/mem0_middleware.py
朱潮 425f3c5bb4 chore: replace Chinese comments and log messages with English
Convert all Chinese comments, docstrings, logger/print output,
HTTPException detail messages, and API response messages to English
across the entire codebase. Functional zh/ja localized strings
(e.g. prompt templates, timezone display names, date formats) are
preserved as-is.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-30 19:45:35 +08:00

460 lines
15 KiB
Python

"""
Mem0 agent middleware.
Implements memory recall and storage for AgentMiddleware.
"""
import asyncio
import logging
import threading
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from langchain.agents.middleware import AgentMiddleware, AgentState, ModelRequest
from langgraph.runtime import Runtime
from .mem0_config import Mem0Config
from .mem0_manager import Mem0Manager, get_mem0_manager
logger = logging.getLogger("app")
# Avoid circular imports
if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
class Mem0Middleware(AgentMiddleware):
"""
Mem0 memory middleware.
Features:
1. before_agent: recall relevant memories and inject them into context
2. after_agent: asynchronously extract and store new memories in the background
"""
def __init__(
self,
mem0_manager: Mem0Manager,
config: Mem0Config,
agent_config: "AgentConfig",
):
"""Initialize Mem0Middleware.
Args:
mem0_manager: Mem0Manager instance
config: Mem0Config configuration
agent_config: AgentConfig instance used to pass data between middlewares
"""
self.mem0_manager = mem0_manager
self.config = config
self.agent_config = agent_config
def _extract_user_query(self, state: AgentState) -> str:
"""Extract the user query from state, using the last HumanMessage.
Args:
state: Agent state
Returns:
User query text
"""
from langchain_core.messages import HumanMessage
messages = state.get("messages", [])
if not messages:
return ""
# Find the last HumanMessage
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
return str(msg.content) if msg.content else ""
return ""
def _extract_agent_response(self, state: AgentState) -> str:
"""Extract the agent response from state, using the last AIMessage.
Args:
state: Agent state
Returns:
Agent response text
"""
from langchain_core.messages import AIMessage
messages = state.get("messages", [])
if not messages:
return ""
# Find the last AIMessage
for msg in reversed(messages):
if isinstance(msg, AIMessage):
return str(msg.content) if msg.content else ""
return ""
def _format_memories(self, memories: List[Dict[str, Any]]) -> str:
"""Format a list of memories as text.
Args:
memories: List of memories
Returns:
Formatted memory text
"""
if not memories:
return ""
lines = []
for i, memory in enumerate(memories, 1):
content = memory.get("content", "")
fact_type = memory.get("fact_type", "fact")
lines.append(f"{i}. [{fact_type}] {content}")
return "\n".join(lines)
def before_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None:
"""Recall relevant memories before agent execution, synchronous version.
Args:
state: Agent state
runtime: Runtime context
Returns:
Updated state or None
"""
if not self.config.is_enabled():
return None
try:
# Extract the user query
query = self._extract_user_query(state)
if not query:
return None
# Get attribution parameters
user_id, agent_id = self.config.get_attribution_tuple()
session_id = self.config.session_id or runtime.config.get("configurable", {}).get("thread_id", "default")
# Recall memories synchronously by running the async method
memories = asyncio.run(self._recall_memories_async(query, user_id, agent_id, session_id))
if memories:
# Format memories and append them to the memory prompt
memory_text = self._format_memories(memories)
memory_prompt = self.config.get_memory_prompt([memory_text])
self.agent_config._mem0_context = memory_prompt
logger.info(f"Recalled {len(memories)} memories for context")
else:
self.agent_config._mem0_context = None
return state
except Exception as e:
logger.error(f"Error in Mem0Middleware.before_agent: {e}")
return None
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None:
"""Recall relevant memories before agent execution, asynchronous version.
Args:
state: Agent state
runtime: Runtime context
Returns:
Updated state or None
"""
if not self.config.is_enabled():
return None
try:
# Extract the user query
query = self._extract_user_query(state)
if not query:
logger.debug("No user query found, skipping memory recall")
return None
# Get attribution parameters
user_id, agent_id = self.config.get_attribution_tuple()
# Recall user-level memories across sessions
memories = await self._recall_memories_async(query, user_id, agent_id)
if memories:
# Format memories and append them to the memory prompt
memory_text = self._format_memories(memories)
memory_prompt = self.config.get_memory_prompt([memory_text])
self.agent_config._mem0_context = memory_prompt
logger.info(f"Recalled {len(memories)} memories for context")
else:
self.agent_config._mem0_context = None
return state
except Exception as e:
logger.error(f"Error in Mem0Middleware.abefore_agent: {e}")
return None
async def _recall_memories_async(
self, query: str, user_id: str, agent_id: str
) -> List[Dict[str, Any]]:
"""Recall memories asynchronously.
Args:
query: Query text
user_id: User ID
agent_id: Agent/Bot ID
Returns:
List of memories
"""
return await self.mem0_manager.recall_memories(
query=query,
user_id=user_id,
agent_id=agent_id,
config=self.config,
)
def after_agent(self, state: AgentState, runtime: Runtime) -> None:
"""Trigger memory augmentation after agent execution, synchronous version.
Runs in a background thread to avoid blocking the main flow.
Args:
state: Agent state
runtime: Runtime context
"""
if not self.config.is_enabled():
return
try:
# Run in a background thread so the main flow is never blocked
thread = threading.Thread(
target=self._trigger_augmentation_sync,
args=(state, runtime),
daemon=True,
)
thread.start()
except Exception as e:
logger.error(f"Error in Mem0Middleware.after_agent: {e}")
async def aafter_agent(self, state: AgentState, runtime: Runtime) -> None:
"""Trigger memory augmentation after agent execution, asynchronous version.
Runs in a background thread to avoid blocking the event loop.
Args:
state: Agent state
runtime: Runtime context
"""
if not self.config.is_enabled():
return
try:
# Run in a background thread so the event loop is never blocked
thread = threading.Thread(
target=self._trigger_augmentation_sync,
args=(state, runtime),
daemon=True,
)
thread.start()
except Exception as e:
logger.error(f"Error in Mem0Middleware.aafter_agent: {e}")
def _trigger_augmentation_sync(self, state: AgentState, runtime: Runtime) -> None:
"""Trigger the memory augmentation task, synchronous version executed in a thread.
Extracts information from the conversation and stores it in Mem0
at the user level across sessions.
Args:
state: Agent state
runtime: Runtime context
"""
try:
# Get attribution parameters
user_id, agent_id = self.config.get_attribution_tuple()
# Extract the user query and agent response
user_query = self._extract_user_query(state)
agent_response = self._extract_agent_response(state)
# Store the conversation as user-level memory
if user_query and agent_response:
conversation_text = f"User: {user_query}\nAssistant: {agent_response}"
# Run async code in a new event loop because this runs in a thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
self.mem0_manager.add_memory(
text=conversation_text,
user_id=user_id,
agent_id=agent_id,
metadata={"type": "conversation"},
config=self.config,
)
)
logger.debug(f"Stored conversation as memory for user={user_id}, agent={agent_id}")
finally:
loop.close()
except Exception as e:
logger.error(f"Error in _trigger_augmentation_sync: {e}")
async def _trigger_augmentation_async(self, state: AgentState, runtime: Runtime) -> None:
"""Trigger the memory augmentation task.
Extracts information from the conversation and stores it in Mem0
at the user level across sessions.
Args:
state: Agent state
runtime: Runtime context
"""
try:
# Get attribution parameters
user_id, agent_id = self.config.get_attribution_tuple()
# Extract the user query and agent response
user_query = self._extract_user_query(state)
agent_response = self._extract_agent_response(state)
# Store the conversation as user-level memory
if user_query and agent_response:
conversation_text = f"User: {user_query}\nAssistant: {agent_response}"
await self.mem0_manager.add_memory(
text=conversation_text,
user_id=user_id,
agent_id=agent_id,
metadata={"type": "conversation"},
config=self.config,
)
logger.debug(f"Stored conversation as memory for user={user_id}, agent={agent_id}")
except Exception as e:
logger.error(f"Error in _trigger_augmentation_async: {e}")
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Any],
) -> Any:
"""Wrap model calls and inject memories into the system prompt, synchronous version.
Args:
request: Model request
handler: Original handler
Returns:
Model response
"""
# Get the assembled memory prompt from agent_config
memory_prompt = self.agent_config._mem0_context
if not memory_prompt:
return handler(request)
# Get the current system prompt
current_system_prompt = ""
if request.system_message:
content = request.system_message.content if hasattr(request.system_message, "content") else str(request.system_message)
# content may be a list or a string; make sure it is converted to str
if isinstance(content, list):
current_system_prompt = "\n".join(str(item) for item in content)
else:
current_system_prompt = str(content) if content else ""
# Make sure memory_prompt is also a string
if isinstance(memory_prompt, list):
memory_prompt = "\n".join(str(item) for item in memory_prompt)
else:
memory_prompt = str(memory_prompt) if memory_prompt else ""
# Update the system prompt
new_system_prompt = current_system_prompt + memory_prompt
return handler(request.override(system_prompt=new_system_prompt))
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Any],
) -> Any:
"""Wrap model calls and inject memories into the system prompt, asynchronous version.
Args:
request: Model request
handler: Original handler
Returns:
Model response
"""
# Get the assembled memory prompt from agent_config
memory_prompt = self.agent_config._mem0_context
if not memory_prompt:
return await handler(request)
# Get the current system prompt
current_system_prompt = ""
if request.system_message:
content = request.system_message.content if hasattr(request.system_message, "content") else str(request.system_message)
# content may be a list or a string; make sure it is converted to str
if isinstance(content, list):
current_system_prompt = "\n".join(str(item) for item in content)
else:
current_system_prompt = str(content) if content else ""
# Make sure memory_prompt is also a string
if isinstance(memory_prompt, list):
memory_prompt = "\n".join(str(item) for item in memory_prompt)
else:
memory_prompt = str(memory_prompt) if memory_prompt else ""
# Update the system prompt
new_system_prompt = current_system_prompt + memory_prompt
return await handler(request.override(system_prompt=new_system_prompt))
def create_mem0_middleware(
bot_id: str,
user_identifier: str,
session_id: str,
agent_config: "AgentConfig",
enabled: bool = True,
semantic_search_top_k: int = 20,
mem0_manager: Optional[Mem0Manager] = None,
llm_instance: Optional["BaseChatModel"] = None,
) -> Optional[Mem0Middleware]:
"""Factory function for creating Mem0Middleware.
Args:
bot_id: Bot ID
user_identifier: User identifier
session_id: Session ID
agent_config: AgentConfig instance used to pass data between middlewares
enabled: Whether the middleware is enabled
semantic_search_top_k: Number of semantic search results to return
mem0_manager: Mem0Manager instance; uses the global instance if None
llm_instance: LangChain LLM instance used for Mem0 memory extraction and augmentation
Returns:
Mem0Middleware instance or None
"""
if not enabled:
return None
# Use the provided manager or get the default one
manager = mem0_manager or get_mem0_manager()
# Create configuration
config = Mem0Config(
enabled=True,
user_id=user_identifier,
agent_id=bot_id,
session_id=session_id,
semantic_search_top_k=semantic_search_top_k,
llm_instance=llm_instance,
)
return Mem0Middleware(mem0_manager=manager, config=config, agent_config=agent_config)