deep_assistant/prompt_loader 用 config.bot_id 拼 projects/robot/{bot_id},
之前只改 project_dir 无效。from_v3_request 在 bot_config 带 project_dir_key 时
用它覆盖 bot_id,使 workspace、系统提示词路径、skills 目录都按用户隔离。
普通 bot 不返回 project_dir_key,行为不变。
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
424 lines
18 KiB
Python
424 lines
18 KiB
Python
"""Agent configuration class for managing all agent-related parameters."""
|
||
|
||
from typing import Optional, List, Dict, Any, TYPE_CHECKING
|
||
from dataclasses import dataclass, field
|
||
import logging
|
||
import json
|
||
import hashlib
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
@dataclass
|
||
class AgentConfig:
|
||
"""Agent configuration class containing all parameters required to create and manage an agent."""
|
||
|
||
# Basic parameters
|
||
bot_id: str
|
||
api_key: Optional[str] = None
|
||
model_name: str = "qwen3-next"
|
||
model_server: Optional[str] = None
|
||
language: Optional[str] = "jp"
|
||
|
||
# Configuration parameters
|
||
system_prompt: Optional[str] = None
|
||
mcp_settings: Optional[List[Dict]] = field(default_factory=list)
|
||
generate_cfg: Optional[Dict] = None
|
||
enable_thinking: bool = False
|
||
enable_self_knowledge: bool = False
|
||
|
||
# Context parameters
|
||
project_dir: Optional[str] = None
|
||
user_identifier: Optional[str] = None
|
||
session_id: Optional[str] = None
|
||
dataset_ids: Optional[List[str]] = field(default_factory=list)
|
||
trace_id: Optional[str] = None # Request trace ID, obtained from the X-Request-ID header
|
||
request_started_at: Optional[float] = None
|
||
|
||
# Response control parameters
|
||
stream: bool = False
|
||
tool_response: bool = True
|
||
preamble_text: Optional[str] = None
|
||
messages: Optional[List] = field(default_factory=list)
|
||
_origin_messages: Optional[List] = field(default_factory=list)
|
||
|
||
logging_handler: Optional['LoggingCallbackHandler'] = None
|
||
|
||
# Mem0 long-term memory configuration
|
||
enable_memori: bool = False # Keep the name for API compatibility; Mem0 is used internally
|
||
memori_semantic_search_top_k: int = 20
|
||
_mem0_context: Optional[str] = None # Mem0 recalled memory context, used for passing data between middlewares
|
||
|
||
# Custom shell environment variables
|
||
shell_env: Optional[Dict[str, str]] = field(default_factory=dict)
|
||
|
||
# Checkpointer session history
|
||
_session_history: Optional[List] = field(default_factory=list) # Historical chat records loaded from the checkpointer
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""Convert to a dictionary for passing into functions that require **kwargs."""
|
||
return {
|
||
'bot_id': self.bot_id,
|
||
'api_key': self.api_key,
|
||
'model_name': self.model_name,
|
||
'model_server': self.model_server,
|
||
'language': self.language,
|
||
'system_prompt': self.system_prompt,
|
||
'mcp_settings': self.mcp_settings,
|
||
'generate_cfg': self.generate_cfg,
|
||
'enable_thinking': self.enable_thinking,
|
||
'enable_self_knowledge': self.enable_self_knowledge,
|
||
'project_dir': self.project_dir,
|
||
'user_identifier': self.user_identifier,
|
||
'session_id': self.session_id,
|
||
'dataset_ids': self.dataset_ids,
|
||
'stream': self.stream,
|
||
'tool_response': self.tool_response,
|
||
'preamble_text': self.preamble_text,
|
||
'messages': self.messages,
|
||
'enable_memori': self.enable_memori,
|
||
'memori_semantic_search_top_k': self.memori_semantic_search_top_k,
|
||
'trace_id': self.trace_id,
|
||
'shell_env': self.shell_env,
|
||
}
|
||
|
||
@staticmethod
|
||
def _redact_image_data(messages):
|
||
"""Return a copy of messages with image base64/data URL payloads truncated.
|
||
|
||
Image content blocks can carry huge base64 strings; replace them with a short
|
||
placeholder so logs stay readable.
|
||
"""
|
||
if not isinstance(messages, list):
|
||
return messages
|
||
|
||
def redact_value(value):
|
||
if isinstance(value, str) and value.startswith('data:'):
|
||
return value[:32] + '...[truncated]'
|
||
return value
|
||
|
||
redacted_messages = []
|
||
for message in messages:
|
||
if not isinstance(message, dict):
|
||
redacted_messages.append(message)
|
||
continue
|
||
content = message.get('content')
|
||
if not isinstance(content, list):
|
||
redacted_messages.append(message)
|
||
continue
|
||
redacted_content = []
|
||
for block in content:
|
||
if isinstance(block, dict) and block.get('type') in ('image', 'image_url'):
|
||
new_block = block.copy()
|
||
if 'base64' in new_block:
|
||
new_block['base64'] = '[truncated]'
|
||
if 'image_url' in new_block:
|
||
image_url = new_block['image_url']
|
||
if isinstance(image_url, dict):
|
||
new_block['image_url'] = {**image_url, 'url': redact_value(image_url.get('url'))}
|
||
else:
|
||
new_block['image_url'] = redact_value(image_url)
|
||
if 'url' in new_block:
|
||
new_block['url'] = redact_value(new_block['url'])
|
||
redacted_content.append(new_block)
|
||
else:
|
||
redacted_content.append(block)
|
||
redacted_messages.append({**message, 'content': redacted_content})
|
||
return redacted_messages
|
||
|
||
def safe_print(self):
|
||
"""Safely log the configuration without printing sensitive information."""
|
||
safe_dict = self.to_dict().copy()
|
||
if 'api_key' in safe_dict and isinstance(safe_dict['api_key'], str) and safe_dict['api_key'].startswith('sk-'):
|
||
safe_dict['api_key'] = safe_dict['api_key'][:8] + '***' + safe_dict['api_key'][-6:]
|
||
if 'messages' in safe_dict:
|
||
safe_dict['messages'] = self._redact_image_data(safe_dict['messages'])
|
||
logger.info(f"config={json.dumps(safe_dict, ensure_ascii=False)}")
|
||
|
||
@classmethod
|
||
async def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None, messages: Optional[List] = None):
|
||
"""Create configuration from a v1 request."""
|
||
# Delay imports to avoid circular dependencies.
|
||
from .logging_handler import LoggingCallbackHandler
|
||
from utils.fastapi_utils import get_preamble_text
|
||
from utils.settings import (
|
||
MEM0_ENABLED,
|
||
MEM0_SEMANTIC_SEARCH_TOP_K,
|
||
)
|
||
from .checkpoint_utils import prepare_checkpoint_message
|
||
from .checkpoint_manager import get_checkpointer_manager
|
||
from utils.log_util.context import g
|
||
|
||
if messages is None:
|
||
messages = []
|
||
|
||
# Get trace_id from the global context.
|
||
trace_id = None
|
||
try:
|
||
trace_id = getattr(g, 'trace_id', None)
|
||
except LookupError:
|
||
pass
|
||
|
||
preamble_text, system_prompt = get_preamble_text(request.language, request.system_prompt)
|
||
|
||
config = cls(
|
||
bot_id=request.bot_id,
|
||
api_key=api_key,
|
||
model_name=request.model,
|
||
model_server=request.model_server,
|
||
language=request.language,
|
||
system_prompt=system_prompt,
|
||
mcp_settings=request.mcp_settings,
|
||
user_identifier=request.user_identifier,
|
||
session_id=request.session_id,
|
||
enable_thinking=request.enable_thinking,
|
||
enable_self_knowledge=request.enable_self_knowledge,
|
||
project_dir=project_dir,
|
||
stream=request.stream,
|
||
tool_response=request.tool_response,
|
||
generate_cfg=generate_cfg,
|
||
logging_handler=LoggingCallbackHandler(),
|
||
messages=messages,
|
||
_origin_messages=messages,
|
||
preamble_text=preamble_text,
|
||
dataset_ids=request.dataset_ids,
|
||
enable_memori=request.enable_memory,
|
||
memori_semantic_search_top_k=getattr(request, 'memori_semantic_search_top_k', None) or MEM0_SEMANTIC_SEARCH_TOP_K,
|
||
trace_id=trace_id,
|
||
shell_env=getattr(request, 'shell_env', None) or {},
|
||
)
|
||
|
||
# Prepare checkpoint messages as early as possible when creating config.
|
||
if config.session_id:
|
||
try:
|
||
manager = get_checkpointer_manager()
|
||
checkpointer = manager.checkpointer
|
||
if checkpointer:
|
||
await prepare_checkpoint_message(config, checkpointer)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load checkpointer: {e}")
|
||
|
||
config.safe_print()
|
||
return config
|
||
|
||
|
||
@classmethod
|
||
async def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None, generate_cfg: Optional[Dict] = None, model_name: Optional[str] = None, model_server: Optional[str] = None, api_key: Optional[str] = None):
|
||
"""Create configuration from a v2 request."""
|
||
# Delay imports to avoid circular dependencies.
|
||
from .logging_handler import LoggingCallbackHandler
|
||
from utils.fastapi_utils import get_preamble_text
|
||
from utils.settings import (
|
||
MEM0_ENABLED,
|
||
MEM0_SEMANTIC_SEARCH_TOP_K,
|
||
)
|
||
from .checkpoint_utils import prepare_checkpoint_message
|
||
from .checkpoint_manager import get_checkpointer_manager
|
||
from utils.log_util.context import g
|
||
|
||
if messages is None:
|
||
messages = []
|
||
|
||
# Get trace_id from the global context.
|
||
trace_id = None
|
||
try:
|
||
trace_id = getattr(g, 'trace_id', None)
|
||
except LookupError:
|
||
pass
|
||
language = request.language or bot_config.get("language", "zh")
|
||
preamble_text, system_prompt = get_preamble_text(language, bot_config.get("system_prompt"))
|
||
|
||
enable_thinking = bot_config.get("enable_thinking", False)
|
||
enable_memori = bot_config.get("enable_memory", False)
|
||
enable_self_knowledge = bot_config.get("enable_self_knowledge", False)
|
||
|
||
config = cls(
|
||
bot_id=request.bot_id,
|
||
api_key=api_key or bot_config.get("api_key"),
|
||
model_name=model_name or bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
|
||
model_server=model_server or bot_config.get("model_server", ""),
|
||
language=language,
|
||
system_prompt=system_prompt,
|
||
mcp_settings=bot_config.get("mcp_settings", []),
|
||
user_identifier=request.user_identifier,
|
||
session_id=request.session_id,
|
||
enable_thinking=enable_thinking,
|
||
enable_self_knowledge=enable_self_knowledge,
|
||
project_dir=project_dir,
|
||
stream=request.stream,
|
||
tool_response=request.tool_response,
|
||
generate_cfg=generate_cfg or {}, # The v2 API also supports passing extra generate_cfg.
|
||
logging_handler=LoggingCallbackHandler(),
|
||
messages=messages,
|
||
_origin_messages=messages,
|
||
preamble_text=preamble_text,
|
||
dataset_ids=bot_config.get("dataset_ids", []), # Get dataset_ids from backend configuration.
|
||
enable_memori=enable_memori,
|
||
memori_semantic_search_top_k=bot_config.get("memori_semantic_search_top_k", MEM0_SEMANTIC_SEARCH_TOP_K),
|
||
trace_id=trace_id,
|
||
shell_env=bot_config.get("shell_env") or {},
|
||
)
|
||
|
||
# Prepare checkpoint messages as early as possible when creating config.
|
||
if config.session_id:
|
||
try:
|
||
manager = get_checkpointer_manager()
|
||
checkpointer = manager.checkpointer
|
||
if checkpointer:
|
||
await prepare_checkpoint_message(config, checkpointer)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load checkpointer: {e}")
|
||
|
||
config.safe_print()
|
||
return config
|
||
|
||
@classmethod
|
||
async def from_v3_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None, language: Optional[str] = None):
|
||
"""从v3请求创建配置 - 从数据库读取所有配置"""
|
||
# 延迟导入避免循环依赖
|
||
from .logging_handler import LoggingCallbackHandler
|
||
from utils.fastapi_utils import get_preamble_text
|
||
from utils.settings import (
|
||
MEM0_ENABLED,
|
||
MEM0_SEMANTIC_SEARCH_TOP_K,
|
||
)
|
||
from .checkpoint_utils import prepare_checkpoint_message
|
||
from .checkpoint_manager import get_checkpointer_manager
|
||
from utils.log_util.context import g
|
||
|
||
if messages is None:
|
||
messages = []
|
||
|
||
# 从全局上下文获取 trace_id
|
||
trace_id = None
|
||
try:
|
||
trace_id = getattr(g, 'trace_id', None)
|
||
except LookupError:
|
||
pass
|
||
|
||
# 从数据库配置获取语言(如果没有传递)
|
||
if language is None:
|
||
language = bot_config.get("language", "zh")
|
||
|
||
# 处理 system_prompt 和 preamble
|
||
system_prompt_from_db = bot_config.get("system_prompt", "")
|
||
preamble_text, system_prompt = get_preamble_text(language, system_prompt_from_db)
|
||
|
||
|
||
# 从数据库配置获取其他参数
|
||
enable_thinking = bot_config.get("enable_thinking", False)
|
||
enable_memori = bot_config.get("enable_memori", False)
|
||
|
||
# 通用智能体按用户隔离:用 project_dir_key 覆盖 bot_id,
|
||
# 使 workspace / system_prompt 路径 / skills 目录都落到 projects/robot/{bot_id}-{user}
|
||
effective_bot_id = bot_config.get("project_dir_key") or request.bot_id
|
||
|
||
config = cls(
|
||
bot_id=effective_bot_id,
|
||
api_key=bot_config.get("api_key", ""),
|
||
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
|
||
model_server=bot_config.get("model_server", ""),
|
||
language=language,
|
||
system_prompt=system_prompt,
|
||
mcp_settings=bot_config.get("mcp_settings", []),
|
||
user_identifier=bot_config.get("user_identifier", ""),
|
||
session_id=request.session_id,
|
||
enable_thinking=enable_thinking,
|
||
project_dir=project_dir,
|
||
stream=request.stream,
|
||
tool_response=bot_config.get("tool_response", True),
|
||
generate_cfg={}, # v3接口不传递额外的generate_cfg
|
||
logging_handler=LoggingCallbackHandler(),
|
||
messages=messages,
|
||
_origin_messages=messages,
|
||
preamble_text=preamble_text,
|
||
dataset_ids=bot_config.get("dataset_ids", []),
|
||
enable_memori=enable_memori,
|
||
memori_semantic_search_top_k=bot_config.get("memori_semantic_search_top_k", MEM0_SEMANTIC_SEARCH_TOP_K),
|
||
trace_id=trace_id,
|
||
shell_env=bot_config.get("shell_env") or {},
|
||
)
|
||
|
||
# 在创建 config 时尽早准备 checkpoint 消息
|
||
if config.session_id:
|
||
try:
|
||
manager = get_checkpointer_manager()
|
||
checkpointer = manager.checkpointer
|
||
if checkpointer:
|
||
await prepare_checkpoint_message(config, checkpointer)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load checkpointer: {e}")
|
||
|
||
config.safe_print()
|
||
return config
|
||
|
||
def invoke_config(self):
|
||
"""Return the configuration dictionary required by LangChain."""
|
||
config = {}
|
||
callbacks = []
|
||
if self.logging_handler:
|
||
callbacks.append(self.logging_handler)
|
||
|
||
from utils.settings import LANGFUSE_ENABLED
|
||
if LANGFUSE_ENABLED:
|
||
from langfuse.langchain import CallbackHandler as LangfuseCallbackHandler
|
||
trace_context = {"trace_id": self.trace_id} if self.trace_id else None
|
||
langfuse_handler = LangfuseCallbackHandler(
|
||
trace_context=trace_context,
|
||
)
|
||
callbacks.append(langfuse_handler)
|
||
|
||
langfuse_metadata = {}
|
||
if LANGFUSE_ENABLED:
|
||
if self.session_id:
|
||
langfuse_metadata["langfuse_session_id"] = self.session_id
|
||
if self.user_identifier:
|
||
langfuse_metadata["langfuse_user_id"] = self.user_identifier
|
||
if self.bot_id:
|
||
langfuse_metadata["bot_id"] = self.bot_id
|
||
if self.model_name:
|
||
langfuse_metadata["model"] = self.model_name
|
||
|
||
if callbacks:
|
||
config["callbacks"] = callbacks
|
||
if langfuse_metadata:
|
||
config["metadata"] = langfuse_metadata
|
||
if self.session_id:
|
||
config["configurable"] = {"thread_id": self.session_id}
|
||
return config
|
||
|
||
def get_unique_cache_id(self) -> Optional[str]:
|
||
"""
|
||
Generate a unique cache key.
|
||
|
||
Generates a unique cache key based on session_id, bot_id, system_prompt,
|
||
mcp_settings, dataset_ids, and related fields. If session_id is missing,
|
||
returns None to indicate that caching should not be used.
|
||
|
||
Returns:
|
||
str: The unique cache key, or None if session_id is missing.
|
||
"""
|
||
# Create the data used for generating the hash.
|
||
cache_components = {
|
||
'bot_id': self.bot_id,
|
||
'system_prompt': self.system_prompt,
|
||
'mcp_settings': self.mcp_settings,
|
||
'model_name': self.model_name,
|
||
'language': self.language,
|
||
'generate_cfg': self.generate_cfg,
|
||
'enable_thinking': self.enable_thinking,
|
||
'enable_self_knowledge': self.enable_self_knowledge,
|
||
'user_identifier': self.user_identifier,
|
||
'session_id': self.session_id,
|
||
'dataset_ids': self.dataset_ids, # Include dataset_ids in cache key generation.
|
||
'project_dir': self.project_dir, # project_dir should also be included because dataset_ids affect it.
|
||
}
|
||
|
||
# Convert components to a string and concatenate them.
|
||
cache_string = json.dumps(cache_components, sort_keys=True)
|
||
|
||
# Generate a SHA256 hash to use as the cache key.
|
||
cache_hash = hashlib.sha256(cache_string.encode('utf-8')).hexdigest()
|
||
|
||
# Return a prefixed cache key to make debugging easier.
|
||
return f"agent_cache_{cache_hash[:16]}" # Use the first 16 characters of the hash.
|