From dfc1c003c67babc8dac42a31c97512060fbc604e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Mon, 20 Apr 2026 19:00:15 +0800 Subject: [PATCH] sanitize_model_kwargs --- agent/deep_assistant.py | 59 +++++++---------------- routes/chat.py | 5 +- utils/fastapi_utils.py | 102 ++++++++++++++++++++++++++++++++++++---- 3 files changed, 113 insertions(+), 53 deletions(-) diff --git a/agent/deep_assistant.py b/agent/deep_assistant.py index ce827de..2521b6e 100644 --- a/agent/deep_assistant.py +++ b/agent/deep_assistant.py @@ -18,7 +18,7 @@ from langchain.agents.middleware import SummarizationMiddleware as LangchainSumm from .summarization_middleware import SummarizationMiddleware from langchain_mcp_adapters.client import MultiServerMCPClient from sympy.printing.cxx import none -from utils.fastapi_utils import detect_provider +from utils.fastapi_utils import detect_provider, sanitize_model_kwargs from .guideline_middleware import GuidelineMiddleware from .tool_output_length_middleware import ToolOutputLengthMiddleware from .tool_use_cleanup_middleware import ToolUseCleanupMiddleware @@ -200,47 +200,22 @@ async def init_agent(config: AgentConfig): # 检测或使用指定的提供商 model_provider, base_url = detect_provider(config.model_name, config.model_server) - # 构建模型参数 - model_kwargs = { - "model": config.model_name, - "model_provider": model_provider, - "temperature": 0.8, - "base_url": base_url, - "api_key": config.api_key - } - if config.generate_cfg: - # 内部使用的参数,不应传给任何 LLM - internal_params = { - 'tool_output_max_length', - 'tool_output_truncation_strategy', - 'tool_output_filters', - 'tool_output_exclude', - 'preserve_code_blocks', - 'preserve_json', - } - - # Anthropic 不支持的 OpenAI 特有参数 - openai_only_params = { - 'n', # 生成多少个响应 - 'presence_penalty', - 'frequency_penalty', - 'logprobs', - 'top_logprobs', - 'logit_bias', - 'seed', - 'suffix', - 'best_of', - 'echo', - 'user', - } - - # 根据提供商决定需要过滤的参数 - params_to_filter = internal_params.copy() - if model_provider == 'anthropic': - params_to_filter.update(openai_only_params) - - filtered_cfg = {k: v for k, v in config.generate_cfg.items() if k not in params_to_filter} - model_kwargs.update(filtered_cfg) + model_kwargs, dropped_params, default_temperature_applied = sanitize_model_kwargs( + model_name=config.model_name, + model_provider=model_provider, + base_url=base_url, + api_key=config.api_key, + generate_cfg=config.generate_cfg, + source="init_agent" + ) + if dropped_params: + logger.info( + "init_agent dropped_params=%s model=%s provider=%s default_temperature_applied=%s", + dropped_params, + config.model_name, + model_provider, + default_temperature_applied + ) llm_instance = init_chat_model(**model_kwargs) # 创建新的 agent(不再缓存) diff --git a/routes/chat.py b/routes/chat.py index ec4129a..9b9f6f9 100644 --- a/routes/chat.py +++ b/routes/chat.py @@ -514,6 +514,7 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] = # 收集额外参数作为 generate_cfg exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id', 'enable_thinking', 'skills', 'enable_memory', 'n', 'shell_env', 'max_tokens'} generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields} + logger.info("chat_completions generate_cfg_keys=%s model=%s", list(generate_cfg.keys()), request.model) # 处理消息 messages = process_messages(request.messages, request.language) # 创建 AgentConfig 对象 @@ -665,9 +666,8 @@ async def chat_warmup_v2(request: ChatRequestV2, authorization: Optional[str] = # 处理消息 messages = process_messages(empty_messages, request.language or "ja") - # 收集额外参数作为 generate_cfg - exclude_fields = {'messages', 'stream', 'tool_response', 'bot_id', 'language', 'user_identifier', 'session_id', 'n', 'model', 'model_server', 'api_key', 'shell_env', 'max_tokens'} generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields} + logger.info("chat_warmup_v2 generate_cfg_keys=%s requested_model=%s", list(generate_cfg.keys()), request.model) # 从请求中提取 model/model_server/api_key,优先级高于 bot_config(排除 "whatever" 和空值) req_data = request.model_dump() req_model = req_data.get("model") or "" @@ -775,6 +775,7 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st # 收集额外参数作为 generate_cfg exclude_fields = {'messages', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings', 'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id', 'enable_thinking', 'skills', 'enable_memory', 'n', 'model', 'model_server', 'api_key', 'shell_env', 'max_tokens'} generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields} + logger.info("chat_completions_v2 generate_cfg_keys=%s requested_model=%s", list(generate_cfg.keys()), request.model) # 从请求中提取 model/model_server/api_key,优先级高于 bot_config(排除 "whatever" 和空值) req_data = request.model_dump() req_model = req_data.get("model") or "" diff --git a/utils/fastapi_utils.py b/utils/fastapi_utils.py index ed6847b..589fde8 100644 --- a/utils/fastapi_utils.py +++ b/utils/fastapi_utils.py @@ -36,6 +36,83 @@ def detect_provider(model_name,model_server): # 默认使用 openai 兼容格式 return "openai",model_server + +def is_anthropic_opus_model(model_name: Optional[str]) -> bool: + """判断是否为 Anthropic Opus 模型""" + return bool(model_name and "opus" in model_name.lower()) + + +def sanitize_model_kwargs( + model_name: str, + model_provider: str, + base_url: Optional[str], + api_key: Optional[str], + generate_cfg: Optional[Dict[str, Any]] = None, + source: str = "agent" +) -> tuple[Dict[str, Any], List[str], bool]: + """清洗模型参数,过滤不兼容参数并返回日志所需信息""" + model_kwargs = { + "model": model_name, + "model_provider": model_provider, + "base_url": base_url, + "api_key": api_key + } + + internal_params = { + 'tool_output_max_length', + 'tool_output_truncation_strategy', + 'tool_output_filters', + 'tool_output_exclude', + 'preserve_code_blocks', + 'preserve_json', + } + + openai_only_params = { + 'n', + 'presence_penalty', + 'frequency_penalty', + 'logprobs', + 'top_logprobs', + 'logit_bias', + 'seed', + 'suffix', + 'best_of', + 'echo', + 'user', + } + + params_to_filter = set(internal_params) + is_opus_model = model_provider == 'anthropic' and is_anthropic_opus_model(model_name) + + if model_provider == 'anthropic': + params_to_filter.update(openai_only_params) + if is_opus_model: + params_to_filter.add('temperature') + + original_keys = list((generate_cfg or {}).keys()) + filtered_cfg = {k: v for k, v in (generate_cfg or {}).items() if k not in params_to_filter} + dropped_params = [k for k in original_keys if k in params_to_filter] + + default_temperature_applied = False + if not is_opus_model: + model_kwargs["temperature"] = 0.8 + default_temperature_applied = True + + model_kwargs.update(filtered_cfg) + + logger.info( + "sanitize_model_kwargs source=%s provider=%s model=%s original_keys=%s dropped_keys=%s default_temperature_applied=%s", + source, + model_provider, + model_name, + original_keys, + dropped_params, + default_temperature_applied + ) + + return model_kwargs, dropped_params, default_temperature_applied + + def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extension: str) -> tuple[str, int]: """ 获取带版本号的文件名,自动处理文件删除和版本递增 @@ -451,15 +528,22 @@ async def _sync_call_llm(llm_config, messages) -> str: api_key = llm_config.get('api_key') # 检测或使用指定的提供商 model_provider,base_url = detect_provider(model_name,model_server) - - # 构建模型参数 - model_kwargs = { - "model": model_name, - "model_provider": model_provider, - "temperature": 0.8, - "base_url":base_url, - "api_key":api_key - } + + model_kwargs, dropped_params, default_temperature_applied = sanitize_model_kwargs( + model_name=model_name, + model_provider=model_provider, + base_url=base_url, + api_key=api_key, + source="_sync_call_llm" + ) + if dropped_params: + logger.info( + "_sync_call_llm dropped_params=%s model=%s provider=%s default_temperature_applied=%s", + dropped_params, + model_name, + model_provider, + default_temperature_applied + ) llm_instance = init_chat_model(**model_kwargs) # 转换消息格式为LangChain格式