Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/axolotl/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from axolotl.integrations.base import PluginManager
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
migrate_fsdp_config,
normalize_cfg_datasets,
normalize_config,
validate_config,
Expand Down Expand Up @@ -227,7 +226,6 @@ def load_cfg(
},
)

migrate_fsdp_config(cfg)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)
Expand Down
13 changes: 0 additions & 13 deletions src/axolotl/utils/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,16 +313,3 @@ def prepare_plugins(cfg):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)


# TODO @SalmanMohammadi remove this function in 0.12
def migrate_fsdp_config(cfg):
if cfg.get("fsdp_config"):
fsdp_config_keys = cfg.fsdp_config.keys()
if "fsdp_version" in fsdp_config_keys:
cfg.fsdp_version = cfg.fsdp_config.pop("fsdp_version")

for key in list(fsdp_config_keys):
if key.startswith("fsdp_") and key != "fsdp_version":
cfg.fsdp_config[key.replace("fsdp_", "")] = cfg.fsdp_config[key]
del cfg.fsdp_config[key]
66 changes: 0 additions & 66 deletions src/axolotl/utils/schemas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,72 +1143,6 @@ def check_fsdp_torch_version(cls, data):

return data

@model_validator(mode="before")
@classmethod
def check_fsdp_version(cls, data):
fsdp_config = data.get("fsdp_config", {})
if fsdp_config and str(data.get("fsdp_version")) != "2":
LOG.info(
"FSDP1 will be deprecated in an upcoming release of Axolotl."
"We recommend that you use FSDP version 2 for better performance and compatibility. "
"Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp "
"For more details on migrating your config. "
)
return data

@model_validator(mode="before")
@classmethod
def check_fsdp2_base_model_quant_ram_efficient_loading(cls, data):
fsdp_config = data.get("fsdp_config")
if fsdp_config and data.get("fsdp_version") == 2:
if fsdp_config.get("cpu_ram_efficient_loading") and (
data.get("load_in_8bit") or data.get("load_in_4bit")
):
raise ValueError(
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
)
return data

@model_validator(mode="before")
@classmethod
def check_fsdp2_base_model_quant_dpo(cls, data):
if data.get("fsdp_version") == 2 and data.get("rl") in [
RLType.DPO,
RLType.KTO,
RLType.ORPO,
RLType.IPO,
]:
if data.get("load_in_8bit") or data.get("load_in_4bit"):
raise ValueError(
"FSDP2 does not support load_in_8bit or load_in_4bit with DPO. Please use DeepSpeed or set `fsdp_version` to 1."
)

return data

@model_validator(mode="before")
@classmethod
def check_fsdp_version_in_fsdp_config(cls, data):
if fsdp_config := data.get("fsdp_config"):
if fsdp_config.get("fsdp_version"):
LOG.warning(
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
"Please configure `fsdp_version` as a top-level field."
)
return data

@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if fsdp_config := data.get("fsdp_config"):
for key, _ in fsdp_config.items():
if key.startswith("fsdp_"):
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
return data

