sanitize_model_kwargs
This commit is contained in:
parent
18d65513f6
commit
dfc1c003c6
@ -18,7 +18,7 @@ from langchain.agents.middleware import SummarizationMiddleware as LangchainSumm
|
||||
from .summarization_middleware import SummarizationMiddleware
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from sympy.printing.cxx import none
|
||||
from utils.fastapi_utils import detect_provider
|
||||
from utils.fastapi_utils import detect_provider, sanitize_model_kwargs
|
||||
from .guideline_middleware import GuidelineMiddleware
|
||||
from .tool_output_length_middleware import ToolOutputLengthMiddleware
|
||||
from .tool_use_cleanup_middleware import ToolUseCleanupMiddleware
|
||||
@ -200,47 +200,22 @@ async def init_agent(config: AgentConfig):
|
||||
|
||||
# 检测或使用指定的提供商
|
||||
model_provider, base_url = detect_provider(config.model_name, config.model_server)
|
||||
# 构建模型参数
|
||||
model_kwargs = {
|
||||
"model": config.model_name,
|
||||
"model_provider": model_provider,
|
||||
"temperature": 0.8,
|
||||
"base_url": base_url,
|
||||
"api_key": config.api_key
|
||||
}
|
||||
if config.generate_cfg:
|
||||
# 内部使用的参数,不应传给任何 LLM
|
||||
internal_params = {
|
||||
'tool_output_max_length',
|
||||
'tool_output_truncation_strategy',
|
||||
'tool_output_filters',
|
||||
'tool_output_exclude',
|
||||
'preserve_code_blocks',
|
||||
'preserve_json',
|
||||
}
|
||||
|
||||
# Anthropic 不支持的 OpenAI 特有参数
|
||||
openai_only_params = {
|
||||
'n', # 生成多少个响应
|
||||
'presence_penalty',
|
||||
'frequency_penalty',
|
||||
'logprobs',
|
||||
'top_logprobs',
|
||||
'logit_bias',
|
||||
'seed',
|
||||
'suffix',
|
||||
'best_of',
|
||||
'echo',
|
||||
'user',
|
||||
}
|
||||
|
||||
# 根据提供商决定需要过滤的参数
|
||||
params_to_filter = internal_params.copy()
|
||||
if model_provider == 'anthropic':
|
||||
params_to_filter.update(openai_only_params)
|
||||
|
||||
filtered_cfg = {k: v for k, v in config.generate_cfg.items() if k not in params_to_filter}
|
||||
model_kwargs.update(filtered_cfg)
|
||||
model_kwargs, dropped_params, default_temperature_applied = sanitize_model_kwargs(
|
||||
model_name=config.model_name,
|
||||
model_provider=model_provider,
|
||||
base_url=base_url,
|
||||
api_key=config.api_key,
|
||||
generate_cfg=config.generate_cfg,
|
||||
source="init_agent"
|
||||
)
|
||||
if dropped_params:
|
||||
logger.info(
|
||||
"init_agent dropped_params=%s model=%s provider=%s default_temperature_applied=%s",
|
||||
dropped_params,
|
||||
config.model_name,
|
||||
model_provider,
|
||||
default_temperature_applied
|
||||
)
|
||||
llm_instance = init_chat_model(**model_kwargs)
|
||||
|
||||
# 创建新的 agent(不再缓存)
|
||||
|
||||
@ -514,6 +514,7 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
|
||||
# 收集额外参数作为 generate_cfg
|
||||
exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id', 'enable_thinking', 'skills', 'enable_memory', 'n', 'shell_env', 'max_tokens'}
|
||||
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
|
||||
logger.info("chat_completions generate_cfg_keys=%s model=%s", list(generate_cfg.keys()), request.model)
|
||||
# 处理消息
|
||||
messages = process_messages(request.messages, request.language)
|
||||
# 创建 AgentConfig 对象
|
||||
@ -665,9 +666,8 @@ async def chat_warmup_v2(request: ChatRequestV2, authorization: Optional[str] =
|
||||
# 处理消息
|
||||
messages = process_messages(empty_messages, request.language or "ja")
|
||||
|
||||
# 收集额外参数作为 generate_cfg
|
||||
exclude_fields = {'messages', 'stream', 'tool_response', 'bot_id', 'language', 'user_identifier', 'session_id', 'n', 'model', 'model_server', 'api_key', 'shell_env', 'max_tokens'}
|
||||
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
|
||||
logger.info("chat_warmup_v2 generate_cfg_keys=%s requested_model=%s", list(generate_cfg.keys()), request.model)
|
||||
# 从请求中提取 model/model_server/api_key,优先级高于 bot_config(排除 "whatever" 和空值)
|
||||
req_data = request.model_dump()
|
||||
req_model = req_data.get("model") or ""
|
||||
@ -775,6 +775,7 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
|
||||
# 收集额外参数作为 generate_cfg
|
||||
exclude_fields = {'messages', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings', 'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id', 'enable_thinking', 'skills', 'enable_memory', 'n', 'model', 'model_server', 'api_key', 'shell_env', 'max_tokens'}
|
||||
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
|
||||
logger.info("chat_completions_v2 generate_cfg_keys=%s requested_model=%s", list(generate_cfg.keys()), request.model)
|
||||
# 从请求中提取 model/model_server/api_key,优先级高于 bot_config(排除 "whatever" 和空值)
|
||||
req_data = request.model_dump()
|
||||
req_model = req_data.get("model") or ""
|
||||
|
||||
@ -36,6 +36,83 @@ def detect_provider(model_name,model_server):
|
||||
# 默认使用 openai 兼容格式
|
||||
return "openai",model_server
|
||||
|
||||
|
||||
def is_anthropic_opus_model(model_name: Optional[str]) -> bool:
|
||||
"""判断是否为 Anthropic Opus 模型"""
|
||||
return bool(model_name and "opus" in model_name.lower())
|
||||
|
||||
|
||||
def sanitize_model_kwargs(
|
||||
model_name: str,
|
||||
model_provider: str,
|
||||
base_url: Optional[str],
|
||||
api_key: Optional[str],
|
||||
generate_cfg: Optional[Dict[str, Any]] = None,
|
||||
source: str = "agent"
|
||||
) -> tuple[Dict[str, Any], List[str], bool]:
|
||||
"""清洗模型参数,过滤不兼容参数并返回日志所需信息"""
|
||||
model_kwargs = {
|
||||
"model": model_name,
|
||||
"model_provider": model_provider,
|
||||
"base_url": base_url,
|
||||
"api_key": api_key
|
||||
}
|
||||
|
||||
internal_params = {
|
||||
'tool_output_max_length',
|
||||
'tool_output_truncation_strategy',
|
||||
'tool_output_filters',
|
||||
'tool_output_exclude',
|
||||
'preserve_code_blocks',
|
||||
'preserve_json',
|
||||
}
|
||||
|
||||
openai_only_params = {
|
||||
'n',
|
||||
'presence_penalty',
|
||||
'frequency_penalty',
|
||||
'logprobs',
|
||||
'top_logprobs',
|
||||
'logit_bias',
|
||||
'seed',
|
||||
'suffix',
|
||||
'best_of',
|
||||
'echo',
|
||||
'user',
|
||||
}
|
||||
|
||||
params_to_filter = set(internal_params)
|
||||
is_opus_model = model_provider == 'anthropic' and is_anthropic_opus_model(model_name)
|
||||
|
||||
if model_provider == 'anthropic':
|
||||
params_to_filter.update(openai_only_params)
|
||||
if is_opus_model:
|
||||
params_to_filter.add('temperature')
|
||||
|
||||
original_keys = list((generate_cfg or {}).keys())
|
||||
filtered_cfg = {k: v for k, v in (generate_cfg or {}).items() if k not in params_to_filter}
|
||||
dropped_params = [k for k in original_keys if k in params_to_filter]
|
||||
|
||||
default_temperature_applied = False
|
||||
if not is_opus_model:
|
||||
model_kwargs["temperature"] = 0.8
|
||||
default_temperature_applied = True
|
||||
|
||||
model_kwargs.update(filtered_cfg)
|
||||
|
||||
logger.info(
|
||||
"sanitize_model_kwargs source=%s provider=%s model=%s original_keys=%s dropped_keys=%s default_temperature_applied=%s",
|
||||
source,
|
||||
model_provider,
|
||||
model_name,
|
||||
original_keys,
|
||||
dropped_params,
|
||||
default_temperature_applied
|
||||
)
|
||||
|
||||
return model_kwargs, dropped_params, default_temperature_applied
|
||||
|
||||
|
||||
def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extension: str) -> tuple[str, int]:
|
||||
"""
|
||||
获取带版本号的文件名,自动处理文件删除和版本递增
|
||||
@ -451,15 +528,22 @@ async def _sync_call_llm(llm_config, messages) -> str:
|
||||
api_key = llm_config.get('api_key')
|
||||
# 检测或使用指定的提供商
|
||||
model_provider,base_url = detect_provider(model_name,model_server)
|
||||
|
||||
# 构建模型参数
|
||||
model_kwargs = {
|
||||
"model": model_name,
|
||||
"model_provider": model_provider,
|
||||
"temperature": 0.8,
|
||||
"base_url":base_url,
|
||||
"api_key":api_key
|
||||
}
|
||||
|
||||
model_kwargs, dropped_params, default_temperature_applied = sanitize_model_kwargs(
|
||||
model_name=model_name,
|
||||
model_provider=model_provider,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
source="_sync_call_llm"
|
||||
)
|
||||
if dropped_params:
|
||||
logger.info(
|
||||
"_sync_call_llm dropped_params=%s model=%s provider=%s default_temperature_applied=%s",
|
||||
dropped_params,
|
||||
model_name,
|
||||
model_provider,
|
||||
default_temperature_applied
|
||||
)
|
||||
llm_instance = init_chat_model(**model_kwargs)
|
||||
|
||||
# 转换消息格式为LangChain格式
|
||||
|
||||
Loading…
Reference in New Issue
Block a user