qwen_agent/agent/agent_config.py
2026-06-22 16:39:04 +08:00

427 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Agent configuration class for managing all agent-related parameters."""
from typing import Optional, List, Dict, Any, TYPE_CHECKING
from dataclasses import dataclass, field
import logging
import json
import hashlib
logger = logging.getLogger('app')
@dataclass
class AgentConfig:
"""Agent configuration class containing all parameters required to create and manage an agent."""
# Basic parameters
bot_id: str
api_key: Optional[str] = None
model_name: str = "qwen3-next"
model_server: Optional[str] = None
language: Optional[str] = "jp"
# Configuration parameters
system_prompt: Optional[str] = None
mcp_settings: Optional[List[Dict]] = field(default_factory=list)
generate_cfg: Optional[Dict] = None
enable_thinking: bool = False
enable_self_knowledge: bool = False
# Context parameters
project_dir: Optional[str] = None
user_identifier: Optional[str] = None
session_id: Optional[str] = None
dataset_ids: Optional[List[str]] = field(default_factory=list)
trace_id: Optional[str] = None # Request trace ID, obtained from the X-Request-ID header
request_started_at: Optional[float] = None
# Response control parameters
stream: bool = False
tool_response: bool = True
preamble_text: Optional[str] = None
messages: Optional[List] = field(default_factory=list)
_origin_messages: Optional[List] = field(default_factory=list)
logging_handler: Optional['LoggingCallbackHandler'] = None
# Mem0 long-term memory configuration
enable_memori: bool = False # Keep the name for API compatibility; Mem0 is used internally
memori_semantic_search_top_k: int = 20
_mem0_context: Optional[str] = None # Mem0 recalled memory context, used for passing data between middlewares
# Custom shell environment variables
shell_env: Optional[Dict[str, str]] = field(default_factory=dict)
# Checkpointer session history
_session_history: Optional[List] = field(default_factory=list) # Historical chat records loaded from the checkpointer
def to_dict(self) -> Dict[str, Any]:
"""Convert to a dictionary for passing into functions that require **kwargs."""
return {
'bot_id': self.bot_id,
'api_key': self.api_key,
'model_name': self.model_name,
'model_server': self.model_server,
'language': self.language,
'system_prompt': self.system_prompt,
'mcp_settings': self.mcp_settings,
'generate_cfg': self.generate_cfg,
'enable_thinking': self.enable_thinking,
'enable_self_knowledge': self.enable_self_knowledge,
'project_dir': self.project_dir,
'user_identifier': self.user_identifier,
'session_id': self.session_id,
'dataset_ids': self.dataset_ids,
'stream': self.stream,
'tool_response': self.tool_response,
'preamble_text': self.preamble_text,
'messages': self.messages,
'enable_memori': self.enable_memori,
'memori_semantic_search_top_k': self.memori_semantic_search_top_k,
'trace_id': self.trace_id,
'shell_env': self.shell_env,
}
@staticmethod
def _redact_image_data(messages):
"""Return a copy of messages with image base64/data URL payloads truncated.
Image content blocks can carry huge base64 strings; replace them with a short
placeholder so logs stay readable.
"""
if not isinstance(messages, list):
return messages
def redact_value(value):
if isinstance(value, str) and value.startswith('data:'):
return value[:32] + '...[truncated]'
return value
redacted_messages = []
for message in messages:
if not isinstance(message, dict):
redacted_messages.append(message)
continue
content = message.get('content')
if not isinstance(content, list):
redacted_messages.append(message)
continue
redacted_content = []
for block in content:
if isinstance(block, dict) and block.get('type') in ('image', 'image_url'):
new_block = block.copy()
if 'base64' in new_block:
new_block['base64'] = '[truncated]'
if 'image_url' in new_block:
image_url = new_block['image_url']
if isinstance(image_url, dict):
new_block['image_url'] = {**image_url, 'url': redact_value(image_url.get('url'))}
else:
new_block['image_url'] = redact_value(image_url)
if 'url' in new_block:
new_block['url'] = redact_value(new_block['url'])
redacted_content.append(new_block)
else:
redacted_content.append(block)
redacted_messages.append({**message, 'content': redacted_content})
return redacted_messages
def safe_print(self):
"""Safely log the configuration without printing sensitive information."""
safe_dict = self.to_dict().copy()
if 'api_key' in safe_dict and isinstance(safe_dict['api_key'], str) and safe_dict['api_key'].startswith('sk-'):
safe_dict['api_key'] = safe_dict['api_key'][:8] + '***' + safe_dict['api_key'][-6:]
if 'messages' in safe_dict:
safe_dict['messages'] = self._redact_image_data(safe_dict['messages'])
logger.info(f"config={json.dumps(safe_dict, ensure_ascii=False)}")
@classmethod
async def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None, messages: Optional[List] = None):
"""Create configuration from a v1 request."""
# Delay imports to avoid circular dependencies.
from .logging_handler import LoggingCallbackHandler
from utils.fastapi_utils import get_preamble_text
from utils.settings import (
MEM0_ENABLED,
MEM0_SEMANTIC_SEARCH_TOP_K,
)
from .checkpoint_utils import prepare_checkpoint_message
from .checkpoint_manager import get_checkpointer_manager
from utils.log_util.context import g
if messages is None:
messages = []
# Get trace_id from the global context.
trace_id = None
try:
trace_id = getattr(g, 'trace_id', None)
except LookupError:
pass
preamble_text, system_prompt = get_preamble_text(request.language, request.system_prompt)
config = cls(
bot_id=request.bot_id,
api_key=api_key,
model_name=request.model,
model_server=request.model_server,
language=request.language,
system_prompt=system_prompt,
mcp_settings=request.mcp_settings,
user_identifier=request.user_identifier,
session_id=request.session_id,
enable_thinking=request.enable_thinking,
enable_self_knowledge=request.enable_self_knowledge,
project_dir=project_dir,
stream=request.stream,
tool_response=request.tool_response,
generate_cfg=generate_cfg,
logging_handler=LoggingCallbackHandler(),
messages=messages,
_origin_messages=messages,
preamble_text=preamble_text,
dataset_ids=request.dataset_ids,
enable_memori=request.enable_memory,
memori_semantic_search_top_k=getattr(request, 'memori_semantic_search_top_k', None) or MEM0_SEMANTIC_SEARCH_TOP_K,
trace_id=trace_id,
shell_env=getattr(request, 'shell_env', None) or {},
)
# Prepare checkpoint messages as early as possible when creating config.
if config.session_id:
try:
manager = get_checkpointer_manager()
checkpointer = manager.checkpointer
if checkpointer:
await prepare_checkpoint_message(config, checkpointer)
except Exception as e:
logger.warning(f"Failed to load checkpointer: {e}")
config.safe_print()
return config
@classmethod
async def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None, generate_cfg: Optional[Dict] = None, model_name: Optional[str] = None, model_server: Optional[str] = None, api_key: Optional[str] = None):
"""Create configuration from a v2 request."""
# Delay imports to avoid circular dependencies.
from .logging_handler import LoggingCallbackHandler
from utils.fastapi_utils import get_preamble_text
from utils.settings import (
MEM0_ENABLED,
MEM0_SEMANTIC_SEARCH_TOP_K,
)
from .checkpoint_utils import prepare_checkpoint_message
from .checkpoint_manager import get_checkpointer_manager
from utils.log_util.context import g
if messages is None:
messages = []
# Get trace_id from the global context.
trace_id = None
try:
trace_id = getattr(g, 'trace_id', None)
except LookupError:
pass
language = request.language or bot_config.get("language", "zh")
preamble_text, system_prompt = get_preamble_text(language, bot_config.get("system_prompt"))
enable_thinking = bot_config.get("enable_thinking", False)
enable_memori = bot_config.get("enable_memory", False)
enable_self_knowledge = bot_config.get("enable_self_knowledge", False)
config = cls(
bot_id=request.bot_id,
api_key=api_key or bot_config.get("api_key"),
model_name=model_name or bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
model_server=model_server or bot_config.get("model_server", ""),
language=language,
system_prompt=system_prompt,
mcp_settings=bot_config.get("mcp_settings", []),
user_identifier=request.user_identifier,
session_id=request.session_id,
enable_thinking=enable_thinking,
enable_self_knowledge=enable_self_knowledge,
project_dir=project_dir,
stream=request.stream,
tool_response=request.tool_response,
generate_cfg=generate_cfg or {}, # The v2 API also supports passing extra generate_cfg.
logging_handler=LoggingCallbackHandler(),
messages=messages,
_origin_messages=messages,
preamble_text=preamble_text,
dataset_ids=bot_config.get("dataset_ids", []), # Get dataset_ids from backend configuration.
enable_memori=enable_memori,
memori_semantic_search_top_k=bot_config.get("memori_semantic_search_top_k", MEM0_SEMANTIC_SEARCH_TOP_K),
trace_id=trace_id,
shell_env=bot_config.get("shell_env") or {},
)
# Prepare checkpoint messages as early as possible when creating config.
if config.session_id:
try:
manager = get_checkpointer_manager()
checkpointer = manager.checkpointer
if checkpointer:
await prepare_checkpoint_message(config, checkpointer)
except Exception as e:
logger.warning(f"Failed to load checkpointer: {e}")
config.safe_print()
return config
@classmethod
async def from_v3_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None, language: Optional[str] = None):
"""从v3请求创建配置 - 从数据库读取所有配置"""
# 延迟导入避免循环依赖
from .logging_handler import LoggingCallbackHandler
from utils.fastapi_utils import get_preamble_text
from utils.settings import (
MEM0_ENABLED,
MEM0_SEMANTIC_SEARCH_TOP_K,
)
from .checkpoint_utils import prepare_checkpoint_message
from .checkpoint_manager import get_checkpointer_manager
from utils.log_util.context import g
if messages is None:
messages = []
# 从全局上下文获取 trace_id
trace_id = None
try:
trace_id = getattr(g, 'trace_id', None)
except LookupError:
pass
# 从数据库配置获取语言(如果没有传递)
if language is None:
language = bot_config.get("language", "zh")
# 处理 system_prompt 和 preamble
system_prompt_from_db = bot_config.get("system_prompt", "")
preamble_text, system_prompt = get_preamble_text(language, system_prompt_from_db)
# 从数据库配置获取其他参数
enable_thinking = bot_config.get("enable_thinking", False)
enable_memori = bot_config.get("enable_memori", False)
# 通用智能体按用户隔离:用 project_dir_key 覆盖 bot_id
# 使 workspace / system_prompt 路径 / skills 目录都落到 projects/robot/{bot_id}-{user}
effective_bot_id = bot_config.get("project_dir_key") or request.bot_id
# 前端可通过请求体覆盖 model_id仅本次请求生效
effective_model_name = request.model_id or bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct")
config = cls(
bot_id=effective_bot_id,
api_key=bot_config.get("api_key", ""),
model_name=effective_model_name,
model_server=bot_config.get("model_server", ""),
language=language,
system_prompt=system_prompt,
mcp_settings=bot_config.get("mcp_settings", []),
user_identifier=bot_config.get("user_identifier", ""),
session_id=request.session_id,
enable_thinking=enable_thinking,
project_dir=project_dir,
stream=request.stream,
tool_response=bot_config.get("tool_response", True),
generate_cfg={}, # v3接口不传递额外的generate_cfg
logging_handler=LoggingCallbackHandler(),
messages=messages,
_origin_messages=messages,
preamble_text=preamble_text,
dataset_ids=bot_config.get("dataset_ids", []),
enable_memori=enable_memori,
memori_semantic_search_top_k=bot_config.get("memori_semantic_search_top_k", MEM0_SEMANTIC_SEARCH_TOP_K),
trace_id=trace_id,
shell_env=bot_config.get("shell_env") or {},
)
# 在创建 config 时尽早准备 checkpoint 消息
if config.session_id:
try:
manager = get_checkpointer_manager()
checkpointer = manager.checkpointer
if checkpointer:
await prepare_checkpoint_message(config, checkpointer)
except Exception as e:
logger.warning(f"Failed to load checkpointer: {e}")
config.safe_print()
return config
def invoke_config(self):
"""Return the configuration dictionary required by LangChain."""
config = {}
callbacks = []
if self.logging_handler:
callbacks.append(self.logging_handler)
from utils.settings import LANGFUSE_ENABLED
if LANGFUSE_ENABLED:
from langfuse.langchain import CallbackHandler as LangfuseCallbackHandler
trace_context = {"trace_id": self.trace_id} if self.trace_id else None
langfuse_handler = LangfuseCallbackHandler(
trace_context=trace_context,
)
callbacks.append(langfuse_handler)
langfuse_metadata = {}
if LANGFUSE_ENABLED:
if self.session_id:
langfuse_metadata["langfuse_session_id"] = self.session_id
if self.user_identifier:
langfuse_metadata["langfuse_user_id"] = self.user_identifier
if self.bot_id:
langfuse_metadata["bot_id"] = self.bot_id
if self.model_name:
langfuse_metadata["model"] = self.model_name
if callbacks:
config["callbacks"] = callbacks
if langfuse_metadata:
config["metadata"] = langfuse_metadata
if self.session_id:
config["configurable"] = {"thread_id": self.session_id}
return config
def get_unique_cache_id(self) -> Optional[str]:
"""
Generate a unique cache key.
Generates a unique cache key based on session_id, bot_id, system_prompt,
mcp_settings, dataset_ids, and related fields. If session_id is missing,
returns None to indicate that caching should not be used.
Returns:
str: The unique cache key, or None if session_id is missing.
"""
# Create the data used for generating the hash.
cache_components = {
'bot_id': self.bot_id,
'system_prompt': self.system_prompt,
'mcp_settings': self.mcp_settings,
'model_name': self.model_name,
'language': self.language,
'generate_cfg': self.generate_cfg,
'enable_thinking': self.enable_thinking,
'enable_self_knowledge': self.enable_self_knowledge,
'user_identifier': self.user_identifier,
'session_id': self.session_id,
'dataset_ids': self.dataset_ids, # Include dataset_ids in cache key generation.
'project_dir': self.project_dir, # project_dir should also be included because dataset_ids affect it.
}
# Convert components to a string and concatenate them.
cache_string = json.dumps(cache_components, sort_keys=True)
# Generate a SHA256 hash to use as the cache key.
cache_hash = hashlib.sha256(cache_string.encode('utf-8')).hexdigest()
# Return a prefixed cache key to make debugging easier.
return f"agent_cache_{cache_hash[:16]}" # Use the first 16 characters of the hash.