Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 5 additions & 28 deletions src/axolotl/utils/schemas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
from axolotl.utils.schemas.fsdp import FSDPConfig
from axolotl.utils.schemas.integrations import (
CometConfig,
GradioConfig,
Expand Down Expand Up @@ -667,19 +668,12 @@ class AxolotlInputConfig(
json_schema_extra={"description": "FSDP configuration"},
deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ",
)
# TODO @SalmanMohammadi strongly type this as its own schema
fsdp_config: dict[str, Any] | None = Field(
fsdp_config: FSDPConfig | None = Field(
default=None, json_schema_extra={"description": "FSDP configuration options"}
)
Comment on lines +671 to 673

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Typed FSDP config: good move. Also fix downstream validators that still read top-level fsdp_version.

This change makes fsdp_version nested. Validators below still do str(data.get("fsdp_version")) == "2", which will now miss FSDP2 and mis-gate LoRA kernels on multi-GPU.

Patch both validators to read the nested version:

@@ def check_multigpu_lora_kernels(cls, data):
-            capabilities = data.get("capabilities")
-            is_fsdp = data.get("fsdp_config") is not None
-            is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
+            capabilities = data.get("capabilities")
+            fsdp_cfg = data.get("fsdp_config")
+            is_fsdp = fsdp_cfg is not None
+            fsdp_version = (
+                fsdp_cfg.get("fsdp_version")
+                if isinstance(fsdp_cfg, dict)
+                else getattr(fsdp_cfg, "fsdp_version", None)
+            )
+            is_fsdp2 = is_fsdp and str(fsdp_version) == "2"
@@ def check_auto_enable_lora_kernels(cls, data):
-            capabilities = data.get("capabilities")
-            is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
-            is_fsdp = data.get("fsdp_config") is not None
-            is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
+            capabilities = data.get("capabilities")
+            is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
+            fsdp_cfg = data.get("fsdp_config")
+            is_fsdp = fsdp_cfg is not None
+            fsdp_version = (
+                fsdp_cfg.get("fsdp_version")
+                if isinstance(fsdp_cfg, dict)
+                else getattr(fsdp_cfg, "fsdp_version", None)
+            )
+            is_fsdp2 = is_fsdp and str(fsdp_version) == "2"

Run to find any remaining top-level uses:


🏁 Script executed:

#!/bin/bash
rg -n -C2 -g 'src/**' -e '\bfsdp_version\b' -e '\bfsdp_final_state_dict_type\b' -e '\bfp8_enable_fsdp_float8_all_gather\b' | grep -v 'schemas/fsdp.py'

Length of output: 9737


Fix all top-level fsdp_version checks — read version from fsdp_config (critical)

fsdp_version was nested into fsdp_config; existing top-level checks (data.get("fsdp_version") / cfg.fsdp_version / str(...)) will miss FSDP2 and mis-gate runtime behavior. Patch the two validators below and every other occurrence to read the nested value (handle dict or typed object, optionally fall back to top-level for compatibility).

Patch for the two validators (unchanged from suggestion):

@@ def check_multigpu_lora_kernels(cls, data):
-            capabilities = data.get("capabilities")
-            is_fsdp = data.get("fsdp_config") is not None
-            is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
+            capabilities = data.get("capabilities")
+            fsdp_cfg = data.get("fsdp_config")
+            is_fsdp = fsdp_cfg is not None
+            fsdp_version = (
+                fsdp_cfg.get("fsdp_version")
+                if isinstance(fsdp_cfg, dict)
+                else getattr(fsdp_cfg, "fsdp_version", None)
+            )
+            is_fsdp2 = is_fsdp and str(fsdp_version) == "2"
@@ def check_auto_enable_lora_kernels(cls, data):
-            capabilities = data.get("capabilities")
-            is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
-            is_fsdp = data.get("fsdp_config") is not None
-            is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
+            capabilities = data.get("capabilities")
+            is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
+            fsdp_cfg = data.get("fsdp_config")
+            is_fsdp = fsdp_cfg is not None
+            fsdp_version = (
+                fsdp_cfg.get("fsdp_version")
+                if isinstance(fsdp_cfg, dict)
+                else getattr(fsdp_cfg, "fsdp_version", None)
+            )
+            is_fsdp2 = is_fsdp and str(fsdp_version) == "2"

Other places found by the grep that must be updated to use the same pattern (update logic or centralize a helper to extract fsdp_version from fsdp_config):

  • src/axolotl/utils/trainer.py: line ~584 — if str(cfg.fsdp_version) == "2":
  • src/axolotl/utils/schemas/config.py: lines ~1097–1101 and ~1133–1136 — (the two validator sites above; diff applied)
  • src/axolotl/utils/schemas/validation.py: multiple locations (~417, ~810, ~822–827, ~838–846, ~855–861, ~891–907) — update data.get/self.fsdp_version checks and the migration/warning in check_fsdp_version_in_fsdp_config (it currently pops nested fsdp_version into top-level; reverse or remove this migration).
  • src/axolotl/loaders/patch_manager.py: lines ~112, ~119, ~179, ~286 — str(self.cfg.fsdp_version) == "2" checks
  • src/axolotl/loaders/model.py: lines ~203, ~730 — self.cfg.fsdp_version comparisons
  • src/axolotl/monkeypatch/accelerate/parallelism_config.py: line ~69 — self.fsdp_plugin.fsdp_version == 2

Actionable: replace direct top-level reads with a small centralized getter (or inline pattern above) that:

  • reads fsdp_cfg = cfg.fsdp_config (or data.get("fsdp_config"))
  • if dict -> fsdp_cfg.get("fsdp_version"); else -> getattr(fsdp_cfg, "fsdp_version", None)
  • fallback to getattr(cfg, "fsdp_version", None) only if backward-compat needed
    use str(...) == "2" or int compare consistently once value resolved.

Committable suggestion skipped: line range outside the PR's diff.

fsdp_version: int | None = Field(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also should not have been moved

default=None,
json_schema_extra={"description": "FSDP version"},
)
fsdp_final_state_dict_type: (
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
) = Field(
default=None,
deprecated="Configuring FSDP final state dict type using `fsdp_final_state_dict_type` is deprecated. Please use `fsdp_config.final_state_dict_type` instead.",

fsdp_version: Literal[1, 2] | None = Field(
default=None, json_schema_extra={"description": "FSDP version (1 or 2)"}
)

val_set_size: float | None = Field(
Expand Down Expand Up @@ -1281,23 +1275,6 @@ def check_qat_config(cls, data):

return data

@model_validator(mode="before")
@classmethod
def check_fsdp_torch_version(cls, data):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")

if torch_version is None:
import torch

torch_version = str(torch.__version__).split("+", maxsplit=1)[0]

if data.get("fsdp_config") and str(data.get("fsdp_version")) == "2":
if version.parse(torch_version) < version.parse("2.7.0"):
raise ValueError("FSDP2 is not supported on torch version < 2.7.0")

return data

@model_validator(mode="before")
@classmethod
def default_dataloader_opts(cls, data):
Expand Down
93 changes: 93 additions & 0 deletions src/axolotl/utils/schemas/fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Literal

from packaging import version
from pydantic import BaseModel, ConfigDict, Field, model_validator

from axolotl.utils.logging import get_logger

LOG = get_logger(__name__)


class FSDPConfig(BaseModel):
fsdp: list[str] | None = Field(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be here

default=None,
json_schema_extra={"description": "FSDP configuration"},
deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ",
)
final_state_dict_type: (
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
) = Field(
default=None,
deprecated="Configuring FSDP final state dict type using `fsdp_final_state_dict_type` is deprecated. Please use `fsdp_config.final_state_dict_type` instead.",
)
activation_checkpointing: bool | None = Field(
default=None, description="Enable activation checkpointing for FSDP."
)
offload_params: bool | None = Field(
default=None, description="Enable parameter offloading to CPU for FSDP."
)
sync_module_states: bool | None = Field(
default=None, description="Synchronize module states across FSDP processes."
)
use_orig_params: bool | None = Field(
default=None, description="Use original parameters for FSDP."
)
state_dict_type: (
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
) = Field(default=None, description="Type of state dict to use for FSDP.")
auto_wrap_policy: Literal["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP"] | None = (
Field(default=None, description="Auto wrap policy for FSDP.")
)
transformer_layer_cls_to_wrap: str | None = Field(
default=None, description="List of transformer layer classes to wrap with FSDP."
)
cpu_ram_efficient_loading: bool | None = Field(
default=None, description="Enable CPU RAM efficient loading for FSDP."
)
reshard_after_forward: bool | None = Field(
default=None, description="Reshard parameters after forward pass in FSDP."
)

@model_validator(mode="before")
@classmethod
def check_fsdp_torch_version(cls, data):
if not isinstance(data, dict):
return data

fsdp_version = data.get("fsdp_version")
if fsdp_version == 2 or str(fsdp_version) == "2":
try:
import torch

torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.7.0"):
raise ValueError("FSDP2 is not supported on torch version < 2.7.0")
except ImportError:
pass

return data

@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if not isinstance(data, dict):
return data
should_fix = False
for key, _ in data.items():
if key.startswith("fsdp_"):
should_fix = True
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
break
if should_fix:
update_fsdp_config = {}
for key, value in data.items():
if key.startswith("fsdp_") and key != "fsdp_version":
update_fsdp_config[key.replace("fsdp_", "")] = value
else:
update_fsdp_config[key] = value
data.clear()
data.update(update_fsdp_config)
return data
46 changes: 17 additions & 29 deletions src/axolotl/utils/schemas/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,9 +823,13 @@ def check_fsdp2_base_model_quant_ram_efficient_loading(self):
load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None
load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None
if fsdp_config and fsdp_version == 2:
if fsdp_config.get("cpu_ram_efficient_loading") and (
load_in_8bit or load_in_4bit
):
cpu_ram_efficient_loading = None
if hasattr(fsdp_config, "cpu_ram_efficient_loading"):
cpu_ram_efficient_loading = fsdp_config.cpu_ram_efficient_loading
elif isinstance(fsdp_config, dict):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would this be the case?

cpu_ram_efficient_loading = fsdp_config.get("cpu_ram_efficient_loading")

if cpu_ram_efficient_loading and (load_in_8bit or load_in_4bit):
raise ValueError(
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
Expand Down Expand Up @@ -860,41 +864,25 @@ def check_fsdp_version_in_fsdp_config(cls, data):
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
return data

@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if fsdp_config := data.get("fsdp_config"):
should_fix = False
for key, _ in fsdp_config.items():
if key.startswith("fsdp_"):
should_fix = True
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
if should_fix:
update_fsdp_config = {}
for key, value in fsdp_config.items():
if key.startswith("fsdp_") and key != "fsdp_version":
update_fsdp_config[key.replace("fsdp_", "")] = value
else:
update_fsdp_config[key] = value
data["fsdp_config"] = update_fsdp_config
return data

@model_validator(mode="after")
def check_fsdp_offload_w_8bit_optimizer(self):
if (
hasattr(self, "fsdp_config")
and self.fsdp_config
and self.optimizer
and "8bit" in self.optimizer.value
and self.fsdp_config["offload_params"]
and str(self.fsdp_version) != "2"
):
raise ValueError(
f"FSDP Offload not compatible with {str(self.optimizer.value)}"
)
offload_params = None
if hasattr(self.fsdp_config, "offload_params"):
offload_params = self.fsdp_config.offload_params
elif isinstance(self.fsdp_config, dict):
offload_params = self.fsdp_config.get("offload_params")

if offload_params:
raise ValueError(
f"FSDP Offload not compatible with {str(self.optimizer.value)}"
)
return self

@model_validator(mode="after")
Expand Down
Loading