sanitize_model_kwargs

This commit is contained in:
朱潮 2026-04-20 19:00:15 +08:00
parent 18d65513f6
commit dfc1c003c6
3 changed files with 113 additions and 53 deletions

View File

@ -18,7 +18,7 @@ from langchain.agents.middleware import SummarizationMiddleware as LangchainSumm
from .summarization_middleware import SummarizationMiddleware
from langchain_mcp_adapters.client import MultiServerMCPClient
from sympy.printing.cxx import none
from utils.fastapi_utils import detect_provider
from utils.fastapi_utils import detect_provider, sanitize_model_kwargs
from .guideline_middleware import GuidelineMiddleware
from .tool_output_length_middleware import ToolOutputLengthMiddleware
from .tool_use_cleanup_middleware import ToolUseCleanupMiddleware
@ -200,47 +200,22 @@ async def init_agent(config: AgentConfig):
# 检测或使用指定的提供商
model_provider, base_url = detect_provider(config.model_name, config.model_server)
# 构建模型参数
model_kwargs = {
"model": config.model_name,
"model_provider": model_provider,
"temperature": 0.8,
"base_url": base_url,
"api_key": config.api_key
}
if config.generate_cfg:
# 内部使用的参数,不应传给任何 LLM
internal_params = {
'tool_output_max_length',
'tool_output_truncation_strategy',
'tool_output_filters',
'tool_output_exclude',
'preserve_code_blocks',
'preserve_json',
}
# Anthropic 不支持的 OpenAI 特有参数
openai_only_params = {
'n', # 生成多少个响应
'presence_penalty',
'frequency_penalty',
'logprobs',
'top_logprobs',
'logit_bias',
'seed',
'suffix',
'best_of',
'echo',
'user',
}
# 根据提供商决定需要过滤的参数
params_to_filter = internal_params.copy()
if model_provider == 'anthropic':
params_to_filter.update(openai_only_params)
filtered_cfg = {k: v for k, v in config.generate_cfg.items() if k not in params_to_filter}
model_kwargs.update(filtered_cfg)
model_kwargs, dropped_params, default_temperature_applied = sanitize_model_kwargs(
model_name=config.model_name,
model_provider=model_provider,
base_url=base_url,
api_key=config.api_key,
generate_cfg=config.generate_cfg,
source="init_agent"
)
if dropped_params:
logger.info(
"init_agent dropped_params=%s model=%s provider=%s default_temperature_applied=%s",
dropped_params,
config.model_name,
model_provider,
default_temperature_applied
)
llm_instance = init_chat_model(**model_kwargs)
# 创建新的 agent不再缓存

View File

@ -514,6 +514,7 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
# 收集额外参数作为 generate_cfg
exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id', 'enable_thinking', 'skills', 'enable_memory', 'n', 'shell_env', 'max_tokens'}
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
logger.info("chat_completions generate_cfg_keys=%s model=%s", list(generate_cfg.keys()), request.model)
# 处理消息
messages = process_messages(request.messages, request.language)
# 创建 AgentConfig 对象
@ -665,9 +666,8 @@ async def chat_warmup_v2(request: ChatRequestV2, authorization: Optional[str] =
# 处理消息
messages = process_messages(empty_messages, request.language or "ja")
# 收集额外参数作为 generate_cfg
exclude_fields = {'messages', 'stream', 'tool_response', 'bot_id', 'language', 'user_identifier', 'session_id', 'n', 'model', 'model_server', 'api_key', 'shell_env', 'max_tokens'}
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
logger.info("chat_warmup_v2 generate_cfg_keys=%s requested_model=%s", list(generate_cfg.keys()), request.model)
# 从请求中提取 model/model_server/api_key优先级高于 bot_config排除 "whatever" 和空值)
req_data = request.model_dump()
req_model = req_data.get("model") or ""
@ -775,6 +775,7 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
# 收集额外参数作为 generate_cfg
exclude_fields = {'messages', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings', 'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id', 'enable_thinking', 'skills', 'enable_memory', 'n', 'model', 'model_server', 'api_key', 'shell_env', 'max_tokens'}
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
logger.info("chat_completions_v2 generate_cfg_keys=%s requested_model=%s", list(generate_cfg.keys()), request.model)
# 从请求中提取 model/model_server/api_key优先级高于 bot_config排除 "whatever" 和空值)
req_data = request.model_dump()
req_model = req_data.get("model") or ""

View File

@ -36,6 +36,83 @@ def detect_provider(model_name,model_server):
# 默认使用 openai 兼容格式
return "openai",model_server
def is_anthropic_opus_model(model_name: Optional[str]) -> bool:
"""判断是否为 Anthropic Opus 模型"""
return bool(model_name and "opus" in model_name.lower())
def sanitize_model_kwargs(
model_name: str,
model_provider: str,
base_url: Optional[str],
api_key: Optional[str],
generate_cfg: Optional[Dict[str, Any]] = None,
source: str = "agent"
) -> tuple[Dict[str, Any], List[str], bool]:
"""清洗模型参数,过滤不兼容参数并返回日志所需信息"""
model_kwargs = {
"model": model_name,
"model_provider": model_provider,
"base_url": base_url,
"api_key": api_key
}
internal_params = {
'tool_output_max_length',
'tool_output_truncation_strategy',
'tool_output_filters',
'tool_output_exclude',
'preserve_code_blocks',
'preserve_json',
}
openai_only_params = {
'n',
'presence_penalty',
'frequency_penalty',
'logprobs',
'top_logprobs',
'logit_bias',
'seed',
'suffix',
'best_of',
'echo',
'user',
}
params_to_filter = set(internal_params)
is_opus_model = model_provider == 'anthropic' and is_anthropic_opus_model(model_name)
if model_provider == 'anthropic':
params_to_filter.update(openai_only_params)
if is_opus_model:
params_to_filter.add('temperature')
original_keys = list((generate_cfg or {}).keys())
filtered_cfg = {k: v for k, v in (generate_cfg or {}).items() if k not in params_to_filter}
dropped_params = [k for k in original_keys if k in params_to_filter]
default_temperature_applied = False
if not is_opus_model:
model_kwargs["temperature"] = 0.8
default_temperature_applied = True
model_kwargs.update(filtered_cfg)
logger.info(
"sanitize_model_kwargs source=%s provider=%s model=%s original_keys=%s dropped_keys=%s default_temperature_applied=%s",
source,
model_provider,
model_name,
original_keys,
dropped_params,
default_temperature_applied
)
return model_kwargs, dropped_params, default_temperature_applied
def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extension: str) -> tuple[str, int]:
"""
获取带版本号的文件名自动处理文件删除和版本递增
@ -451,15 +528,22 @@ async def _sync_call_llm(llm_config, messages) -> str:
api_key = llm_config.get('api_key')
# 检测或使用指定的提供商
model_provider,base_url = detect_provider(model_name,model_server)
# 构建模型参数
model_kwargs = {
"model": model_name,
"model_provider": model_provider,
"temperature": 0.8,
"base_url":base_url,
"api_key":api_key
}
model_kwargs, dropped_params, default_temperature_applied = sanitize_model_kwargs(
model_name=model_name,
model_provider=model_provider,
base_url=base_url,
api_key=api_key,
source="_sync_call_llm"
)
if dropped_params:
logger.info(
"_sync_call_llm dropped_params=%s model=%s provider=%s default_temperature_applied=%s",
dropped_params,
model_name,
model_provider,
default_temperature_applied
)
llm_instance = init_chat_model(**model_kwargs)
# 转换消息格式为LangChain格式