qwen_agent/agent/memori_middleware.py
朱潮 455a48409d feat: integrate Memori long-term memory system
Add Memori (https://github.com/MemoriLabs/Memori) integration for
persistent cross-session memory capabilities in both create_agent
and create_deep_agent.

## New Files

- agent/memori_config.py: MemoriConfig dataclass for configuration
- agent/memori_manager.py: MemoriManager for connection and instance management
- agent/memori_middleware.py: MemoriMiddleware for memory recall/storage
- tests/: Unit tests for Memori components

## Modified Files

- agent/agent_config.py: Added enable_memori, memori_semantic_search_top_k, etc.
- agent/deep_assistant.py: Integrated MemoriMiddleware into init_agent()
- utils/settings.py: Added MEMORI_* environment variables
- pyproject.toml: Added memori>=3.1.0 dependency

## Features

- Semantic memory search with configurable top-k and threshold
- Multi-tenant isolation (entity_id=user, process_id=bot, session_id)
- Memory injection into system prompt
- Background asynchronous memory augmentation
- Graceful degradation when Memori is unavailable

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-20 00:12:43 +08:00

343 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Memori Agent 中间件
实现记忆召回和存储的 AgentMiddleware
"""
import asyncio
import logging
from typing import Any, Dict, List, Optional
from langchain.agents.middleware import AgentMiddleware, AgentState
from langgraph.runtime import Runtime
from .memori_config import MemoriConfig
from .memori_manager import MemoriManager, get_memori_manager
logger = logging.getLogger("app")
class MemoriMiddleware(AgentMiddleware):
"""
Memori 记忆中间件
功能:
1. before_agent: 召回相关记忆并注入到上下文
2. after_agent: 后台异步提取和存储新记忆
"""
def __init__(
self,
memori_manager: MemoriManager,
config: MemoriConfig,
):
"""初始化 MemoriMiddleware
Args:
memori_manager: MemoriManager 实例
config: MemoriConfig 配置
"""
self.memori_manager = memori_manager
self.config = config
def _extract_user_query(self, state: AgentState) -> str:
"""从状态中提取用户查询
Args:
state: Agent 状态
Returns:
用户查询文本
"""
messages = state.get("messages", [])
if not messages:
return ""
# 获取最后一条消息
last_message = messages[-1]
# 尝试获取内容
content = getattr(last_message, "content", None)
if content is None:
content = last_message.get("content", "") if isinstance(last_message, dict) else ""
return str(content) if content else ""
def _format_memories(self, memories: List[Dict[str, Any]]) -> str:
"""格式化记忆列表为文本
Args:
memories: 记忆列表
Returns:
格式化的记忆文本
"""
if not memories:
return ""
lines = []
for i, memory in enumerate(memories, 1):
content = memory.get("content", "")
similarity = memory.get("similarity", 0.0)
fact_type = memory.get("fact_type", "fact")
# 添加相似度分数(调试用)
lines.append(f"{i}. [{fact_type}] {content}")
return "\n".join(lines)
def _inject_memory_context(self, state: AgentState, memory_text: str) -> AgentState:
"""将记忆上下文注入到状态中
Args:
state: 原始状态
memory_text: 记忆文本
Returns:
更新后的状态
"""
if not memory_text or not self.config.inject_memory_to_system_prompt:
return state
# 生成记忆提示
memory_prompt = self.config.get_memory_prompt([memory_text])
# 检查是否有系统消息
messages = state.get("messages", [])
if not messages:
return state
# 在系统消息后添加记忆上下文
from langchain_core.messages import SystemMessage
# 查找系统消息
system_message = None
for msg in messages:
if hasattr(msg, "type") and msg.type == "system":
system_message = msg
break
elif isinstance(msg, dict) and msg.get("role") == "system":
system_message = msg
break
if system_message:
# 修改现有系统消息
if hasattr(system_message, "content"):
original_content = system_message.content
system_message.content = original_content + memory_prompt
elif isinstance(system_message, dict):
original_content = system_message.get("content", "")
system_message["content"] = original_content + memory_prompt
else:
# 添加新的系统消息
new_messages = list(messages)
new_messages.insert(0, SystemMessage(content=memory_prompt))
state = {**state, "messages": new_messages}
return state
def before_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None:
"""Agent 执行前:召回相关记忆(同步版本)
Args:
state: Agent 状态
runtime: 运行时上下文
Returns:
更新后的状态或 None
"""
if not self.config.is_enabled():
return None
try:
# 提取用户查询
query = self._extract_user_query(state)
if not query:
return None
# 获取 attribution 参数
entity_id, process_id = self.config.get_attribution_tuple()
session_id = self.config.session_id or runtime.config.get("configurable", {}).get("thread_id", "default")
# 召回记忆(同步方式 - 在后台任务中执行)
memories = asyncio.run(self._recall_memories_async(query, entity_id, process_id, session_id))
if memories:
# 格式化记忆
memory_text = self._format_memories(memories)
# 注入到状态
updated_state = self._inject_memory_context(state, memory_text)
logger.info(f"Injected {len(memories)} memories into context")
return updated_state
return None
except Exception as e:
logger.error(f"Error in MemoriMiddleware.before_agent: {e}")
return None
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> Dict[str, Any] | None:
"""Agent 执行前:召回相关记忆(异步版本)
Args:
state: Agent 状态
runtime: 运行时上下文
Returns:
更新后的状态或 None
"""
if not self.config.is_enabled():
return None
try:
# 提取用户查询
query = self._extract_user_query(state)
if not query:
logger.debug("No user query found, skipping memory recall")
return None
# 获取 attribution 参数
entity_id, process_id = self.config.get_attribution_tuple()
session_id = self.config.session_id or runtime.config.get("configurable", {}).get("thread_id", "default")
# 召回记忆
memories = await self._recall_memories_async(query, entity_id, process_id, session_id)
if memories:
# 格式化记忆
memory_text = self._format_memories(memories)
# 注入到状态
updated_state = self._inject_memory_context(state, memory_text)
logger.info(f"Injected {len(memories)} memories into context (similarity > {self.config.semantic_search_threshold})")
return updated_state
return None
except Exception as e:
logger.error(f"Error in MemoriMiddleware.abefore_agent: {e}")
return None
async def _recall_memories_async(
self, query: str, entity_id: str, process_id: str, session_id: str
) -> List[Dict[str, Any]]:
"""异步召回记忆
Args:
query: 查询文本
entity_id: 实体 ID
process_id: 进程 ID
session_id: 会话 ID
Returns:
记忆列表
"""
return await self.memori_manager.recall_memories(
query=query,
entity_id=entity_id,
process_id=process_id,
session_id=session_id,
config=self.config,
)
def after_agent(self, state: AgentState, runtime: Runtime) -> None:
"""Agent 执行后:触发记忆增强(同步版本)
Args:
state: Agent 状态
runtime: 运行时上下文
"""
if not self.config.is_enabled() or not self.config.augmentation_enabled:
return
try:
# 触发后台增强任务
asyncio.create_task(self._trigger_augmentation_async(state, runtime))
except Exception as e:
logger.error(f"Error in MemoriMiddleware.after_agent: {e}")
async def aafter_agent(self, state: AgentState, runtime: Runtime) -> None:
"""Agent 执行后:触发记忆增强(异步版本)
注意Memori 的增强会自动在后台执行,这里主要是记录日志
Args:
state: Agent 状态
runtime: 运行时上下文
"""
if not self.config.is_enabled() or not self.config.augmentation_enabled:
return
try:
# 如果配置了等待超时,则等待增强完成
if self.config.augmentation_wait_timeout is not None:
entity_id, process_id = self.config.get_attribution_tuple()
session_id = self.config.session_id or runtime.config.get("configurable", {}).get("thread_id", "default")
await self.memori_manager.wait_for_augmentation(
entity_id=entity_id,
process_id=process_id,
session_id=session_id,
timeout=self.config.augmentation_wait_timeout,
)
except Exception as e:
logger.error(f"Error in MemoriMiddleware.aafter_agent: {e}")
async def _trigger_augmentation_async(self, state: AgentState, runtime: Runtime) -> None:
"""触发记忆增强任务
注意Memori 的 LLM 客户端注册后会自动捕获对话并进行增强,
这里不需要手动触发,只是确保会话正确设置
Args:
state: Agent 状态
runtime: 运行时上下文
"""
# Memori 的增强是自动的,这里主要是确保配置正确
# 如果需要手动触发,可以在这里实现
pass
def create_memori_middleware(
bot_id: str,
user_identifier: str,
session_id: str,
enabled: bool = True,
semantic_search_top_k: int = 5,
semantic_search_threshold: float = 0.7,
memori_manager: Optional[MemoriManager] = None,
) -> Optional[MemoriMiddleware]:
"""创建 MemoriMiddleware 的工厂函数
Args:
bot_id: Bot ID
user_identifier: 用户标识
session_id: 会话 ID
enabled: 是否启用
semantic_search_top_k: 语义搜索返回数量
semantic_search_threshold: 语义搜索相似度阈值
memori_manager: MemoriManager 实例(如果为 None使用全局实例
Returns:
MemoriMiddleware 实例或 None
"""
if not enabled:
return None
# 获取或使用提供的 manager
manager = memori_manager or get_memori_manager()
# 创建配置
config = MemoriConfig(
enabled=True,
entity_id=user_identifier,
process_id=bot_id,
session_id=session_id,
semantic_search_top_k=semantic_search_top_k,
semantic_search_threshold=semantic_search_threshold,
)
return MemoriMiddleware(memori_manager=manager, config=config)