refactor(mem0): optimize connection pool and async memory handling

- Fix mem0 connection pool exhausted error with proper pooling
- Convert memory operations to async tasks
- Optimize docker-compose configuration
- Add skill upload functionality
- Reduce cache size for better performance
- Update dependencies

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
朱潮 2026-01-22 19:39:12 +08:00
parent f29fd1fb54
commit 3dc119bca8
11 changed files with 243 additions and 41 deletions

View File

@ -105,9 +105,7 @@ class AgentConfig:
enable_thinking = request.enable_thinking and "<guidelines>" in request.system_prompt
# 从请求中获取 Mem0 配置,如果没有则使用全局配置
enable_memori = getattr(request, 'enable_memori', None)
if enable_memori is None:
enable_memori = MEM0_ENABLED
enable_memori = getattr(request, 'enable_memori', MEM0_ENABLED)
config = cls(
bot_id=request.bot_id,
@ -171,10 +169,7 @@ class AgentConfig:
enable_thinking = request.enable_thinking and "<guidelines>" in bot_config.get("system_prompt")
# 从请求或后端配置中获取 Mem0 配置
enable_memori = getattr(request, 'enable_memori', None)
if enable_memori is None:
enable_memori = bot_config.get("enable_memori", MEM0_ENABLED)
enable_memori = getattr(request, 'enable_memori', MEM0_ENABLED)
config = cls(
bot_id=request.bot_id,

View File

@ -12,6 +12,7 @@ from psycopg2 import pool as psycopg2_pool
from utils.settings import (
CHECKPOINT_DB_URL,
CHECKPOINT_POOL_SIZE,
MEM0_POOL_SIZE,
CHECKPOINT_CLEANUP_ENABLED,
CHECKPOINT_CLEANUP_INTERVAL_HOURS,
CHECKPOINT_CLEANUP_INACTIVE_DAYS,
@ -62,7 +63,7 @@ class DBPoolManager:
await self._pool.open()
# 2. 创建同步 psycopg2 连接池(供 Mem0 使用)
self._sync_pool = self._create_sync_pool(CHECKPOINT_DB_URL, CHECKPOINT_POOL_SIZE)
self._sync_pool = self._create_sync_pool(CHECKPOINT_DB_URL, MEM0_POOL_SIZE)
self._initialized = True
logger.info("PostgreSQL connection pool initialized successfully")

View File

@ -5,11 +5,16 @@ Mem0 连接和实例管理器
import logging
import asyncio
import threading
import concurrent.futures
from typing import Any, Dict, List, Optional, Literal
from collections import OrderedDict
from embedding.manager import GlobalModelManager, get_model_manager
import json_repair
from psycopg2 import pool
from utils.settings import (
MEM0_POOL_SIZE
)
from .mem0_config import Mem0Config
logger = logging.getLogger("app")
@ -27,7 +32,9 @@ class CustomMem0Embedding:
这样 Mem0 就不需要再次加载同一个模型节省内存
"""
_model_manager = None # 缓存 GlobalModelManager 实例
_model = None # 类变量,缓存模型实例
_lock = threading.Lock() # 线程安全锁
_executor = None # 线程池执行器
def __init__(self, config: Optional[Any] = None):
"""初始化自定义 Embedding"""
@ -41,6 +48,41 @@ class CustomMem0Embedding:
"""获取 embedding 维度"""
return 384 # gte-tiny 的维度
def _get_model_sync(self):
"""同步获取模型,避免 asyncio.run()"""
# 首先尝试从 manager 获取已加载的模型
manager = get_model_manager()
model = manager.get_model_sync()
if model is not None:
# 缓存模型
CustomMem0Embedding._model = model
return model
# 如果模型未加载,使用线程池运行异步初始化
if CustomMem0Embedding._executor is None:
CustomMem0Embedding._executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1,
thread_name_prefix="mem0_embed"
)
# 在独立线程中运行异步代码
def run_async_in_thread():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(manager.get_model())
return result
finally:
loop.close()
future = CustomMem0Embedding._executor.submit(run_async_in_thread)
model = future.result(timeout=30) # 30秒超时
# 缓存模型
CustomMem0Embedding._model = model
return model
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
"""
获取文本的 embedding 向量同步方法 Mem0 调用
@ -52,8 +94,13 @@ class CustomMem0Embedding:
Returns:
list: embedding 向量
"""
manager = get_model_manager()
model = asyncio.run(manager.get_model())
# 线程安全地获取模型
if CustomMem0Embedding._model is None:
with CustomMem0Embedding._lock:
if CustomMem0Embedding._model is None:
self._get_model_sync()
model = CustomMem0Embedding._model
embeddings = model.encode(text, convert_to_numpy=True)
return embeddings.tolist()
@ -136,8 +183,9 @@ class Mem0Manager:
"""
self._sync_pool = sync_pool
# 缓存 Mem0 实例: key = f"{user_id}:{agent_id}"
self._instances: Dict[str, Any] = {}
# 使用 OrderedDict 实现 LRU 缓存,最多保留 50 个实例
self._instances: OrderedDict[str, Any] = OrderedDict()
self._max_instances = MEM0_POOL_SIZE/2 # 最大缓存实例数
self._initialized = False
async def initialize(self) -> None:
@ -194,10 +242,16 @@ class Mem0Manager:
llm_suffix = f":{id(config.llm_instance)}"
cache_key = f"{user_id}:{agent_id}{llm_suffix}"
# 检查缓存
# 检查缓存(同时移动到末尾表示最近使用)
if cache_key in self._instances:
self._instances.move_to_end(cache_key)
return self._instances[cache_key]
# 检查缓存大小,超过则移除最旧的
if len(self._instances) >= self._max_instances:
removed_key, _ = self._instances.popitem(last=False)
logger.debug(f"Mem0 instance cache full, removed oldest entry: {removed_key}")
# 创建新实例
mem0_instance = await self._create_mem0_instance(
user_id=user_id,
@ -206,7 +260,7 @@ class Mem0Manager:
config=config,
)
# 缓存实例
# 缓存实例(新实例自动在末尾)
self._instances[cache_key] = mem0_instance
return mem0_instance

View File

@ -3,7 +3,9 @@ Mem0 Agent 中间件
实现记忆召回和存储的 AgentMiddleware
"""
import asyncio
import logging
import threading
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from langchain.agents.middleware import AgentMiddleware, AgentState, ModelRequest
@ -123,8 +125,6 @@ class Mem0Middleware(AgentMiddleware):
return None
try:
import asyncio
# 提取用户查询
query = self._extract_user_query(state)
if not query:
@ -216,6 +216,8 @@ class Mem0Middleware(AgentMiddleware):
def after_agent(self, state: AgentState, runtime: Runtime) -> None:
"""Agent 执行后:触发记忆增强(同步版本)
使用后台线程执行避免阻塞主流程
Args:
state: Agent 状态
runtime: 运行时上下文
@ -224,16 +226,21 @@ class Mem0Middleware(AgentMiddleware):
return
try:
import asyncio
# 触发后台增强任务
asyncio.create_task(self._trigger_augmentation_async(state, runtime))
# 在后台线程中执行,完全不阻塞主流程
thread = threading.Thread(
target=self._trigger_augmentation_sync,
args=(state, runtime),
daemon=True,
)
thread.start()
except Exception as e:
logger.error(f"Error in Mem0Middleware.after_agent: {e}")
async def aafter_agent(self, state: AgentState, runtime: Runtime) -> None:
"""Agent 执行后:触发记忆增强(异步版本)
使用后台线程执行避免阻塞事件循环
Args:
state: Agent 状态
runtime: 运行时上下文
@ -242,10 +249,57 @@ class Mem0Middleware(AgentMiddleware):
return
try:
await self._trigger_augmentation_async(state, runtime)
# 在后台线程中执行,完全不阻塞事件循环
thread = threading.Thread(
target=self._trigger_augmentation_sync,
args=(state, runtime),
daemon=True,
)
thread.start()
except Exception as e:
logger.error(f"Error in Mem0Middleware.aafter_agent: {e}")
def _trigger_augmentation_sync(self, state: AgentState, runtime: Runtime) -> None:
"""触发记忆增强任务(同步版本,在线程中执行)
从对话中提取信息并存储到 Mem0用户级别跨会话
Args:
state: Agent 状态
runtime: 运行时上下文
"""
try:
# 获取 attribution 参数
user_id, agent_id = self.config.get_attribution_tuple()
# 提取用户查询和 Agent 响应
user_query = self._extract_user_query(state)
agent_response = self._extract_agent_response(state)
# 将对话作为记忆存储(用户级别)
if user_query and agent_response:
conversation_text = f"User: {user_query}\nAssistant: {agent_response}"
# 在新的事件循环中运行异步代码(因为在线程中)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
self.mem0_manager.add_memory(
text=conversation_text,
user_id=user_id,
agent_id=agent_id,
metadata={"type": "conversation"},
config=self.config,
)
)
logger.debug(f"Stored conversation as memory for user={user_id}, agent={agent_id}")
finally:
loop.close()
except Exception as e:
logger.error(f"Error in _trigger_augmentation_sync: {e}")
async def _trigger_augmentation_async(self, state: AgentState, runtime: Runtime) -> None:
"""触发记忆增强任务

View File

@ -2,7 +2,7 @@ version: "3.8"
services:
postgres:
image: postgres:16-alpine
image: pgvector/pgvector:pg16
container_name: qwen-agent-postgres
environment:
- POSTGRES_USER=postgres
@ -37,6 +37,7 @@ services:
volumes:
# 挂载项目数据目录
- ./projects:/app/projects
- ./models:/app/models
depends_on:
postgres:
condition: service_healthy

View File

@ -108,6 +108,16 @@ class GlobalModelManager:
logger.error(f"文本编码失败: {e}")
raise
def get_model_sync(self) -> Optional[SentenceTransformer]:
"""同步获取模型实例(供同步上下文使用)
如果模型未加载返回 None调用者应确保先通过异步方法初始化模型
Returns:
已加载的 SentenceTransformer 模型 None
"""
return self._model
def get_model_info(self) -> Dict[str, Any]:
"""获取模型信息"""
return {

16
poetry.lock generated
View File

@ -1773,21 +1773,21 @@ tiktoken = ">=0.7.0,<1.0.0"
[[package]]
name = "langgraph"
version = "1.0.4"
version = "1.0.6"
description = "Building stateful, multi-actor applications with LLMs"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "langgraph-1.0.4-py3-none-any.whl", hash = "sha256:b1a835ceb0a8d69b9db48075e1939e28b1ad70ee23fa3fa8f90149904778bacf"},
{file = "langgraph-1.0.4.tar.gz", hash = "sha256:86d08e25d7244340f59c5200fa69fdd11066aa999b3164b531e2a20036fac156"},
{file = "langgraph-1.0.6-py3-none-any.whl", hash = "sha256:bcfce190974519c72e29f6e5b17f0023914fd6f936bfab8894083215b271eb89"},
{file = "langgraph-1.0.6.tar.gz", hash = "sha256:dd8e754c76d34a07485308d7117221acf63990e7de8f46ddf5fe256b0a22e6c5"},
]
[package.dependencies]
langchain-core = ">=0.1"
langgraph-checkpoint = ">=2.1.0,<4.0.0"
langgraph-checkpoint = ">=2.1.0,<5.0.0"
langgraph-prebuilt = ">=1.0.2,<1.1.0"
langgraph-sdk = ">=0.2.2,<0.3.0"
langgraph-sdk = ">=0.3.0,<0.4.0"
pydantic = ">=2.7.4"
xxhash = ">=3.5.0"
@ -1843,14 +1843,14 @@ langgraph-checkpoint = ">=2.1.0,<4.0.0"
[[package]]
name = "langgraph-sdk"
version = "0.2.15"
version = "0.3.3"
description = "SDK for interacting with LangGraph API"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "langgraph_sdk-0.2.15-py3-none-any.whl", hash = "sha256:746566a5d89aa47160eccc17d71682a78771c754126f6c235a68353d61ed7462"},
{file = "langgraph_sdk-0.2.15.tar.gz", hash = "sha256:8faaafe2c1193b89f782dd66c591060cd67862aa6aaf283749b7846f331d5334"},
{file = "langgraph_sdk-0.3.3-py3-none-any.whl", hash = "sha256:a52ebaf09d91143e55378bb2d0b033ed98f57f48c9ad35c8f81493b88705fc7b"},
{file = "langgraph_sdk-0.3.3.tar.gz", hash = "sha256:c34c3dce3b6848755eb61f0c94369d1ba04aceeb1b76015db1ea7362c544fb26"},
]
[package.dependencies]

