From 3dc119bca89fc9c5592f0f45d93a75033d51f1d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Thu, 22 Jan 2026 19:39:12 +0800 Subject: [PATCH] refactor(mem0): optimize connection pool and async memory handling - Fix mem0 connection pool exhausted error with proper pooling - Convert memory operations to async tasks - Optimize docker-compose configuration - Add skill upload functionality - Reduce cache size for better performance - Update dependencies Co-Authored-By: Claude Opus 4.5 --- agent/agent_config.py | 9 +--- agent/db_pool_manager.py | 3 +- agent/mem0_manager.py | 70 ++++++++++++++++++++++++++---- agent/mem0_middleware.py | 68 ++++++++++++++++++++++++++--- docker-compose-with-pgsql.yml | 3 +- embedding/manager.py | 10 +++++ poetry.lock | 16 +++---- public/index.html | 18 +++++--- requirements.txt | 4 +- routes/skill_manager.py | 82 ++++++++++++++++++++++++++++++++++- utils/settings.py | 1 + 11 files changed, 243 insertions(+), 41 deletions(-) diff --git a/agent/agent_config.py b/agent/agent_config.py index badb285..16c4374 100644 --- a/agent/agent_config.py +++ b/agent/agent_config.py @@ -105,9 +105,7 @@ class AgentConfig: enable_thinking = request.enable_thinking and "" in request.system_prompt # 从请求中获取 Mem0 配置,如果没有则使用全局配置 - enable_memori = getattr(request, 'enable_memori', None) - if enable_memori is None: - enable_memori = MEM0_ENABLED + enable_memori = getattr(request, 'enable_memori', MEM0_ENABLED) config = cls( bot_id=request.bot_id, @@ -171,10 +169,7 @@ class AgentConfig: enable_thinking = request.enable_thinking and "" in bot_config.get("system_prompt") # 从请求或后端配置中获取 Mem0 配置 - enable_memori = getattr(request, 'enable_memori', None) - if enable_memori is None: - enable_memori = bot_config.get("enable_memori", MEM0_ENABLED) - + enable_memori = getattr(request, 'enable_memori', MEM0_ENABLED) config = cls( bot_id=request.bot_id, diff --git a/agent/db_pool_manager.py b/agent/db_pool_manager.py index 058b382..769a5d3 100644 --- a/agent/db_pool_manager.py +++ b/agent/db_pool_manager.py @@ -12,6 +12,7 @@ from psycopg2 import pool as psycopg2_pool from utils.settings import ( CHECKPOINT_DB_URL, CHECKPOINT_POOL_SIZE, + MEM0_POOL_SIZE, CHECKPOINT_CLEANUP_ENABLED, CHECKPOINT_CLEANUP_INTERVAL_HOURS, CHECKPOINT_CLEANUP_INACTIVE_DAYS, @@ -62,7 +63,7 @@ class DBPoolManager: await self._pool.open() # 2. 创建同步 psycopg2 连接池(供 Mem0 使用) - self._sync_pool = self._create_sync_pool(CHECKPOINT_DB_URL, CHECKPOINT_POOL_SIZE) + self._sync_pool = self._create_sync_pool(CHECKPOINT_DB_URL, MEM0_POOL_SIZE) self._initialized = True logger.info("PostgreSQL connection pool initialized successfully") diff --git a/agent/mem0_manager.py b/agent/mem0_manager.py index 96a020a..4fcc3b2 100644 --- a/agent/mem0_manager.py +++ b/agent/mem0_manager.py @@ -5,11 +5,16 @@ Mem0 连接和实例管理器 import logging import asyncio +import threading +import concurrent.futures from typing import Any, Dict, List, Optional, Literal +from collections import OrderedDict from embedding.manager import GlobalModelManager, get_model_manager import json_repair from psycopg2 import pool - +from utils.settings import ( + MEM0_POOL_SIZE +) from .mem0_config import Mem0Config logger = logging.getLogger("app") @@ -27,7 +32,9 @@ class CustomMem0Embedding: 这样 Mem0 就不需要再次加载同一个模型,节省内存 """ - _model_manager = None # 缓存 GlobalModelManager 实例 + _model = None # 类变量,缓存模型实例 + _lock = threading.Lock() # 线程安全锁 + _executor = None # 线程池执行器 def __init__(self, config: Optional[Any] = None): """初始化自定义 Embedding""" @@ -41,6 +48,41 @@ class CustomMem0Embedding: """获取 embedding 维度""" return 384 # gte-tiny 的维度 + def _get_model_sync(self): + """同步获取模型,避免 asyncio.run()""" + # 首先尝试从 manager 获取已加载的模型 + manager = get_model_manager() + model = manager.get_model_sync() + + if model is not None: + # 缓存模型 + CustomMem0Embedding._model = model + return model + + # 如果模型未加载,使用线程池运行异步初始化 + if CustomMem0Embedding._executor is None: + CustomMem0Embedding._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, + thread_name_prefix="mem0_embed" + ) + + # 在独立线程中运行异步代码 + def run_async_in_thread(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + result = loop.run_until_complete(manager.get_model()) + return result + finally: + loop.close() + + future = CustomMem0Embedding._executor.submit(run_async_in_thread) + model = future.result(timeout=30) # 30秒超时 + + # 缓存模型 + CustomMem0Embedding._model = model + return model + def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None): """ 获取文本的 embedding 向量(同步方法,供 Mem0 调用) @@ -52,8 +94,13 @@ class CustomMem0Embedding: Returns: list: embedding 向量 """ - manager = get_model_manager() - model = asyncio.run(manager.get_model()) + # 线程安全地获取模型 + if CustomMem0Embedding._model is None: + with CustomMem0Embedding._lock: + if CustomMem0Embedding._model is None: + self._get_model_sync() + + model = CustomMem0Embedding._model embeddings = model.encode(text, convert_to_numpy=True) return embeddings.tolist() @@ -136,8 +183,9 @@ class Mem0Manager: """ self._sync_pool = sync_pool - # 缓存 Mem0 实例: key = f"{user_id}:{agent_id}" - self._instances: Dict[str, Any] = {} + # 使用 OrderedDict 实现 LRU 缓存,最多保留 50 个实例 + self._instances: OrderedDict[str, Any] = OrderedDict() + self._max_instances = MEM0_POOL_SIZE/2 # 最大缓存实例数 self._initialized = False async def initialize(self) -> None: @@ -194,10 +242,16 @@ class Mem0Manager: llm_suffix = f":{id(config.llm_instance)}" cache_key = f"{user_id}:{agent_id}{llm_suffix}" - # 检查缓存 + # 检查缓存(同时移动到末尾表示最近使用) if cache_key in self._instances: + self._instances.move_to_end(cache_key) return self._instances[cache_key] + # 检查缓存大小,超过则移除最旧的 + if len(self._instances) >= self._max_instances: + removed_key, _ = self._instances.popitem(last=False) + logger.debug(f"Mem0 instance cache full, removed oldest entry: {removed_key}") + # 创建新实例 mem0_instance = await self._create_mem0_instance( user_id=user_id, @@ -206,7 +260,7 @@ class Mem0Manager: config=config, ) - # 缓存实例 + # 缓存实例(新实例自动在末尾) self._instances[cache_key] = mem0_instance return mem0_instance diff --git a/agent/mem0_middleware.py b/agent/mem0_middleware.py index 8b50cbd..4cf8d10 100644 --- a/agent/mem0_middleware.py +++ b/agent/mem0_middleware.py @@ -3,7 +3,9 @@ Mem0 Agent 中间件 实现记忆召回和存储的 AgentMiddleware """ +import asyncio import logging +import threading from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from langchain.agents.middleware import AgentMiddleware, AgentState, ModelRequest @@ -123,8 +125,6 @@ class Mem0Middleware(AgentMiddleware): return None try: - import asyncio - # 提取用户查询 query = self._extract_user_query(state) if not query: @@ -216,6 +216,8 @@ class Mem0Middleware(AgentMiddleware): def after_agent(self, state: AgentState, runtime: Runtime) -> None: """Agent 执行后:触发记忆增强(同步版本) + 使用后台线程执行,避免阻塞主流程 + Args: state: Agent 状态 runtime: 运行时上下文 @@ -224,16 +226,21 @@ class Mem0Middleware(AgentMiddleware): return try: - import asyncio - - # 触发后台增强任务 - asyncio.create_task(self._trigger_augmentation_async(state, runtime)) + # 在后台线程中执行,完全不阻塞主流程 + thread = threading.Thread( + target=self._trigger_augmentation_sync, + args=(state, runtime), + daemon=True, + ) + thread.start() except Exception as e: logger.error(f"Error in Mem0Middleware.after_agent: {e}") async def aafter_agent(self, state: AgentState, runtime: Runtime) -> None: """Agent 执行后:触发记忆增强(异步版本) + 使用后台线程执行,避免阻塞事件循环 + Args: state: Agent 状态 runtime: 运行时上下文 @@ -242,10 +249,57 @@ class Mem0Middleware(AgentMiddleware): return try: - await self._trigger_augmentation_async(state, runtime) + # 在后台线程中执行,完全不阻塞事件循环 + thread = threading.Thread( + target=self._trigger_augmentation_sync, + args=(state, runtime), + daemon=True, + ) + thread.start() except Exception as e: logger.error(f"Error in Mem0Middleware.aafter_agent: {e}") + def _trigger_augmentation_sync(self, state: AgentState, runtime: Runtime) -> None: + """触发记忆增强任务(同步版本,在线程中执行) + + 从对话中提取信息并存储到 Mem0(用户级别,跨会话) + + Args: + state: Agent 状态 + runtime: 运行时上下文 + """ + try: + # 获取 attribution 参数 + user_id, agent_id = self.config.get_attribution_tuple() + + # 提取用户查询和 Agent 响应 + user_query = self._extract_user_query(state) + agent_response = self._extract_agent_response(state) + + # 将对话作为记忆存储(用户级别) + if user_query and agent_response: + conversation_text = f"User: {user_query}\nAssistant: {agent_response}" + + # 在新的事件循环中运行异步代码(因为在线程中) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + self.mem0_manager.add_memory( + text=conversation_text, + user_id=user_id, + agent_id=agent_id, + metadata={"type": "conversation"}, + config=self.config, + ) + ) + logger.debug(f"Stored conversation as memory for user={user_id}, agent={agent_id}") + finally: + loop.close() + + except Exception as e: + logger.error(f"Error in _trigger_augmentation_sync: {e}") + async def _trigger_augmentation_async(self, state: AgentState, runtime: Runtime) -> None: """触发记忆增强任务 diff --git a/docker-compose-with-pgsql.yml b/docker-compose-with-pgsql.yml index f15ac75..2c9ca52 100644 --- a/docker-compose-with-pgsql.yml +++ b/docker-compose-with-pgsql.yml @@ -2,7 +2,7 @@ version: "3.8" services: postgres: - image: postgres:16-alpine + image: pgvector/pgvector:pg16 container_name: qwen-agent-postgres environment: - POSTGRES_USER=postgres @@ -37,6 +37,7 @@ services: volumes: # 挂载项目数据目录 - ./projects:/app/projects + - ./models:/app/models depends_on: postgres: condition: service_healthy diff --git a/embedding/manager.py b/embedding/manager.py index 807a57b..aab5ffe 100644 --- a/embedding/manager.py +++ b/embedding/manager.py @@ -108,6 +108,16 @@ class GlobalModelManager: logger.error(f"文本编码失败: {e}") raise + def get_model_sync(self) -> Optional[SentenceTransformer]: + """同步获取模型实例(供同步上下文使用) + + 如果模型未加载,返回 None。调用者应确保先通过异步方法初始化模型。 + + Returns: + 已加载的 SentenceTransformer 模型,或 None + """ + return self._model + def get_model_info(self) -> Dict[str, Any]: """获取模型信息""" return { diff --git a/poetry.lock b/poetry.lock index 6e1c2db..d0ad387 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1773,21 +1773,21 @@ tiktoken = ">=0.7.0,<1.0.0" [[package]] name = "langgraph" -version = "1.0.4" +version = "1.0.6" description = "Building stateful, multi-actor applications with LLMs" optional = false python-versions = ">=3.10" groups = ["main"] files = [ - {file = "langgraph-1.0.4-py3-none-any.whl", hash = "sha256:b1a835ceb0a8d69b9db48075e1939e28b1ad70ee23fa3fa8f90149904778bacf"}, - {file = "langgraph-1.0.4.tar.gz", hash = "sha256:86d08e25d7244340f59c5200fa69fdd11066aa999b3164b531e2a20036fac156"}, + {file = "langgraph-1.0.6-py3-none-any.whl", hash = "sha256:bcfce190974519c72e29f6e5b17f0023914fd6f936bfab8894083215b271eb89"}, + {file = "langgraph-1.0.6.tar.gz", hash = "sha256:dd8e754c76d34a07485308d7117221acf63990e7de8f46ddf5fe256b0a22e6c5"}, ] [package.dependencies] langchain-core = ">=0.1" -langgraph-checkpoint = ">=2.1.0,<4.0.0" +langgraph-checkpoint = ">=2.1.0,<5.0.0" langgraph-prebuilt = ">=1.0.2,<1.1.0" -langgraph-sdk = ">=0.2.2,<0.3.0" +langgraph-sdk = ">=0.3.0,<0.4.0" pydantic = ">=2.7.4" xxhash = ">=3.5.0" @@ -1843,14 +1843,14 @@ langgraph-checkpoint = ">=2.1.0,<4.0.0" [[package]] name = "langgraph-sdk" -version = "0.2.15" +version = "0.3.3" description = "SDK for interacting with LangGraph API" optional = false python-versions = ">=3.10" groups = ["main"] files = [ - {file = "langgraph_sdk-0.2.15-py3-none-any.whl", hash = "sha256:746566a5d89aa47160eccc17d71682a78771c754126f6c235a68353d61ed7462"}, - {file = "langgraph_sdk-0.2.15.tar.gz", hash = "sha256:8faaafe2c1193b89f782dd66c591060cd67862aa6aaf283749b7846f331d5334"}, + {file = "langgraph_sdk-0.3.3-py3-none-any.whl", hash = "sha256:a52ebaf09d91143e55378bb2d0b033ed98f57f48c9ad35c8f81493b88705fc7b"}, + {file = "langgraph_sdk-0.3.3.tar.gz", hash = "sha256:c34c3dce3b6848755eb61f0c94369d1ba04aceeb1b76015db1ea7362c544fb26"}, ] [package.dependencies] diff --git a/public/index.html b/public/index.html index 0a010e1..9ec1a85 100644 --- a/public/index.html +++ b/public/index.html @@ -220,17 +220,15 @@ color: var(--text); } - - - - + + - + @@ -1911,6 +1909,13 @@ +
+
+ + +
+