@model_validator(mode="before")
@classmethod
def default_dataloader_opts(cls, data):
Expand Down
130 changes: 107 additions & 23 deletions src/axolotl/utils/schemas/validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Module with validation methods for config pydantic model."""

# pylint: disable=too-many-lines
# pylint: disable=too-many-lines,too-many-boolean-expressions

import logging

Expand Down Expand Up @@ -748,44 +748,128 @@ def check_xentropy_patch_conflicts(cls, data):

@model_validator(mode="before")
@classmethod
def check_fsdp_offload_w_8bit_optimizer(cls, data):
def check_fsdp_version(cls, data):
fsdp_config = data.get("fsdp_config", {})
if fsdp_config and str(data.get("fsdp_version")) != "2":
LOG.info(
"FSDP1 will be deprecated in an upcoming release of Axolotl."
"We recommend that you use FSDP version 2 for better performance and compatibility. "
"Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp "
"For more details on migrating your config. "
)
return data

@model_validator(mode="after")
def check_fsdp2_base_model_quant_ram_efficient_loading(self):
fsdp_config = self.fsdp_config if hasattr(self, "fsdp_config") else None
fsdp_version = self.fsdp_version if hasattr(self, "fsdp_version") else None
load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None
load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None
if fsdp_config and fsdp_version == 2:
if fsdp_config.get("cpu_ram_efficient_loading") and (
load_in_8bit or load_in_4bit
):
raise ValueError(
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
)
return self

@model_validator(mode="before")
@classmethod
def check_fsdp2_base_model_quant_rl(cls, data):
if data.get("fsdp_version") == 2 and data.get("rl") in [
RLType.DPO,
RLType.KTO,
RLType.ORPO,
RLType.IPO,
]:
if data.get("load_in_8bit") or data.get("load_in_4bit"):
raise ValueError(
f"FSDP2 does not support load_in_8bit or load_in_4bit with {data.get('rl')}. Please use DeepSpeed or set `fsdp_version` to 1."
)

return data

@model_validator(mode="before")
@classmethod
def check_fsdp_version_in_fsdp_config(cls, data):
if data.get("fsdp_config"):
if data.get("fsdp_config", {}).get("fsdp_version"):
LOG.warning(
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
"Please configure `fsdp_version` as a top-level field."
)
data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version")
return data

@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if fsdp_config := data.get("fsdp_config"):
should_fix = False
for key, _ in fsdp_config.items():
if key.startswith("fsdp_"):
should_fix = True
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
if should_fix:
Comment thread
winglian marked this conversation as resolved.
update_fsdp_config = {}
for key, value in fsdp_config.items():
if key.startswith("fsdp_") and key != "fsdp_version":
update_fsdp_config[key.replace("fsdp_", "")] = value
else:
update_fsdp_config[key] = value
data["fsdp_config"] = update_fsdp_config
return data

@model_validator(mode="after")
def check_fsdp_offload_w_8bit_optimizer(self):
if (
data.get("fsdp")
and "8bit" in data.get("optimizer", "")
and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_offload_params")
and str(data["fsdp_config"].get("fsdp_version")) != "2"
hasattr(self, "fsdp_config")
and self.fsdp_config
and self.optimizer
and "8bit" in self.optimizer.value
and self.fsdp_config["offload_params"]
and str(self.fsdp_version) != "2"
):
raise ValueError(
f"FSDP Offload not compatible with {data.get('optimizer')}"
f"FSDP Offload not compatible with {str(self.optimizer.value)}"
)
return self

@model_validator(mode="after")
def check_fsdp2_w_8bit_optimizer(self):
if (
data.get("fsdp")
and "8bit" in data.get("optimizer", "")
and data.get("fsdp_config")
and str(data["fsdp_config"].get("fsdp_version")) == "2"
hasattr(self, "fsdp_config")
and self.fsdp_config
and self.optimizer
and "8bit" in self.optimizer.value
and str(self.fsdp_version) == "2"
):
if data.get("optimizer", "") in ["adamw_8bit", "adamw_bnb_8bit"]:
if self.optimizer in ["adamw_8bit", "adamw_bnb_8bit"]:
# CUDA ops errors with bnb 8bit optimizer + FSDP2
raise ValueError(
f"FSDP2 not compatible with {data.get('optimizer')}, use `adamw_torch_8bit` instead"
f"FSDP2 not compatible with {self.optimizer.value}, use `adamw_torch_8bit` instead"
)

return data
return self

@model_validator(mode="before")
@classmethod
def check_fsdp_sharded_state_dict_w_safetensors(cls, data):
@model_validator(mode="after")
def check_fsdp_sharded_state_dict_w_safetensors(self):
if (
data.get("fsdp_config")
and data.get("save_safetensors")
and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
hasattr(self, "fsdp_config")
and self.fsdp_config
and hasattr(self, "save_safetensors")
and self.save_safetensors
and self.fsdp_config.get("state_dict_type", "") == "SHARDED_STATE_DICT"
):
raise ValueError(
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
)
return data
return self


class SystemValidationMixin:
Expand Down
30 changes: 18 additions & 12 deletions tests/test_normalize_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from unittest.mock import patch

from axolotl.utils.config import (
migrate_fsdp_config,
normalize_cfg_datasets,
normalize_config,
validate_config,
)
from axolotl.utils.dict import DictDefault

Expand All @@ -27,6 +27,13 @@ def _get_base_cfg(self):
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"learning_rate": 0.0001,
}
)

Expand Down Expand Up @@ -97,7 +104,7 @@ def test_bf16_disables_fp16(self, mock_bf16_avail):

def test_migrate_fsdp_config(self):
"""Test basic FSDP config migration with and without fsdp_version"""
cfg_with_version = DictDefault(
cfg_with_version = self._get_base_cfg() | DictDefault(
{
"fsdp_config": {
"fsdp_version": 2,
Expand All @@ -109,7 +116,7 @@ def test_migrate_fsdp_config(self):
}
)

migrate_fsdp_config(cfg_with_version)
cfg_with_version = validate_config(cfg_with_version)

self.assertEqual(cfg_with_version.fsdp_version, 2)
self.assertEqual(
Expand All @@ -125,7 +132,7 @@ def test_migrate_fsdp_config(self):
self.assertNotIn("fsdp_version", cfg_with_version.fsdp_config)
self.assertNotIn("version", cfg_with_version.fsdp_config)

cfg_without_version = DictDefault(
cfg_without_version = self._get_base_cfg() | DictDefault(
{
"fsdp_config": {
"fsdp_auto_wrap_policy": "SIZE_BASED_WRAP",
Expand All @@ -135,7 +142,7 @@ def test_migrate_fsdp_config(self):
}
)

migrate_fsdp_config(cfg_without_version)
cfg_without_version = validate_config(cfg_without_version)

self.assertNotIn("fsdp_version", cfg_without_version)
self.assertEqual(
Expand All @@ -149,26 +156,25 @@ def test_migrate_fsdp_config(self):

def test_migrate_fsdp_config_no_fsdp_config(self):
"""Test that function doesn't crash when no fsdp_config is present"""
cfg = DictDefault({"some_other_config": "value"})
cfg = self._get_base_cfg()

