qwen_agent/agent/agent_config.py
2025-12-17 20:27:06 +08:00

144 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Agent配置类用于管理所有Agent相关的参数"""
from typing import Optional, List, Dict, Any, TYPE_CHECKING
from dataclasses import dataclass, field
import logging
import json
logger = logging.getLogger('app')
@dataclass
class AgentConfig:
"""Agent配置类包含创建和管理Agent所需的所有参数"""
# 基础参数
bot_id: str
api_key: Optional[str] = None
model_name: str = "qwen3-next"
model_server: Optional[str] = None
language: Optional[str] = "jp"
# 配置参数
system_prompt: Optional[str] = None
mcp_settings: Optional[List[Dict]] = field(default_factory=list)
robot_type: Optional[str] = "general_agent"
generate_cfg: Optional[Dict] = None
enable_thinking: bool = False
# 上下文参数
project_dir: Optional[str] = None
user_identifier: Optional[str] = None
session_id: Optional[str] = None
# 响应控制参数
stream: bool = False
tool_response: bool = True
preamble_text: Optional[str] = None
messages: Optional[List] = field(default_factory=list)
logging_handler: Optional['LoggingCallbackHandler'] = None
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式,用于传递给需要**kwargs的函数"""
return {
'bot_id': self.bot_id,
'api_key': self.api_key,
'model_name': self.model_name,
'model_server': self.model_server,
'language': self.language,
'system_prompt': self.system_prompt,
'mcp_settings': self.mcp_settings,
'robot_type': self.robot_type,
'generate_cfg': self.generate_cfg,
'enable_thinking': self.enable_thinking,
'project_dir': self.project_dir,
'user_identifier': self.user_identifier,
'session_id': self.session_id,
'stream': self.stream,
'tool_response': self.tool_response,
'preamble_text': self.preamble_text,
'messages': self.messages,
}
def safe_print(self):
"""安全打印配置,避免打印敏感信息"""
safe_dict = self.to_dict().copy()
if 'api_key' in safe_dict and isinstance(safe_dict['api_key'], str) and safe_dict['api_key'].startswith('sk-'):
safe_dict['api_key'] = safe_dict['api_key'][:8] + '***' + safe_dict['api_key'][-6:]
logger.info(f"config={json.dumps(safe_dict, ensure_ascii=False)}")
@classmethod
def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None, messages: Optional[List] = None):
"""从v1请求创建配置"""
# 延迟导入避免循环依赖
from .logging_handler import LoggingCallbackHandler
from utils.fastapi_utils import get_preamble_text
if messages is None:
messages = []
preamble_text, system_prompt = get_preamble_text(request.language, request.system_prompt)
config = cls(
bot_id=request.bot_id,
api_key=api_key,
model_name=request.model,
model_server=request.model_server,
language=request.language,
system_prompt=system_prompt,
mcp_settings=request.mcp_settings,
robot_type=request.robot_type,
user_identifier=request.user_identifier,
session_id=request.session_id,
enable_thinking=request.enable_thinking,
project_dir=project_dir,
stream=request.stream,
tool_response=request.tool_response,
generate_cfg=generate_cfg,
logging_handler=LoggingCallbackHandler(),
messages=messages,
preamble_text=preamble_text,
)
config.safe_print()
return config
@classmethod
def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None):
"""从v2请求创建配置"""
# 延迟导入避免循环依赖
from .logging_handler import LoggingCallbackHandler
from utils.fastapi_utils import get_preamble_text
if messages is None:
messages = []
language = request.language or bot_config.get("language", "zh")
preamble_text, system_prompt = get_preamble_text(language, bot_config.get("system_prompt"))
config = cls(
bot_id=request.bot_id,
api_key=bot_config.get("api_key"),
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
model_server=bot_config.get("model_server", ""),
language=language,
system_prompt=system_prompt,
mcp_settings=bot_config.get("mcp_settings", []),
robot_type=bot_config.get("robot_type", "general_agent"),
user_identifier=request.user_identifier,
session_id=request.session_id,
enable_thinking=request.enable_thinking,
project_dir=project_dir,
stream=request.stream,
tool_response=request.tool_response,
generate_cfg={}, # v2接口不传递额外的generate_cfg
logging_handler=LoggingCallbackHandler(),
messages=messages,
preamble_text=preamble_text,
)
config.safe_print()
return config
def invoke_config(self):
"""返回Langchain需要的配置字典"""
config = {}
if self.logging_handler:
config["callbacks"] = [self.logging_handler]
if self.session_id:
config["configurable"] = {"thread_id": self.session_id}
return config