Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions _quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ website:
- docs/batch_vs_grad.qmd
- docs/dataset_preprocessing.qmd
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/gradient_accumulation.qmd

- section: "Advanced Features"
contents:
Expand Down
149 changes: 149 additions & 0 deletions docs/mixed_precision.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
---
title: "Mixed Precision Training"
format:
html:
toc: true
toc-depth: 3
number-sections: true
code-tools: true
execute:
enabled: false
---

Mixed precision training uses lower precision data types to reduce memory usage and increase training speed while maintaining model quality. Axolotl supports several mixed precision formats:

- **FP16** - Half precision 16-bit (Pascal generation+)
- **BF16** - Brain Float 16-bit (Ampere generation+)
- **FP8** - 8-bit floating point (Hopper generation+)

## FP16 Mixed Precision {#sec-fp16}

### Overview {#sec-fp16-overview}

FP16 is the traditional half-precision format, supported on older GPUs but can be less numerically stable than BF16.

### Configuration {#sec-fp16-config}

```{.yaml}
fp16: true
```

### FP16 Considerations {#sec-fp16-considerations}

- May require gradient scaling to prevent underflow
- Less numerically stable than BF16
- Can cause training instability with some model architectures
- Consider using BF16 if your hardware supports it

## BF16 Mixed Precision {#sec-bf16}

### Overview {#sec-bf16-overview}

BF16 (Brain Float 16) offers better numerical stability than FP16 and is the recommended mixed precision format for modern GPUs. It provides the same dynamic range as FP32 while using half the memory.

### Configuration {#sec-bf16-config}

```{.yaml}
# Automatic BF16 detection (recommended)
bf16: auto

# Or explicitly enable
bf16: true

# For evaluation with BF16
bf16: full # Equivalent to bf16_full_eval in the HF trainer
```
Comment thread
djsaunde marked this conversation as resolved.

## FP8 Mixed Precision {#sec-fp8}

::: {.callout-note}
FP8 support is experimental and requires compatible hardware (H100, H200) and recent PyTorch versions with TorchAO.
:::

### What is FP8? {#sec-fp8-overview}

FP8 (8-bit floating point) can provide significant time savings compared to FP16/BF16 while maintaining training stability. Axolotl's implementation uses PyTorch's TorchAO library with "tensorwise" scaling strategy.

### Requirements {#sec-fp8-software}

- Hopper+ GPUs (H100/H200)
- PyTorch 2.7+ (+ compatible TorchAO version)
- CUDA 12.4+

### Configuration {#sec-fp8-config}

Add to your YAML config:

```{.yaml}
# Enable FP8 mixed precision
fp8: true

# Optional: Enable FP8 for FSDP all-gather operations
fp8_enable_fsdp_float8_all_gather: true

# Enable torch.compile (almost always necessary for FP8 speedups)
torch_compile: true
```

::: {.callout-important}
**torch.compile is critical for FP8 performance**

FP8 training requires `torch_compile: true` to see meaningful speedups. Without compilation, FP8 may actually be slower and use more memory than FP16/BF16.
:::

### Advanced FP8 Configs {#sec-fp8-advanced}

For [FSDP](multi-gpu.qmd#sec-fsdp) (Fully Sharded Data Parallel) training:

```{.yaml}
fp8: true
fp8_enable_fsdp_float8_all_gather: true

torch_compile: true

# FSDP configuration
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
state_dict_type: FULL_STATE_DICT
reshard_after_forward: true
```

## Best Practices {#sec-best-practices}

### Choosing Precision Format {#sec-choosing-format}

- **Start with automatic detection**: `bf16: auto`
- **For Hopper+ (H100/H200)**: Try FP8 + torch.compile for maximum speed
- **For Ampere (A100/RTX 30/40)**: Use BF16
- **For older Pascal/Turing GPUs**: Use FP16 with caution
- **For very old or unsupported GPUs**: Use FP32

### Validation and Testing {#sec-validation}

Always validate your mixed precision setup:

- **Start with a small dataset** to verify stability
- **Monitor loss curves** for irregularities
- **Compare with FP32 baseline** when possible
- **Test evaluation metrics** match expectations

### FP8 Particulars {#sec-fp8-details}

- Use cases
- Single GPU training
- Multi GPU training with FSDP2 or Deepspeed
- Speedups
- Please refer to the [TorchAO FP8 training benchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling) for expected matmul speedups for different (M, K, N) settings
- Concrete number for LLaMA 3 8B training can be found [here](https://github.com/pytorch/ao/tree/main/torchao/float8#training-benchmarks)
- Known issues:
- FP8 + DDP + `torch.compile` (causes [error](https://gist.github.com/djsaunde/0c1664c32e44a64d31b5e01b4aafe5c4))
- FP8 + FSDP2 + `torch.compile` + FSDP2 activation checkpointing tends to be _slower_ than the BF16 equivalent training
- Flash Attention 2 does not play nicely with `torch.compile`

See `examples/llama-3/3b-fp8-fsdp2.yaml` for an optimized example config. Enabling FP8 mixed precision + FP8 all-gather training results in ~10% faster iterations per second vs. BF16 for a relatively small (3B param) model

For more information on multi-GPU training, see our [Multi-GPU guide](multi-gpu.qmd).
76 changes: 76 additions & 0 deletions examples/llama-3/3b-fp8-fsdp2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

plugins:
- axolotl.integrations.liger.LigerPlugin

liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true

datasets:
- path: yahma/alpaca-cleaned
type: alpaca

output_dir: ./outputs/fp8_out/

sample_packing: true
pad_to_sequence_len: true
sequence_len: 512

flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs

torch_compile: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused

cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true

fp8: true
fp8_enable_fsdp_float8_all_gather: true

resume_from_checkpoint:
logging_steps: 1

evals_per_epoch: 1
saves_per_epoch: 1

warmup_steps: 10
weight_decay: 0.0

fsdp_version: 2
fsdp_config:
offload_params: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: false

special_tokens:
pad_token: <|end_of_text|>

# save_first_step: true # uncomment this to validate checkpoint saving works with your config
18 changes: 14 additions & 4 deletions src/axolotl/core/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
from collections import defaultdict
from functools import partial, wraps
from typing import Callable, Literal, Optional
from typing import Any, Callable, Literal, Optional

import datasets
import torch
Expand Down Expand Up @@ -522,15 +522,25 @@ def create_accelerator_and_postprocess(self):

return res

# pylint: disable=unused-argument
def additional_accelerator_args(
self, fp8=None, **kwargs
): # pylint: disable=unused-argument
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
) -> dict[str, Any]:
ret_kwargs = {}
if fp8:
from accelerate.utils import AORecipeKwargs
from torchao.float8 import Float8LinearConfig

# By default, Float8LinearConfig is instantiated using the "tensorwise"
# scaling strategy. See more details here:
# https://github.com/pytorch/ao/tree/main/torchao/float8.
config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
force_recompute_fp8_weight_in_bwd=enable_fsdp_float8_all_gather is True,
)

ret_kwargs["mixed_precision"] = "fp8"
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()]
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs(config=config)] # type: ignore
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"

