-
Notifications
You must be signed in to change notification settings - Fork 681
Add primary weighs fp8 support for mxfp8 #2055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
1b7a13e
Add primary weighs fp8 support for mxfp8
kunlunl 14a8a61
Fix unit test and add better error log to unit test
kunlunl 182dc33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b215cd1
Move post all-gather processing out of for loop
kunlunl e3368ba
Merge branch 'main' into native-mxfp8
timmoon10 220ffa6
Add descriptions and ASCII diagrams for partial cast and partial amax…
kunlunl cec738a
Merge branch 'main' into native-mxfp8
kunlunl 560e48b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f7c7064
Minor fix based on greptile bot
kunlunl 96a5a93
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ea0d060
Fix compilation errors due to arch-specific PTX instructions
timmoon10 3f033cf
Remove unused noop flag from C API
timmoon10 435b951
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 38595fb
Expose test_partial_cast
kunlunl dd4c6c4
Skip mxfp8 partial cast test if mxfp8 is not available
kunlunl ad16561
Fix pytest error
kunlunl 4ccf5d6
pylint ignore unused manual_post_all_gather_processing
kunlunl c939c81
Fix error when using is_mxfp8_available
kunlunl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
168 changes: 115 additions & 53 deletions
168
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| import transformer_engine.pytorch as te | ||
| import transformer_engine_torch as tex | ||
| from transformer_engine_torch import multi_tensor_compute_scale_inv_e8m0 | ||
| from transformer_engine.pytorch import is_mxfp8_available | ||
| from transformer_engine.pytorch.optimizers.multi_tensor_apply import multi_tensor_applier | ||
|
|
||
|
|
||
| mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) | ||
|
|
||
|
|
||
| def compute_partial_amax_reference(inp, amax_rowwise, amax_colwise, h, w, start_offset): | ||
| n = inp.view(-1).size(0) | ||
| if n == h * w: | ||
| full = inp.view(-1) | ||
| else: | ||
| full = torch.zeros(h * w, dtype=inp.dtype, device=inp.device) | ||
| full[start_offset : start_offset + n].copy_(inp) | ||
| full = torch.abs(full) | ||
| _amax_rowwise, _ = torch.max(full.view(h, w // 32, 32), dim=2) | ||
| amax_rowwise[:h, : (w // 32)].copy_(_amax_rowwise) | ||
| _amax_colwise, _ = torch.max(full.view(h // 32, 32, w), dim=1) | ||
| amax_colwise[: (h // 32), :w].copy_(_amax_colwise) | ||
|
|
||
|
|
||
| def partial_cast_reference( | ||
| inp, rowwise_out, colwise_out, rowwise_inv_scale, colwise_inv_scale, h, w, start_offset | ||
| ): | ||
| rowwise_scale = ((254 - rowwise_inv_scale.int()) * 2**23).view(torch.float32) | ||
| colwise_scale = ((254 - colwise_inv_scale.int()) * 2**23).view(torch.float32) | ||
| n = inp.view(-1).size(0) | ||
| if n == h * w: | ||
| full = inp | ||
| else: | ||
| full = torch.empty(h * w, dtype=inp.dtype, device=inp.device) | ||
| full[start_offset : start_offset + n].copy_(inp) | ||
| full = full.float() | ||
| rowwise_scale = rowwise_scale[:h, : (w // 32)].contiguous().float() | ||
| colwise_scale = colwise_scale[: (h // 32), :w].contiguous().float() | ||
| scaled = (full.view(-1, 32) * rowwise_scale.view(-1, 1)).view(-1) | ||
| rowwise_out.copy_( | ||
| scaled[start_offset : start_offset + n].to(torch.float8_e4m3fn).view(rowwise_out.dtype) | ||
| ) | ||
| scaled = (full.view(h // 32, 32, w) * colwise_scale.view(h // 32, 1, w)).view(-1) | ||
| colwise_out.copy_( | ||
| scaled[start_offset : start_offset + n].to(torch.float8_e4m3fn).view(colwise_out.dtype) | ||
| ) | ||
|
|
||
|
|
||
| def run_one_case(n, h, w, start_offset): | ||
| inp = torch.randn(n, dtype=torch.bfloat16, device="cuda") | ||
|
|
||
| rowwise_padding = [128, 4] | ||
| colwise_padding = [4, 128] | ||
|
|
||
| def _pad(x, padding): | ||
| return (x + padding - 1) // padding * padding | ||
|
|
||
| rowwise_shape = [_pad(h, rowwise_padding[0]), _pad(w // 32, rowwise_padding[1])] | ||
| colwise_shape = [_pad(h // 32, colwise_padding[0]), _pad(w, colwise_padding[1])] | ||
|
|
||
| # Partial amax cuda kernel | ||
| amax_rowwise = torch.zeros(*rowwise_shape, dtype=inp.dtype, device=inp.device) | ||
| amax_colwise = torch.zeros(*colwise_shape, dtype=inp.dtype, device=inp.device) | ||
| tex.mxfp8_scaling_compute_partial_amax(inp, amax_rowwise, amax_colwise, h, w, start_offset) | ||
|
|
||
| # Partial amax pytorch reference | ||
| amax_rowwise_ref = torch.zeros(*rowwise_shape, dtype=inp.dtype, device=inp.device) | ||
| amax_colwise_ref = torch.zeros(*colwise_shape, dtype=inp.dtype, device=inp.device) | ||
| compute_partial_amax_reference(inp, amax_rowwise_ref, amax_colwise_ref, h, w, start_offset) | ||
|
|
||
| # Check partial amax | ||
| torch.testing.assert_close(amax_rowwise, amax_rowwise_ref, atol=0, rtol=0) | ||
| torch.testing.assert_close(amax_colwise, amax_colwise_ref, atol=0, rtol=0) | ||
|
|
||
| # Calculate scales and scale_invs | ||
| scale_inv_rowwise = torch.empty_like(amax_rowwise).to(torch.uint8) | ||
| scale_inv_colwise = torch.empty_like(amax_colwise).to(torch.uint8) | ||
| multi_tensor_applier( | ||
| multi_tensor_compute_scale_inv_e8m0, | ||
| None, | ||
| [ | ||
| [amax_rowwise, amax_colwise], | ||
| [scale_inv_rowwise, scale_inv_colwise], | ||
| ], | ||
| ) | ||
|
|
||
| # Partial cast cuda kernel | ||
| output_rowwise = torch.empty_like(inp).to(torch.uint8) | ||
| output_colwise = torch.empty_like(inp).to(torch.uint8) | ||
| tex.mxfp8_scaling_partial_cast( | ||
| inp, | ||
| output_rowwise, | ||
| output_colwise, | ||
| scale_inv_rowwise, | ||
| scale_inv_colwise, | ||
| h, | ||
| w, | ||
| start_offset, | ||
| ) | ||
|
|
||
| # Partial cast pytorch reference | ||
| output_rowwise_ref = torch.empty_like(inp).to(torch.uint8) | ||
| output_colwise_ref = torch.empty_like(inp).to(torch.uint8) | ||
| partial_cast_reference( | ||
| inp, | ||
| output_rowwise_ref, | ||
| output_colwise_ref, | ||
| scale_inv_rowwise, | ||
| scale_inv_colwise, | ||
| h, | ||
| w, | ||
| start_offset, | ||
| ) | ||
|
|
||
| # Check partial cast results | ||
| torch.testing.assert_close(output_rowwise, output_rowwise_ref, atol=0, rtol=0) | ||
| torch.testing.assert_close(output_colwise, output_colwise_ref, atol=0, rtol=0) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) | ||
| def test_mxfp8_scaling_partial_cast(): | ||
| torch.cuda.manual_seed(1234) | ||
|
|
||
| run_one_case(3, 32, 64, 31) | ||
| run_one_case(64 * 64 - 2, 64, 64, 1) | ||
| run_one_case(16384 * 6144, 16384, 6144, 0) | ||
| run_one_case(32768, 256, 128, 0) | ||
| run_one_case(131072, 768, 256, 0) | ||
| run_one_case(65536, 768, 256, 131072) | ||
| run_one_case(98304, 128, 768, 0) | ||
timmoon10 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.