343 lines
14 KiB
Python
343 lines
14 KiB
Python
"""Agent configuration class for managing all agent-related parameters."""
|
|
|
|
from typing import Optional, List, Dict, Any, TYPE_CHECKING
|
|
from dataclasses import dataclass, field
|
|
import logging
|
|
import json
|
|
import hashlib
|
|
|
|
logger = logging.getLogger('app')
|
|
|
|
@dataclass
|
|
class AgentConfig:
|
|
"""Agent configuration class containing all parameters required to create and manage an agent."""
|
|
|
|
# Basic parameters
|
|
bot_id: str
|
|
api_key: Optional[str] = None
|
|
model_name: str = "qwen3-next"
|
|
model_server: Optional[str] = None
|
|
language: Optional[str] = "jp"
|
|
|
|
# Configuration parameters
|
|
system_prompt: Optional[str] = None
|
|
mcp_settings: Optional[List[Dict]] = field(default_factory=list)
|
|
generate_cfg: Optional[Dict] = None
|
|
enable_thinking: bool = False
|
|
enable_self_knowledge: bool = False
|
|
|
|
# Context parameters
|
|
project_dir: Optional[str] = None
|
|
user_identifier: Optional[str] = None
|
|
session_id: Optional[str] = None
|
|
dataset_ids: Optional[List[str]] = field(default_factory=list)
|
|
trace_id: Optional[str] = None # Request trace ID, obtained from the X-Request-ID header
|
|
request_started_at: Optional[float] = None
|
|
|
|
# Response control parameters
|
|
stream: bool = False
|
|
tool_response: bool = True
|
|
preamble_text: Optional[str] = None
|
|
messages: Optional[List] = field(default_factory=list)
|
|
_origin_messages: Optional[List] = field(default_factory=list)
|
|
|
|
logging_handler: Optional['LoggingCallbackHandler'] = None
|
|
|
|
# Mem0 long-term memory configuration
|
|
enable_memori: bool = False # Keep the name for API compatibility; Mem0 is used internally
|
|
memori_semantic_search_top_k: int = 20
|
|
_mem0_context: Optional[str] = None # Mem0 recalled memory context, used for passing data between middlewares
|
|
|
|
# Custom shell environment variables
|
|
shell_env: Optional[Dict[str, str]] = field(default_factory=dict)
|
|
|
|
# Checkpointer session history
|
|
_session_history: Optional[List] = field(default_factory=list) # Historical chat records loaded from the checkpointer
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to a dictionary for passing into functions that require **kwargs."""
|
|
return {
|
|
'bot_id': self.bot_id,
|
|
'api_key': self.api_key,
|
|
'model_name': self.model_name,
|
|
'model_server': self.model_server,
|
|
'language': self.language,
|
|
'system_prompt': self.system_prompt,
|
|
'mcp_settings': self.mcp_settings,
|
|
'generate_cfg': self.generate_cfg,
|
|
'enable_thinking': self.enable_thinking,
|
|
'enable_self_knowledge': self.enable_self_knowledge,
|
|
'project_dir': self.project_dir,
|
|
'user_identifier': self.user_identifier,
|
|
'session_id': self.session_id,
|
|
'dataset_ids': self.dataset_ids,
|
|
'stream': self.stream,
|
|
'tool_response': self.tool_response,
|
|
'preamble_text': self.preamble_text,
|
|
'messages': self.messages,
|
|
'enable_memori': self.enable_memori,
|
|
'memori_semantic_search_top_k': self.memori_semantic_search_top_k,
|
|
'trace_id': self.trace_id,
|
|
'shell_env': self.shell_env,
|
|
}
|
|
|
|
@staticmethod
|
|
def _redact_image_data(messages):
|
|
"""Return a copy of messages with image base64/data URL payloads truncated.
|
|
|
|
Image content blocks can carry huge base64 strings; replace them with a short
|
|
placeholder so logs stay readable.
|
|
"""
|
|
if not isinstance(messages, list):
|
|
return messages
|
|
|
|
def redact_value(value):
|
|
if isinstance(value, str) and value.startswith('data:'):
|
|
return value[:32] + '...[truncated]'
|
|
return value
|
|
|
|
redacted_messages = []
|
|
for message in messages:
|
|
if not isinstance(message, dict):
|
|
redacted_messages.append(message)
|
|
continue
|
|
content = message.get('content')
|
|
if not isinstance(content, list):
|
|
redacted_messages.append(message)
|
|
continue
|
|
redacted_content = []
|
|
for block in content:
|
|
if isinstance(block, dict) and block.get('type') in ('image', 'image_url'):
|
|
new_block = block.copy()
|
|
if 'base64' in new_block:
|
|
new_block['base64'] = '[truncated]'
|
|
if 'image_url' in new_block:
|
|
image_url = new_block['image_url']
|
|
if isinstance(image_url, dict):
|
|
new_block['image_url'] = {**image_url, 'url': redact_value(image_url.get('url'))}
|
|
else:
|
|
new_block['image_url'] = redact_value(image_url)
|
|
if 'url' in new_block:
|
|
new_block['url'] = redact_value(new_block['url'])
|
|
redacted_content.append(new_block)
|
|
else:
|
|
redacted_content.append(block)
|
|
redacted_messages.append({**message, 'content': redacted_content})
|
|
return redacted_messages
|
|
|
|
def safe_print(self):
|
|
"""Safely log the configuration without printing sensitive information."""
|
|
safe_dict = self.to_dict().copy()
|
|
if 'api_key' in safe_dict and isinstance(safe_dict['api_key'], str) and safe_dict['api_key'].startswith('sk-'):
|
|
safe_dict['api_key'] = safe_dict['api_key'][:8] + '***' + safe_dict['api_key'][-6:]
|
|
if 'messages' in safe_dict:
|
|
safe_dict['messages'] = self._redact_image_data(safe_dict['messages'])
|
|
logger.info(f"config={json.dumps(safe_dict, ensure_ascii=False)}")
|
|
|
|
@classmethod
|
|
async def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None, messages: Optional[List] = None):
|
|
"""Create configuration from a v1 request."""
|
|
# Delay imports to avoid circular dependencies.
|
|
from .logging_handler import LoggingCallbackHandler
|
|
from utils.fastapi_utils import get_preamble_text
|
|
from utils.settings import (
|
|
MEM0_ENABLED,
|
|
MEM0_SEMANTIC_SEARCH_TOP_K,
|
|
)
|
|
from .checkpoint_utils import prepare_checkpoint_message
|
|
from .checkpoint_manager import get_checkpointer_manager
|
|
from utils.log_util.context import g
|
|
|
|
if messages is None:
|
|
messages = []
|
|
|
|
# Get trace_id from the global context.
|
|
trace_id = None
|
|
try:
|
|
trace_id = getattr(g, 'trace_id', None)
|
|
except LookupError:
|
|
pass
|
|
|
|
preamble_text, system_prompt = get_preamble_text(request.language, request.system_prompt)
|
|
|
|
config = cls(
|
|
bot_id=request.bot_id,
|
|
api_key=api_key,
|
|
model_name=request.model,
|
|
model_server=request.model_server,
|
|
language=request.language,
|
|
system_prompt=system_prompt,
|
|
mcp_settings=request.mcp_settings,
|
|
user_identifier=request.user_identifier,
|
|
session_id=request.session_id,
|
|
enable_thinking=request.enable_thinking,
|
|
enable_self_knowledge=request.enable_self_knowledge,
|
|
project_dir=project_dir,
|
|
stream=request.stream,
|
|
tool_response=request.tool_response,
|
|
generate_cfg=generate_cfg,
|
|
logging_handler=LoggingCallbackHandler(),
|
|
messages=messages,
|
|
_origin_messages=messages,
|
|
preamble_text=preamble_text,
|
|
dataset_ids=request.dataset_ids,
|
|
enable_memori=request.enable_memory,
|
|
memori_semantic_search_top_k=getattr(request, 'memori_semantic_search_top_k', None) or MEM0_SEMANTIC_SEARCH_TOP_K,
|
|
trace_id=trace_id,
|
|
shell_env=getattr(request, 'shell_env', None) or {},
|
|
)
|
|
|
|
# Prepare checkpoint messages as early as possible when creating config.
|
|
if config.session_id:
|
|
try:
|
|
manager = get_checkpointer_manager()
|
|
checkpointer = manager.checkpointer
|
|
if checkpointer:
|
|
await prepare_checkpoint_message(config, checkpointer)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load checkpointer: {e}")
|
|
|
|
config.safe_print()
|
|
return config
|
|
|
|
|
|
@classmethod
|
|
async def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None, generate_cfg: Optional[Dict] = None, model_name: Optional[str] = None, model_server: Optional[str] = None, api_key: Optional[str] = None):
|
|
"""Create configuration from a v2 request."""
|
|
# Delay imports to avoid circular dependencies.
|
|
from .logging_handler import LoggingCallbackHandler
|
|
from utils.fastapi_utils import get_preamble_text
|
|
from utils.settings import (
|
|
MEM0_ENABLED,
|
|
MEM0_SEMANTIC_SEARCH_TOP_K,
|
|
)
|
|
from .checkpoint_utils import prepare_checkpoint_message
|
|
from .checkpoint_manager import get_checkpointer_manager
|
|
from utils.log_util.context import g
|
|
|
|
if messages is None:
|
|
messages = []
|
|
|
|
# Get trace_id from the global context.
|
|
trace_id = None
|
|
try:
|
|
trace_id = getattr(g, 'trace_id', None)
|
|
except LookupError:
|
|
pass
|
|
language = request.language or bot_config.get("language", "zh")
|
|
preamble_text, system_prompt = get_preamble_text(language, bot_config.get("system_prompt"))
|
|
|
|
enable_thinking = bot_config.get("enable_thinking", False)
|
|
enable_memori = bot_config.get("enable_memory", False)
|
|
enable_self_knowledge = bot_config.get("enable_self_knowledge", False)
|
|
|
|
config = cls(
|
|
bot_id=request.bot_id,
|
|
api_key=api_key or bot_config.get("api_key"),
|
|
model_name=model_name or bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
|
|
model_server=model_server or bot_config.get("model_server", ""),
|
|
language=language,
|
|
system_prompt=system_prompt,
|
|
mcp_settings=bot_config.get("mcp_settings", []),
|
|
user_identifier=request.user_identifier,
|
|
session_id=request.session_id,
|
|
enable_thinking=enable_thinking,
|
|
enable_self_knowledge=enable_self_knowledge,
|
|
project_dir=project_dir,
|
|
stream=request.stream,
|
|
tool_response=request.tool_response,
|
|
generate_cfg=generate_cfg or {}, # The v2 API also supports passing extra generate_cfg.
|
|
logging_handler=LoggingCallbackHandler(),
|
|
messages=messages,
|
|
_origin_messages=messages,
|
|
preamble_text=preamble_text,
|
|
dataset_ids=bot_config.get("dataset_ids", []), # Get dataset_ids from backend configuration.
|
|
enable_memori=enable_memori,
|
|
memori_semantic_search_top_k=bot_config.get("memori_semantic_search_top_k", MEM0_SEMANTIC_SEARCH_TOP_K),
|
|
trace_id=trace_id,
|
|
shell_env=bot_config.get("shell_env") or {},
|
|
)
|
|
|
|
# Prepare checkpoint messages as early as possible when creating config.
|
|
if config.session_id:
|
|
try:
|
|
manager = get_checkpointer_manager()
|
|
checkpointer = manager.checkpointer
|
|
if checkpointer:
|
|
await prepare_checkpoint_message(config, checkpointer)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load checkpointer: {e}")
|
|
|
|
config.safe_print()
|
|
return config
|
|
def invoke_config(self):
|
|
"""Return the configuration dictionary required by LangChain."""
|
|
config = {}
|
|
callbacks = []
|
|
if self.logging_handler:
|
|
callbacks.append(self.logging_handler)
|
|
|
|
from utils.settings import LANGFUSE_ENABLED
|
|
if LANGFUSE_ENABLED:
|
|
from langfuse.langchain import CallbackHandler as LangfuseCallbackHandler
|
|
trace_context = {"trace_id": self.trace_id} if self.trace_id else None
|
|
langfuse_handler = LangfuseCallbackHandler(
|
|
trace_context=trace_context,
|
|
)
|
|
callbacks.append(langfuse_handler)
|
|
|
|
langfuse_metadata = {}
|
|
if LANGFUSE_ENABLED:
|
|
if self.session_id:
|
|
langfuse_metadata["langfuse_session_id"] = self.session_id
|
|
if self.user_identifier:
|
|
langfuse_metadata["langfuse_user_id"] = self.user_identifier
|
|
if self.bot_id:
|
|
langfuse_metadata["bot_id"] = self.bot_id
|
|
if self.model_name:
|
|
langfuse_metadata["model"] = self.model_name
|
|
|
|
if callbacks:
|
|
config["callbacks"] = callbacks
|
|
if langfuse_metadata:
|
|
config["metadata"] = langfuse_metadata
|
|
if self.session_id:
|
|
config["configurable"] = {"thread_id": self.session_id}
|
|
return config
|
|
|
|
def get_unique_cache_id(self) -> Optional[str]:
|
|
"""
|
|
Generate a unique cache key.
|
|
|
|
Generates a unique cache key based on session_id, bot_id, system_prompt,
|
|
mcp_settings, dataset_ids, and related fields. If session_id is missing,
|
|
returns None to indicate that caching should not be used.
|
|
|
|
Returns:
|
|
str: The unique cache key, or None if session_id is missing.
|
|
"""
|
|
# Create the data used for generating the hash.
|
|
cache_components = {
|
|
'bot_id': self.bot_id,
|
|
'system_prompt': self.system_prompt,
|
|
'mcp_settings': self.mcp_settings,
|
|
'model_name': self.model_name,
|
|
'language': self.language,
|
|
'generate_cfg': self.generate_cfg,
|
|
'enable_thinking': self.enable_thinking,
|
|
'enable_self_knowledge': self.enable_self_knowledge,
|
|
'user_identifier': self.user_identifier,
|
|
'session_id': self.session_id,
|
|
'dataset_ids': self.dataset_ids, # Include dataset_ids in cache key generation.
|
|
'project_dir': self.project_dir, # project_dir should also be included because dataset_ids affect it.
|
|
}
|
|
|
|
# Convert components to a string and concatenate them.
|
|
cache_string = json.dumps(cache_components, sort_keys=True)
|
|
|
|
# Generate a SHA256 hash to use as the cache key.
|
|
cache_hash = hashlib.sha256(cache_string.encode('utf-8')).hexdigest()
|
|
|
|
# Return a prefixed cache key to make debugging easier.
|
|
return f"agent_cache_{cache_hash[:16]}" # Use the first 16 characters of the hash.
|