return ret_kwargs
Expand Down
4 changes: 3 additions & 1 deletion src/axolotl/loaders/patch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ def _apply_fp8_patches(self):
patch_create_accelerate_code_for_fp8,
)

patch_create_accelerate_code_for_fp8()
patch_create_accelerate_code_for_fp8(
self.cfg.fp8_enable_fsdp_float8_all_gather
)

def _apply_flash_attention_peft_patches(self):
"""Apply patches for Flash Attention with PEFT."""
Expand Down
11 changes: 7 additions & 4 deletions src/axolotl/monkeypatch/trainer_accelerator_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

PATCHED_TRAINER_CODE = """
if hasattr(self, "additional_accelerator_args"):
additional_args = self.additional_accelerator_args(fp8=True, **args)
additional_args = self.additional_accelerator_args(fp8=True, enable_fsdp_float8_all_gather={enable_fsdp_float8_all_gather}, **args)
if additional_args:
args.update(additional_args)

Expand All @@ -38,9 +38,9 @@ def check_create_accelerate_code_is_patchable() -> bool:
return ORIGINAL_TRAINER_CODE in create_code


def patch_create_accelerate_code_for_fp8():
def patch_create_accelerate_code_for_fp8(enable_fsdp_float8_all_gather: bool):
"""
monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs
Monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs.
"""

try:
Expand All @@ -54,7 +54,10 @@ def patch_create_accelerate_code_for_fp8():
if ORIGINAL_TRAINER_CODE not in create_code:
return

create_code = create_code.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
patched_trainer_code = PATCHED_TRAINER_CODE.format(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather
)
create_code = create_code.replace(ORIGINAL_TRAINER_CODE, patched_trainer_code)
create_code = create_code.replace(
"def create_accelerator_and_postprocess(",
"def fixed_create_accelerator_and_postprocess(",
Expand Down
15 changes: 14 additions & 1 deletion src/axolotl/utils/schemas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,20 @@ class AxolotlInputConfig(
fp16: bool | None = Field(
default=None, json_schema_extra={"description": "Use CUDA fp16"}
)
fp8: bool | None = None
fp8: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable FP8 mixed precision training using TorchAO. Best "
"used in combination with torch.compile."
},
)
fp8_enable_fsdp_float8_all_gather: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable FSDP float8 all-gather optimization for FP8 training. Can "
"improve training speed by 10-15% when FSDP is enabled."
},
)
bfloat16: bool | None = Field(
default=None,
json_schema_extra={
Expand Down
30 changes: 30 additions & 0 deletions src/axolotl/utils/schemas/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,36 @@ def check_fft_possible_bad_config(self):
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
return self

@model_validator(mode="before")
@classmethod
def check_fp8_config(cls, data):
if data.get("fp8") and not data.get("torch_compile"):
LOG.warning(
"torch_compile is strongly recommended for FP8 training in order to "
"see speed improvements. Please consider setting `torch_compile: "
"true` in your config."
)
if data.get("fp8") and (
data.get("fsdp_config", {}).get("activation_checkpointing", False) is True
or data.get("fsdp_config", {}).get("fsdp_activation_checkpointing", False)
is True
):
LOG.warning(
"FP8 + FSDP2 + activation checkpointing may be slower than BF16 "
Comment thread
djsaunde marked this conversation as resolved.
"training. Please considering setting `activation_checkpointing: false` "
"in your FSDP config."
)
if (
data.get("fp8_enable_fsdp_float8_all_gather")
and not data.get("fsdp_version", None) == 2
):
raise ValueError(
"fp8_enable_fsdp_float8_all_gather requires FSDP2 (fsdp_version: 2) "
"to be used."
)

return data

@model_validator(mode="before")
@classmethod
def check_use_reentrant_mismatch(cls, data):
Expand Down
Loading
Loading