启用后,AI 会记住对话中的信息以提供更个性化的回复

+
@@ -2842,6 +2847,7 @@ 'dataset-ids': document.getElementById('dataset-ids').value, 'system-prompt': document.getElementById('system-prompt').value, 'user-identifier': document.getElementById('user-identifier').value, + 'enable-memori': document.getElementById('enable-memori').checked, 'skills': selectedSkills.join(','), 'mcp-settings': mcpSettingsValue, 'tool-response': document.getElementById('tool-response').checked @@ -3235,6 +3241,7 @@ systemPrompt: getValue('system-prompt'), sessionId, userIdentifier: getValue('user-identifier'), + enableMemori: getChecked('enable-memori'), skills, mcpSettings, toolResponse: getChecked('tool-response') @@ -3492,6 +3499,7 @@ if (settings.systemPrompt) requestBody.system_prompt = settings.systemPrompt; if (settings.sessionId) requestBody.session_id = settings.sessionId; if (settings.userIdentifier) requestBody.user_identifier = settings.userIdentifier; + if (settings.enableMemori) requestBody.enable_memori = settings.enableMemori; if (settings.skills?.length) requestBody.skills = settings.skills; if (settings.datasetIds?.length) requestBody.dataset_ids = settings.datasetIds; if (settings.mcpSettings?.length) requestBody.mcp_settings = settings.mcpSettings; diff --git a/requirements.txt b/requirements.txt index 7b2669c..3fe136c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -64,8 +64,8 @@ langchain==1.1.3 ; python_version >= "3.12" and python_version < "4.0" langgraph-checkpoint-postgres==2.0.25 ; python_version >= "3.12" and python_version < "4.0" langgraph-checkpoint==2.1.2 ; python_version >= "3.12" and python_version < "4.0" langgraph-prebuilt==1.0.5 ; python_version >= "3.12" and python_version < "4.0" -langgraph-sdk==0.2.15 ; python_version >= "3.12" and python_version < "4.0" -langgraph==1.0.4 ; python_version >= "3.12" and python_version < "4.0" +langgraph-sdk==0.3.3 ; python_version >= "3.12" and python_version < "4.0" +langgraph==1.0.6 ; python_version >= "3.12" and python_version < "4.0" langsmith==0.4.59 ; python_version >= "3.12" and python_version < "4.0" markdown-it-py==4.0.0 ; python_version >= "3.12" and python_version < "4.0" markdownify==1.2.2 ; python_version >= "3.12" and python_version < "4.0" diff --git a/routes/skill_manager.py b/routes/skill_manager.py index 938db8d..2db3446 100644 --- a/routes/skill_manager.py +++ b/routes/skill_manager.py @@ -199,6 +199,73 @@ async def safe_extract_zip(zip_path: str, extract_dir: str) -> None: raise HTTPException(status_code=400, detail=f"无效的 zip 文件: {str(e)}") +async def validate_and_rename_skill_folder( + extract_dir: str, + has_top_level_dirs: bool +) -> str: + """验证并重命名解压后的 skill 文件夹 + + 检查解压后文件夹名称是否与 SKILL.md 中的 name 匹配, + 如果不匹配则重命名文件夹。 + + Args: + extract_dir: 解压目标目录 + has_top_level_dirs: zip 是否包含顶级目录 + + Returns: + str: 最终的解压路径(可能因为重命名而改变) + """ + try: + if has_top_level_dirs: + # zip 包含目录,检查每个目录 + for folder_name in os.listdir(extract_dir): + folder_path = os.path.join(extract_dir, folder_name) + if os.path.isdir(folder_path): + skill_md_path = os.path.join(folder_path, 'SKILL.md') + if os.path.exists(skill_md_path): + metadata = await asyncio.to_thread( + parse_skill_frontmatter, skill_md_path + ) + if metadata and 'name' in metadata: + expected_name = metadata['name'] + if folder_name != expected_name: + new_folder_path = os.path.join(extract_dir, expected_name) + await asyncio.to_thread( + shutil.move, folder_path, new_folder_path + ) + logger.info( + f"Renamed skill folder: {folder_name} -> {expected_name}" + ) + return extract_dir + else: + # zip 直接包含文件,检查当前目录的 SKILL.md + skill_md_path = os.path.join(extract_dir, 'SKILL.md') + if os.path.exists(skill_md_path): + metadata = await asyncio.to_thread( + parse_skill_frontmatter, skill_md_path + ) + if metadata and 'name' in metadata: + expected_name = metadata['name'] + # 获取当前文件夹名称 + current_name = os.path.basename(extract_dir) + if current_name != expected_name: + parent_dir = os.path.dirname(extract_dir) + new_folder_path = os.path.join(parent_dir, expected_name) + await asyncio.to_thread( + shutil.move, extract_dir, new_folder_path + ) + logger.info( + f"Renamed skill folder: {current_name} -> {expected_name}" + ) + return new_folder_path + return extract_dir + + except Exception as e: + logger.warning(f"Failed to validate/rename skill folder: {e}") + # 不抛出异常,允许上传继续 + return extract_dir + + async def save_upload_file_async(file: UploadFile, destination: str) -> None: """异步保存上传文件到目标路径""" async with aiofiles.open(destination, 'wb') as f: @@ -453,13 +520,24 @@ async def upload_skill(file: UploadFile = File(...), bot_id: Optional[str] = For await safe_extract_zip(file_path, extract_target) logger.info(f"Extracted to: {extract_target}") + # 验证并重命名文件夹以匹配 SKILL.md 中的 name + final_extract_path = await validate_and_rename_skill_folder( + extract_target, has_top_level_dirs + ) + + # 获取最终的 skill 名称 + if has_top_level_dirs: + final_skill_name = folder_name + else: + final_skill_name = os.path.basename(final_extract_path) + return { "success": True, "message": f"Skill文件上传并解压成功", "file_path": file_path, - "extract_path": extract_target, + "extract_path": final_extract_path, "original_filename": original_filename, - "skill_name": folder_name + "skill_name": final_skill_name } except HTTPException: diff --git a/utils/settings.py b/utils/settings.py index c7eed3a..920d275 100644 --- a/utils/settings.py +++ b/utils/settings.py @@ -54,6 +54,7 @@ CHECKPOINT_DB_URL = os.getenv("CHECKPOINT_DB_URL", "postgresql://moshui:@localho # 连接池大小 # 同时可以持有的最大连接数 CHECKPOINT_POOL_SIZE = int(os.getenv("CHECKPOINT_POOL_SIZE", "20")) +MEM0_POOL_SIZE = int(os.getenv("MEM0_POOL_SIZE", "20")) # Checkpoint 自动清理配置 # 是否启用自动清理旧 session