Add Memori (https://github.com/MemoriLabs/Memori) integration for persistent cross-session memory capabilities in both create_agent and create_deep_agent. ## New Files - agent/memori_config.py: MemoriConfig dataclass for configuration - agent/memori_manager.py: MemoriManager for connection and instance management - agent/memori_middleware.py: MemoriMiddleware for memory recall/storage - tests/: Unit tests for Memori components ## Modified Files - agent/agent_config.py: Added enable_memori, memori_semantic_search_top_k, etc. - agent/deep_assistant.py: Integrated MemoriMiddleware into init_agent() - utils/settings.py: Added MEMORI_* environment variables - pyproject.toml: Added memori>=3.1.0 dependency ## Features - Semantic memory search with configurable top-k and threshold - Multi-tenant isolation (entity_id=user, process_id=bot, session_id) - Memory injection into system prompt - Background asynchronous memory augmentation - Graceful degradation when Memori is unavailable 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
343 lines
11 KiB
Python
343 lines
11 KiB
Python
"""
|
||
Memori Agent 中间件
|
||
实现记忆召回和存储的 AgentMiddleware
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||
from langgraph.runtime import Runtime
|
||
|
||
from .memori_config import MemoriConfig
|
||
from .memori_manager import MemoriManager, get_memori_manager
|
||
|
||
logger = logging.getLogger("app")
|
||
|
||
|
||
class MemoriMiddleware(AgentMiddleware):
|
||
"""
|
||
Memori 记忆中间件
|
||
|
||
功能:
|
||
1. before_agent: 召回相关记忆并注入到上下文
|
||
2. after_agent: 后台异步提取和存储新记忆
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
memori_manager: MemoriManager,
|
||
config: MemoriConfig,
|
||
):
|
||
"""初始化 MemoriMiddleware
|
||
|
||
Args:
|
||
memori_manager: MemoriManager 实例
|
||
config: MemoriConfig 配置
|
||
"""
|
||
self.memori_manager = memori_manager
|
||
self.config = config
|
||
|
||
def _extract_user_query(self, state: AgentState) -> str:
|
||
"""从状态中提取用户查询
|
||
|
||
Args:
|
||
state: Agent 状态
|
||
|
||
Returns:
|
||
用户查询文本
|
||
"""
|
||
messages = state.get("messages", [])
|
||
if not messages:
|
||
return ""
|
||
|
||
# 获取最后一条消息
|
||
last_message = messages[-1]
|
||
|
||
# 尝试获取内容
|
||
content = getattr(last_message, "content", None)
|
||
if content is None:
|
||
content = last_message.get("content", "") if isinstance(last_message, dict) else ""
|
||
|
||
return str(content) if content else ""
|
||
|
||
def _format_memories(self, memories: List[Dict[str, Any]]) -> str:
|
||
"""格式化记忆列表为文本
|
||
|
||
Args:
|
||
memories: 记忆列表
|
||
|
||
Returns:
|
||
格式化的记忆文本
|
||
"""
|
||
if not memories:
|
||
return ""
|
||
|
||
lines = []
|
||
for i, memory in enumerate(memories, 1):
|
||
content = memory.get("content", "")
|
||
similarity = memory.get("similarity", 0.0)
|
||
fact_type = memory.get("fact_type", "fact")
|
||
|
||
# 添加相似度分数(调试用)
|
||
lines.append(f"{i}. [{fact_type}] {content}")
|
||
|
||
return "\n".join(lines)
|
||
|
||
def _inject_memory_context(self, state: AgentState, memory_text: str) -> AgentState:
|
||
"""将记忆上下文注入到状态中
|
||
|
||
Args:
|
||
state: 原始状态
|
||
memory_text: 记忆文本
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
if not memory_text or not self.config.inject_memory_to_system_prompt:
|
||
return state
|
||
|
||
# 生成记忆提示
|
||
memory_prompt = self.config.get_memory_prompt([memory_text])
|
||
|
||
# 检查是否有系统消息
|
||
messages = state.get("messages", [])
|
||
if not messages:
|
||
return state
|
||
|
||
# 在系统消息后添加记忆上下文
|
||
from langchain_core.messages import SystemMessage
|
||
|
||
# 查找系统消息
|
||
system_message = None
|
||
for msg in messages:
|
||
if hasattr(msg, "type") and msg.type == "system":
|
||
system_message = msg
|
||
break
|
||
elif isinstance(msg, dict) and msg.get("role") == "system":
|
||
system_message = msg
|
||
break
|
||
|
||
if system_message:
|
||
# 修改现有系统消息
|
||
if hasattr(system_message, "content"):
|
||
original_content = system_message.content
|
||
system_message.content = original_content + memory_prompt
|
||
elif isinstance(system_message, dict):
|
||
original_content = system_message.get("content", "")
|
||
system_message["content"] = original_content + memory_prompt
|
||
else:
|
||
# 添加新的系统消息
|
||
new_messages = list(messages)
|
||
new_messages.insert(0, SystemMessage(content=memory_prompt))
|
||
state = {**state, "messages": new_messages}
|
||
|
||
return state
|
||
|
||
def before_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None:
|
||
"""Agent 执行前:召回相关记忆(同步版本)
|
||
|
||
Args:
|
||
state: Agent 状态
|
||
runtime: 运行时上下文
|
||
|
||
Returns:
|
||
更新后的状态或 None
|
||
"""
|
||
if not self.config.is_enabled():
|
||
return None
|
||
|
||
try:
|
||
# 提取用户查询
|
||
query = self._extract_user_query(state)
|
||
if not query:
|
||
return None
|
||
|
||
# 获取 attribution 参数
|
||
entity_id, process_id = self.config.get_attribution_tuple()
|
||
session_id = self.config.session_id or runtime.config.get("configurable", {}).get("thread_id", "default")
|
||
|
||
# 召回记忆(同步方式 - 在后台任务中执行)
|
||
memories = asyncio.run(self._recall_memories_async(query, entity_id, process_id, session_id))
|
||
|
||
if memories:
|
||
# 格式化记忆
|
||
memory_text = self._format_memories(memories)
|
||
|
||
# 注入到状态
|
||
updated_state = self._inject_memory_context(state, memory_text)
|
||
|
||
logger.info(f"Injected {len(memories)} memories into context")
|
||
return updated_state
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in MemoriMiddleware.before_agent: {e}")
|
||
return None
|
||
|
||
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None:
|
||
"""Agent 执行前:召回相关记忆(异步版本)
|
||
|
||
Args:
|
||
state: Agent 状态
|
||
runtime: 运行时上下文
|
||
|
||
Returns:
|
||
更新后的状态或 None
|
||
"""
|
||
if not self.config.is_enabled():
|
||
return None
|
||
|
||
try:
|
||
# 提取用户查询
|
||
query = self._extract_user_query(state)
|
||
if not query:
|
||
logger.debug("No user query found, skipping memory recall")
|
||
return None
|
||
|
||
# 获取 attribution 参数
|
||
entity_id, process_id = self.config.get_attribution_tuple()
|
||
session_id = self.config.session_id or runtime.config.get("configurable", {}).get("thread_id", "default")
|
||
|
||
# 召回记忆
|
||
memories = await self._recall_memories_async(query, entity_id, process_id, session_id)
|
||
|
||
if memories:
|
||
# 格式化记忆
|
||
memory_text = self._format_memories(memories)
|
||
|
||
# 注入到状态
|
||
updated_state = self._inject_memory_context(state, memory_text)
|
||
|
||
logger.info(f"Injected {len(memories)} memories into context (similarity > {self.config.semantic_search_threshold})")
|
||
return updated_state
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in MemoriMiddleware.abefore_agent: {e}")
|
||
return None
|
||
|
||
async def _recall_memories_async(
|
||
self, query: str, entity_id: str, process_id: str, session_id: str
|
||
) -> List[Dict[str, Any]]:
|
||
"""异步召回记忆
|
||
|
||
Args:
|
||
query: 查询文本
|
||
entity_id: 实体 ID
|
||
process_id: 进程 ID
|
||
session_id: 会话 ID
|
||
|
||
Returns:
|
||
记忆列表
|
||
"""
|
||
return await self.memori_manager.recall_memories(
|
||
query=query,
|
||
entity_id=entity_id,
|
||
process_id=process_id,
|
||
session_id=session_id,
|
||
config=self.config,
|
||
)
|
||
|
||
def after_agent(self, state: AgentState, runtime: Runtime) -> None:
|
||
"""Agent 执行后:触发记忆增强(同步版本)
|
||
|
||
Args:
|
||
state: Agent 状态
|
||
runtime: 运行时上下文
|
||
"""
|
||
if not self.config.is_enabled() or not self.config.augmentation_enabled:
|
||
return
|
||
|
||
try:
|
||
# 触发后台增强任务
|
||
asyncio.create_task(self._trigger_augmentation_async(state, runtime))
|
||
except Exception as e:
|
||
logger.error(f"Error in MemoriMiddleware.after_agent: {e}")
|
||
|
||
async def aafter_agent(self, state: AgentState, runtime: Runtime) -> None:
|
||
"""Agent 执行后:触发记忆增强(异步版本)
|
||
|
||
注意:Memori 的增强会自动在后台执行,这里主要是记录日志
|
||
|
||
Args:
|
||
state: Agent 状态
|
||
runtime: 运行时上下文
|
||
"""
|
||
if not self.config.is_enabled() or not self.config.augmentation_enabled:
|
||
return
|
||
|
||
try:
|
||
# 如果配置了等待超时,则等待增强完成
|
||
if self.config.augmentation_wait_timeout is not None:
|
||
entity_id, process_id = self.config.get_attribution_tuple()
|
||
session_id = self.config.session_id or runtime.config.get("configurable", {}).get("thread_id", "default")
|
||
|
||
await self.memori_manager.wait_for_augmentation(
|
||
entity_id=entity_id,
|
||
process_id=process_id,
|
||
session_id=session_id,
|
||
timeout=self.config.augmentation_wait_timeout,
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Error in MemoriMiddleware.aafter_agent: {e}")
|
||
|
||
async def _trigger_augmentation_async(self, state: AgentState, runtime: Runtime) -> None:
|
||
"""触发记忆增强任务
|
||
|
||
注意:Memori 的 LLM 客户端注册后会自动捕获对话并进行增强,
|
||
这里不需要手动触发,只是确保会话正确设置
|
||
|
||
Args:
|
||
state: Agent 状态
|
||
runtime: 运行时上下文
|
||
"""
|
||
# Memori 的增强是自动的,这里主要是确保配置正确
|
||
# 如果需要手动触发,可以在这里实现
|
||
pass
|
||
|
||
|
||
def create_memori_middleware(
|
||
bot_id: str,
|
||
user_identifier: str,
|
||
session_id: str,
|
||
enabled: bool = True,
|
||
semantic_search_top_k: int = 5,
|
||
semantic_search_threshold: float = 0.7,
|
||
memori_manager: Optional[MemoriManager] = None,
|
||
) -> Optional[MemoriMiddleware]:
|
||
"""创建 MemoriMiddleware 的工厂函数
|
||
|
||
Args:
|
||
bot_id: Bot ID
|
||
user_identifier: 用户标识
|
||
session_id: 会话 ID
|
||
enabled: 是否启用
|
||
semantic_search_top_k: 语义搜索返回数量
|
||
semantic_search_threshold: 语义搜索相似度阈值
|
||
memori_manager: MemoriManager 实例(如果为 None,使用全局实例)
|
||
|
||
Returns:
|
||
MemoriMiddleware 实例或 None
|
||
"""
|
||
if not enabled:
|
||
return None
|
||
|
||
# 获取或使用提供的 manager
|
||
manager = memori_manager or get_memori_manager()
|
||
|
||
# 创建配置
|
||
config = MemoriConfig(
|
||
enabled=True,
|
||
entity_id=user_identifier,
|
||
process_id=bot_id,
|
||
session_id=session_id,
|
||
semantic_search_top_k=semantic_search_top_k,
|
||
semantic_search_threshold=semantic_search_threshold,
|
||
)
|
||
|
||
return MemoriMiddleware(memori_manager=manager, config=config)
|