refactor(mem0): optimize connection pool and async memory handling

- Fix mem0 connection pool exhausted error with proper pooling
- Convert memory operations to async tasks
- Optimize docker-compose configuration
- Add skill upload functionality
- Reduce cache size for better performance
- Update dependencies

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
朱潮 2026-01-22 19:39:12 +08:00
parent f29fd1fb54
commit 3dc119bca8
11 changed files with 243 additions and 41 deletions

View File

@ -105,9 +105,7 @@ class AgentConfig:
enable_thinking = request.enable_thinking and "<guidelines>" in request.system_prompt enable_thinking = request.enable_thinking and "<guidelines>" in request.system_prompt
# 从请求中获取 Mem0 配置,如果没有则使用全局配置 # 从请求中获取 Mem0 配置,如果没有则使用全局配置
enable_memori = getattr(request, 'enable_memori', None) enable_memori = getattr(request, 'enable_memori', MEM0_ENABLED)
if enable_memori is None:
enable_memori = MEM0_ENABLED
config = cls( config = cls(
bot_id=request.bot_id, bot_id=request.bot_id,
@ -171,10 +169,7 @@ class AgentConfig:
enable_thinking = request.enable_thinking and "<guidelines>" in bot_config.get("system_prompt") enable_thinking = request.enable_thinking and "<guidelines>" in bot_config.get("system_prompt")
# 从请求或后端配置中获取 Mem0 配置 # 从请求或后端配置中获取 Mem0 配置
enable_memori = getattr(request, 'enable_memori', None) enable_memori = getattr(request, 'enable_memori', MEM0_ENABLED)
if enable_memori is None:
enable_memori = bot_config.get("enable_memori", MEM0_ENABLED)
config = cls( config = cls(
bot_id=request.bot_id, bot_id=request.bot_id,

View File

@ -12,6 +12,7 @@ from psycopg2 import pool as psycopg2_pool
from utils.settings import ( from utils.settings import (
CHECKPOINT_DB_URL, CHECKPOINT_DB_URL,
CHECKPOINT_POOL_SIZE, CHECKPOINT_POOL_SIZE,
MEM0_POOL_SIZE,
CHECKPOINT_CLEANUP_ENABLED, CHECKPOINT_CLEANUP_ENABLED,
CHECKPOINT_CLEANUP_INTERVAL_HOURS, CHECKPOINT_CLEANUP_INTERVAL_HOURS,
CHECKPOINT_CLEANUP_INACTIVE_DAYS, CHECKPOINT_CLEANUP_INACTIVE_DAYS,
@ -62,7 +63,7 @@ class DBPoolManager:
await self._pool.open() await self._pool.open()
# 2. 创建同步 psycopg2 连接池(供 Mem0 使用) # 2. 创建同步 psycopg2 连接池(供 Mem0 使用)
self._sync_pool = self._create_sync_pool(CHECKPOINT_DB_URL, CHECKPOINT_POOL_SIZE) self._sync_pool = self._create_sync_pool(CHECKPOINT_DB_URL, MEM0_POOL_SIZE)
self._initialized = True self._initialized = True
logger.info("PostgreSQL connection pool initialized successfully") logger.info("PostgreSQL connection pool initialized successfully")

View File

@ -5,11 +5,16 @@ Mem0 连接和实例管理器
import logging import logging
import asyncio import asyncio
import threading
import concurrent.futures
from typing import Any, Dict, List, Optional, Literal from typing import Any, Dict, List, Optional, Literal
from collections import OrderedDict
from embedding.manager import GlobalModelManager, get_model_manager from embedding.manager import GlobalModelManager, get_model_manager
import json_repair import json_repair
from psycopg2 import pool from psycopg2 import pool
from utils.settings import (
MEM0_POOL_SIZE
)
from .mem0_config import Mem0Config from .mem0_config import Mem0Config
logger = logging.getLogger("app") logger = logging.getLogger("app")
@ -27,7 +32,9 @@ class CustomMem0Embedding:
这样 Mem0 就不需要再次加载同一个模型节省内存 这样 Mem0 就不需要再次加载同一个模型节省内存
""" """
_model_manager = None # 缓存 GlobalModelManager 实例 _model = None # 类变量,缓存模型实例
_lock = threading.Lock() # 线程安全锁
_executor = None # 线程池执行器
def __init__(self, config: Optional[Any] = None): def __init__(self, config: Optional[Any] = None):
"""初始化自定义 Embedding""" """初始化自定义 Embedding"""
@ -41,6 +48,41 @@ class CustomMem0Embedding:
"""获取 embedding 维度""" """获取 embedding 维度"""
return 384 # gte-tiny 的维度 return 384 # gte-tiny 的维度
def _get_model_sync(self):
"""同步获取模型,避免 asyncio.run()"""
# 首先尝试从 manager 获取已加载的模型
manager = get_model_manager()
model = manager.get_model_sync()
if model is not None:
# 缓存模型
CustomMem0Embedding._model = model
return model
# 如果模型未加载,使用线程池运行异步初始化
if CustomMem0Embedding._executor is None:
CustomMem0Embedding._executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1,
thread_name_prefix="mem0_embed"
)
# 在独立线程中运行异步代码
def run_async_in_thread():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(manager.get_model())
return result
finally:
loop.close()
future = CustomMem0Embedding._executor.submit(run_async_in_thread)
model = future.result(timeout=30) # 30秒超时
# 缓存模型
CustomMem0Embedding._model = model
return model
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None): def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
""" """
获取文本的 embedding 向量同步方法 Mem0 调用 获取文本的 embedding 向量同步方法 Mem0 调用
@ -52,8 +94,13 @@ class CustomMem0Embedding:
Returns: Returns:
list: embedding 向量 list: embedding 向量
""" """
manager = get_model_manager() # 线程安全地获取模型
model = asyncio.run(manager.get_model()) if CustomMem0Embedding._model is None:
with CustomMem0Embedding._lock:
if CustomMem0Embedding._model is None:
self._get_model_sync()
model = CustomMem0Embedding._model
embeddings = model.encode(text, convert_to_numpy=True) embeddings = model.encode(text, convert_to_numpy=True)
return embeddings.tolist() return embeddings.tolist()
@ -136,8 +183,9 @@ class Mem0Manager:
""" """
self._sync_pool = sync_pool self._sync_pool = sync_pool
# 缓存 Mem0 实例: key = f"{user_id}:{agent_id}" # 使用 OrderedDict 实现 LRU 缓存,最多保留 50 个实例
self._instances: Dict[str, Any] = {} self._instances: OrderedDict[str, Any] = OrderedDict()
self._max_instances = MEM0_POOL_SIZE/2 # 最大缓存实例数
self._initialized = False self._initialized = False
async def initialize(self) -> None: async def initialize(self) -> None:
@ -194,10 +242,16 @@ class Mem0Manager:
llm_suffix = f":{id(config.llm_instance)}" llm_suffix = f":{id(config.llm_instance)}"
cache_key = f"{user_id}:{agent_id}{llm_suffix}" cache_key = f"{user_id}:{agent_id}{llm_suffix}"
# 检查缓存 # 检查缓存(同时移动到末尾表示最近使用)
if cache_key in self._instances: if cache_key in self._instances:
self._instances.move_to_end(cache_key)
return self._instances[cache_key] return self._instances[cache_key]
# 检查缓存大小,超过则移除最旧的
if len(self._instances) >= self._max_instances:
removed_key, _ = self._instances.popitem(last=False)
logger.debug(f"Mem0 instance cache full, removed oldest entry: {removed_key}")
# 创建新实例 # 创建新实例
mem0_instance = await self._create_mem0_instance( mem0_instance = await self._create_mem0_instance(
user_id=user_id, user_id=user_id,
@ -206,7 +260,7 @@ class Mem0Manager:
config=config, config=config,
) )
# 缓存实例 # 缓存实例(新实例自动在末尾)
self._instances[cache_key] = mem0_instance self._instances[cache_key] = mem0_instance
return mem0_instance return mem0_instance

View File

@ -3,7 +3,9 @@ Mem0 Agent 中间件
实现记忆召回和存储的 AgentMiddleware 实现记忆召回和存储的 AgentMiddleware
""" """
import asyncio
import logging import logging
import threading
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from langchain.agents.middleware import AgentMiddleware, AgentState, ModelRequest from langchain.agents.middleware import AgentMiddleware, AgentState, ModelRequest
@ -123,8 +125,6 @@ class Mem0Middleware(AgentMiddleware):
return None return None
try: try:
import asyncio
# 提取用户查询 # 提取用户查询
query = self._extract_user_query(state) query = self._extract_user_query(state)
if not query: if not query:
@ -216,6 +216,8 @@ class Mem0Middleware(AgentMiddleware):
def after_agent(self, state: AgentState, runtime: Runtime) -> None: def after_agent(self, state: AgentState, runtime: Runtime) -> None:
"""Agent 执行后:触发记忆增强(同步版本) """Agent 执行后:触发记忆增强(同步版本)
使用后台线程执行避免阻塞主流程
Args: Args:
state: Agent 状态 state: Agent 状态
runtime: 运行时上下文 runtime: 运行时上下文
@ -224,16 +226,21 @@ class Mem0Middleware(AgentMiddleware):
return return
try: try:
import asyncio # 在后台线程中执行,完全不阻塞主流程
thread = threading.Thread(
# 触发后台增强任务 target=self._trigger_augmentation_sync,
asyncio.create_task(self._trigger_augmentation_async(state, runtime)) args=(state, runtime),
daemon=True,
)
thread.start()
except Exception as e: except Exception as e:
logger.error(f"Error in Mem0Middleware.after_agent: {e}") logger.error(f"Error in Mem0Middleware.after_agent: {e}")
async def aafter_agent(self, state: AgentState, runtime: Runtime) -> None: async def aafter_agent(self, state: AgentState, runtime: Runtime) -> None:
"""Agent 执行后:触发记忆增强(异步版本) """Agent 执行后:触发记忆增强(异步版本)
使用后台线程执行避免阻塞事件循环
Args: Args:
state: Agent 状态 state: Agent 状态
runtime: 运行时上下文 runtime: 运行时上下文
@ -242,10 +249,57 @@ class Mem0Middleware(AgentMiddleware):
return return
try: try:
await self._trigger_augmentation_async(state, runtime) # 在后台线程中执行,完全不阻塞事件循环
thread = threading.Thread(
target=self._trigger_augmentation_sync,
args=(state, runtime),
daemon=True,
)
thread.start()
except Exception as e: except Exception as e:
logger.error(f"Error in Mem0Middleware.aafter_agent: {e}") logger.error(f"Error in Mem0Middleware.aafter_agent: {e}")
def _trigger_augmentation_sync(self, state: AgentState, runtime: Runtime) -> None:
"""触发记忆增强任务(同步版本,在线程中执行)
从对话中提取信息并存储到 Mem0用户级别跨会话
Args:
state: Agent 状态
runtime: 运行时上下文
"""
try:
# 获取 attribution 参数
user_id, agent_id = self.config.get_attribution_tuple()
# 提取用户查询和 Agent 响应
user_query = self._extract_user_query(state)
agent_response = self._extract_agent_response(state)
# 将对话作为记忆存储(用户级别)
if user_query and agent_response:
conversation_text = f"User: {user_query}\nAssistant: {agent_response}"
# 在新的事件循环中运行异步代码(因为在线程中)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
self.mem0_manager.add_memory(
text=conversation_text,
user_id=user_id,
agent_id=agent_id,
metadata={"type": "conversation"},
config=self.config,
)
)
logger.debug(f"Stored conversation as memory for user={user_id}, agent={agent_id}")
finally:
loop.close()
except Exception as e:
logger.error(f"Error in _trigger_augmentation_sync: {e}")
async def _trigger_augmentation_async(self, state: AgentState, runtime: Runtime) -> None: async def _trigger_augmentation_async(self, state: AgentState, runtime: Runtime) -> None:
"""触发记忆增强任务 """触发记忆增强任务

View File

@ -2,7 +2,7 @@ version: "3.8"
services: services:
postgres: postgres:
image: postgres:16-alpine image: pgvector/pgvector:pg16
container_name: qwen-agent-postgres container_name: qwen-agent-postgres
environment: environment:
- POSTGRES_USER=postgres - POSTGRES_USER=postgres
@ -37,6 +37,7 @@ services:
volumes: volumes:
# 挂载项目数据目录 # 挂载项目数据目录
- ./projects:/app/projects - ./projects:/app/projects
- ./models:/app/models
depends_on: depends_on:
postgres: postgres:
condition: service_healthy condition: service_healthy

View File

@ -108,6 +108,16 @@ class GlobalModelManager:
logger.error(f"文本编码失败: {e}") logger.error(f"文本编码失败: {e}")
raise raise
def get_model_sync(self) -> Optional[SentenceTransformer]:
"""同步获取模型实例(供同步上下文使用)
如果模型未加载返回 None调用者应确保先通过异步方法初始化模型
Returns:
已加载的 SentenceTransformer 模型 None
"""
return self._model
def get_model_info(self) -> Dict[str, Any]: def get_model_info(self) -> Dict[str, Any]:
"""获取模型信息""" """获取模型信息"""
return { return {

16
poetry.lock generated
View File

@ -1773,21 +1773,21 @@ tiktoken = ">=0.7.0,<1.0.0"
[[package]] [[package]]
name = "langgraph" name = "langgraph"
version = "1.0.4" version = "1.0.6"
description = "Building stateful, multi-actor applications with LLMs" description = "Building stateful, multi-actor applications with LLMs"
optional = false optional = false
python-versions = ">=3.10" python-versions = ">=3.10"
groups = ["main"] groups = ["main"]
files = [ files = [
{file = "langgraph-1.0.4-py3-none-any.whl", hash = "sha256:b1a835ceb0a8d69b9db48075e1939e28b1ad70ee23fa3fa8f90149904778bacf"}, {file = "langgraph-1.0.6-py3-none-any.whl", hash = "sha256:bcfce190974519c72e29f6e5b17f0023914fd6f936bfab8894083215b271eb89"},
{file = "langgraph-1.0.4.tar.gz", hash = "sha256:86d08e25d7244340f59c5200fa69fdd11066aa999b3164b531e2a20036fac156"}, {file = "langgraph-1.0.6.tar.gz", hash = "sha256:dd8e754c76d34a07485308d7117221acf63990e7de8f46ddf5fe256b0a22e6c5"},
] ]
[package.dependencies] [package.dependencies]
langchain-core = ">=0.1" langchain-core = ">=0.1"
langgraph-checkpoint = ">=2.1.0,<4.0.0" langgraph-checkpoint = ">=2.1.0,<5.0.0"
langgraph-prebuilt = ">=1.0.2,<1.1.0" langgraph-prebuilt = ">=1.0.2,<1.1.0"
langgraph-sdk = ">=0.2.2,<0.3.0" langgraph-sdk = ">=0.3.0,<0.4.0"
pydantic = ">=2.7.4" pydantic = ">=2.7.4"
xxhash = ">=3.5.0" xxhash = ">=3.5.0"
@ -1843,14 +1843,14 @@ langgraph-checkpoint = ">=2.1.0,<4.0.0"
[[package]] [[package]]
name = "langgraph-sdk" name = "langgraph-sdk"
version = "0.2.15" version = "0.3.3"
description = "SDK for interacting with LangGraph API" description = "SDK for interacting with LangGraph API"
optional = false optional = false
python-versions = ">=3.10" python-versions = ">=3.10"
groups = ["main"] groups = ["main"]
files = [ files = [
{file = "langgraph_sdk-0.2.15-py3-none-any.whl", hash = "sha256:746566a5d89aa47160eccc17d71682a78771c754126f6c235a68353d61ed7462"}, {file = "langgraph_sdk-0.3.3-py3-none-any.whl", hash = "sha256:a52ebaf09d91143e55378bb2d0b033ed98f57f48c9ad35c8f81493b88705fc7b"},
{file = "langgraph_sdk-0.2.15.tar.gz", hash = "sha256:8faaafe2c1193b89f782dd66c591060cd67862aa6aaf283749b7846f331d5334"}, {file = "langgraph_sdk-0.3.3.tar.gz", hash = "sha256:c34c3dce3b6848755eb61f0c94369d1ba04aceeb1b76015db1ea7362c544fb26"},
] ]
[package.dependencies] [package.dependencies]

View File

@ -220,17 +220,15 @@
color: var(--text); color: var(--text);
} }
</style> </style>
<script src="https://unpkg.com/lucide@latest/dist/umd/lucide.js"></script>
<!-- Fonts --> <!-- Fonts -->
<link rel="preconnect" href="https://fonts.googleapis.com"> <link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin> <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;500;600;700&family=Poppins:wght@400;500;600;700&display=swap" rel="stylesheet"> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;500;600;700&family=Poppins:wght@400;500;600;700&display=swap" rel="stylesheet">
<!-- Highlight.js --> <!-- Highlight.js -->
<link rel="stylesheet" href="https://cdn.staticfile.net/highlight.js/11.9.0/styles/github.min.css"> <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/github.min.css">
<script src="https://cdn.staticfile.net/highlight.js/11.9.0/highlight.min.js"></script> <script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js"></script>
<!-- Marked.js --> <!-- Marked.js -->
<script src="https://cdn.staticfile.net/marked/4.3.0/marked.min.js"></script> <script src="https://cdn.jsdelivr.net/npm/marked@4.3.0/marked.min.js"></script>
<!-- Lucide Icons --> <!-- Lucide Icons -->
<script src="https://unpkg.com/lucide@latest/dist/umd/lucide.js"></script> <script src="https://unpkg.com/lucide@latest/dist/umd/lucide.js"></script>
@ -1911,6 +1909,13 @@
<label class="settings-label" for="user-identifier">用户标识</label> <label class="settings-label" for="user-identifier">用户标识</label>
<input type="text" id="user-identifier" class="settings-input" placeholder="输入用户标识..."> <input type="text" id="user-identifier" class="settings-input" placeholder="输入用户标识...">
</div> </div>
<div class="settings-group">
<div class="settings-checkbox-wrapper">
<input type="checkbox" id="enable-memori" class="settings-checkbox">
<label class="settings-label" for="enable-memori" style="margin-bottom: 0;">启用记忆存储</label>
</div>
<p style="font-size: 11px; color: var(--text-muted); margin-top: 4px;">启用后AI 会记住对话中的信息以提供更个性化的回复</p>
</div>
</div> </div>
</div> </div>
@ -2842,6 +2847,7 @@
'dataset-ids': document.getElementById('dataset-ids').value, 'dataset-ids': document.getElementById('dataset-ids').value,
'system-prompt': document.getElementById('system-prompt').value, 'system-prompt': document.getElementById('system-prompt').value,
'user-identifier': document.getElementById('user-identifier').value, 'user-identifier': document.getElementById('user-identifier').value,
'enable-memori': document.getElementById('enable-memori').checked,
'skills': selectedSkills.join(','), 'skills': selectedSkills.join(','),
'mcp-settings': mcpSettingsValue, 'mcp-settings': mcpSettingsValue,
'tool-response': document.getElementById('tool-response').checked 'tool-response': document.getElementById('tool-response').checked
@ -3235,6 +3241,7 @@
systemPrompt: getValue('system-prompt'), systemPrompt: getValue('system-prompt'),
sessionId, sessionId,
userIdentifier: getValue('user-identifier'), userIdentifier: getValue('user-identifier'),
enableMemori: getChecked('enable-memori'),
skills, skills,
mcpSettings, mcpSettings,
toolResponse: getChecked('tool-response') toolResponse: getChecked('tool-response')
@ -3492,6 +3499,7 @@
if (settings.systemPrompt) requestBody.system_prompt = settings.systemPrompt; if (settings.systemPrompt) requestBody.system_prompt = settings.systemPrompt;
if (settings.sessionId) requestBody.session_id = settings.sessionId; if (settings.sessionId) requestBody.session_id = settings.sessionId;
if (settings.userIdentifier) requestBody.user_identifier = settings.userIdentifier; if (settings.userIdentifier) requestBody.user_identifier = settings.userIdentifier;
if (settings.enableMemori) requestBody.enable_memori = settings.enableMemori;
if (settings.skills?.length) requestBody.skills = settings.skills; if (settings.skills?.length) requestBody.skills = settings.skills;
if (settings.datasetIds?.length) requestBody.dataset_ids = settings.datasetIds; if (settings.datasetIds?.length) requestBody.dataset_ids = settings.datasetIds;
if (settings.mcpSettings?.length) requestBody.mcp_settings = settings.mcpSettings; if (settings.mcpSettings?.length) requestBody.mcp_settings = settings.mcpSettings;

View File

@ -64,8 +64,8 @@ langchain==1.1.3 ; python_version >= "3.12" and python_version < "4.0"
langgraph-checkpoint-postgres==2.0.25 ; python_version >= "3.12" and python_version < "4.0" langgraph-checkpoint-postgres==2.0.25 ; python_version >= "3.12" and python_version < "4.0"
langgraph-checkpoint==2.1.2 ; python_version >= "3.12" and python_version < "4.0" langgraph-checkpoint==2.1.2 ; python_version >= "3.12" and python_version < "4.0"
langgraph-prebuilt==1.0.5 ; python_version >= "3.12" and python_version < "4.0" langgraph-prebuilt==1.0.5 ; python_version >= "3.12" and python_version < "4.0"
langgraph-sdk==0.2.15 ; python_version >= "3.12" and python_version < "4.0" langgraph-sdk==0.3.3 ; python_version >= "3.12" and python_version < "4.0"
langgraph==1.0.4 ; python_version >= "3.12" and python_version < "4.0" langgraph==1.0.6 ; python_version >= "3.12" and python_version < "4.0"
langsmith==0.4.59 ; python_version >= "3.12" and python_version < "4.0" langsmith==0.4.59 ; python_version >= "3.12" and python_version < "4.0"
markdown-it-py==4.0.0 ; python_version >= "3.12" and python_version < "4.0" markdown-it-py==4.0.0 ; python_version >= "3.12" and python_version < "4.0"
markdownify==1.2.2 ; python_version >= "3.12" and python_version < "4.0" markdownify==1.2.2 ; python_version >= "3.12" and python_version < "4.0"

View File

@ -199,6 +199,73 @@ async def safe_extract_zip(zip_path: str, extract_dir: str) -> None:
raise HTTPException(status_code=400, detail=f"无效的 zip 文件: {str(e)}") raise HTTPException(status_code=400, detail=f"无效的 zip 文件: {str(e)}")
async def validate_and_rename_skill_folder(
extract_dir: str,
has_top_level_dirs: bool
) -> str:
"""验证并重命名解压后的 skill 文件夹
检查解压后文件夹名称是否与 SKILL.md 中的 name 匹配
如果不匹配则重命名文件夹
Args:
extract_dir: 解压目标目录
has_top_level_dirs: zip 是否包含顶级目录
Returns:
str: 最终的解压路径可能因为重命名而改变
"""
try:
if has_top_level_dirs:
# zip 包含目录,检查每个目录
for folder_name in os.listdir(extract_dir):
folder_path = os.path.join(extract_dir, folder_name)
if os.path.isdir(folder_path):
skill_md_path = os.path.join(folder_path, 'SKILL.md')
if os.path.exists(skill_md_path):
metadata = await asyncio.to_thread(
parse_skill_frontmatter, skill_md_path
)
if metadata and 'name' in metadata:
expected_name = metadata['name']
if folder_name != expected_name:
new_folder_path = os.path.join(extract_dir, expected_name)
await asyncio.to_thread(
shutil.move, folder_path, new_folder_path
)
logger.info(
f"Renamed skill folder: {folder_name} -> {expected_name}"
)
return extract_dir
else:
# zip 直接包含文件,检查当前目录的 SKILL.md
skill_md_path = os.path.join(extract_dir, 'SKILL.md')
if os.path.exists(skill_md_path):
metadata = await asyncio.to_thread(
parse_skill_frontmatter, skill_md_path
)
if metadata and 'name' in metadata:
expected_name = metadata['name']
# 获取当前文件夹名称
current_name = os.path.basename(extract_dir)
if current_name != expected_name:
parent_dir = os.path.dirname(extract_dir)
new_folder_path = os.path.join(parent_dir, expected_name)
await asyncio.to_thread(
shutil.move, extract_dir, new_folder_path
)
logger.info(
f"Renamed skill folder: {current_name} -> {expected_name}"
)
return new_folder_path
return extract_dir
except Exception as e:
logger.warning(f"Failed to validate/rename skill folder: {e}")
# 不抛出异常,允许上传继续
return extract_dir
async def save_upload_file_async(file: UploadFile, destination: str) -> None: async def save_upload_file_async(file: UploadFile, destination: str) -> None:
"""异步保存上传文件到目标路径""" """异步保存上传文件到目标路径"""
async with aiofiles.open(destination, 'wb') as f: async with aiofiles.open(destination, 'wb') as f:
@ -453,13 +520,24 @@ async def upload_skill(file: UploadFile = File(...), bot_id: Optional[str] = For
await safe_extract_zip(file_path, extract_target) await safe_extract_zip(file_path, extract_target)
logger.info(f"Extracted to: {extract_target}") logger.info(f"Extracted to: {extract_target}")
# 验证并重命名文件夹以匹配 SKILL.md 中的 name
final_extract_path = await validate_and_rename_skill_folder(
extract_target, has_top_level_dirs
)
# 获取最终的 skill 名称
if has_top_level_dirs:
final_skill_name = folder_name
else:
final_skill_name = os.path.basename(final_extract_path)
return { return {
"success": True, "success": True,
"message": f"Skill文件上传并解压成功", "message": f"Skill文件上传并解压成功",
"file_path": file_path, "file_path": file_path,
"extract_path": extract_target, "extract_path": final_extract_path,
"original_filename": original_filename, "original_filename": original_filename,
"skill_name": folder_name "skill_name": final_skill_name
} }
except HTTPException: except HTTPException:

View File

@ -54,6 +54,7 @@ CHECKPOINT_DB_URL = os.getenv("CHECKPOINT_DB_URL", "postgresql://moshui:@localho
# 连接池大小 # 连接池大小
# 同时可以持有的最大连接数 # 同时可以持有的最大连接数
CHECKPOINT_POOL_SIZE = int(os.getenv("CHECKPOINT_POOL_SIZE", "20")) CHECKPOINT_POOL_SIZE = int(os.getenv("CHECKPOINT_POOL_SIZE", "20"))
MEM0_POOL_SIZE = int(os.getenv("MEM0_POOL_SIZE", "20"))
# Checkpoint 自动清理配置 # Checkpoint 自动清理配置
# 是否启用自动清理旧 session # 是否启用自动清理旧 session