-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Create FSDPConfig schema
#3157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create FSDPConfig schema
#3157
Changes from all commits
1be904c
f13109b
b7245c4
cf97f66
213829e
1b6b0a9
bc7ebf8
e4542d9
c656dc6
ff45c3b
26fe19b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |
| ) | ||
| from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters | ||
| from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType | ||
| from axolotl.utils.schemas.fsdp import FSDPConfig | ||
| from axolotl.utils.schemas.integrations import ( | ||
| CometConfig, | ||
| GradioConfig, | ||
|
|
@@ -667,19 +668,12 @@ class AxolotlInputConfig( | |
| json_schema_extra={"description": "FSDP configuration"}, | ||
| deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ", | ||
| ) | ||
| # TODO @SalmanMohammadi strongly type this as its own schema | ||
| fsdp_config: dict[str, Any] | None = Field( | ||
| fsdp_config: FSDPConfig | None = Field( | ||
| default=None, json_schema_extra={"description": "FSDP configuration options"} | ||
| ) | ||
| fsdp_version: int | None = Field( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also should not have been moved |
||
| default=None, | ||
| json_schema_extra={"description": "FSDP version"}, | ||
| ) | ||
| fsdp_final_state_dict_type: ( | ||
| Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None | ||
| ) = Field( | ||
| default=None, | ||
| deprecated="Configuring FSDP final state dict type using `fsdp_final_state_dict_type` is deprecated. Please use `fsdp_config.final_state_dict_type` instead.", | ||
|
|
||
| fsdp_version: Literal[1, 2] | None = Field( | ||
| default=None, json_schema_extra={"description": "FSDP version (1 or 2)"} | ||
| ) | ||
|
|
||
| val_set_size: float | None = Field( | ||
|
|
@@ -1281,23 +1275,6 @@ def check_qat_config(cls, data): | |
|
|
||
| return data | ||
|
|
||
| @model_validator(mode="before") | ||
| @classmethod | ||
| def check_fsdp_torch_version(cls, data): | ||
| env_capabilities = data.get("env_capabilities", {}) | ||
| torch_version = env_capabilities.get("torch_version") | ||
|
|
||
| if torch_version is None: | ||
| import torch | ||
|
|
||
| torch_version = str(torch.__version__).split("+", maxsplit=1)[0] | ||
|
|
||
| if data.get("fsdp_config") and str(data.get("fsdp_version")) == "2": | ||
| if version.parse(torch_version) < version.parse("2.7.0"): | ||
| raise ValueError("FSDP2 is not supported on torch version < 2.7.0") | ||
|
|
||
| return data | ||
|
|
||
| @model_validator(mode="before") | ||
| @classmethod | ||
| def default_dataloader_opts(cls, data): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| from typing import Literal | ||
|
|
||
| from packaging import version | ||
| from pydantic import BaseModel, ConfigDict, Field, model_validator | ||
|
|
||
| from axolotl.utils.logging import get_logger | ||
|
|
||
| LOG = get_logger(__name__) | ||
|
|
||
|
|
||
| class FSDPConfig(BaseModel): | ||
| fsdp: list[str] | None = Field( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This shouldn't be here |
||
| default=None, | ||
| json_schema_extra={"description": "FSDP configuration"}, | ||
| deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ", | ||
| ) | ||
| final_state_dict_type: ( | ||
| Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None | ||
| ) = Field( | ||
| default=None, | ||
| deprecated="Configuring FSDP final state dict type using `fsdp_final_state_dict_type` is deprecated. Please use `fsdp_config.final_state_dict_type` instead.", | ||
| ) | ||
| activation_checkpointing: bool | None = Field( | ||
| default=None, description="Enable activation checkpointing for FSDP." | ||
| ) | ||
| offload_params: bool | None = Field( | ||
| default=None, description="Enable parameter offloading to CPU for FSDP." | ||
| ) | ||
| sync_module_states: bool | None = Field( | ||
| default=None, description="Synchronize module states across FSDP processes." | ||
| ) | ||
| use_orig_params: bool | None = Field( | ||
| default=None, description="Use original parameters for FSDP." | ||
| ) | ||
| state_dict_type: ( | ||
| Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None | ||
| ) = Field(default=None, description="Type of state dict to use for FSDP.") | ||
| auto_wrap_policy: Literal["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP"] | None = ( | ||
| Field(default=None, description="Auto wrap policy for FSDP.") | ||
| ) | ||
| transformer_layer_cls_to_wrap: str | None = Field( | ||
| default=None, description="List of transformer layer classes to wrap with FSDP." | ||
| ) | ||
| cpu_ram_efficient_loading: bool | None = Field( | ||
| default=None, description="Enable CPU RAM efficient loading for FSDP." | ||
| ) | ||
| reshard_after_forward: bool | None = Field( | ||
| default=None, description="Reshard parameters after forward pass in FSDP." | ||
| ) | ||
|
|
||
| @model_validator(mode="before") | ||
| @classmethod | ||
| def check_fsdp_torch_version(cls, data): | ||
| if not isinstance(data, dict): | ||
| return data | ||
|
|
||
| fsdp_version = data.get("fsdp_version") | ||
| if fsdp_version == 2 or str(fsdp_version) == "2": | ||
| try: | ||
| import torch | ||
|
|
||
| torch_version = str(torch.__version__).split("+", maxsplit=1)[0] | ||
| if version.parse(torch_version) < version.parse("2.7.0"): | ||
| raise ValueError("FSDP2 is not supported on torch version < 2.7.0") | ||
| except ImportError: | ||
| pass | ||
|
|
||
| return data | ||
|
|
||
| @model_validator(mode="before") | ||
| @classmethod | ||
| def check_fsdp_config_kwargs_prefix(cls, data): | ||
| if not isinstance(data, dict): | ||
| return data | ||
| should_fix = False | ||
| for key, _ in data.items(): | ||
| if key.startswith("fsdp_"): | ||
| should_fix = True | ||
| LOG.warning_once( | ||
| "Configuring FSDP fields with the `fsdp_` prefix is deprecated. " | ||
| "Please omit the `fsdp_` prefix from the any fields in `fsdp_config`." | ||
| ) | ||
| break | ||
| if should_fix: | ||
| update_fsdp_config = {} | ||
| for key, value in data.items(): | ||
| if key.startswith("fsdp_") and key != "fsdp_version": | ||
| update_fsdp_config[key.replace("fsdp_", "")] = value | ||
| else: | ||
| update_fsdp_config[key] = value | ||
| data.clear() | ||
| data.update(update_fsdp_config) | ||
| return data | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -823,9 +823,13 @@ def check_fsdp2_base_model_quant_ram_efficient_loading(self): | |
| load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None | ||
| load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None | ||
| if fsdp_config and fsdp_version == 2: | ||
| if fsdp_config.get("cpu_ram_efficient_loading") and ( | ||
| load_in_8bit or load_in_4bit | ||
| ): | ||
| cpu_ram_efficient_loading = None | ||
| if hasattr(fsdp_config, "cpu_ram_efficient_loading"): | ||
| cpu_ram_efficient_loading = fsdp_config.cpu_ram_efficient_loading | ||
| elif isinstance(fsdp_config, dict): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When would this be the case? |
||
| cpu_ram_efficient_loading = fsdp_config.get("cpu_ram_efficient_loading") | ||
|
|
||
| if cpu_ram_efficient_loading and (load_in_8bit or load_in_4bit): | ||
| raise ValueError( | ||
| "FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, " | ||
| "set fsdp_version to 1, or disable cpu_ram_efficient_loading." | ||
|
|
@@ -860,41 +864,25 @@ def check_fsdp_version_in_fsdp_config(cls, data): | |
| data["fsdp_version"] = fsdp_config.pop("fsdp_version") | ||
| return data | ||
|
|
||
| @model_validator(mode="before") | ||
| @classmethod | ||
| def check_fsdp_config_kwargs_prefix(cls, data): | ||
| if fsdp_config := data.get("fsdp_config"): | ||
| should_fix = False | ||
| for key, _ in fsdp_config.items(): | ||
| if key.startswith("fsdp_"): | ||
| should_fix = True | ||
| LOG.warning_once( | ||
| "Configuring FSDP fields with the `fsdp_` prefix is deprecated. " | ||
| "Please omit the `fsdp_` prefix from the any fields in `fsdp_config`." | ||
| ) | ||
| if should_fix: | ||
| update_fsdp_config = {} | ||
| for key, value in fsdp_config.items(): | ||
| if key.startswith("fsdp_") and key != "fsdp_version": | ||
| update_fsdp_config[key.replace("fsdp_", "")] = value | ||
| else: | ||
| update_fsdp_config[key] = value | ||
| data["fsdp_config"] = update_fsdp_config | ||
| return data | ||
|
|
||
| @model_validator(mode="after") | ||
| def check_fsdp_offload_w_8bit_optimizer(self): | ||
| if ( | ||
| hasattr(self, "fsdp_config") | ||
| and self.fsdp_config | ||
| and self.optimizer | ||
| and "8bit" in self.optimizer.value | ||
| and self.fsdp_config["offload_params"] | ||
| and str(self.fsdp_version) != "2" | ||
| ): | ||
| raise ValueError( | ||
| f"FSDP Offload not compatible with {str(self.optimizer.value)}" | ||
| ) | ||
| offload_params = None | ||
| if hasattr(self.fsdp_config, "offload_params"): | ||
| offload_params = self.fsdp_config.offload_params | ||
| elif isinstance(self.fsdp_config, dict): | ||
| offload_params = self.fsdp_config.get("offload_params") | ||
|
|
||
| if offload_params: | ||
| raise ValueError( | ||
| f"FSDP Offload not compatible with {str(self.optimizer.value)}" | ||
| ) | ||
| return self | ||
|
|
||
| @model_validator(mode="after") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Typed FSDP config: good move. Also fix downstream validators that still read top-level fsdp_version.
This change makes
fsdp_versionnested. Validators below still dostr(data.get("fsdp_version")) == "2", which will now miss FSDP2 and mis-gate LoRA kernels on multi-GPU.Patch both validators to read the nested version:
Run to find any remaining top-level uses:
🏁 Script executed:
Length of output: 9737
Fix all top-level fsdp_version checks — read version from fsdp_config (critical)
fsdp_version was nested into fsdp_config; existing top-level checks (data.get("fsdp_version") / cfg.fsdp_version / str(...)) will miss FSDP2 and mis-gate runtime behavior. Patch the two validators below and every other occurrence to read the nested value (handle dict or typed object, optionally fall back to top-level for compatibility).
Patch for the two validators (unchanged from suggestion):
Other places found by the grep that must be updated to use the same pattern (update logic or centralize a helper to extract fsdp_version from fsdp_config):
if str(cfg.fsdp_version) == "2":str(self.cfg.fsdp_version) == "2"checksself.cfg.fsdp_versioncomparisonsself.fsdp_plugin.fsdp_version == 2Actionable: replace direct top-level reads with a small centralized getter (or inline pattern above) that:
use str(...) == "2" or int compare consistently once value resolved.