View File

@ -220,17 +220,15 @@
color: var(--text);
}
</style>
<script src="https://unpkg.com/lucide@latest/dist/umd/lucide.js"></script>
<!-- Fonts -->
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;500;600;700&family=Poppins:wght@400;500;600;700&display=swap" rel="stylesheet">
<!-- Highlight.js -->
<link rel="stylesheet" href="https://cdn.staticfile.net/highlight.js/11.9.0/styles/github.min.css">
<script src="https://cdn.staticfile.net/highlight.js/11.9.0/highlight.min.js"></script>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/github.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js"></script>
<!-- Marked.js -->
<script src="https://cdn.staticfile.net/marked/4.3.0/marked.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/marked@4.3.0/marked.min.js"></script>
<!-- Lucide Icons -->
<script src="https://unpkg.com/lucide@latest/dist/umd/lucide.js"></script>
@ -1911,6 +1909,13 @@
<label class="settings-label" for="user-identifier">用户标识</label>
<input type="text" id="user-identifier" class="settings-input" placeholder="输入用户标识...">
</div>
<div class="settings-group">
<div class="settings-checkbox-wrapper">
<input type="checkbox" id="enable-memori" class="settings-checkbox">
<label class="settings-label" for="enable-memori" style="margin-bottom: 0;">启用记忆存储</label>
</div>
<p style="font-size: 11px; color: var(--text-muted); margin-top: 4px;">启用后AI 会记住对话中的信息以提供更个性化的回复</p>
</div>
</div>
</div>
@ -2842,6 +2847,7 @@
'dataset-ids': document.getElementById('dataset-ids').value,
'system-prompt': document.getElementById('system-prompt').value,
'user-identifier': document.getElementById('user-identifier').value,
'enable-memori': document.getElementById('enable-memori').checked,
'skills': selectedSkills.join(','),
'mcp-settings': mcpSettingsValue,
'tool-response': document.getElementById('tool-response').checked
@ -3235,6 +3241,7 @@
systemPrompt: getValue('system-prompt'),
sessionId,
userIdentifier: getValue('user-identifier'),
enableMemori: getChecked('enable-memori'),
skills,
mcpSettings,
toolResponse: getChecked('tool-response')
@ -3492,6 +3499,7 @@
if (settings.systemPrompt) requestBody.system_prompt = settings.systemPrompt;
if (settings.sessionId) requestBody.session_id = settings.sessionId;
if (settings.userIdentifier) requestBody.user_identifier = settings.userIdentifier;
if (settings.enableMemori) requestBody.enable_memori = settings.enableMemori;
if (settings.skills?.length) requestBody.skills = settings.skills;
if (settings.datasetIds?.length) requestBody.dataset_ids = settings.datasetIds;
if (settings.mcpSettings?.length) requestBody.mcp_settings = settings.mcpSettings;

