-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
basic torchao fp8 mixed precision training #2926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
8e93e79
debug
djsaunde 520fa82
debug
djsaunde 77705b3
debug
djsaunde abdb95f
revert unneeded change
djsaunde 265d9cd
add accelerator config to base trainer builder
djsaunde 8afdd2f
add back accumulated_cache_size_limit setting
djsaunde 1f9f366
lint
djsaunde 950c734
accelerator constructor patch for single-GPU torch fp8
djsaunde d87b8b6
lint
djsaunde e467ace
re-using existing fp8 code
djsaunde 73d993f
lint
djsaunde d53f9a7
remove accelerate patch now fix in latest release
djsaunde 4d0a4fe
fix
djsaunde 1220c02
docs
djsaunde de54dc7
add fp8 + fsdp2 example
djsaunde 015d9de
remove unused config
djsaunde d33b4e4
update config
djsaunde 28f91ac
smoke tests
djsaunde a9d8497
add validator
djsaunde d5be517
add 2.7.0 guard for fsdp2
djsaunde 705d171
fix
djsaunde a700c03
add config descriptions
djsaunde 319e57f
add FSDP doc link
djsaunde f47d847
nit
djsaunde 9b477d8
set force_recompute_fp8_weight_in_bwd with enable_fsdp_float8_all_gather
djsaunde d0d1160
better cfg for smoke tests
djsaunde 1f375e2
add test for accelerate patching
djsaunde 9f8cc4d
update fp8 validator
djsaunde File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| --- | ||
| title: "Mixed Precision Training" | ||
| format: | ||
| html: | ||
| toc: true | ||
| toc-depth: 3 | ||
| number-sections: true | ||
| code-tools: true | ||
| execute: | ||
| enabled: false | ||
| --- | ||
|
|
||
| Mixed precision training uses lower precision data types to reduce memory usage and increase training speed while maintaining model quality. Axolotl supports several mixed precision formats: | ||
|
|
||
| - **FP16** - Half precision 16-bit (Pascal generation+) | ||
| - **BF16** - Brain Float 16-bit (Ampere generation+) | ||
| - **FP8** - 8-bit floating point (Hopper generation+) | ||
|
|
||
| ## FP16 Mixed Precision {#sec-fp16} | ||
|
|
||
| ### Overview {#sec-fp16-overview} | ||
|
|
||
| FP16 is the traditional half-precision format, supported on older GPUs but can be less numerically stable than BF16. | ||
|
|
||
| ### Configuration {#sec-fp16-config} | ||
|
|
||
| ```{.yaml} | ||
| fp16: true | ||
| ``` | ||
|
|
||
| ### FP16 Considerations {#sec-fp16-considerations} | ||
|
|
||
| - May require gradient scaling to prevent underflow | ||
| - Less numerically stable than BF16 | ||
| - Can cause training instability with some model architectures | ||
| - Consider using BF16 if your hardware supports it | ||
|
|
||
| ## BF16 Mixed Precision {#sec-bf16} | ||
|
|
||
| ### Overview {#sec-bf16-overview} | ||
|
|
||
| BF16 (Brain Float 16) offers better numerical stability than FP16 and is the recommended mixed precision format for modern GPUs. It provides the same dynamic range as FP32 while using half the memory. | ||
|
|
||
| ### Configuration {#sec-bf16-config} | ||
|
|
||
| ```{.yaml} | ||
| # Automatic BF16 detection (recommended) | ||
| bf16: auto | ||
|
|
||
| # Or explicitly enable | ||
| bf16: true | ||
|
|
||
| # For evaluation with BF16 | ||
| bf16: full # Equivalent to bf16_full_eval in the HF trainer | ||
| ``` | ||
|
|
||
| ## FP8 Mixed Precision {#sec-fp8} | ||
|
|
||
| ::: {.callout-note} | ||
| FP8 support is experimental and requires compatible hardware (H100, H200) and recent PyTorch versions with TorchAO. | ||
| ::: | ||
|
|
||
| ### What is FP8? {#sec-fp8-overview} | ||
|
|
||
| FP8 (8-bit floating point) can provide significant time savings compared to FP16/BF16 while maintaining training stability. Axolotl's implementation uses PyTorch's TorchAO library with "tensorwise" scaling strategy. | ||
|
|
||
| ### Requirements {#sec-fp8-software} | ||
|
|
||
| - Hopper+ GPUs (H100/H200) | ||
| - PyTorch 2.7+ (+ compatible TorchAO version) | ||
| - CUDA 12.4+ | ||
|
|
||
| ### Configuration {#sec-fp8-config} | ||
|
|
||
| Add to your YAML config: | ||
|
|
||
| ```{.yaml} | ||
| # Enable FP8 mixed precision | ||
| fp8: true | ||
|
|
||
| # Optional: Enable FP8 for FSDP all-gather operations | ||
| fp8_enable_fsdp_float8_all_gather: true | ||
|
|
||
| # Enable torch.compile (almost always necessary for FP8 speedups) | ||
| torch_compile: true | ||
| ``` | ||
|
|
||
| ::: {.callout-important} | ||
| **torch.compile is critical for FP8 performance** | ||
|
|
||
| FP8 training requires `torch_compile: true` to see meaningful speedups. Without compilation, FP8 may actually be slower and use more memory than FP16/BF16. | ||
| ::: | ||
|
|
||
| ### Advanced FP8 Configs {#sec-fp8-advanced} | ||
|
|
||
| For [FSDP](multi-gpu.qmd#sec-fsdp) (Fully Sharded Data Parallel) training: | ||
|
|
||
| ```{.yaml} | ||
| fp8: true | ||
| fp8_enable_fsdp_float8_all_gather: true | ||
|
|
||
| torch_compile: true | ||
|
|
||
| # FSDP configuration | ||
| fsdp_version: 2 | ||
| fsdp_config: | ||
| offload_params: false | ||
| cpu_ram_efficient_loading: true | ||
| auto_wrap_policy: TRANSFORMER_BASED_WRAP | ||
| transformer_layer_cls_to_wrap: LlamaDecoderLayer | ||
| state_dict_type: FULL_STATE_DICT | ||
| reshard_after_forward: true | ||
| ``` | ||
|
|
||
| ## Best Practices {#sec-best-practices} | ||
|
|
||
| ### Choosing Precision Format {#sec-choosing-format} | ||
|
|
||
| - **Start with automatic detection**: `bf16: auto` | ||
| - **For Hopper+ (H100/H200)**: Try FP8 + torch.compile for maximum speed | ||
| - **For Ampere (A100/RTX 30/40)**: Use BF16 | ||
| - **For older Pascal/Turing GPUs**: Use FP16 with caution | ||
| - **For very old or unsupported GPUs**: Use FP32 | ||
|
|
||
| ### Validation and Testing {#sec-validation} | ||
|
|
||
| Always validate your mixed precision setup: | ||
|
|
||
| - **Start with a small dataset** to verify stability | ||
| - **Monitor loss curves** for irregularities | ||
| - **Compare with FP32 baseline** when possible | ||
| - **Test evaluation metrics** match expectations | ||
|
|
||
| ### FP8 Particulars {#sec-fp8-details} | ||
|
|
||
| - Use cases | ||
| - Single GPU training | ||
| - Multi GPU training with FSDP2 or Deepspeed | ||
| - Speedups | ||
| - Please refer to the [TorchAO FP8 training benchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling) for expected matmul speedups for different (M, K, N) settings | ||
| - Concrete number for LLaMA 3 8B training can be found [here](https://github.com/pytorch/ao/tree/main/torchao/float8#training-benchmarks) | ||
| - Known issues: | ||
| - FP8 + DDP + `torch.compile` (causes [error](https://gist.github.com/djsaunde/0c1664c32e44a64d31b5e01b4aafe5c4)) | ||
| - FP8 + FSDP2 + `torch.compile` + FSDP2 activation checkpointing tends to be _slower_ than the BF16 equivalent training | ||
| - Flash Attention 2 does not play nicely with `torch.compile` | ||
|
|
||
| See `examples/llama-3/3b-fp8-fsdp2.yaml` for an optimized example config. Enabling FP8 mixed precision + FP8 all-gather training results in ~10% faster iterations per second vs. BF16 for a relatively small (3B param) model | ||
|
|
||
| For more information on multi-GPU training, see our [Multi-GPU guide](multi-gpu.qmd). | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| base_model: meta-llama/Llama-3.2-3B | ||
| # Automatically upload checkpoint and final model to HF | ||
| # hub_model_id: username/custom_model_name | ||
|
|
||
| load_in_8bit: false | ||
| load_in_4bit: false | ||
| strict: false | ||
|
|
||
| plugins: | ||
| - axolotl.integrations.liger.LigerPlugin | ||
|
|
||
| liger_rope: true | ||
| liger_rms_norm: true | ||
| liger_glu_activation: true | ||
| liger_layer_norm: true | ||
| liger_fused_linear_cross_entropy: true | ||
|
|
||
| datasets: | ||
| - path: yahma/alpaca-cleaned | ||
| type: alpaca | ||
|
|
||
| output_dir: ./outputs/fp8_out/ | ||
|
|
||
| sample_packing: true | ||
| pad_to_sequence_len: true | ||
| sequence_len: 512 | ||
|
|
||
| flex_attention: true | ||
| flex_attn_compile_kwargs: | ||
| dynamic: false | ||
| mode: max-autotune-no-cudagraphs | ||
|
|
||
| torch_compile: true | ||
|
|
||
| wandb_project: | ||
| wandb_entity: | ||
| wandb_watch: | ||
| wandb_name: | ||
| wandb_log_model: | ||
|
|
||
| gradient_accumulation_steps: 1 | ||
| micro_batch_size: 16 | ||
| num_epochs: 1 | ||
| optimizer: adamw_torch_fused | ||
|
|
||
| cosine_constant_lr_ratio: 0 | ||
| cosine_min_lr_ratio: 1.0 | ||
| learning_rate: 2e-5 | ||
| save_only_model: true | ||
|
|
||
| fp8: true | ||
| fp8_enable_fsdp_float8_all_gather: true | ||
|
|
||
| resume_from_checkpoint: | ||
| logging_steps: 1 | ||
|
|
||
| evals_per_epoch: 1 | ||
| saves_per_epoch: 1 | ||
|
|
||
| warmup_steps: 10 | ||
| weight_decay: 0.0 | ||
|
|
||
| fsdp_version: 2 | ||
| fsdp_config: | ||
| offload_params: false | ||
| auto_wrap_policy: TRANSFORMER_BASED_WRAP | ||
| transformer_layer_cls_to_wrap: LlamaDecoderLayer | ||
| state_dict_type: FULL_STATE_DICT | ||
| sharding_strategy: FULL_SHARD | ||
| reshard_after_forward: true | ||
| activation_checkpointing: false | ||
|
|
||
| special_tokens: | ||
| pad_token: <|end_of_text|> | ||
|
|
||
| # save_first_step: true # uncomment this to validate checkpoint saving works with your config |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.