qwen_agent/agent/agent_config.py
2026-02-06 22:07:47 +08:00

338 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Agent配置类用于管理所有Agent相关的参数"""
from typing import Optional, List, Dict, Any, TYPE_CHECKING
from dataclasses import dataclass, field
import logging
import json
import hashlib
logger = logging.getLogger('app')
@dataclass
class AgentConfig:
"""Agent配置类包含创建和管理Agent所需的所有参数"""
# 基础参数
bot_id: str
api_key: Optional[str] = None
model_name: str = "qwen3-next"
model_server: Optional[str] = None
language: Optional[str] = "jp"
# 配置参数
system_prompt: Optional[str] = None
mcp_settings: Optional[List[Dict]] = field(default_factory=list)
generate_cfg: Optional[Dict] = None
enable_thinking: bool = False
# 上下文参数
project_dir: Optional[str] = None
user_identifier: Optional[str] = None
session_id: Optional[str] = None
dataset_ids: Optional[List[str]] = field(default_factory=list)
trace_id: Optional[str] = None # 请求追踪ID从 X-Request-ID header 获取
# 响应控制参数
stream: bool = False
tool_response: bool = True
preamble_text: Optional[str] = None
messages: Optional[List] = field(default_factory=list)
_origin_messages: Optional[List] = field(default_factory=list)
logging_handler: Optional['LoggingCallbackHandler'] = None
# Mem0 长期记忆配置
enable_memori: bool = False # 保留名称以兼容 API实际使用 Mem0
memori_semantic_search_top_k: int = 20
_mem0_context: Optional[str] = None # Mem0 召回的记忆上下文,供中间件间传递使用
# Checkpointer 会话历史
_session_history: Optional[List] = field(default_factory=list) # 从 checkpointer 读取的历史聊天记录
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式,用于传递给需要**kwargs的函数"""
return {
'bot_id': self.bot_id,
'api_key': self.api_key,
'model_name': self.model_name,
'model_server': self.model_server,
'language': self.language,
'system_prompt': self.system_prompt,
'mcp_settings': self.mcp_settings,
'generate_cfg': self.generate_cfg,
'enable_thinking': self.enable_thinking,
'project_dir': self.project_dir,
'user_identifier': self.user_identifier,
'session_id': self.session_id,
'dataset_ids': self.dataset_ids,
'stream': self.stream,
'tool_response': self.tool_response,
'preamble_text': self.preamble_text,
'messages': self.messages,
'enable_memori': self.enable_memori,
'memori_semantic_search_top_k': self.memori_semantic_search_top_k,
'trace_id': self.trace_id,
}
def safe_print(self):
"""安全打印配置,避免打印敏感信息"""
safe_dict = self.to_dict().copy()
if 'api_key' in safe_dict and isinstance(safe_dict['api_key'], str) and safe_dict['api_key'].startswith('sk-'):
safe_dict['api_key'] = safe_dict['api_key'][:8] + '***' + safe_dict['api_key'][-6:]
logger.info(f"config={json.dumps(safe_dict, ensure_ascii=False)}")
@classmethod
async def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None, messages: Optional[List] = None):
"""从v1请求创建配置"""
# 延迟导入避免循环依赖
from .logging_handler import LoggingCallbackHandler
from utils.fastapi_utils import get_preamble_text
from utils.settings import (
MEM0_ENABLED,
MEM0_SEMANTIC_SEARCH_TOP_K,
)
from .checkpoint_utils import prepare_checkpoint_message
from .checkpoint_manager import get_checkpointer_manager
from utils.log_util.context import g
if messages is None:
messages = []
# 从全局上下文获取 trace_id
trace_id = None
try:
trace_id = getattr(g, 'trace_id', None)
except LookupError:
pass
preamble_text, system_prompt = get_preamble_text(request.language, request.system_prompt)
config = cls(
bot_id=request.bot_id,
api_key=api_key,
model_name=request.model,
model_server=request.model_server,
language=request.language,
system_prompt=system_prompt,
mcp_settings=request.mcp_settings,
user_identifier=request.user_identifier,
session_id=request.session_id,
enable_thinking=request.enable_thinking,
project_dir=project_dir,
stream=request.stream,
tool_response=request.tool_response,
generate_cfg=generate_cfg,
logging_handler=LoggingCallbackHandler(),
messages=messages,
_origin_messages=messages,
preamble_text=preamble_text,
dataset_ids=request.dataset_ids,
enable_memori=request.enable_memory,
memori_semantic_search_top_k=getattr(request, 'memori_semantic_search_top_k', None) or MEM0_SEMANTIC_SEARCH_TOP_K,
trace_id=trace_id,
)
# 在创建 config 时尽早准备 checkpoint 消息
if config.session_id:
try:
manager = get_checkpointer_manager()
checkpointer = manager.checkpointer
if checkpointer:
await prepare_checkpoint_message(config, checkpointer)
except Exception as e:
logger.warning(f"Failed to load checkpointer: {e}")
config.safe_print()
return config
@classmethod
async def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None):
"""从v2请求创建配置"""
# 延迟导入避免循环依赖
from .logging_handler import LoggingCallbackHandler
from utils.fastapi_utils import get_preamble_text
from utils.settings import (
MEM0_ENABLED,
MEM0_SEMANTIC_SEARCH_TOP_K,
)
from .checkpoint_utils import prepare_checkpoint_message
from .checkpoint_manager import get_checkpointer_manager
from utils.log_util.context import g
if messages is None:
messages = []
# 从全局上下文获取 trace_id
trace_id = None
try:
trace_id = getattr(g, 'trace_id', None)
except LookupError:
pass
language = request.language or bot_config.get("language", "zh")
preamble_text, system_prompt = get_preamble_text(language, bot_config.get("system_prompt"))
enable_thinking = bot_config.get("enable_thinking", False)
enable_memori = bot_config.get("enable_memory", False)
config = cls(
bot_id=request.bot_id,
api_key=bot_config.get("api_key"),
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
model_server=bot_config.get("model_server", ""),
language=language,
system_prompt=system_prompt,
mcp_settings=bot_config.get("mcp_settings", []),
user_identifier=request.user_identifier,
session_id=request.session_id,
enable_thinking=enable_thinking,
project_dir=project_dir,
stream=request.stream,
tool_response=request.tool_response,
generate_cfg={}, # v2接口不传递额外的generate_cfg
logging_handler=LoggingCallbackHandler(),
messages=messages,
_origin_messages=messages,
preamble_text=preamble_text,
dataset_ids=bot_config.get("dataset_ids", []), # 从后端配置获取dataset_ids
enable_memori=enable_memori,
memori_semantic_search_top_k=bot_config.get("memori_semantic_search_top_k", MEM0_SEMANTIC_SEARCH_TOP_K),
trace_id=trace_id,
)
# 在创建 config 时尽早准备 checkpoint 消息
if config.session_id:
try:
manager = get_checkpointer_manager()
checkpointer = manager.checkpointer
if checkpointer:
await prepare_checkpoint_message(config, checkpointer)
except Exception as e:
logger.warning(f"Failed to load checkpointer: {e}")
config.safe_print()
return config
@classmethod
async def from_v3_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None, language: Optional[str] = None):
"""从v3请求创建配置 - 从数据库读取所有配置"""
# 延迟导入避免循环依赖
from .logging_handler import LoggingCallbackHandler
from utils.fastapi_utils import get_preamble_text
from utils.settings import (
MEM0_ENABLED,
MEM0_SEMANTIC_SEARCH_TOP_K,
)
from .checkpoint_utils import prepare_checkpoint_message
from .checkpoint_manager import get_checkpointer_manager
from utils.log_util.context import g
if messages is None:
messages = []
# 从全局上下文获取 trace_id
trace_id = None
try:
trace_id = getattr(g, 'trace_id', None)
except LookupError:
pass
# 从数据库配置获取语言(如果没有传递)
if language is None:
language = bot_config.get("language", "zh")
# 处理 system_prompt 和 preamble
system_prompt_from_db = bot_config.get("system_prompt", "")
preamble_text, system_prompt = get_preamble_text(language, system_prompt_from_db)
# 获取 robot_type
robot_type = bot_config.get("robot_type", "general_agent")
if robot_type == "catalog_agent":
robot_type = "deep_agent"
# 从数据库配置获取其他参数
enable_thinking = bot_config.get("enable_thinking", False)
enable_memori = bot_config.get("enable_memori", False)
config = cls(
bot_id=request.bot_id,
api_key=bot_config.get("api_key", ""),
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
model_server=bot_config.get("model_server", ""),
language=language,
system_prompt=system_prompt,
mcp_settings=bot_config.get("mcp_settings", []),
robot_type=robot_type,
user_identifier=bot_config.get("user_identifier", ""),
session_id=request.session_id,
enable_thinking=enable_thinking,
project_dir=project_dir,
stream=request.stream,
tool_response=bot_config.get("tool_response", True),
generate_cfg={}, # v3接口不传递额外的generate_cfg
logging_handler=LoggingCallbackHandler(),
messages=messages,
_origin_messages=messages,
preamble_text=preamble_text,
dataset_ids=bot_config.get("dataset_ids", []),
enable_memori=enable_memori,
memori_semantic_search_top_k=bot_config.get("memori_semantic_search_top_k", MEM0_SEMANTIC_SEARCH_TOP_K),
trace_id=trace_id,
)
# 在创建 config 时尽早准备 checkpoint 消息
if config.session_id:
try:
manager = get_checkpointer_manager()
checkpointer = manager.checkpointer
if checkpointer:
await prepare_checkpoint_message(config, checkpointer)
except Exception as e:
logger.warning(f"Failed to load checkpointer: {e}")
config.safe_print()
return config
def invoke_config(self):
"""返回Langchain需要的配置字典"""
config = {}
if self.logging_handler:
config["callbacks"] = [self.logging_handler]
if self.session_id:
config["configurable"] = {"thread_id": self.session_id}
return config
def get_unique_cache_id(self) -> Optional[str]:
"""
生成唯一的缓存键
基于session_id, bot_id, system_prompt, mcp_settings, dataset_ids等生成唯一的缓存键。
如果没有session_id返回None表示不使用缓存。
Returns:
str: 唯一的缓存键如果没有session_id则返回None
"""
# 创建用于生成哈希的数据
cache_components = {
'bot_id': self.bot_id,
'system_prompt': self.system_prompt,
'mcp_settings': self.mcp_settings,
'model_name': self.model_name,
'language': self.language,
'generate_cfg': self.generate_cfg,
'enable_thinking': self.enable_thinking,
'user_identifier': self.user_identifier,
'session_id': self.session_id,
'dataset_ids': self.dataset_ids, # 添加dataset_ids到缓存键生成
'project_dir': self.project_dir, # 也应该包含project_dir因为dataset_ids影响project_dir
}
# 将组件转换为字符串并连接
cache_string = json.dumps(cache_components, sort_keys=True)
# 生成SHA256哈希作为缓存键
cache_hash = hashlib.sha256(cache_string.encode('utf-8')).hexdigest()
# 返回带有前缀的缓存键,便于调试
return f"agent_cache_{cache_hash[:16]}" # 使用前16位哈希值