qwen_agent/agent/agent_config.py
朱潮 5dfe2eba28 feat: add bot_id and model to langfuse metadata
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-14 18:14:41 +08:00

297 lines
12 KiB
Python

"""Agent configuration class for managing all agent-related parameters."""
from typing import Optional, List, Dict, Any, TYPE_CHECKING
from dataclasses import dataclass, field
import logging
import json
import hashlib
logger = logging.getLogger('app')
@dataclass
class AgentConfig:
"""Agent configuration class containing all parameters required to create and manage an agent."""
# Basic parameters
bot_id: str
api_key: Optional[str] = None
model_name: str = "qwen3-next"
model_server: Optional[str] = None
language: Optional[str] = "jp"
# Configuration parameters
system_prompt: Optional[str] = None
mcp_settings: Optional[List[Dict]] = field(default_factory=list)
generate_cfg: Optional[Dict] = None
enable_thinking: bool = False
enable_self_knowledge: bool = False
# Context parameters
project_dir: Optional[str] = None
user_identifier: Optional[str] = None
session_id: Optional[str] = None
dataset_ids: Optional[List[str]] = field(default_factory=list)
trace_id: Optional[str] = None # Request trace ID, obtained from the X-Request-ID header
request_started_at: Optional[float] = None
# Response control parameters
stream: bool = False
tool_response: bool = True
preamble_text: Optional[str] = None
messages: Optional[List] = field(default_factory=list)
_origin_messages: Optional[List] = field(default_factory=list)
logging_handler: Optional['LoggingCallbackHandler'] = None
# Mem0 long-term memory configuration
enable_memori: bool = False # Keep the name for API compatibility; Mem0 is used internally
memori_semantic_search_top_k: int = 20
_mem0_context: Optional[str] = None # Mem0 recalled memory context, used for passing data between middlewares
# Custom shell environment variables
shell_env: Optional[Dict[str, str]] = field(default_factory=dict)
# Checkpointer session history
_session_history: Optional[List] = field(default_factory=list) # Historical chat records loaded from the checkpointer
def to_dict(self) -> Dict[str, Any]:
"""Convert to a dictionary for passing into functions that require **kwargs."""
return {
'bot_id': self.bot_id,
'api_key': self.api_key,
'model_name': self.model_name,
'model_server': self.model_server,
'language': self.language,
'system_prompt': self.system_prompt,
'mcp_settings': self.mcp_settings,
'generate_cfg': self.generate_cfg,
'enable_thinking': self.enable_thinking,
'enable_self_knowledge': self.enable_self_knowledge,
'project_dir': self.project_dir,
'user_identifier': self.user_identifier,
'session_id': self.session_id,
'dataset_ids': self.dataset_ids,
'stream': self.stream,
'tool_response': self.tool_response,
'preamble_text': self.preamble_text,
'messages': self.messages,
'enable_memori': self.enable_memori,
'memori_semantic_search_top_k': self.memori_semantic_search_top_k,
'trace_id': self.trace_id,
'shell_env': self.shell_env,
}
def safe_print(self):
"""Safely log the configuration without printing sensitive information."""
safe_dict = self.to_dict().copy()
if 'api_key' in safe_dict and isinstance(safe_dict['api_key'], str) and safe_dict['api_key'].startswith('sk-'):
safe_dict['api_key'] = safe_dict['api_key'][:8] + '***' + safe_dict['api_key'][-6:]
logger.info(f"config={json.dumps(safe_dict, ensure_ascii=False)}")
@classmethod
async def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None, messages: Optional[List] = None):
"""Create configuration from a v1 request."""
# Delay imports to avoid circular dependencies.
from .logging_handler import LoggingCallbackHandler
from utils.fastapi_utils import get_preamble_text
from utils.settings import (
MEM0_ENABLED,
MEM0_SEMANTIC_SEARCH_TOP_K,
)
from .checkpoint_utils import prepare_checkpoint_message
from .checkpoint_manager import get_checkpointer_manager
from utils.log_util.context import g
if messages is None:
messages = []
# Get trace_id from the global context.
trace_id = None
try:
trace_id = getattr(g, 'trace_id', None)
except LookupError:
pass
preamble_text, system_prompt = get_preamble_text(request.language, request.system_prompt)
config = cls(
bot_id=request.bot_id,
api_key=api_key,
model_name=request.model,
model_server=request.model_server,
language=request.language,
system_prompt=system_prompt,
mcp_settings=request.mcp_settings,
user_identifier=request.user_identifier,
session_id=request.session_id,
enable_thinking=request.enable_thinking,
enable_self_knowledge=request.enable_self_knowledge,
project_dir=project_dir,
stream=request.stream,
tool_response=request.tool_response,
generate_cfg=generate_cfg,
logging_handler=LoggingCallbackHandler(),
messages=messages,
_origin_messages=messages,
preamble_text=preamble_text,
dataset_ids=request.dataset_ids,
enable_memori=request.enable_memory,
memori_semantic_search_top_k=getattr(request, 'memori_semantic_search_top_k', None) or MEM0_SEMANTIC_SEARCH_TOP_K,
trace_id=trace_id,
shell_env=getattr(request, 'shell_env', None) or {},
)
# Prepare checkpoint messages as early as possible when creating config.
if config.session_id:
try:
manager = get_checkpointer_manager()
checkpointer = manager.checkpointer
if checkpointer:
await prepare_checkpoint_message(config, checkpointer)
except Exception as e:
logger.warning(f"Failed to load checkpointer: {e}")
config.safe_print()
return config
@classmethod
async def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None, generate_cfg: Optional[Dict] = None, model_name: Optional[str] = None, model_server: Optional[str] = None, api_key: Optional[str] = None):
"""Create configuration from a v2 request."""
# Delay imports to avoid circular dependencies.
from .logging_handler import LoggingCallbackHandler
from utils.fastapi_utils import get_preamble_text
from utils.settings import (
MEM0_ENABLED,
MEM0_SEMANTIC_SEARCH_TOP_K,
)
from .checkpoint_utils import prepare_checkpoint_message
from .checkpoint_manager import get_checkpointer_manager
from utils.log_util.context import g
if messages is None:
messages = []
# Get trace_id from the global context.
trace_id = None
try:
trace_id = getattr(g, 'trace_id', None)
except LookupError:
pass
language = request.language or bot_config.get("language", "zh")
preamble_text, system_prompt = get_preamble_text(language, bot_config.get("system_prompt"))
enable_thinking = bot_config.get("enable_thinking", False)
enable_memori = bot_config.get("enable_memory", False)
enable_self_knowledge = bot_config.get("enable_self_knowledge", False)
config = cls(
bot_id=request.bot_id,
api_key=api_key or bot_config.get("api_key"),
model_name=model_name or bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
model_server=model_server or bot_config.get("model_server", ""),
language=language,
system_prompt=system_prompt,
mcp_settings=bot_config.get("mcp_settings", []),
user_identifier=request.user_identifier,
session_id=request.session_id,
enable_thinking=enable_thinking,
enable_self_knowledge=enable_self_knowledge,
project_dir=project_dir,
stream=request.stream,
tool_response=request.tool_response,
generate_cfg=generate_cfg or {}, # The v2 API also supports passing extra generate_cfg.
logging_handler=LoggingCallbackHandler(),
messages=messages,
_origin_messages=messages,
preamble_text=preamble_text,
dataset_ids=bot_config.get("dataset_ids", []), # Get dataset_ids from backend configuration.
enable_memori=enable_memori,
memori_semantic_search_top_k=bot_config.get("memori_semantic_search_top_k", MEM0_SEMANTIC_SEARCH_TOP_K),
trace_id=trace_id,
shell_env=bot_config.get("shell_env") or {},
)
# Prepare checkpoint messages as early as possible when creating config.
if config.session_id:
try:
manager = get_checkpointer_manager()
checkpointer = manager.checkpointer
if checkpointer:
await prepare_checkpoint_message(config, checkpointer)
except Exception as e:
logger.warning(f"Failed to load checkpointer: {e}")
config.safe_print()
return config
def invoke_config(self):
"""Return the configuration dictionary required by LangChain."""
config = {}
callbacks = []
if self.logging_handler:
callbacks.append(self.logging_handler)
from utils.settings import LANGFUSE_ENABLED
if LANGFUSE_ENABLED:
from langfuse.langchain import CallbackHandler as LangfuseCallbackHandler
trace_context = {"trace_id": self.trace_id} if self.trace_id else None
langfuse_handler = LangfuseCallbackHandler(
trace_context=trace_context,
)
callbacks.append(langfuse_handler)
langfuse_metadata = {}
if LANGFUSE_ENABLED:
if self.session_id:
langfuse_metadata["langfuse_session_id"] = self.session_id
if self.user_identifier:
langfuse_metadata["langfuse_user_id"] = self.user_identifier
if self.bot_id:
langfuse_metadata["bot_id"] = self.bot_id
if self.model_name:
langfuse_metadata["model"] = self.model_name
if callbacks:
config["callbacks"] = callbacks
if langfuse_metadata:
config["metadata"] = langfuse_metadata
if self.session_id:
config["configurable"] = {"thread_id": self.session_id}
return config
def get_unique_cache_id(self) -> Optional[str]:
"""
Generate a unique cache key.
Generates a unique cache key based on session_id, bot_id, system_prompt,
mcp_settings, dataset_ids, and related fields. If session_id is missing,
returns None to indicate that caching should not be used.
Returns:
str: The unique cache key, or None if session_id is missing.
"""
# Create the data used for generating the hash.
cache_components = {
'bot_id': self.bot_id,
'system_prompt': self.system_prompt,
'mcp_settings': self.mcp_settings,
'model_name': self.model_name,
'language': self.language,
'generate_cfg': self.generate_cfg,
'enable_thinking': self.enable_thinking,
'enable_self_knowledge': self.enable_self_knowledge,
'user_identifier': self.user_identifier,
'session_id': self.session_id,
'dataset_ids': self.dataset_ids, # Include dataset_ids in cache key generation.
'project_dir': self.project_dir, # project_dir should also be included because dataset_ids affect it.
}
# Convert components to a string and concatenate them.
cache_string = json.dumps(cache_components, sort_keys=True)
# Generate a SHA256 hash to use as the cache key.
cache_hash = hashlib.sha256(cache_string.encode('utf-8')).hexdigest()
# Return a prefixed cache key to make debugging easier.
return f"agent_cache_{cache_hash[:16]}" # Use the first 16 characters of the hash.