View File

@ -64,8 +64,8 @@ langchain==1.1.3 ; python_version >= "3.12" and python_version < "4.0"
langgraph-checkpoint-postgres==2.0.25 ; python_version >= "3.12" and python_version < "4.0"
langgraph-checkpoint==2.1.2 ; python_version >= "3.12" and python_version < "4.0"
langgraph-prebuilt==1.0.5 ; python_version >= "3.12" and python_version < "4.0"
langgraph-sdk==0.2.15 ; python_version >= "3.12" and python_version < "4.0"
langgraph==1.0.4 ; python_version >= "3.12" and python_version < "4.0"
langgraph-sdk==0.3.3 ; python_version >= "3.12" and python_version < "4.0"
langgraph==1.0.6 ; python_version >= "3.12" and python_version < "4.0"
langsmith==0.4.59 ; python_version >= "3.12" and python_version < "4.0"
markdown-it-py==4.0.0 ; python_version >= "3.12" and python_version < "4.0"
markdownify==1.2.2 ; python_version >= "3.12" and python_version < "4.0"

View File

@ -199,6 +199,73 @@ async def safe_extract_zip(zip_path: str, extract_dir: str) -> None:
raise HTTPException(status_code=400, detail=f"无效的 zip 文件: {str(e)}")
async def validate_and_rename_skill_folder(
extract_dir: str,
has_top_level_dirs: bool
) -> str:
"""验证并重命名解压后的 skill 文件夹
检查解压后文件夹名称是否与 SKILL.md 中的 name 匹配
如果不匹配则重命名文件夹
Args:
extract_dir: 解压目标目录
has_top_level_dirs: zip 是否包含顶级目录
Returns:
str: 最终的解压路径可能因为重命名而改变
"""
try:
if has_top_level_dirs:
# zip 包含目录,检查每个目录
for folder_name in os.listdir(extract_dir):
folder_path = os.path.join(extract_dir, folder_name)
if os.path.isdir(folder_path):
skill_md_path = os.path.join(folder_path, 'SKILL.md')
if os.path.exists(skill_md_path):
metadata = await asyncio.to_thread(
parse_skill_frontmatter, skill_md_path
)
if metadata and 'name' in metadata:
expected_name = metadata['name']
if folder_name != expected_name:
new_folder_path = os.path.join(extract_dir, expected_name)
await asyncio.to_thread(
shutil.move, folder_path, new_folder_path
)
logger.info(
f"Renamed skill folder: {folder_name} -> {expected_name}"
)
return extract_dir
else:
# zip 直接包含文件,检查当前目录的 SKILL.md
skill_md_path = os.path.join(extract_dir, 'SKILL.md')
if os.path.exists(skill_md_path):
metadata = await asyncio.to_thread(
parse_skill_frontmatter, skill_md_path
)
if metadata and 'name' in metadata:
expected_name = metadata['name']
# 获取当前文件夹名称
current_name = os.path.basename(extract_dir)
if current_name != expected_name:
parent_dir = os.path.dirname(extract_dir)
new_folder_path = os.path.join(parent_dir, expected_name)
await asyncio.to_thread(
shutil.move, extract_dir, new_folder_path
)
logger.info(
f"Renamed skill folder: {current_name} -> {expected_name}"
)
return new_folder_path
return extract_dir
except Exception as e:
logger.warning(f"Failed to validate/rename skill folder: {e}")
# 不抛出异常,允许上传继续
return extract_dir
async def save_upload_file_async(file: UploadFile, destination: str) -> None:
"""异步保存上传文件到目标路径"""
async with aiofiles.open(destination, 'wb') as f:
@ -453,13 +520,24 @@ async def upload_skill(file: UploadFile = File(...), bot_id: Optional[str] = For
await safe_extract_zip(file_path, extract_target)
logger.info(f"Extracted to: {extract_target}")
# 验证并重命名文件夹以匹配 SKILL.md 中的 name
final_extract_path = await validate_and_rename_skill_folder(
extract_target, has_top_level_dirs
)
# 获取最终的 skill 名称
if has_top_level_dirs:
final_skill_name = folder_name
else:
final_skill_name = os.path.basename(final_extract_path)
return {
"success": True,
"message": f"Skill文件上传并解压成功",
"file_path": file_path,
"extract_path": extract_target,
"extract_path": final_extract_path,
"original_filename": original_filename,
"skill_name": folder_name
"skill_name": final_skill_name
}
except HTTPException:

View File

@ -54,6 +54,7 @@ CHECKPOINT_DB_URL = os.getenv("CHECKPOINT_DB_URL", "postgresql://moshui:@localho
# 连接池大小
# 同时可以持有的最大连接数
CHECKPOINT_POOL_SIZE = int(os.getenv("CHECKPOINT_POOL_SIZE", "20"))
MEM0_POOL_SIZE = int(os.getenv("MEM0_POOL_SIZE", "20"))
# Checkpoint 自动清理配置
# 是否启用自动清理旧 session