refactor(mem0): optimize connection pool and async memory handling
- Fix mem0 connection pool exhausted error with proper pooling - Convert memory operations to async tasks - Optimize docker-compose configuration - Add skill upload functionality - Reduce cache size for better performance - Update dependencies Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
f29fd1fb54
commit
3dc119bca8
@ -105,9 +105,7 @@ class AgentConfig:
|
|||||||
enable_thinking = request.enable_thinking and "<guidelines>" in request.system_prompt
|
enable_thinking = request.enable_thinking and "<guidelines>" in request.system_prompt
|
||||||
|
|
||||||
# 从请求中获取 Mem0 配置,如果没有则使用全局配置
|
# 从请求中获取 Mem0 配置,如果没有则使用全局配置
|
||||||
enable_memori = getattr(request, 'enable_memori', None)
|
enable_memori = getattr(request, 'enable_memori', MEM0_ENABLED)
|
||||||
if enable_memori is None:
|
|
||||||
enable_memori = MEM0_ENABLED
|
|
||||||
|
|
||||||
config = cls(
|
config = cls(
|
||||||
bot_id=request.bot_id,
|
bot_id=request.bot_id,
|
||||||
@ -171,10 +169,7 @@ class AgentConfig:
|
|||||||
enable_thinking = request.enable_thinking and "<guidelines>" in bot_config.get("system_prompt")
|
enable_thinking = request.enable_thinking and "<guidelines>" in bot_config.get("system_prompt")
|
||||||
|
|
||||||
# 从请求或后端配置中获取 Mem0 配置
|
# 从请求或后端配置中获取 Mem0 配置
|
||||||
enable_memori = getattr(request, 'enable_memori', None)
|
enable_memori = getattr(request, 'enable_memori', MEM0_ENABLED)
|
||||||
if enable_memori is None:
|
|
||||||
enable_memori = bot_config.get("enable_memori", MEM0_ENABLED)
|
|
||||||
|
|
||||||
|
|
||||||
config = cls(
|
config = cls(
|
||||||
bot_id=request.bot_id,
|
bot_id=request.bot_id,
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from psycopg2 import pool as psycopg2_pool
|
|||||||
from utils.settings import (
|
from utils.settings import (
|
||||||
CHECKPOINT_DB_URL,
|
CHECKPOINT_DB_URL,
|
||||||
CHECKPOINT_POOL_SIZE,
|
CHECKPOINT_POOL_SIZE,
|
||||||
|
MEM0_POOL_SIZE,
|
||||||
CHECKPOINT_CLEANUP_ENABLED,
|
CHECKPOINT_CLEANUP_ENABLED,
|
||||||
CHECKPOINT_CLEANUP_INTERVAL_HOURS,
|
CHECKPOINT_CLEANUP_INTERVAL_HOURS,
|
||||||
CHECKPOINT_CLEANUP_INACTIVE_DAYS,
|
CHECKPOINT_CLEANUP_INACTIVE_DAYS,
|
||||||
@ -62,7 +63,7 @@ class DBPoolManager:
|
|||||||
await self._pool.open()
|
await self._pool.open()
|
||||||
|
|
||||||
# 2. 创建同步 psycopg2 连接池(供 Mem0 使用)
|
# 2. 创建同步 psycopg2 连接池(供 Mem0 使用)
|
||||||
self._sync_pool = self._create_sync_pool(CHECKPOINT_DB_URL, CHECKPOINT_POOL_SIZE)
|
self._sync_pool = self._create_sync_pool(CHECKPOINT_DB_URL, MEM0_POOL_SIZE)
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
logger.info("PostgreSQL connection pool initialized successfully")
|
logger.info("PostgreSQL connection pool initialized successfully")
|
||||||
|
|||||||
@ -5,11 +5,16 @@ Mem0 连接和实例管理器
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import concurrent.futures
|
||||||
from typing import Any, Dict, List, Optional, Literal
|
from typing import Any, Dict, List, Optional, Literal
|
||||||
|
from collections import OrderedDict
|
||||||
from embedding.manager import GlobalModelManager, get_model_manager
|
from embedding.manager import GlobalModelManager, get_model_manager
|
||||||
import json_repair
|
import json_repair
|
||||||
from psycopg2 import pool
|
from psycopg2 import pool
|
||||||
|
from utils.settings import (
|
||||||
|
MEM0_POOL_SIZE
|
||||||
|
)
|
||||||
from .mem0_config import Mem0Config
|
from .mem0_config import Mem0Config
|
||||||
|
|
||||||
logger = logging.getLogger("app")
|
logger = logging.getLogger("app")
|
||||||
@ -27,7 +32,9 @@ class CustomMem0Embedding:
|
|||||||
这样 Mem0 就不需要再次加载同一个模型,节省内存
|
这样 Mem0 就不需要再次加载同一个模型,节省内存
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_model_manager = None # 缓存 GlobalModelManager 实例
|
_model = None # 类变量,缓存模型实例
|
||||||
|
_lock = threading.Lock() # 线程安全锁
|
||||||
|
_executor = None # 线程池执行器
|
||||||
|
|
||||||
def __init__(self, config: Optional[Any] = None):
|
def __init__(self, config: Optional[Any] = None):
|
||||||
"""初始化自定义 Embedding"""
|
"""初始化自定义 Embedding"""
|
||||||
@ -41,6 +48,41 @@ class CustomMem0Embedding:
|
|||||||
"""获取 embedding 维度"""
|
"""获取 embedding 维度"""
|
||||||
return 384 # gte-tiny 的维度
|
return 384 # gte-tiny 的维度
|
||||||
|
|
||||||
|
def _get_model_sync(self):
|
||||||
|
"""同步获取模型,避免 asyncio.run()"""
|
||||||
|
# 首先尝试从 manager 获取已加载的模型
|
||||||
|
manager = get_model_manager()
|
||||||
|
model = manager.get_model_sync()
|
||||||
|
|
||||||
|
if model is not None:
|
||||||
|
# 缓存模型
|
||||||
|
CustomMem0Embedding._model = model
|
||||||
|
return model
|
||||||
|
|
||||||
|
# 如果模型未加载,使用线程池运行异步初始化
|
||||||
|
if CustomMem0Embedding._executor is None:
|
||||||
|
CustomMem0Embedding._executor = concurrent.futures.ThreadPoolExecutor(
|
||||||
|
max_workers=1,
|
||||||
|
thread_name_prefix="mem0_embed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 在独立线程中运行异步代码
|
||||||
|
def run_async_in_thread():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
result = loop.run_until_complete(manager.get_model())
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
future = CustomMem0Embedding._executor.submit(run_async_in_thread)
|
||||||
|
model = future.result(timeout=30) # 30秒超时
|
||||||
|
|
||||||
|
# 缓存模型
|
||||||
|
CustomMem0Embedding._model = model
|
||||||
|
return model
|
||||||
|
|
||||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||||
"""
|
"""
|
||||||
获取文本的 embedding 向量(同步方法,供 Mem0 调用)
|
获取文本的 embedding 向量(同步方法,供 Mem0 调用)
|
||||||
@ -52,8 +94,13 @@ class CustomMem0Embedding:
|
|||||||
Returns:
|
Returns:
|
||||||
list: embedding 向量
|
list: embedding 向量
|
||||||
"""
|
"""
|
||||||
manager = get_model_manager()
|
# 线程安全地获取模型
|
||||||
model = asyncio.run(manager.get_model())
|
if CustomMem0Embedding._model is None:
|
||||||
|
with CustomMem0Embedding._lock:
|
||||||
|
if CustomMem0Embedding._model is None:
|
||||||
|
self._get_model_sync()
|
||||||
|
|
||||||
|
model = CustomMem0Embedding._model
|
||||||
embeddings = model.encode(text, convert_to_numpy=True)
|
embeddings = model.encode(text, convert_to_numpy=True)
|
||||||
return embeddings.tolist()
|
return embeddings.tolist()
|
||||||
|
|
||||||
@ -136,8 +183,9 @@ class Mem0Manager:
|
|||||||
"""
|
"""
|
||||||
self._sync_pool = sync_pool
|
self._sync_pool = sync_pool
|
||||||
|
|
||||||
# 缓存 Mem0 实例: key = f"{user_id}:{agent_id}"
|
# 使用 OrderedDict 实现 LRU 缓存,最多保留 50 个实例
|
||||||
self._instances: Dict[str, Any] = {}
|
self._instances: OrderedDict[str, Any] = OrderedDict()
|
||||||
|
self._max_instances = MEM0_POOL_SIZE/2 # 最大缓存实例数
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
@ -194,10 +242,16 @@ class Mem0Manager:
|
|||||||
llm_suffix = f":{id(config.llm_instance)}"
|
llm_suffix = f":{id(config.llm_instance)}"
|
||||||
cache_key = f"{user_id}:{agent_id}{llm_suffix}"
|
cache_key = f"{user_id}:{agent_id}{llm_suffix}"
|
||||||
|
|
||||||
# 检查缓存
|
# 检查缓存(同时移动到末尾表示最近使用)
|
||||||
if cache_key in self._instances:
|
if cache_key in self._instances:
|
||||||
|
self._instances.move_to_end(cache_key)
|
||||||
return self._instances[cache_key]
|
return self._instances[cache_key]
|
||||||
|
|
||||||
|
# 检查缓存大小,超过则移除最旧的
|
||||||
|
if len(self._instances) >= self._max_instances:
|
||||||
|
removed_key, _ = self._instances.popitem(last=False)
|
||||||
|
logger.debug(f"Mem0 instance cache full, removed oldest entry: {removed_key}")
|
||||||
|
|
||||||
# 创建新实例
|
# 创建新实例
|
||||||
mem0_instance = await self._create_mem0_instance(
|
mem0_instance = await self._create_mem0_instance(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@ -206,7 +260,7 @@ class Mem0Manager:
|
|||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 缓存实例
|
# 缓存实例(新实例自动在末尾)
|
||||||
self._instances[cache_key] = mem0_instance
|
self._instances[cache_key] = mem0_instance
|
||||||
return mem0_instance
|
return mem0_instance
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,9 @@ Mem0 Agent 中间件
|
|||||||
实现记忆召回和存储的 AgentMiddleware
|
实现记忆召回和存储的 AgentMiddleware
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.agents.middleware import AgentMiddleware, AgentState, ModelRequest
|
from langchain.agents.middleware import AgentMiddleware, AgentState, ModelRequest
|
||||||
@ -123,8 +125,6 @@ class Mem0Middleware(AgentMiddleware):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# 提取用户查询
|
# 提取用户查询
|
||||||
query = self._extract_user_query(state)
|
query = self._extract_user_query(state)
|
||||||
if not query:
|
if not query:
|
||||||
@ -216,6 +216,8 @@ class Mem0Middleware(AgentMiddleware):
|
|||||||
def after_agent(self, state: AgentState, runtime: Runtime) -> None:
|
def after_agent(self, state: AgentState, runtime: Runtime) -> None:
|
||||||
"""Agent 执行后:触发记忆增强(同步版本)
|
"""Agent 执行后:触发记忆增强(同步版本)
|
||||||
|
|
||||||
|
使用后台线程执行,避免阻塞主流程
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: Agent 状态
|
state: Agent 状态
|
||||||
runtime: 运行时上下文
|
runtime: 运行时上下文
|
||||||
@ -224,16 +226,21 @@ class Mem0Middleware(AgentMiddleware):
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import asyncio
|
# 在后台线程中执行,完全不阻塞主流程
|
||||||
|
thread = threading.Thread(
|
||||||
# 触发后台增强任务
|
target=self._trigger_augmentation_sync,
|
||||||
asyncio.create_task(self._trigger_augmentation_async(state, runtime))
|
args=(state, runtime),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in Mem0Middleware.after_agent: {e}")
|
logger.error(f"Error in Mem0Middleware.after_agent: {e}")
|
||||||
|
|
||||||
async def aafter_agent(self, state: AgentState, runtime: Runtime) -> None:
|
async def aafter_agent(self, state: AgentState, runtime: Runtime) -> None:
|
||||||
"""Agent 执行后:触发记忆增强(异步版本)
|
"""Agent 执行后:触发记忆增强(异步版本)
|
||||||
|
|
||||||
|
使用后台线程执行,避免阻塞事件循环
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: Agent 状态
|
state: Agent 状态
|
||||||
runtime: 运行时上下文
|
runtime: 运行时上下文
|
||||||
@ -242,10 +249,57 @@ class Mem0Middleware(AgentMiddleware):
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._trigger_augmentation_async(state, runtime)
|
# 在后台线程中执行,完全不阻塞事件循环
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=self._trigger_augmentation_sync,
|
||||||
|
args=(state, runtime),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in Mem0Middleware.aafter_agent: {e}")
|
logger.error(f"Error in Mem0Middleware.aafter_agent: {e}")
|
||||||
|
|
||||||
|
def _trigger_augmentation_sync(self, state: AgentState, runtime: Runtime) -> None:
|
||||||
|
"""触发记忆增强任务(同步版本,在线程中执行)
|
||||||
|
|
||||||
|
从对话中提取信息并存储到 Mem0(用户级别,跨会话)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Agent 状态
|
||||||
|
runtime: 运行时上下文
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取 attribution 参数
|
||||||
|
user_id, agent_id = self.config.get_attribution_tuple()
|
||||||
|
|
||||||
|
# 提取用户查询和 Agent 响应
|
||||||
|
user_query = self._extract_user_query(state)
|
||||||
|
agent_response = self._extract_agent_response(state)
|
||||||
|
|
||||||
|
# 将对话作为记忆存储(用户级别)
|
||||||
|
if user_query and agent_response:
|
||||||
|
conversation_text = f"User: {user_query}\nAssistant: {agent_response}"
|
||||||
|
|
||||||
|
# 在新的事件循环中运行异步代码(因为在线程中)
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(
|
||||||
|
self.mem0_manager.add_memory(
|
||||||
|
text=conversation_text,
|
||||||
|
user_id=user_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
metadata={"type": "conversation"},
|
||||||
|
config=self.config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.debug(f"Stored conversation as memory for user={user_id}, agent={agent_id}")
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in _trigger_augmentation_sync: {e}")
|
||||||
|
|
||||||
async def _trigger_augmentation_async(self, state: AgentState, runtime: Runtime) -> None:
|
async def _trigger_augmentation_async(self, state: AgentState, runtime: Runtime) -> None:
|
||||||
"""触发记忆增强任务
|
"""触发记忆增强任务
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@ version: "3.8"
|
|||||||
|
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
image: postgres:16-alpine
|
image: pgvector/pgvector:pg16
|
||||||
container_name: qwen-agent-postgres
|
container_name: qwen-agent-postgres
|
||||||
environment:
|
environment:
|
||||||
- POSTGRES_USER=postgres
|
- POSTGRES_USER=postgres
|
||||||
@ -37,6 +37,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
# 挂载项目数据目录
|
# 挂载项目数据目录
|
||||||
- ./projects:/app/projects
|
- ./projects:/app/projects
|
||||||
|
- ./models:/app/models
|
||||||
depends_on:
|
depends_on:
|
||||||
postgres:
|
postgres:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
|||||||
@ -108,6 +108,16 @@ class GlobalModelManager:
|
|||||||
logger.error(f"文本编码失败: {e}")
|
logger.error(f"文本编码失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def get_model_sync(self) -> Optional[SentenceTransformer]:
|
||||||
|
"""同步获取模型实例(供同步上下文使用)
|
||||||
|
|
||||||
|
如果模型未加载,返回 None。调用者应确保先通过异步方法初始化模型。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
已加载的 SentenceTransformer 模型,或 None
|
||||||
|
"""
|
||||||
|
return self._model
|
||||||
|
|
||||||
def get_model_info(self) -> Dict[str, Any]:
|
def get_model_info(self) -> Dict[str, Any]:
|
||||||
"""获取模型信息"""
|
"""获取模型信息"""
|
||||||
return {
|
return {
|
||||||
|
|||||||
16
poetry.lock
generated
16
poetry.lock
generated
@ -1773,21 +1773,21 @@ tiktoken = ">=0.7.0,<1.0.0"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langgraph"
|
name = "langgraph"
|
||||||
version = "1.0.4"
|
version = "1.0.6"
|
||||||
description = "Building stateful, multi-actor applications with LLMs"
|
description = "Building stateful, multi-actor applications with LLMs"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.10"
|
python-versions = ">=3.10"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "langgraph-1.0.4-py3-none-any.whl", hash = "sha256:b1a835ceb0a8d69b9db48075e1939e28b1ad70ee23fa3fa8f90149904778bacf"},
|
{file = "langgraph-1.0.6-py3-none-any.whl", hash = "sha256:bcfce190974519c72e29f6e5b17f0023914fd6f936bfab8894083215b271eb89"},
|
||||||
{file = "langgraph-1.0.4.tar.gz", hash = "sha256:86d08e25d7244340f59c5200fa69fdd11066aa999b3164b531e2a20036fac156"},
|
{file = "langgraph-1.0.6.tar.gz", hash = "sha256:dd8e754c76d34a07485308d7117221acf63990e7de8f46ddf5fe256b0a22e6c5"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
langchain-core = ">=0.1"
|
langchain-core = ">=0.1"
|
||||||
langgraph-checkpoint = ">=2.1.0,<4.0.0"
|
langgraph-checkpoint = ">=2.1.0,<5.0.0"
|
||||||
langgraph-prebuilt = ">=1.0.2,<1.1.0"
|
langgraph-prebuilt = ">=1.0.2,<1.1.0"
|
||||||
langgraph-sdk = ">=0.2.2,<0.3.0"
|
langgraph-sdk = ">=0.3.0,<0.4.0"
|
||||||
pydantic = ">=2.7.4"
|
pydantic = ">=2.7.4"
|
||||||
xxhash = ">=3.5.0"
|
xxhash = ">=3.5.0"
|
||||||
|
|
||||||
@ -1843,14 +1843,14 @@ langgraph-checkpoint = ">=2.1.0,<4.0.0"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langgraph-sdk"
|
name = "langgraph-sdk"
|
||||||
version = "0.2.15"
|
version = "0.3.3"
|
||||||
description = "SDK for interacting with LangGraph API"
|
description = "SDK for interacting with LangGraph API"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.10"
|
python-versions = ">=3.10"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "langgraph_sdk-0.2.15-py3-none-any.whl", hash = "sha256:746566a5d89aa47160eccc17d71682a78771c754126f6c235a68353d61ed7462"},
|
{file = "langgraph_sdk-0.3.3-py3-none-any.whl", hash = "sha256:a52ebaf09d91143e55378bb2d0b033ed98f57f48c9ad35c8f81493b88705fc7b"},
|
||||||
{file = "langgraph_sdk-0.2.15.tar.gz", hash = "sha256:8faaafe2c1193b89f782dd66c591060cd67862aa6aaf283749b7846f331d5334"},
|
{file = "langgraph_sdk-0.3.3.tar.gz", hash = "sha256:c34c3dce3b6848755eb61f0c94369d1ba04aceeb1b76015db1ea7362c544fb26"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
|||||||
@ -220,17 +220,15 @@
|
|||||||
color: var(--text);
|
color: var(--text);
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
<script src="https://unpkg.com/lucide@latest/dist/umd/lucide.js"></script>
|
|
||||||
|
|
||||||
<!-- Fonts -->
|
<!-- Fonts -->
|
||||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||||
<link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;500;600;700&family=Poppins:wght@400;500;600;700&display=swap" rel="stylesheet">
|
<link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;500;600;700&family=Poppins:wght@400;500;600;700&display=swap" rel="stylesheet">
|
||||||
<!-- Highlight.js -->
|
<!-- Highlight.js -->
|
||||||
<link rel="stylesheet" href="https://cdn.staticfile.net/highlight.js/11.9.0/styles/github.min.css">
|
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/github.min.css">
|
||||||
<script src="https://cdn.staticfile.net/highlight.js/11.9.0/highlight.min.js"></script>
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js"></script>
|
||||||
<!-- Marked.js -->
|
<!-- Marked.js -->
|
||||||
<script src="https://cdn.staticfile.net/marked/4.3.0/marked.min.js"></script>
|
<script src="https://cdn.jsdelivr.net/npm/marked@4.3.0/marked.min.js"></script>
|
||||||
<!-- Lucide Icons -->
|
<!-- Lucide Icons -->
|
||||||
<script src="https://unpkg.com/lucide@latest/dist/umd/lucide.js"></script>
|
<script src="https://unpkg.com/lucide@latest/dist/umd/lucide.js"></script>
|
||||||
|
|
||||||
@ -1911,6 +1909,13 @@
|
|||||||
<label class="settings-label" for="user-identifier">用户标识</label>
|
<label class="settings-label" for="user-identifier">用户标识</label>
|
||||||
<input type="text" id="user-identifier" class="settings-input" placeholder="输入用户标识...">
|
<input type="text" id="user-identifier" class="settings-input" placeholder="输入用户标识...">
|
||||||
</div>
|
</div>
|
||||||
|
<div class="settings-group">
|
||||||
|
<div class="settings-checkbox-wrapper">
|
||||||
|
<input type="checkbox" id="enable-memori" class="settings-checkbox">
|
||||||
|
<label class="settings-label" for="enable-memori" style="margin-bottom: 0;">启用记忆存储</label>
|
||||||
|
</div>
|
||||||
|
<p style="font-size: 11px; color: var(--text-muted); margin-top: 4px;">启用后,AI 会记住对话中的信息以提供更个性化的回复</p>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -2842,6 +2847,7 @@
|
|||||||
'dataset-ids': document.getElementById('dataset-ids').value,
|
'dataset-ids': document.getElementById('dataset-ids').value,
|
||||||
'system-prompt': document.getElementById('system-prompt').value,
|
'system-prompt': document.getElementById('system-prompt').value,
|
||||||
'user-identifier': document.getElementById('user-identifier').value,
|
'user-identifier': document.getElementById('user-identifier').value,
|
||||||
|
'enable-memori': document.getElementById('enable-memori').checked,
|
||||||
'skills': selectedSkills.join(','),
|
'skills': selectedSkills.join(','),
|
||||||
'mcp-settings': mcpSettingsValue,
|
'mcp-settings': mcpSettingsValue,
|
||||||
'tool-response': document.getElementById('tool-response').checked
|
'tool-response': document.getElementById('tool-response').checked
|
||||||
@ -3235,6 +3241,7 @@
|
|||||||
systemPrompt: getValue('system-prompt'),
|
systemPrompt: getValue('system-prompt'),
|
||||||
sessionId,
|
sessionId,
|
||||||
userIdentifier: getValue('user-identifier'),
|
userIdentifier: getValue('user-identifier'),
|
||||||
|
enableMemori: getChecked('enable-memori'),
|
||||||
skills,
|
skills,
|
||||||
mcpSettings,
|
mcpSettings,
|
||||||
toolResponse: getChecked('tool-response')
|
toolResponse: getChecked('tool-response')
|
||||||
@ -3492,6 +3499,7 @@
|
|||||||
if (settings.systemPrompt) requestBody.system_prompt = settings.systemPrompt;
|
if (settings.systemPrompt) requestBody.system_prompt = settings.systemPrompt;
|
||||||
if (settings.sessionId) requestBody.session_id = settings.sessionId;
|
if (settings.sessionId) requestBody.session_id = settings.sessionId;
|
||||||
if (settings.userIdentifier) requestBody.user_identifier = settings.userIdentifier;
|
if (settings.userIdentifier) requestBody.user_identifier = settings.userIdentifier;
|
||||||
|
if (settings.enableMemori) requestBody.enable_memori = settings.enableMemori;
|
||||||
if (settings.skills?.length) requestBody.skills = settings.skills;
|
if (settings.skills?.length) requestBody.skills = settings.skills;
|
||||||
if (settings.datasetIds?.length) requestBody.dataset_ids = settings.datasetIds;
|
if (settings.datasetIds?.length) requestBody.dataset_ids = settings.datasetIds;
|
||||||
if (settings.mcpSettings?.length) requestBody.mcp_settings = settings.mcpSettings;
|
if (settings.mcpSettings?.length) requestBody.mcp_settings = settings.mcpSettings;
|
||||||
|
|||||||
@ -64,8 +64,8 @@ langchain==1.1.3 ; python_version >= "3.12" and python_version < "4.0"
|
|||||||
langgraph-checkpoint-postgres==2.0.25 ; python_version >= "3.12" and python_version < "4.0"
|
langgraph-checkpoint-postgres==2.0.25 ; python_version >= "3.12" and python_version < "4.0"
|
||||||
langgraph-checkpoint==2.1.2 ; python_version >= "3.12" and python_version < "4.0"
|
langgraph-checkpoint==2.1.2 ; python_version >= "3.12" and python_version < "4.0"
|
||||||
langgraph-prebuilt==1.0.5 ; python_version >= "3.12" and python_version < "4.0"
|
langgraph-prebuilt==1.0.5 ; python_version >= "3.12" and python_version < "4.0"
|
||||||
langgraph-sdk==0.2.15 ; python_version >= "3.12" and python_version < "4.0"
|
langgraph-sdk==0.3.3 ; python_version >= "3.12" and python_version < "4.0"
|
||||||
langgraph==1.0.4 ; python_version >= "3.12" and python_version < "4.0"
|
langgraph==1.0.6 ; python_version >= "3.12" and python_version < "4.0"
|
||||||
langsmith==0.4.59 ; python_version >= "3.12" and python_version < "4.0"
|
langsmith==0.4.59 ; python_version >= "3.12" and python_version < "4.0"
|
||||||
markdown-it-py==4.0.0 ; python_version >= "3.12" and python_version < "4.0"
|
markdown-it-py==4.0.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||||
markdownify==1.2.2 ; python_version >= "3.12" and python_version < "4.0"
|
markdownify==1.2.2 ; python_version >= "3.12" and python_version < "4.0"
|
||||||
|
|||||||
@ -199,6 +199,73 @@ async def safe_extract_zip(zip_path: str, extract_dir: str) -> None:
|
|||||||
raise HTTPException(status_code=400, detail=f"无效的 zip 文件: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"无效的 zip 文件: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_and_rename_skill_folder(
|
||||||
|
extract_dir: str,
|
||||||
|
has_top_level_dirs: bool
|
||||||
|
) -> str:
|
||||||
|
"""验证并重命名解压后的 skill 文件夹
|
||||||
|
|
||||||
|
检查解压后文件夹名称是否与 SKILL.md 中的 name 匹配,
|
||||||
|
如果不匹配则重命名文件夹。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
extract_dir: 解压目标目录
|
||||||
|
has_top_level_dirs: zip 是否包含顶级目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 最终的解压路径(可能因为重命名而改变)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if has_top_level_dirs:
|
||||||
|
# zip 包含目录,检查每个目录
|
||||||
|
for folder_name in os.listdir(extract_dir):
|
||||||
|
folder_path = os.path.join(extract_dir, folder_name)
|
||||||
|
if os.path.isdir(folder_path):
|
||||||
|
skill_md_path = os.path.join(folder_path, 'SKILL.md')
|
||||||
|
if os.path.exists(skill_md_path):
|
||||||
|
metadata = await asyncio.to_thread(
|
||||||
|
parse_skill_frontmatter, skill_md_path
|
||||||
|
)
|
||||||
|
if metadata and 'name' in metadata:
|
||||||
|
expected_name = metadata['name']
|
||||||
|
if folder_name != expected_name:
|
||||||
|
new_folder_path = os.path.join(extract_dir, expected_name)
|
||||||
|
await asyncio.to_thread(
|
||||||
|
shutil.move, folder_path, new_folder_path
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Renamed skill folder: {folder_name} -> {expected_name}"
|
||||||
|
)
|
||||||
|
return extract_dir
|
||||||
|
else:
|
||||||
|
# zip 直接包含文件,检查当前目录的 SKILL.md
|
||||||
|
skill_md_path = os.path.join(extract_dir, 'SKILL.md')
|
||||||
|
if os.path.exists(skill_md_path):
|
||||||
|
metadata = await asyncio.to_thread(
|
||||||
|
parse_skill_frontmatter, skill_md_path
|
||||||
|
)
|
||||||
|
if metadata and 'name' in metadata:
|
||||||
|
expected_name = metadata['name']
|
||||||
|
# 获取当前文件夹名称
|
||||||
|
current_name = os.path.basename(extract_dir)
|
||||||
|
if current_name != expected_name:
|
||||||
|
parent_dir = os.path.dirname(extract_dir)
|
||||||
|
new_folder_path = os.path.join(parent_dir, expected_name)
|
||||||
|
await asyncio.to_thread(
|
||||||
|
shutil.move, extract_dir, new_folder_path
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Renamed skill folder: {current_name} -> {expected_name}"
|
||||||
|
)
|
||||||
|
return new_folder_path
|
||||||
|
return extract_dir
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to validate/rename skill folder: {e}")
|
||||||
|
# 不抛出异常,允许上传继续
|
||||||
|
return extract_dir
|
||||||
|
|
||||||
|
|
||||||
async def save_upload_file_async(file: UploadFile, destination: str) -> None:
|
async def save_upload_file_async(file: UploadFile, destination: str) -> None:
|
||||||
"""异步保存上传文件到目标路径"""
|
"""异步保存上传文件到目标路径"""
|
||||||
async with aiofiles.open(destination, 'wb') as f:
|
async with aiofiles.open(destination, 'wb') as f:
|
||||||
@ -453,13 +520,24 @@ async def upload_skill(file: UploadFile = File(...), bot_id: Optional[str] = For
|
|||||||
await safe_extract_zip(file_path, extract_target)
|
await safe_extract_zip(file_path, extract_target)
|
||||||
logger.info(f"Extracted to: {extract_target}")
|
logger.info(f"Extracted to: {extract_target}")
|
||||||
|
|
||||||
|
# 验证并重命名文件夹以匹配 SKILL.md 中的 name
|
||||||
|
final_extract_path = await validate_and_rename_skill_folder(
|
||||||
|
extract_target, has_top_level_dirs
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取最终的 skill 名称
|
||||||
|
if has_top_level_dirs:
|
||||||
|
final_skill_name = folder_name
|
||||||
|
else:
|
||||||
|
final_skill_name = os.path.basename(final_extract_path)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": f"Skill文件上传并解压成功",
|
"message": f"Skill文件上传并解压成功",
|
||||||
"file_path": file_path,
|
"file_path": file_path,
|
||||||
"extract_path": extract_target,
|
"extract_path": final_extract_path,
|
||||||
"original_filename": original_filename,
|
"original_filename": original_filename,
|
||||||
"skill_name": folder_name
|
"skill_name": final_skill_name
|
||||||
}
|
}
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@ -54,6 +54,7 @@ CHECKPOINT_DB_URL = os.getenv("CHECKPOINT_DB_URL", "postgresql://moshui:@localho
|
|||||||
# 连接池大小
|
# 连接池大小
|
||||||
# 同时可以持有的最大连接数
|
# 同时可以持有的最大连接数
|
||||||
CHECKPOINT_POOL_SIZE = int(os.getenv("CHECKPOINT_POOL_SIZE", "20"))
|
CHECKPOINT_POOL_SIZE = int(os.getenv("CHECKPOINT_POOL_SIZE", "20"))
|
||||||
|
MEM0_POOL_SIZE = int(os.getenv("MEM0_POOL_SIZE", "20"))
|
||||||
|
|
||||||
# Checkpoint 自动清理配置
|
# Checkpoint 自动清理配置
|
||||||
# 是否启用自动清理旧 session
|
# 是否启用自动清理旧 session
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user