修改memory为异步任务

This commit is contained in:
朱潮 2026-01-21 22:30:05 +08:00
parent 8daa37c4c7
commit 89f9554be5
2 changed files with 71 additions and 4 deletions

View File

@ -3,7 +3,9 @@ Mem0 Agent 中间件
实现记忆召回和存储的 AgentMiddleware
"""
import asyncio
import logging
import threading
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from langchain.agents.middleware import AgentMiddleware, AgentState, ModelRequest
@ -11,7 +13,6 @@ from langgraph.runtime import Runtime
from .mem0_config import Mem0Config
from .mem0_manager import Mem0Manager, get_mem0_manager
import asyncio
logger = logging.getLogger("app")
@ -215,6 +216,8 @@ class Mem0Middleware(AgentMiddleware):
def after_agent(self, state: AgentState, runtime: Runtime) -> None:
"""Agent 执行后:触发记忆增强(同步版本)
使用后台线程执行避免阻塞主流程
Args:
state: Agent 状态
runtime: 运行时上下文
@ -223,14 +226,21 @@ class Mem0Middleware(AgentMiddleware):
return
try:
# 触发后台增强任务
asyncio.create_task(self._trigger_augmentation_async(state, runtime))
# 在后台线程中执行,完全不阻塞主流程
thread = threading.Thread(
target=self._trigger_augmentation_sync,
args=(state, runtime),
daemon=True,
)
thread.start()
except Exception as e:
logger.error(f"Error in Mem0Middleware.after_agent: {e}")
async def aafter_agent(self, state: AgentState, runtime: Runtime) -> None:
"""Agent 执行后:触发记忆增强(异步版本)
使用后台线程执行避免阻塞事件循环
Args:
state: Agent 状态
runtime: 运行时上下文
@ -239,10 +249,57 @@ class Mem0Middleware(AgentMiddleware):
return
try:
asyncio.create_task(self._trigger_augmentation_async(state, runtime))
# 在后台线程中执行,完全不阻塞事件循环
thread = threading.Thread(
target=self._trigger_augmentation_sync,
args=(state, runtime),
daemon=True,
)
thread.start()
except Exception as e:
logger.error(f"Error in Mem0Middleware.aafter_agent: {e}")
def _trigger_augmentation_sync(self, state: AgentState, runtime: Runtime) -> None:
"""触发记忆增强任务(同步版本,在线程中执行)
从对话中提取信息并存储到 Mem0用户级别跨会话
Args:
state: Agent 状态
runtime: 运行时上下文
"""
try:
# 获取 attribution 参数
user_id, agent_id = self.config.get_attribution_tuple()
# 提取用户查询和 Agent 响应
user_query = self._extract_user_query(state)
agent_response = self._extract_agent_response(state)
# 将对话作为记忆存储(用户级别)
if user_query and agent_response:
conversation_text = f"User: {user_query}\nAssistant: {agent_response}"
# 在新的事件循环中运行异步代码(因为在线程中)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
self.mem0_manager.add_memory(
text=conversation_text,
user_id=user_id,
agent_id=agent_id,
metadata={"type": "conversation"},
config=self.config,
)
)
logger.debug(f"Stored conversation as memory for user={user_id}, agent={agent_id}")
finally:
loop.close()
except Exception as e:
logger.error(f"Error in _trigger_augmentation_sync: {e}")
async def _trigger_augmentation_async(self, state: AgentState, runtime: Runtime) -> None:
"""触发记忆增强任务

View File

@ -1909,6 +1909,13 @@
<label class="settings-label" for="user-identifier">用户标识</label>
<input type="text" id="user-identifier" class="settings-input" placeholder="输入用户标识...">
</div>
<div class="settings-group">
<div class="settings-checkbox-wrapper">
<input type="checkbox" id="enable-memori" class="settings-checkbox">
<label class="settings-label" for="enable-memori" style="margin-bottom: 0;">启用记忆存储</label>
</div>
<p style="font-size: 11px; color: var(--text-muted); margin-top: 4px;">启用后AI 会记住对话中的信息以提供更个性化的回复</p>
</div>
</div>
</div>
@ -2840,6 +2847,7 @@
'dataset-ids': document.getElementById('dataset-ids').value,
'system-prompt': document.getElementById('system-prompt').value,
'user-identifier': document.getElementById('user-identifier').value,
'enable-memori': document.getElementById('enable-memori').checked,
'skills': selectedSkills.join(','),
'mcp-settings': mcpSettingsValue,
'tool-response': document.getElementById('tool-response').checked
@ -3233,6 +3241,7 @@
systemPrompt: getValue('system-prompt'),
sessionId,
userIdentifier: getValue('user-identifier'),
enableMemori: getChecked('enable-memori'),
skills,
mcpSettings,
toolResponse: getChecked('tool-response')
@ -3490,6 +3499,7 @@
if (settings.systemPrompt) requestBody.system_prompt = settings.systemPrompt;
if (settings.sessionId) requestBody.session_id = settings.sessionId;
if (settings.userIdentifier) requestBody.user_identifier = settings.userIdentifier;
if (settings.enableMemori) requestBody.enable_memori = settings.enableMemori;
if (settings.skills?.length) requestBody.skills = settings.skills;
if (settings.datasetIds?.length) requestBody.dataset_ids = settings.datasetIds;
if (settings.mcpSettings?.length) requestBody.mcp_settings = settings.mcpSettings;