From 4b5d034ac8f6814331d8498b7bf396dd2c0cdd23 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 13 Aug 2025 15:01:05 +0700 Subject: [PATCH 1/2] fix: fsdp_config validation being None --- src/axolotl/utils/schemas/validation.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 72991c9470..96d0683205 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -817,13 +817,13 @@ def check_fsdp2_base_model_quant_rl(cls, data): @model_validator(mode="before") @classmethod def check_fsdp_version_in_fsdp_config(cls, data): - if data.get("fsdp_config"): - if data.get("fsdp_config", {}).get("fsdp_version"): - LOG.warning( - "Configuring `fsdp_version` in `fsdp_config` is deprecated. " - "Please configure `fsdp_version` as a top-level field." - ) - data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version") + fsdp_config = data.get("fsdp_config") or {} + if fsdp_config and fsdp_config.get("fsdp_version"): + LOG.warning( + "Configuring `fsdp_version` in `fsdp_config` is deprecated. " + "Please configure `fsdp_version` as a top-level field." + ) + data["fsdp_version"] = fsdp_config.pop("fsdp_version") return data @model_validator(mode="before") @@ -1151,10 +1151,8 @@ def check_gptq_w_revision(cls, data): @classmethod def check_gpt_oss_fsdp_loading(cls, data): if data.get("model_quantization_config", "") == "Mxfp4Config": - if ( - data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False) - is True - ): + fsdp_config = data.get("fsdp_config") or {} + if fsdp_config.get("cpu_ram_efficient_loading", False) is True: raise ValueError( "FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization." ) From 156a32a3ac7839273471fa39421d725aff5d58f5 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 13 Aug 2025 15:25:32 +0700 Subject: [PATCH 2/2] fix: handling --- src/axolotl/utils/schemas/validation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 96d0683205..12c5f0ad41 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -369,10 +369,10 @@ def check_fp8_config(cls, data): "see speed improvements. Please consider setting `torch_compile: " "true` in your config." ) + fsdp_config = data.get("fsdp_config") or {} if data.get("fp8") and ( - data.get("fsdp_config", {}).get("activation_checkpointing", False) is True - or data.get("fsdp_config", {}).get("fsdp_activation_checkpointing", False) - is True + fsdp_config.get("activation_checkpointing", False) is True + or fsdp_config.get("fsdp_activation_checkpointing", False) is True ): LOG.warning( "FP8 + FSDP2 + activation checkpointing may be slower than BF16 "