diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index d612ec8a59..cd2a809ba2 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -24,6 +24,7 @@ ) from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType +from axolotl.utils.schemas.fsdp import FSDPConfig from axolotl.utils.schemas.integrations import ( CometConfig, GradioConfig, @@ -667,19 +668,12 @@ class AxolotlInputConfig( json_schema_extra={"description": "FSDP configuration"}, deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ", ) - # TODO @SalmanMohammadi strongly type this as its own schema - fsdp_config: dict[str, Any] | None = Field( + fsdp_config: FSDPConfig | None = Field( default=None, json_schema_extra={"description": "FSDP configuration options"} ) - fsdp_version: int | None = Field( - default=None, - json_schema_extra={"description": "FSDP version"}, - ) - fsdp_final_state_dict_type: ( - Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None - ) = Field( - default=None, - deprecated="Configuring FSDP final state dict type using `fsdp_final_state_dict_type` is deprecated. Please use `fsdp_config.final_state_dict_type` instead.", + + fsdp_version: Literal[1, 2] | None = Field( + default=None, json_schema_extra={"description": "FSDP version (1 or 2)"} ) val_set_size: float | None = Field( @@ -1281,23 +1275,6 @@ def check_qat_config(cls, data): return data - @model_validator(mode="before") - @classmethod - def check_fsdp_torch_version(cls, data): - env_capabilities = data.get("env_capabilities", {}) - torch_version = env_capabilities.get("torch_version") - - if torch_version is None: - import torch - - torch_version = str(torch.__version__).split("+", maxsplit=1)[0] - - if data.get("fsdp_config") and str(data.get("fsdp_version")) == "2": - if version.parse(torch_version) < version.parse("2.7.0"): - raise ValueError("FSDP2 is not supported on torch version < 2.7.0") - - return data - @model_validator(mode="before") @classmethod def default_dataloader_opts(cls, data): diff --git a/src/axolotl/utils/schemas/fsdp.py b/src/axolotl/utils/schemas/fsdp.py new file mode 100644 index 0000000000..2c8774d8ec --- /dev/null +++ b/src/axolotl/utils/schemas/fsdp.py @@ -0,0 +1,93 @@ +from typing import Literal + +from packaging import version +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +class FSDPConfig(BaseModel): + fsdp: list[str] | None = Field( + default=None, + json_schema_extra={"description": "FSDP configuration"}, + deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ", + ) + final_state_dict_type: ( + Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None + ) = Field( + default=None, + deprecated="Configuring FSDP final state dict type using `fsdp_final_state_dict_type` is deprecated. Please use `fsdp_config.final_state_dict_type` instead.", + ) + activation_checkpointing: bool | None = Field( + default=None, description="Enable activation checkpointing for FSDP." + ) + offload_params: bool | None = Field( + default=None, description="Enable parameter offloading to CPU for FSDP." + ) + sync_module_states: bool | None = Field( + default=None, description="Synchronize module states across FSDP processes." + ) + use_orig_params: bool | None = Field( + default=None, description="Use original parameters for FSDP." + ) + state_dict_type: ( + Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None + ) = Field(default=None, description="Type of state dict to use for FSDP.") + auto_wrap_policy: Literal["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP"] | None = ( + Field(default=None, description="Auto wrap policy for FSDP.") + ) + transformer_layer_cls_to_wrap: str | None = Field( + default=None, description="List of transformer layer classes to wrap with FSDP." + ) + cpu_ram_efficient_loading: bool | None = Field( + default=None, description="Enable CPU RAM efficient loading for FSDP." + ) + reshard_after_forward: bool | None = Field( + default=None, description="Reshard parameters after forward pass in FSDP." + ) + + @model_validator(mode="before") + @classmethod + def check_fsdp_torch_version(cls, data): + if not isinstance(data, dict): + return data + + fsdp_version = data.get("fsdp_version") + if fsdp_version == 2 or str(fsdp_version) == "2": + try: + import torch + + torch_version = str(torch.__version__).split("+", maxsplit=1)[0] + if version.parse(torch_version) < version.parse("2.7.0"): + raise ValueError("FSDP2 is not supported on torch version < 2.7.0") + except ImportError: + pass + + return data + + @model_validator(mode="before") + @classmethod + def check_fsdp_config_kwargs_prefix(cls, data): + if not isinstance(data, dict): + return data + should_fix = False + for key, _ in data.items(): + if key.startswith("fsdp_"): + should_fix = True + LOG.warning_once( + "Configuring FSDP fields with the `fsdp_` prefix is deprecated. " + "Please omit the `fsdp_` prefix from the any fields in `fsdp_config`." + ) + break + if should_fix: + update_fsdp_config = {} + for key, value in data.items(): + if key.startswith("fsdp_") and key != "fsdp_version": + update_fsdp_config[key.replace("fsdp_", "")] = value + else: + update_fsdp_config[key] = value + data.clear() + data.update(update_fsdp_config) + return data diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 64018ca486..320bee7edf 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -823,9 +823,13 @@ def check_fsdp2_base_model_quant_ram_efficient_loading(self): load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None if fsdp_config and fsdp_version == 2: - if fsdp_config.get("cpu_ram_efficient_loading") and ( - load_in_8bit or load_in_4bit - ): + cpu_ram_efficient_loading = None + if hasattr(fsdp_config, "cpu_ram_efficient_loading"): + cpu_ram_efficient_loading = fsdp_config.cpu_ram_efficient_loading + elif isinstance(fsdp_config, dict): + cpu_ram_efficient_loading = fsdp_config.get("cpu_ram_efficient_loading") + + if cpu_ram_efficient_loading and (load_in_8bit or load_in_4bit): raise ValueError( "FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, " "set fsdp_version to 1, or disable cpu_ram_efficient_loading." @@ -860,28 +864,6 @@ def check_fsdp_version_in_fsdp_config(cls, data): data["fsdp_version"] = fsdp_config.pop("fsdp_version") return data - @model_validator(mode="before") - @classmethod - def check_fsdp_config_kwargs_prefix(cls, data): - if fsdp_config := data.get("fsdp_config"): - should_fix = False - for key, _ in fsdp_config.items(): - if key.startswith("fsdp_"): - should_fix = True - LOG.warning_once( - "Configuring FSDP fields with the `fsdp_` prefix is deprecated. " - "Please omit the `fsdp_` prefix from the any fields in `fsdp_config`." - ) - if should_fix: - update_fsdp_config = {} - for key, value in fsdp_config.items(): - if key.startswith("fsdp_") and key != "fsdp_version": - update_fsdp_config[key.replace("fsdp_", "")] = value - else: - update_fsdp_config[key] = value - data["fsdp_config"] = update_fsdp_config - return data - @model_validator(mode="after") def check_fsdp_offload_w_8bit_optimizer(self): if ( @@ -889,12 +871,18 @@ def check_fsdp_offload_w_8bit_optimizer(self): and self.fsdp_config and self.optimizer and "8bit" in self.optimizer.value - and self.fsdp_config["offload_params"] and str(self.fsdp_version) != "2" ): - raise ValueError( - f"FSDP Offload not compatible with {str(self.optimizer.value)}" - ) + offload_params = None + if hasattr(self.fsdp_config, "offload_params"): + offload_params = self.fsdp_config.offload_params + elif isinstance(self.fsdp_config, dict): + offload_params = self.fsdp_config.get("offload_params") + + if offload_params: + raise ValueError( + f"FSDP Offload not compatible with {str(self.optimizer.value)}" + ) return self @model_validator(mode="after")