migrate_fsdp_config(cfg)
cfg = validate_config(cfg)

self.assertNotIn("fsdp_config", cfg)
self.assertNotIn("fsdp_version", cfg)
self.assertEqual(cfg.some_other_config, "value")

def test_migrate_fsdp_config_empty_fsdp_config(self):
"""Test migration with empty fsdp_config"""
cfg = DictDefault({"fsdp_config": {}})
cfg = self._get_base_cfg() | DictDefault({"fsdp_config": {}})

migrate_fsdp_config(cfg)
cfg = validate_config(cfg)

self.assertNotIn("fsdp_version", cfg)
self.assertEqual(cfg.fsdp_config, {})

def test_migrate_fsdp_config_mixed_keys(self):
"""Test migration with a mix of fsdp_ and non-fsdp_ keys"""
cfg = DictDefault(
cfg = self._get_base_cfg() | DictDefault(
{
"fsdp_config": {
"fsdp_version": 1,
Expand All @@ -180,7 +186,7 @@ def test_migrate_fsdp_config_mixed_keys(self):
}
)

migrate_fsdp_config(cfg)
cfg = validate_config(cfg)

self.assertEqual(cfg.fsdp_version, 1)
self.assertEqual(cfg.fsdp_config.state_dict_type, "FULL_STATE_DICT")
Expand Down
Loading
Loading