From 8e93e79c4534b7adb7e48eb31ac6116bcbf814ab Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 1 Jul 2025 21:59:31 +0000 Subject: [PATCH 01/28] debug --- src/axolotl/core/builders/base.py | 8 ++++++++ src/axolotl/core/builders/causal.py | 5 +++++ src/axolotl/utils/config/__init__.py | 7 ++++++- 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index d3a3b32424..178a704026 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -218,6 +218,14 @@ def _configure_warmup_and_logging( training_args_kwargs["warmup_steps"] = warmup_steps def _configure_precision_settings(self, training_args_kwargs: dict): + import torch.distributed as dist + + if dist.get_rank() == 0: + import ipdb + + ipdb.set_trace() + dist.barrier() + training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False training_args_kwargs["tf32"] = self.cfg.tf32 if self.cfg.bf16 == "full": diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 00cee35a72..9fcd51c1d5 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -310,6 +310,11 @@ def build(self, total_num_steps): self.cfg.neftune_noise_alpha ) + if self.cfg.accelerator_config: + training_arguments_kwargs["accelerator_config"] = ( + self.cfg.accelerator_config + ) + if self.cfg.image_size: training_arguments_kwargs["image_size"] = self.cfg.image_size if self.cfg.image_resize_algorithm: diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index c9613c39b1..fc24b351c7 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -59,8 +59,13 @@ def get_device(): def resolve_dtype(cfg): + print(cfg.bf16) + if ( - not cfg.fp16 and cfg.bf16 == "auto" and not cfg.use_ray + not cfg.mixed_precision == "fp8" + or not cfg.fp16 + and cfg.bf16 == "auto" + and not cfg.use_ray ): # if we use ray we want to defer this check to the worker node if is_torch_bf16_gpu_available(): LOG.debug("bf16 support detected, enabling for this configuration.") From 520fa821e9bf7ba988dfe694780b502eacd05b73 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 3 Jul 2025 00:48:07 +0000 Subject: [PATCH 02/28] debug --- src/axolotl/core/builders/base.py | 8 -------- src/axolotl/loaders/model.py | 6 ++++++ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 178a704026..d3a3b32424 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -218,14 +218,6 @@ def _configure_warmup_and_logging( training_args_kwargs["warmup_steps"] = warmup_steps def _configure_precision_settings(self, training_args_kwargs: dict): - import torch.distributed as dist - - if dist.get_rank() == 0: - import ipdb - - ipdb.set_trace() - dist.barrier() - training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False training_args_kwargs["tf32"] = self.cfg.tf32 if self.cfg.bf16 == "full": diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 1ce98ef319..b1e85d5527 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -620,6 +620,12 @@ def _configure_zero3_memory_efficient_loading( def _build_model(self) -> bool: """Load model, with load strategy depending on config.""" + import torch.distributed as dist + if dist.get_rank() == 0: + import ipdb + ipdb.set_trace() + dist.barrier() + skip_move_to_device = False if self.is_fsdp_enabled: if self.cfg.fsdp_config.cpu_ram_efficient_loading: From 77705b378292e0be80bad5cb78e9d5711ddb1843 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 3 Jul 2025 02:36:14 +0000 Subject: [PATCH 03/28] debug --- src/axolotl/loaders/model.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index b1e85d5527..1ce98ef319 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -620,12 +620,6 @@ def _configure_zero3_memory_efficient_loading( def _build_model(self) -> bool: """Load model, with load strategy depending on config.""" - import torch.distributed as dist - if dist.get_rank() == 0: - import ipdb - ipdb.set_trace() - dist.barrier() - skip_move_to_device = False if self.is_fsdp_enabled: if self.cfg.fsdp_config.cpu_ram_efficient_loading: From abdb95fcda60adb2a3e0ea10ac675d4c73f542ad Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 10 Jul 2025 20:59:01 +0000 Subject: [PATCH 04/28] revert unneeded change --- src/axolotl/utils/config/__init__.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index fc24b351c7..c9613c39b1 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -59,13 +59,8 @@ def get_device(): def resolve_dtype(cfg): - print(cfg.bf16) - if ( - not cfg.mixed_precision == "fp8" - or not cfg.fp16 - and cfg.bf16 == "auto" - and not cfg.use_ray + not cfg.fp16 and cfg.bf16 == "auto" and not cfg.use_ray ): # if we use ray we want to defer this check to the worker node if is_torch_bf16_gpu_available(): LOG.debug("bf16 support detected, enabling for this configuration.") From 265d9cd9be3ee9c3eb786c35aa5c046d8fa49761 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 11 Jul 2025 20:30:08 +0000 Subject: [PATCH 05/28] add accelerator config to base trainer builder --- src/axolotl/core/builders/base.py | 4 +++- src/axolotl/core/builders/causal.py | 5 ----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index d3a3b32424..39e09ad448 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -435,7 +435,9 @@ def _configure_torch_compile(self, training_args_kwargs: dict): def _configure_accelerator_config(self, training_args_kwargs: dict): if self.cfg.accelerator_config: - training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config + training_args_kwargs["accelerator_config"] = ( + self.cfg.accelerator_config + ) def _configure_gradient_checkpointing(self, training_args_kwargs: dict): if self.cfg.activation_offloading is True: diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 9fcd51c1d5..00cee35a72 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -310,11 +310,6 @@ def build(self, total_num_steps): self.cfg.neftune_noise_alpha ) - if self.cfg.accelerator_config: - training_arguments_kwargs["accelerator_config"] = ( - self.cfg.accelerator_config - ) - if self.cfg.image_size: training_arguments_kwargs["image_size"] = self.cfg.image_size if self.cfg.image_resize_algorithm: From 8afdd2f28a32302abb054e41510f13de37bb8f28 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 11 Jul 2025 20:32:18 +0000 Subject: [PATCH 06/28] add back accumulated_cache_size_limit setting --- src/axolotl/core/trainers/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index b983f10765..00b8f1495e 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -77,6 +77,7 @@ def __init__( if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + def _create_multipack_sampler( self, base_sampler: Sampler, dataset: Dataset ) -> MultipackBatchSampler: From 1f9f366fe9baf3e9d9c6f2ab2cea01a6d858dad9 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 14 Jul 2025 15:50:43 +0000 Subject: [PATCH 07/28] lint --- src/axolotl/core/builders/base.py | 4 +--- src/axolotl/core/trainers/base.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 39e09ad448..d3a3b32424 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -435,9 +435,7 @@ def _configure_torch_compile(self, training_args_kwargs: dict): def _configure_accelerator_config(self, training_args_kwargs: dict): if self.cfg.accelerator_config: - training_args_kwargs["accelerator_config"] = ( - self.cfg.accelerator_config - ) + training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config def _configure_gradient_checkpointing(self, training_args_kwargs: dict): if self.cfg.activation_offloading is True: diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 00b8f1495e..b983f10765 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -77,7 +77,6 @@ def __init__( if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - def _create_multipack_sampler( self, base_sampler: Sampler, dataset: Dataset ) -> MultipackBatchSampler: From 950c7345b1a775c44ffa2106864be840bd0a5887 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 14 Jul 2025 20:31:00 +0000 Subject: [PATCH 08/28] accelerator constructor patch for single-GPU torch fp8 --- src/axolotl/loaders/patch_manager.py | 4 + .../accelerate_torchao_fsdp_check.py | 78 +++++++++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 src/axolotl/monkeypatch/accelerate_torchao_fsdp_check.py diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index f346c56e04..e0244bc16d 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -153,8 +153,12 @@ def _apply_fp8_patches(self): from axolotl.monkeypatch.trainer_accelerator_args import ( patch_create_accelerate_code_for_fp8, ) + from axolotl.monkeypatch.accelerate_torchao_fsdp_check import ( + patch_accelerator_constructor_code_for_fp8 + ) patch_create_accelerate_code_for_fp8() + patch_accelerator_constructor_code_for_fp8() def _apply_flash_attention_peft_patches(self): """Apply patches for Flash Attention with PEFT.""" diff --git a/src/axolotl/monkeypatch/accelerate_torchao_fsdp_check.py b/src/axolotl/monkeypatch/accelerate_torchao_fsdp_check.py new file mode 100644 index 0000000000..ed468794c2 --- /dev/null +++ b/src/axolotl/monkeypatch/accelerate_torchao_fsdp_check.py @@ -0,0 +1,78 @@ +""" +Fix check in accelerate.Accelerator constructor logic. + +This can be removed if / when this PR lands in a release: +https://github.com/huggingface/accelerate/pull/3677. +""" + +import inspect + +from accelerate.accelerator import Accelerator + +from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +ORIGINAL_ACCELERATE_CODE = """ + if self.fp8_backend == "AO" and self.state.fsdp_plugin.cpu_ram_efficient_loading: +""" + +PATCHED_ACCELERATE_CODE = """ + if self.fp8_backend == "AO" and hasattr(self.state, "fsdp_plugin") and self.state.fsdp_plugin.cpu_ram_efficient_loading: +""" + + + +def get_accelerator_constructor_code() -> str: + constructor = inspect.getsource(Accelerator.__init__) + return constructor + + +def check_accelerator_constructor_code_is_patchable() -> bool: + constructor_code = get_accelerator_constructor_code() + constructor_code, _ = detab_code(constructor_code) + return ORIGINAL_ACCELERATE_CODE in constructor_code + + +def patch_accelerator_constructor_code_for_fp8(): + """ + Monkeypatch for Accelerator constructor so torchao fp8 training works outside of + FSDP training. + """ + try: + constructor_code = get_accelerator_constructor_code() + except OSError: + return + + Accelerator._original__init__ = ( # pylint: disable=protected-access + constructor_code + ) + constructor_code, _ = detab_code(constructor_code) + if ORIGINAL_ACCELERATE_CODE not in constructor_code: + return + + constructor_code = constructor_code.replace(ORIGINAL_ACCELERATE_CODE, PATCHED_ACCELERATE_CODE) + constructor_code = constructor_code.replace( + "def __init__(", + "def _patched__init__(", + 1, + ) + + # load imports necessary + import accelerate.accelerator + + items_to_import = [] + for item in dir(accelerate.accelerator): + if item in constructor_code: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + "from accelerate.accelerator import (" + + ", ".join(x for x in items_to_import) + + ")", + globals(), + ) + exec(constructor_code, globals()) + Accelerator.__init__ = _patched__init__ # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 + LOG.info("patched Accelerator.__init__ to fix torchao + FSDP guard") From d87b8b6a5d6f147126457b2f83ea87d462583b08 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 14 Jul 2025 20:31:43 +0000 Subject: [PATCH 09/28] lint --- src/axolotl/loaders/patch_manager.py | 6 +++--- src/axolotl/monkeypatch/accelerate_torchao_fsdp_check.py | 9 ++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index e0244bc16d..662e1b7406 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -150,12 +150,12 @@ def _apply_model_specific_patches(self): def _apply_fp8_patches(self): """Apply patches for FP8 support.""" if self.cfg.fp8: + from axolotl.monkeypatch.accelerate_torchao_fsdp_check import ( + patch_accelerator_constructor_code_for_fp8, + ) from axolotl.monkeypatch.trainer_accelerator_args import ( patch_create_accelerate_code_for_fp8, ) - from axolotl.monkeypatch.accelerate_torchao_fsdp_check import ( - patch_accelerator_constructor_code_for_fp8 - ) patch_create_accelerate_code_for_fp8() patch_accelerator_constructor_code_for_fp8() diff --git a/src/axolotl/monkeypatch/accelerate_torchao_fsdp_check.py b/src/axolotl/monkeypatch/accelerate_torchao_fsdp_check.py index ed468794c2..2d3ba1dfe1 100644 --- a/src/axolotl/monkeypatch/accelerate_torchao_fsdp_check.py +++ b/src/axolotl/monkeypatch/accelerate_torchao_fsdp_check.py @@ -23,7 +23,6 @@ """ - def get_accelerator_constructor_code() -> str: constructor = inspect.getsource(Accelerator.__init__) return constructor @@ -45,14 +44,14 @@ def patch_accelerator_constructor_code_for_fp8(): except OSError: return - Accelerator._original__init__ = ( # pylint: disable=protected-access - constructor_code - ) + Accelerator._original__init__ = constructor_code # pylint: disable=protected-access constructor_code, _ = detab_code(constructor_code) if ORIGINAL_ACCELERATE_CODE not in constructor_code: return - constructor_code = constructor_code.replace(ORIGINAL_ACCELERATE_CODE, PATCHED_ACCELERATE_CODE) + constructor_code = constructor_code.replace( + ORIGINAL_ACCELERATE_CODE, PATCHED_ACCELERATE_CODE + ) constructor_code = constructor_code.replace( "def __init__(", "def _patched__init__(", From e467ace2c735940aad3948d22ef5c213ed386c8f Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 14 Jul 2025 21:36:23 +0000 Subject: [PATCH 10/28] re-using existing fp8 code --- src/axolotl/core/trainers/base.py | 9 +++++++-- src/axolotl/loaders/patch_manager.py | 4 +++- src/axolotl/monkeypatch/trainer_accelerator_args.py | 9 +++++---- src/axolotl/utils/schemas/config.py | 1 + 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index b983f10765..d14b5ee0db 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -523,14 +523,19 @@ def create_accelerator_and_postprocess(self): return res def additional_accelerator_args( - self, fp8=None, **kwargs + self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs ): # pylint: disable=unused-argument ret_kwargs = {} if fp8: from accelerate.utils import AORecipeKwargs + from torchao.float8 import Float8LinearConfig + + config = Float8LinearConfig( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather + ) ret_kwargs["mixed_precision"] = "fp8" - ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()] + ret_kwargs["kwargs_handlers"] = [AORecipeKwargs(config=config)] os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8" return ret_kwargs diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 662e1b7406..bf4321b949 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -157,7 +157,9 @@ def _apply_fp8_patches(self): patch_create_accelerate_code_for_fp8, ) - patch_create_accelerate_code_for_fp8() + patch_create_accelerate_code_for_fp8( + self.cfg.fp8_enable_fsdp_float8_all_gather + ) patch_accelerator_constructor_code_for_fp8() def _apply_flash_attention_peft_patches(self): diff --git a/src/axolotl/monkeypatch/trainer_accelerator_args.py b/src/axolotl/monkeypatch/trainer_accelerator_args.py index 0a5b27c13e..feccecf059 100644 --- a/src/axolotl/monkeypatch/trainer_accelerator_args.py +++ b/src/axolotl/monkeypatch/trainer_accelerator_args.py @@ -18,7 +18,7 @@ PATCHED_TRAINER_CODE = """ if hasattr(self, "additional_accelerator_args"): - additional_args = self.additional_accelerator_args(fp8=True, **args) + additional_args = self.additional_accelerator_args(fp8=True, enable_fsdp_float8_all_gather={}, **args) if additional_args: args.update(additional_args) @@ -38,9 +38,9 @@ def check_create_accelerate_code_is_patchable() -> bool: return ORIGINAL_TRAINER_CODE in create_code -def patch_create_accelerate_code_for_fp8(): +def patch_create_accelerate_code_for_fp8(enable_fsdp_float8_all_gather: bool): """ - monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs + Monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs. """ try: @@ -54,7 +54,8 @@ def patch_create_accelerate_code_for_fp8(): if ORIGINAL_TRAINER_CODE not in create_code: return - create_code = create_code.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE) + patched_trainer_code = PATCHED_TRAINER_CODE.format(enable_fsdp_float8_all_gather) + create_code = create_code.replace(ORIGINAL_TRAINER_CODE, patched_trainer_code) create_code = create_code.replace( "def create_accelerator_and_postprocess(", "def fixed_create_accelerator_and_postprocess(", diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index de928d11c5..04a7a1f85e 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -344,6 +344,7 @@ class AxolotlInputConfig( default=None, json_schema_extra={"description": "Use CUDA fp16"} ) fp8: bool | None = None + fp8_enable_fsdp_float8_all_gather: bool | None = None bfloat16: bool | None = Field( default=None, json_schema_extra={ From 73d993f2de2e643d3289f4c4419a5e951925b656 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 15 Jul 2025 19:11:29 +0000 Subject: [PATCH 11/28] lint --- src/axolotl/core/trainers/base.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index d14b5ee0db..b444120d5f 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -7,7 +7,7 @@ import os from collections import defaultdict from functools import partial, wraps -from typing import Callable, Literal, Optional +from typing import Any, Callable, Literal, Optional import datasets import torch @@ -522,20 +522,24 @@ def create_accelerator_and_postprocess(self): return res + # pylint: disable=unused-argument def additional_accelerator_args( self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs - ): # pylint: disable=unused-argument + ) -> dict[str, Any]: ret_kwargs = {} if fp8: from accelerate.utils import AORecipeKwargs from torchao.float8 import Float8LinearConfig + # By default, Float8LinearConfig is instantiated using the "tensorwise" + # scaling strategy. See more details here: + # https://github.com/pytorch/ao/tree/main/torchao/float8. config = Float8LinearConfig( enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather ) ret_kwargs["mixed_precision"] = "fp8" - ret_kwargs["kwargs_handlers"] = [AORecipeKwargs(config=config)] + ret_kwargs["kwargs_handlers"] = [AORecipeKwargs(config=config)] # type: ignore os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8" return ret_kwargs From d53f9a79e20d313ce710a10f327511c7ef622bc5 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 17 Jul 2025 19:09:58 +0000 Subject: [PATCH 12/28] remove accelerate patch now fix in latest release --- src/axolotl/loaders/patch_manager.py | 4 - .../accelerate_torchao_fsdp_check.py | 77 ------------------- 2 files changed, 81 deletions(-) delete mode 100644 src/axolotl/monkeypatch/accelerate_torchao_fsdp_check.py diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index bf4321b949..533bd0f7a7 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -150,9 +150,6 @@ def _apply_model_specific_patches(self): def _apply_fp8_patches(self): """Apply patches for FP8 support.""" if self.cfg.fp8: - from axolotl.monkeypatch.accelerate_torchao_fsdp_check import ( - patch_accelerator_constructor_code_for_fp8, - ) from axolotl.monkeypatch.trainer_accelerator_args import ( patch_create_accelerate_code_for_fp8, ) @@ -160,7 +157,6 @@ def _apply_fp8_patches(self): patch_create_accelerate_code_for_fp8( self.cfg.fp8_enable_fsdp_float8_all_gather ) - patch_accelerator_constructor_code_for_fp8() def _apply_flash_attention_peft_patches(self): """Apply patches for Flash Attention with PEFT.""" diff --git a/src/axolotl/monkeypatch/accelerate_torchao_fsdp_check.py b/src/axolotl/monkeypatch/accelerate_torchao_fsdp_check.py deleted file mode 100644 index 2d3ba1dfe1..0000000000 --- a/src/axolotl/monkeypatch/accelerate_torchao_fsdp_check.py +++ /dev/null @@ -1,77 +0,0 @@ -""" -Fix check in accelerate.Accelerator constructor logic. - -This can be removed if / when this PR lands in a release: -https://github.com/huggingface/accelerate/pull/3677. -""" - -import inspect - -from accelerate.accelerator import Accelerator - -from axolotl.monkeypatch.utils import detab_code -from axolotl.utils.logging import get_logger - -LOG = get_logger(__name__) - -ORIGINAL_ACCELERATE_CODE = """ - if self.fp8_backend == "AO" and self.state.fsdp_plugin.cpu_ram_efficient_loading: -""" - -PATCHED_ACCELERATE_CODE = """ - if self.fp8_backend == "AO" and hasattr(self.state, "fsdp_plugin") and self.state.fsdp_plugin.cpu_ram_efficient_loading: -""" - - -def get_accelerator_constructor_code() -> str: - constructor = inspect.getsource(Accelerator.__init__) - return constructor - - -def check_accelerator_constructor_code_is_patchable() -> bool: - constructor_code = get_accelerator_constructor_code() - constructor_code, _ = detab_code(constructor_code) - return ORIGINAL_ACCELERATE_CODE in constructor_code - - -def patch_accelerator_constructor_code_for_fp8(): - """ - Monkeypatch for Accelerator constructor so torchao fp8 training works outside of - FSDP training. - """ - try: - constructor_code = get_accelerator_constructor_code() - except OSError: - return - - Accelerator._original__init__ = constructor_code # pylint: disable=protected-access - constructor_code, _ = detab_code(constructor_code) - if ORIGINAL_ACCELERATE_CODE not in constructor_code: - return - - constructor_code = constructor_code.replace( - ORIGINAL_ACCELERATE_CODE, PATCHED_ACCELERATE_CODE - ) - constructor_code = constructor_code.replace( - "def __init__(", - "def _patched__init__(", - 1, - ) - - # load imports necessary - import accelerate.accelerator - - items_to_import = [] - for item in dir(accelerate.accelerator): - if item in constructor_code: - items_to_import.append(item) - - exec( # pylint: disable=exec-used # nosec B102 - "from accelerate.accelerator import (" - + ", ".join(x for x in items_to_import) - + ")", - globals(), - ) - exec(constructor_code, globals()) - Accelerator.__init__ = _patched__init__ # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 - LOG.info("patched Accelerator.__init__ to fix torchao + FSDP guard") From 4d0a4feea2629b8070f375e03b8b1da8a4f1c6d9 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 17 Jul 2025 21:17:00 +0000 Subject: [PATCH 13/28] fix --- src/axolotl/monkeypatch/trainer_accelerator_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/trainer_accelerator_args.py b/src/axolotl/monkeypatch/trainer_accelerator_args.py index feccecf059..92daa6b15d 100644 --- a/src/axolotl/monkeypatch/trainer_accelerator_args.py +++ b/src/axolotl/monkeypatch/trainer_accelerator_args.py @@ -18,7 +18,7 @@ PATCHED_TRAINER_CODE = """ if hasattr(self, "additional_accelerator_args"): - additional_args = self.additional_accelerator_args(fp8=True, enable_fsdp_float8_all_gather={}, **args) + additional_args = self.additional_accelerator_args(fp8=True, enable_fsdp_float8_all_gather=False, **args) if additional_args: args.update(additional_args) From 1220c0227f9bd021ddc262cf2fb45f0388afb766 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 21 Jul 2025 13:29:33 -0400 Subject: [PATCH 14/28] docs --- _quarto.yml | 2 + docs/mixed_precision.qmd | 147 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 docs/mixed_precision.qmd diff --git a/_quarto.yml b/_quarto.yml index 3e773a748f..dab1ee363b 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -268,6 +268,8 @@ website: - docs/batch_vs_grad.qmd - docs/dataset_preprocessing.qmd - docs/multipack.qmd + - docs/mixed_precision.qmd + - docs/gradient_accumulation.qmd - section: "Advanced Features" contents: diff --git a/docs/mixed_precision.qmd b/docs/mixed_precision.qmd new file mode 100644 index 0000000000..279ddbadd9 --- /dev/null +++ b/docs/mixed_precision.qmd @@ -0,0 +1,147 @@ +--- +title: "Mixed Precision Training" +format: + html: + toc: true + toc-depth: 3 + number-sections: true + code-tools: true +execute: + enabled: false +--- + +Mixed precision training uses lower precision data types to reduce memory usage and increase training speed while maintaining model quality. Axolotl supports several mixed precision formats: + +- **FP16** - Half precision 16-bit (Pascal generation+) +- **BF16** - Brain Float 16-bit (Ampere generation+) +- **FP8** - 8-bit floating point (Hopper generation+) + +## FP16 Mixed Precision {#sec-fp16} + +### Overview {#sec-fp16-overview} + +FP16 is the traditional half-precision format, supported on older GPUs but can be less numerically stable than BF16. + +### Configuration {#sec-fp16-config} + +```{.yaml} +fp16: true +``` + +### FP16 Considerations {#sec-fp16-considerations} + +- May require gradient scaling to prevent underflow +- Less numerically stable than BF16 +- Can cause training instability with some model architectures +- Consider using BF16 if your hardware supports it + +## BF16 Mixed Precision {#sec-bf16} + +### Overview {#sec-bf16-overview} + +BF16 (Brain Float 16) offers better numerical stability than FP16 and is the recommended mixed precision format for modern GPUs. It provides the same dynamic range as FP32 while using half the memory. + +### Configuration {#sec-bf16-config} + +```{.yaml} +# Automatic BF16 detection (recommended) +bf16: auto + +# Or explicitly enable +bf16: true + +# For evaluation with BF16 +bf16: full # Equivalent to bf16_full_eval in the HF trainer +``` + +## FP8 Mixed Precision {#sec-fp8} + +::: {.callout-note} +FP8 support is experimental and requires compatible hardware (H100, H200) and recent PyTorch versions with TorchAO. +::: + +### What is FP8? {#sec-fp8-overview} + +FP8 (8-bit floating point) can provide significant time savings compared to FP16/BF16 while maintaining training stability. Axolotl's implementation uses PyTorch's TorchAO library with "tensorwise" scaling strategy. + +### Requirements {#sec-fp8-software} + +- Hopper+ GPUs (H100/H200) +- PyTorch 2.7+ (+ compatible TorchAO version) +- CUDA 12.4+ + +### Configuration {#sec-fp8-config} + +Add to your YAML config: + +```{.yaml} +# Enable FP8 mixed precision +fp8: true + +# Optional: Enable FP8 for FSDP all-gather operations +fp8_enable_fsdp_float8_all_gather: true + +# Enable torch.compile (almost always necessary for FP8 speedups) +torch_compile: true +``` + +::: {.callout-important} +**torch.compile is critical for FP8 performance** + +FP8 training requires `torch_compile: true` to see meaningful speedups. Without compilation, FP8 may actually be slower and use more memory than FP16/BF16. +::: + +### Advanced FP8 Configs {#sec-fp8-advanced} + +For FSDP (Fully Sharded Data Parallel) training: + +```{.yaml} +fp8: true +fp8_enable_fsdp_float8_all_gather: true + +torch_compile: true + +# FSDP configuration +fsdp_version: 2 +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: LlamaDecoderLayer + state_dict_type: FULL_STATE_DICT + reshard_after_forward: true +``` + +## Best Practices {#sec-best-practices} + +### Choosing Precision Format {#sec-choosing-format} + +- **Start with automatic detection**: `bf16: auto` +- **For Hopper+ (H100/H200)**: Try FP8 + torch.compile for maximum speed +- **For Ampere (A100/RTX 30/40)**: Use BF16 +- **For older Pascal/Turing GPUs**: Use FP16 with caution +- **For very old or unsupported GPUs**: Use FP32 + +### Validation and Testing {#sec-validation} + +Always validate your mixed precision setup: + +- **Start with a small dataset** to verify stability +- **Monitor loss curves** for irregularities +- **Compare with FP32 baseline** when possible +- **Test evaluation metrics** match expectations + +### FP8 Particulars (#sec-fp8-details) + +- Use cases + - Single GPU training + - Multi GPU training with FSDP2 or Deepspeed +- Speedups + - Please refer to the [TorchAO FP8 training benchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling) for expected matmul speedups for different (M, K, N) settings + - Concrete number for LLaMA 3 8B can be found [here](https://github.com/pytorch/ao/tree/main/torchao/float8#training-benchmarks) +- Known issues: + - FP8 + DDP + `torch.compile` (causes [error](https://gist.github.com/djsaunde/0c1664c32e44a64d31b5e01b4aafe5c4)) + - FP8 + FSDP2 + `torch.compile` + FSDP2 activation checkpointing is _slower_ than the BF16 equivalent training + - Flash Attention 2 does not play nicely with `torch.compile` + +For more information on multi-GPU training, see our [Multi-GPU guide](multi-gpu.qmd). From de54dc747eb3738fe0b43a1afce93a2377f91c4e Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 21 Jul 2025 17:37:25 +0000 Subject: [PATCH 15/28] add fp8 + fsdp2 example --- docs/mixed_precision.qmd | 2 + examples/llama-3/3b-fp8-fsdp2.yaml | 81 ++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 examples/llama-3/3b-fp8-fsdp2.yaml diff --git a/docs/mixed_precision.qmd b/docs/mixed_precision.qmd index 279ddbadd9..063149a9b4 100644 --- a/docs/mixed_precision.qmd +++ b/docs/mixed_precision.qmd @@ -144,4 +144,6 @@ Always validate your mixed precision setup: - FP8 + FSDP2 + `torch.compile` + FSDP2 activation checkpointing is _slower_ than the BF16 equivalent training - Flash Attention 2 does not play nicely with `torch.compile` +See `examples/llama-3/3b-fp8-fsdp2.yaml` for an optimized example config. Enabling FP8 training results in ~50% faster iterations per second vs. BF16. + For more information on multi-GPU training, see our [Multi-GPU guide](multi-gpu.qmd). diff --git a/examples/llama-3/3b-fp8-fsdp2.yaml b/examples/llama-3/3b-fp8-fsdp2.yaml new file mode 100644 index 0000000000..f50208ebf9 --- /dev/null +++ b/examples/llama-3/3b-fp8-fsdp2.yaml @@ -0,0 +1,81 @@ +base_model: meta-llama/Llama-3.2-3B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true + +datasets: + - path: yahma/alpaca-cleaned + type: alpaca + +output_dir: ./outputs/fp8_out/ + +sample_packing: true +pad_to_sequence_len: true +sequence_len: 512 + +flex_attention: true +flex_attn_compile_kwargs: + dynamic: false + mode: max-autotune-no-cudagraphs + +torch_compile: true + +qat: + activation_dtype: int8 + weight_dtype: int4 + group_size: 32 + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 16 +num_epochs: 1 +optimizer: adamw_torch_fused + +cosine_constant_lr_ratio: 0 +cosine_min_lr_ratio: 1.0 +learning_rate: 2e-5 +save_only_model: true + +fp8: true +fp8_enable_fsdp_float8_all_gather: true + +resume_from_checkpoint: +logging_steps: 1 + +evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_steps: 10 +weight_decay: 0.0 + +fsdp_config: + fsdp_version: 2 + fsdp_offload_params: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD + fsdp_reshard_after_forward: true + fsdp_activation_checkpointing: true + +special_tokens: + pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config From 015d9de34efd04d62774b03a59fdcfbd281e6b6d Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 21 Jul 2025 17:38:28 +0000 Subject: [PATCH 16/28] remove unused config --- examples/llama-3/3b-fp8-fsdp2.yaml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/examples/llama-3/3b-fp8-fsdp2.yaml b/examples/llama-3/3b-fp8-fsdp2.yaml index f50208ebf9..cc62091e3f 100644 --- a/examples/llama-3/3b-fp8-fsdp2.yaml +++ b/examples/llama-3/3b-fp8-fsdp2.yaml @@ -32,11 +32,6 @@ flex_attn_compile_kwargs: torch_compile: true -qat: - activation_dtype: int8 - weight_dtype: int4 - group_size: 32 - wandb_project: wandb_entity: wandb_watch: From d33b4e45a1d176f62b116d18bce1520c2f346ac9 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 21 Jul 2025 17:45:39 +0000 Subject: [PATCH 17/28] update config --- docs/mixed_precision.qmd | 2 +- examples/llama-3/3b-fp8-fsdp2.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/mixed_precision.qmd b/docs/mixed_precision.qmd index 063149a9b4..82c6784438 100644 --- a/docs/mixed_precision.qmd +++ b/docs/mixed_precision.qmd @@ -144,6 +144,6 @@ Always validate your mixed precision setup: - FP8 + FSDP2 + `torch.compile` + FSDP2 activation checkpointing is _slower_ than the BF16 equivalent training - Flash Attention 2 does not play nicely with `torch.compile` -See `examples/llama-3/3b-fp8-fsdp2.yaml` for an optimized example config. Enabling FP8 training results in ~50% faster iterations per second vs. BF16. +See `examples/llama-3/3b-fp8-fsdp2.yaml` for an optimized example config. Enabling FP8 mixed precision + FP8 all-gather training results in ~10% faster iterations per second vs. BF16 for a relatively small (3B param) model For more information on multi-GPU training, see our [Multi-GPU guide](multi-gpu.qmd). diff --git a/examples/llama-3/3b-fp8-fsdp2.yaml b/examples/llama-3/3b-fp8-fsdp2.yaml index cc62091e3f..93e5f007a7 100644 --- a/examples/llama-3/3b-fp8-fsdp2.yaml +++ b/examples/llama-3/3b-fp8-fsdp2.yaml @@ -68,7 +68,7 @@ fsdp_config: fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD fsdp_reshard_after_forward: true - fsdp_activation_checkpointing: true + fsdp_activation_checkpointing: false special_tokens: pad_token: <|end_of_text|> From 28f91ac5f3606b634b012eefa296847874e2035b Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 21 Jul 2025 14:19:46 -0400 Subject: [PATCH 18/28] smoke tests --- tests/e2e/integrations/test_fp8.py | 61 +++++++++ tests/e2e/multigpu/test_fp8_fsdp2.py | 193 +++++++++++++++++++++++++++ 2 files changed, 254 insertions(+) create mode 100644 tests/e2e/integrations/test_fp8.py create mode 100644 tests/e2e/multigpu/test_fp8_fsdp2.py diff --git a/tests/e2e/integrations/test_fp8.py b/tests/e2e/integrations/test_fp8.py new file mode 100644 index 0000000000..68acd5c039 --- /dev/null +++ b/tests/e2e/integrations/test_fp8.py @@ -0,0 +1,61 @@ +""" +Simple end-to-end smoke tests for FP8 mixed precision training +""" + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1 + + +class FP8IntegrationTestCase: + """ + e2e smoke tests for FP8 mixed precision training with Axolotl + """ + + @require_torch_2_4_1 + def test_fp8_single_gpu_smoke(self, temp_dir): + """Smoke test for single GPU FP8 + torch.compile training""" + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "trust_remote_code": True, + "sequence_len": 512, + "val_set_size": 0.05, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, # Very short smoke test + "micro_batch_size": 1, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", # Use standard optimizer for stability + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + "fp8": True, # Enable FP8 mixed precision + "torch_compile": True, # Essential for FP8 performance + "save_safetensors": True, + "save_first_step": False, + } + ) + + # pylint: disable=duplicate-code + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/multigpu/test_fp8_fsdp2.py b/tests/e2e/multigpu/test_fp8_fsdp2.py new file mode 100644 index 0000000000..b6dbf9a2fe --- /dev/null +++ b/tests/e2e/multigpu/test_fp8_fsdp2.py @@ -0,0 +1,193 @@ +"""Test module for FP8 mixed precision with FSDP2 multi-GPU functionality.""" + +# pylint: disable=duplicate-code + +import os +from pathlib import Path + +import torch +import yaml +from accelerate.test_utils import execute_subprocess_async +from tbparse import SummaryReader +from transformers.testing_utils import get_torch_dist_unique_port + +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import most_recent_subdir, require_torch_2_4_1 + +AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent + + +def verify_fp8_training_success(temp_dir): + """Verify that FP8 training completed successfully by checking artifacts and loss.""" + output_path = Path(temp_dir) + + model_files = list(output_path.glob("*.bin")) + list( + output_path.glob("*.safetensors") + ) + assert len(model_files) > 0, "No model files found - training may have failed" + + checkpoint_files = list(output_path.glob("checkpoint-*")) + assert ( + len(checkpoint_files) > 0 + ), "No checkpoint files found - training may have failed" + + tb_log_path = most_recent_subdir(temp_dir + "/runs") + if tb_log_path: + event_files = sorted(os.listdir(tb_log_path)) + if event_files: + event_file = os.path.join(tb_log_path, event_files[0]) + reader = SummaryReader(event_file) + df = reader.scalars + train_loss_df = df[df.tag == "train/train_loss"] + if len(train_loss_df) > 0: + final_loss = train_loss_df.value.values[-1] + assert not torch.isnan( + torch.tensor(final_loss) + ), f"Training loss is NaN: {final_loss}" + + +class TestFP8FSDP2: + """Test class for FP8 mixed precision with FSDP2 functionality.""" + + @require_torch_2_4_1 + def test_fp8_fsdp2_smoke(self, temp_dir): + """Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training""" + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "trust_remote_code": True, + "sequence_len": 512, + "val_set_size": 0.05, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, # Very short smoke test + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", # Use standard optimizer for stability + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + # FP8 configuration + "fp8": True, # Enable FP8 mixed precision + "fp8_enable_fsdp_float8_all_gather": True, # Enable FP8 all-gather + "torch_compile": True, # Essential for FP8 performance + # FSDP2 configuration + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "LlamaDecoderLayer", # SmolLM2 uses Llama architecture + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "save_safetensors": True, + "save_first_step": False, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_fp8_training_success(temp_dir) + + @require_torch_2_4_1 + def test_fp8_fsdp2_lora_smoke(self, temp_dir): + """Smoke test for 2-GPU FP8 + torch.compile + FSDP2 + LoRA training""" + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "trust_remote_code": True, + "sequence_len": 512, + "val_set_size": 0.05, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + # LoRA configuration + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "num_epochs": 1, + "max_steps": 3, # Very short smoke test + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", # Use standard optimizer for stability + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + # FP8 configuration + "fp8": True, # Enable FP8 mixed precision + "fp8_enable_fsdp_float8_all_gather": True, # Enable FP8 all-gather + "torch_compile": True, # Essential for FP8 performance + # FSDP2 configuration + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "LlamaDecoderLayer", # SmolLM2 uses Llama architecture + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "save_safetensors": True, + "save_first_step": False, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_fp8_training_success(temp_dir) From a9d8497c8ce2dff1f4e3d7b7d1005c482cd50965 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 21 Jul 2025 15:22:42 -0400 Subject: [PATCH 19/28] add validator --- docs/mixed_precision.qmd | 4 ++-- src/axolotl/utils/schemas/validation.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/docs/mixed_precision.qmd b/docs/mixed_precision.qmd index 82c6784438..b1f226545f 100644 --- a/docs/mixed_precision.qmd +++ b/docs/mixed_precision.qmd @@ -138,10 +138,10 @@ Always validate your mixed precision setup: - Multi GPU training with FSDP2 or Deepspeed - Speedups - Please refer to the [TorchAO FP8 training benchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling) for expected matmul speedups for different (M, K, N) settings - - Concrete number for LLaMA 3 8B can be found [here](https://github.com/pytorch/ao/tree/main/torchao/float8#training-benchmarks) + - Concrete number for LLaMA 3 8B training can be found [here](https://github.com/pytorch/ao/tree/main/torchao/float8#training-benchmarks) - Known issues: - FP8 + DDP + `torch.compile` (causes [error](https://gist.github.com/djsaunde/0c1664c32e44a64d31b5e01b4aafe5c4)) - - FP8 + FSDP2 + `torch.compile` + FSDP2 activation checkpointing is _slower_ than the BF16 equivalent training + - FP8 + FSDP2 + `torch.compile` + FSDP2 activation checkpointing tends to be _slower_ than the BF16 equivalent training - Flash Attention 2 does not play nicely with `torch.compile` See `examples/llama-3/3b-fp8-fsdp2.yaml` for an optimized example config. Enabling FP8 mixed precision + FP8 all-gather training results in ~10% faster iterations per second vs. BF16 for a relatively small (3B param) model diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 64dbb2529a..18100d72bc 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -360,6 +360,26 @@ def check_fft_possible_bad_config(self): # RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half return self + @model_validator(mode="before") + @classmethod + def check_fp8_config(cls, data): + if data.get("fp8") and not data.get("torch_compile"): + LOG.warning( + "torch_compile is strongly recommended for FP8 training in order to " + "see speed improvements. Please consider setting `torch_compile: " + "true` in your config." + ) + if ( + data.get("fp8") + and data.get("fsdp_config", {}).get("activation_checkpointing") is True + ): + LOG.warning( + "FSDP activation checkpointing may be slower with FP8 training. Please " + "proceed with caution." + ) + + return data + @model_validator(mode="before") @classmethod def check_use_reentrant_mismatch(cls, data): From d5be517959e94ae28294613e040c9f18d475f38e Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 22 Jul 2025 14:06:53 +0000 Subject: [PATCH 20/28] add 2.7.0 guard for fsdp2 --- tests/e2e/integrations/test_fp8.py | 4 +- tests/e2e/multigpu/test_fp8_fsdp2.py | 78 +--------------------------- 2 files changed, 4 insertions(+), 78 deletions(-) diff --git a/tests/e2e/integrations/test_fp8.py b/tests/e2e/integrations/test_fp8.py index 68acd5c039..d20566c330 100644 --- a/tests/e2e/integrations/test_fp8.py +++ b/tests/e2e/integrations/test_fp8.py @@ -7,7 +7,7 @@ from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault -from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1 +from tests.e2e.utils import check_model_output_exists, require_torch_2_7_0 class FP8IntegrationTestCase: @@ -15,7 +15,7 @@ class FP8IntegrationTestCase: e2e smoke tests for FP8 mixed precision training with Axolotl """ - @require_torch_2_4_1 + @require_torch_2_7_0 def test_fp8_single_gpu_smoke(self, temp_dir): """Smoke test for single GPU FP8 + torch.compile training""" # pylint: disable=duplicate-code diff --git a/tests/e2e/multigpu/test_fp8_fsdp2.py b/tests/e2e/multigpu/test_fp8_fsdp2.py index b6dbf9a2fe..2ebbee04e3 100644 --- a/tests/e2e/multigpu/test_fp8_fsdp2.py +++ b/tests/e2e/multigpu/test_fp8_fsdp2.py @@ -13,7 +13,7 @@ from axolotl.utils.dict import DictDefault -from tests.e2e.utils import most_recent_subdir, require_torch_2_4_1 +from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0 AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent @@ -50,7 +50,7 @@ def verify_fp8_training_success(temp_dir): class TestFP8FSDP2: """Test class for FP8 mixed precision with FSDP2 functionality.""" - @require_torch_2_4_1 + @require_torch_2_7_0 def test_fp8_fsdp2_smoke(self, temp_dir): """Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training""" cfg = DictDefault( @@ -117,77 +117,3 @@ def test_fp8_fsdp2_smoke(self, temp_dir): ) verify_fp8_training_success(temp_dir) - - @require_torch_2_4_1 - def test_fp8_fsdp2_lora_smoke(self, temp_dir): - """Smoke test for 2-GPU FP8 + torch.compile + FSDP2 + LoRA training""" - cfg = DictDefault( - { - "base_model": "HuggingFaceTB/SmolLM2-135M", - "tokenizer_type": "AutoTokenizer", - "trust_remote_code": True, - "sequence_len": 512, - "val_set_size": 0.05, - "special_tokens": { - "pad_token": "<|endoftext|>", - }, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - # LoRA configuration - "adapter": "lora", - "lora_r": 8, - "lora_alpha": 16, - "lora_dropout": 0.05, - "lora_target_linear": True, - "num_epochs": 1, - "max_steps": 3, # Very short smoke test - "micro_batch_size": 1, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_torch_fused", # Use standard optimizer for stability - "lr_scheduler": "cosine", - "flash_attention": True, - "sample_packing": True, - # FP8 configuration - "fp8": True, # Enable FP8 mixed precision - "fp8_enable_fsdp_float8_all_gather": True, # Enable FP8 all-gather - "torch_compile": True, # Essential for FP8 performance - # FSDP2 configuration - "fsdp_version": 2, - "fsdp_config": { - "offload_params": False, - "cpu_ram_efficient_loading": False, - "transformer_layer_cls_to_wrap": "LlamaDecoderLayer", # SmolLM2 uses Llama architecture - "state_dict_type": "FULL_STATE_DICT", - "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", - "reshard_after_forward": True, - }, - "use_tensorboard": True, - "save_safetensors": True, - "save_first_step": False, - } - ) - - # write cfg to yaml file - Path(temp_dir).mkdir(parents=True, exist_ok=True) - with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: - fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) - - execute_subprocess_async( - [ - "axolotl", - "train", - str(Path(temp_dir) / "config.yaml"), - "--num-processes", - "2", - "--main-process-port", - f"{get_torch_dist_unique_port()}", - ] - ) - - verify_fp8_training_success(temp_dir) From 705d171d4cb2b9235da462ab3083e051cbee3623 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 22 Jul 2025 14:50:37 +0000 Subject: [PATCH 21/28] fix --- docs/mixed_precision.qmd | 2 +- examples/llama-3/3b-fp8-fsdp2.yaml | 16 ++++++++-------- .../monkeypatch/trainer_accelerator_args.py | 6 ++++-- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/docs/mixed_precision.qmd b/docs/mixed_precision.qmd index b1f226545f..9644d15aa5 100644 --- a/docs/mixed_precision.qmd +++ b/docs/mixed_precision.qmd @@ -131,7 +131,7 @@ Always validate your mixed precision setup: - **Compare with FP32 baseline** when possible - **Test evaluation metrics** match expectations -### FP8 Particulars (#sec-fp8-details) +### FP8 Particulars {#sec-fp8-details} - Use cases - Single GPU training diff --git a/examples/llama-3/3b-fp8-fsdp2.yaml b/examples/llama-3/3b-fp8-fsdp2.yaml index 93e5f007a7..bea698c0e3 100644 --- a/examples/llama-3/3b-fp8-fsdp2.yaml +++ b/examples/llama-3/3b-fp8-fsdp2.yaml @@ -60,15 +60,15 @@ saves_per_epoch: 1 warmup_steps: 10 weight_decay: 0.0 +fsdp_version: 2 fsdp_config: - fsdp_version: 2 - fsdp_offload_params: false - fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP - fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer - fsdp_state_dict_type: FULL_STATE_DICT - fsdp_sharding_strategy: FULL_SHARD - fsdp_reshard_after_forward: true - fsdp_activation_checkpointing: false + offload_params: false + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: LlamaDecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: false special_tokens: pad_token: <|end_of_text|> diff --git a/src/axolotl/monkeypatch/trainer_accelerator_args.py b/src/axolotl/monkeypatch/trainer_accelerator_args.py index 92daa6b15d..819a66255c 100644 --- a/src/axolotl/monkeypatch/trainer_accelerator_args.py +++ b/src/axolotl/monkeypatch/trainer_accelerator_args.py @@ -18,7 +18,7 @@ PATCHED_TRAINER_CODE = """ if hasattr(self, "additional_accelerator_args"): - additional_args = self.additional_accelerator_args(fp8=True, enable_fsdp_float8_all_gather=False, **args) + additional_args = self.additional_accelerator_args(fp8=True, enable_fsdp_float8_all_gather={enable_fsdp_float8_all_gather}, **args) if additional_args: args.update(additional_args) @@ -54,7 +54,9 @@ def patch_create_accelerate_code_for_fp8(enable_fsdp_float8_all_gather: bool): if ORIGINAL_TRAINER_CODE not in create_code: return - patched_trainer_code = PATCHED_TRAINER_CODE.format(enable_fsdp_float8_all_gather) + patched_trainer_code = PATCHED_TRAINER_CODE.format( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather + ) create_code = create_code.replace(ORIGINAL_TRAINER_CODE, patched_trainer_code) create_code = create_code.replace( "def create_accelerator_and_postprocess(", From a700c035de08408b5ef82c6f0f83b423beec3745 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 22 Jul 2025 15:01:02 +0000 Subject: [PATCH 22/28] add config descriptions --- src/axolotl/utils/schemas/config.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 04a7a1f85e..96b694043b 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -343,8 +343,20 @@ class AxolotlInputConfig( fp16: bool | None = Field( default=None, json_schema_extra={"description": "Use CUDA fp16"} ) - fp8: bool | None = None - fp8_enable_fsdp_float8_all_gather: bool | None = None + fp8: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Enable FP8 mixed precision training using TorchAO. Best " + "used in combination with torch.compile." + }, + ) + fp8_enable_fsdp_float8_all_gather: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Enable FSDP float8 all-gather optimization for FP8 training. Can " + "improve training speed by 10-15% when FSDP is enabled." + }, + ) bfloat16: bool | None = Field( default=None, json_schema_extra={ From 319e57f50550593ba6b35777a358bf43930dcab2 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 22 Jul 2025 15:02:28 +0000 Subject: [PATCH 23/28] add FSDP doc link --- docs/mixed_precision.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mixed_precision.qmd b/docs/mixed_precision.qmd index 9644d15aa5..7b77cd4bb4 100644 --- a/docs/mixed_precision.qmd +++ b/docs/mixed_precision.qmd @@ -93,7 +93,7 @@ FP8 training requires `torch_compile: true` to see meaningful speedups. Without ### Advanced FP8 Configs {#sec-fp8-advanced} -For FSDP (Fully Sharded Data Parallel) training: +For [FSDP](multi-gpu.qmd#sec-fsdp) (Fully Sharded Data Parallel) training: ```{.yaml} fp8: true From f47d847fb2a186a64dca3eb52b52a17f6a649a3f Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 22 Jul 2025 17:11:03 +0000 Subject: [PATCH 24/28] nit --- src/axolotl/utils/schemas/validation.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 18100d72bc..256618d979 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -369,13 +369,22 @@ def check_fp8_config(cls, data): "see speed improvements. Please consider setting `torch_compile: " "true` in your config." ) - if ( - data.get("fp8") - and data.get("fsdp_config", {}).get("activation_checkpointing") is True + if data.get("fp8") and ( + data.get("fsdp_config", {}).get("activation_checkpointing", False) is True + or data.get("fsdp_config", {}).get("fsdp_activation_checkpointing", False) + is True ): LOG.warning( - "FSDP activation checkpointing may be slower with FP8 training. Please " - "proceed with caution." + "FP8 + FSDP2 + activation checkpointing may be slower than BF16 " + "training. Please proceed with caution." + ) + if ( + data.get("fp8_enable_fsdp_float8_all_gather") + and not data.get("fsdp_version", None) == 2 + ): + raise ValueError( + "fp8_enable_fsdp_float8_all_gather requires FSDP2 (fsdp_version: 2) " + "to be used." ) return data From 9b477d8aaeb911c29bd95df669f99fe2a7b28d93 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 22 Jul 2025 17:36:57 +0000 Subject: [PATCH 25/28] set force_recompute_fp8_weight_in_bwd with enable_fsdp_float8_all_gather --- src/axolotl/core/trainers/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index b444120d5f..3dfaf47ce4 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -535,7 +535,8 @@ def additional_accelerator_args( # scaling strategy. See more details here: # https://github.com/pytorch/ao/tree/main/torchao/float8. config = Float8LinearConfig( - enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, + force_recompute_fp8_weight_in_bwd=enable_fsdp_float8_all_gather is True, ) ret_kwargs["mixed_precision"] = "fp8" From d0d116095c66f48b2bcb8bc06a021d011a5f2c0b Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 22 Jul 2025 17:42:15 +0000 Subject: [PATCH 26/28] better cfg for smoke tests --- tests/e2e/integrations/test_fp8.py | 9 +++++---- tests/e2e/multigpu/test_fp8_fsdp2.py | 11 ++++++----- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/e2e/integrations/test_fp8.py b/tests/e2e/integrations/test_fp8.py index d20566c330..0302b7e35b 100644 --- a/tests/e2e/integrations/test_fp8.py +++ b/tests/e2e/integrations/test_fp8.py @@ -41,12 +41,13 @@ def test_fp8_single_gpu_smoke(self, temp_dir): "gradient_accumulation_steps": 2, "output_dir": temp_dir, "learning_rate": 0.00001, - "optimizer": "adamw_torch_fused", # Use standard optimizer for stability + "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", - "flash_attention": True, + "sdp_attention": True, + "pad_to_seq_len": True, "sample_packing": True, - "fp8": True, # Enable FP8 mixed precision - "torch_compile": True, # Essential for FP8 performance + "fp8": True, + "torch_compile": True, "save_safetensors": True, "save_first_step": False, } diff --git a/tests/e2e/multigpu/test_fp8_fsdp2.py b/tests/e2e/multigpu/test_fp8_fsdp2.py index 2ebbee04e3..6423f5e2e2 100644 --- a/tests/e2e/multigpu/test_fp8_fsdp2.py +++ b/tests/e2e/multigpu/test_fp8_fsdp2.py @@ -77,18 +77,19 @@ def test_fp8_fsdp2_smoke(self, temp_dir): "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", # Use standard optimizer for stability "lr_scheduler": "cosine", - "flash_attention": True, + "sdp_attention": True, + "pad_to_seq_len": True, "sample_packing": True, # FP8 configuration - "fp8": True, # Enable FP8 mixed precision - "fp8_enable_fsdp_float8_all_gather": True, # Enable FP8 all-gather - "torch_compile": True, # Essential for FP8 performance + "fp8": True, + "fp8_enable_fsdp_float8_all_gather": True, + "torch_compile": True, # FSDP2 configuration "fsdp_version": 2, "fsdp_config": { "offload_params": False, "cpu_ram_efficient_loading": False, - "transformer_layer_cls_to_wrap": "LlamaDecoderLayer", # SmolLM2 uses Llama architecture + "transformer_layer_cls_to_wrap": "LlamaDecoderLayer", "state_dict_type": "FULL_STATE_DICT", "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", "reshard_after_forward": True, From 1f375e220d15df3aba76198c804485429c0b05ee Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 22 Jul 2025 15:53:23 -0400 Subject: [PATCH 27/28] add test for accelerate patching --- .../test_trainer_accelerator_args.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/monkeypatch/test_trainer_accelerator_args.py diff --git a/tests/monkeypatch/test_trainer_accelerator_args.py b/tests/monkeypatch/test_trainer_accelerator_args.py new file mode 100644 index 0000000000..fab2597f01 --- /dev/null +++ b/tests/monkeypatch/test_trainer_accelerator_args.py @@ -0,0 +1,26 @@ +""" +Unit tests for trainer accelerator args monkeypatch +""" + +import unittest + +from axolotl.monkeypatch.trainer_accelerator_args import ( + check_create_accelerate_code_is_patchable, +) + + +class TestTrainerAcceleratorArgs(unittest.TestCase): + """ + Unit test class for trainer accelerator args monkeypatch + """ + + def test_check_create_accelerate_code_is_patchable(self): + """ + Test that the upstream transformers code is still patchable. + This will fail if the patched code changes upstream. + """ + assert check_create_accelerate_code_is_patchable() + + +if __name__ == "__main__": + unittest.main() From 9f8cc4d872af74799e4060dce7e41354ce9ff396 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 22 Jul 2025 15:57:47 -0400 Subject: [PATCH 28/28] update fp8 validator --- src/axolotl/utils/schemas/validation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 256618d979..0c1a97fcdf 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -376,7 +376,8 @@ def check_fp8_config(cls, data): ): LOG.warning( "FP8 + FSDP2 + activation checkpointing may be slower than BF16 " - "training. Please proceed with caution." + "training. Please considering setting `activation_checkpointing: false` " + "in your FSDP config." ) if ( data.get("fp8_enable_fsdp_float8_all_gather")