199 lines
7.8 KiB
Python
199 lines
7.8 KiB
Python
"""Agent配置类,用于管理所有Agent相关的参数"""
|
||
|
||
from typing import Optional, List, Dict, Any, TYPE_CHECKING
|
||
from dataclasses import dataclass, field
|
||
import logging
|
||
import json
|
||
import hashlib
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
@dataclass
|
||
class AgentConfig:
|
||
"""Agent配置类,包含创建和管理Agent所需的所有参数"""
|
||
|
||
# 基础参数
|
||
bot_id: str
|
||
api_key: Optional[str] = None
|
||
model_name: str = "qwen3-next"
|
||
model_server: Optional[str] = None
|
||
language: Optional[str] = "jp"
|
||
|
||
# 配置参数
|
||
system_prompt: Optional[str] = None
|
||
mcp_settings: Optional[List[Dict]] = field(default_factory=list)
|
||
robot_type: Optional[str] = "general_agent"
|
||
generate_cfg: Optional[Dict] = None
|
||
enable_thinking: bool = False
|
||
|
||
# 上下文参数
|
||
project_dir: Optional[str] = None
|
||
user_identifier: Optional[str] = None
|
||
session_id: Optional[str] = None
|
||
dataset_ids: Optional[List[str]] = field(default_factory=list)
|
||
|
||
# 响应控制参数
|
||
stream: bool = False
|
||
tool_response: bool = True
|
||
preamble_text: Optional[str] = None
|
||
messages: Optional[List] = field(default_factory=list)
|
||
_origin_messages: Optional[List] = field(default_factory=list)
|
||
|
||
logging_handler: Optional['LoggingCallbackHandler'] = None
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""转换为字典格式,用于传递给需要**kwargs的函数"""
|
||
return {
|
||
'bot_id': self.bot_id,
|
||
'api_key': self.api_key,
|
||
'model_name': self.model_name,
|
||
'model_server': self.model_server,
|
||
'language': self.language,
|
||
'system_prompt': self.system_prompt,
|
||
'mcp_settings': self.mcp_settings,
|
||
'robot_type': self.robot_type,
|
||
'generate_cfg': self.generate_cfg,
|
||
'enable_thinking': self.enable_thinking,
|
||
'project_dir': self.project_dir,
|
||
'user_identifier': self.user_identifier,
|
||
'session_id': self.session_id,
|
||
'dataset_ids': self.dataset_ids,
|
||
'stream': self.stream,
|
||
'tool_response': self.tool_response,
|
||
'preamble_text': self.preamble_text,
|
||
'messages': self.messages,
|
||
}
|
||
|
||
def safe_print(self):
|
||
"""安全打印配置,避免打印敏感信息"""
|
||
safe_dict = self.to_dict().copy()
|
||
if 'api_key' in safe_dict and isinstance(safe_dict['api_key'], str) and safe_dict['api_key'].startswith('sk-'):
|
||
safe_dict['api_key'] = safe_dict['api_key'][:8] + '***' + safe_dict['api_key'][-6:]
|
||
logger.info(f"config={json.dumps(safe_dict, ensure_ascii=False)}")
|
||
|
||
@classmethod
|
||
def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None, messages: Optional[List] = None):
|
||
"""从v1请求创建配置"""
|
||
# 延迟导入避免循环依赖
|
||
from .logging_handler import LoggingCallbackHandler
|
||
from utils.fastapi_utils import get_preamble_text
|
||
if messages is None:
|
||
messages = []
|
||
|
||
robot_type = request.robot_type
|
||
if robot_type == "catalog_agent":
|
||
robot_type = "deep_agent"
|
||
|
||
preamble_text, system_prompt = get_preamble_text(request.language, request.system_prompt)
|
||
enable_thinking = request.enable_thinking and "<guidelines>" in request.system_prompt
|
||
|
||
config = cls(
|
||
bot_id=request.bot_id,
|
||
api_key=api_key,
|
||
model_name=request.model,
|
||
model_server=request.model_server,
|
||
language=request.language,
|
||
system_prompt=system_prompt,
|
||
mcp_settings=request.mcp_settings,
|
||
robot_type=robot_type,
|
||
user_identifier=request.user_identifier,
|
||
session_id=request.session_id,
|
||
enable_thinking=enable_thinking,
|
||
project_dir=project_dir,
|
||
stream=request.stream,
|
||
tool_response=request.tool_response,
|
||
generate_cfg=generate_cfg,
|
||
logging_handler=LoggingCallbackHandler(),
|
||
messages=messages,
|
||
_origin_messages=messages,
|
||
preamble_text=preamble_text,
|
||
dataset_ids=request.dataset_ids,
|
||
)
|
||
config.safe_print()
|
||
return config
|
||
|
||
|
||
@classmethod
|
||
def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None):
|
||
"""从v2请求创建配置"""
|
||
# 延迟导入避免循环依赖
|
||
from .logging_handler import LoggingCallbackHandler
|
||
from utils.fastapi_utils import get_preamble_text
|
||
if messages is None:
|
||
messages = []
|
||
language = request.language or bot_config.get("language", "zh")
|
||
preamble_text, system_prompt = get_preamble_text(language, bot_config.get("system_prompt"))
|
||
robot_type = bot_config.get("robot_type", "general_agent")
|
||
if robot_type == "catalog_agent":
|
||
robot_type = "deep_agent"
|
||
enable_thinking = request.enable_thinking and "<guidelines>" in bot_config.get("system_prompt")
|
||
|
||
|
||
config = cls(
|
||
bot_id=request.bot_id,
|
||
api_key=bot_config.get("api_key"),
|
||
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
|
||
model_server=bot_config.get("model_server", ""),
|
||
language=language,
|
||
system_prompt=system_prompt,
|
||
mcp_settings=bot_config.get("mcp_settings", []),
|
||
robot_type=robot_type,
|
||
user_identifier=request.user_identifier,
|
||
session_id=request.session_id,
|
||
enable_thinking=enable_thinking,
|
||
project_dir=project_dir,
|
||
stream=request.stream,
|
||
tool_response=request.tool_response,
|
||
generate_cfg={}, # v2接口不传递额外的generate_cfg
|
||
logging_handler=LoggingCallbackHandler(),
|
||
messages=messages,
|
||
_origin_messages=messages,
|
||
preamble_text=preamble_text,
|
||
dataset_ids=bot_config.get("dataset_ids", []), # 从后端配置获取dataset_ids
|
||
)
|
||
config.safe_print()
|
||
return config
|
||
|
||
def invoke_config(self):
|
||
"""返回Langchain需要的配置字典"""
|
||
config = {}
|
||
if self.logging_handler:
|
||
config["callbacks"] = [self.logging_handler]
|
||
if self.session_id:
|
||
config["configurable"] = {"thread_id": self.session_id}
|
||
return config
|
||
|
||
def get_unique_cache_id(self) -> Optional[str]:
|
||
"""
|
||
生成唯一的缓存键
|
||
|
||
基于session_id, bot_id, system_prompt, mcp_settings, dataset_ids等生成唯一的缓存键。
|
||
如果没有session_id,返回None,表示不使用缓存。
|
||
|
||
Returns:
|
||
str: 唯一的缓存键,如果没有session_id则返回None
|
||
"""
|
||
# 创建用于生成哈希的数据
|
||
cache_components = {
|
||
'bot_id': self.bot_id,
|
||
'system_prompt': self.system_prompt,
|
||
'mcp_settings': self.mcp_settings,
|
||
'model_name': self.model_name,
|
||
'language': self.language,
|
||
'generate_cfg': self.generate_cfg,
|
||
'enable_thinking': self.enable_thinking,
|
||
'user_identifier': self.user_identifier,
|
||
'session_id': self.session_id,
|
||
'dataset_ids': self.dataset_ids, # 添加dataset_ids到缓存键生成
|
||
'project_dir': self.project_dir, # 也应该包含project_dir因为dataset_ids影响project_dir
|
||
}
|
||
|
||
# 将组件转换为字符串并连接
|
||
cache_string = json.dumps(cache_components, sort_keys=True)
|
||
|
||
# 生成SHA256哈希作为缓存键
|
||
cache_hash = hashlib.sha256(cache_string.encode('utf-8')).hexdigest()
|
||
|
||
# 返回带有前缀的缓存键,便于调试
|
||
return f"agent_cache_{cache_hash[:16]}" # 使用前16位哈希值
|