-
Notifications
You must be signed in to change notification settings - Fork 755
Add primary weighs fp8 support for mxfp8 #2055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
1b7a13e
14a8a61
182dc33
b215cd1
e3368ba
220ffa6
cec738a
560e48b
f7c7064
96a5a93
ea0d060
3f033cf
435b951
38595fb
dd4c6c4
ad16561
4ccf5d6
c939c81
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| import torch | ||
|
|
||
| import transformer_engine.pytorch as te | ||
| import transformer_engine_torch as tex | ||
| from transformer_engine_torch import multi_tensor_compute_scale_inv_e8m0 | ||
| from transformer_engine.pytorch.optimizers.multi_tensor_apply import multi_tensor_applier | ||
|
|
||
|
|
||
| def compute_partial_amax_reference(inp, amax_rowwise, amax_colwise, h, w, start_offset): | ||
| n = inp.view(-1).size(0) | ||
| if n == h * w: | ||
| full = inp.view(-1) | ||
| else: | ||
| full = torch.zeros(h * w, dtype=inp.dtype, device=inp.device) | ||
| full[start_offset : start_offset + n].copy_(inp) | ||
| full = torch.abs(full) | ||
| _amax_rowwise, _ = torch.max(full.view(h, w // 32, 32), dim=2) | ||
| amax_rowwise[:h, : (w // 32)].copy_(_amax_rowwise) | ||
| _amax_colwise, _ = torch.max(full.view(h // 32, 32, w), dim=1) | ||
| amax_colwise[: (h // 32), :w].copy_(_amax_colwise) | ||
|
|
||
|
|
||
| def partial_cast_reference( | ||
| inp, rowwise_out, colwise_out, rowwise_inv_scale, colwise_inv_scale, h, w, start_offset | ||
| ): | ||
| rowwise_scale = ((254 - rowwise_inv_scale.int()) * 2**23).view(torch.float32) | ||
| colwise_scale = ((254 - colwise_inv_scale.int()) * 2**23).view(torch.float32) | ||
| n = inp.view(-1).size(0) | ||
| if n == h * w: | ||
| full = inp | ||
| else: | ||
| full = torch.empty(h * w, dtype=inp.dtype, device=inp.device) | ||
| full[start_offset : start_offset + n].copy_(inp) | ||
| full = full.float() | ||
| rowwise_scale = rowwise_scale[:h, : (w // 32)].contiguous().float() | ||
| colwise_scale = colwise_scale[: (h // 32), :w].contiguous().float() | ||
| scaled = (full.view(-1, 32) * rowwise_scale.view(-1, 1)).view(-1) | ||
| rowwise_out.copy_( | ||
| scaled[start_offset : start_offset + n].to(torch.float8_e4m3fn).view(rowwise_out.dtype) | ||
| ) | ||
| scaled = (full.view(h // 32, 32, w) * colwise_scale.view(h // 32, 1, w)).view(-1) | ||
| colwise_out.copy_( | ||
| scaled[start_offset : start_offset + n].to(torch.float8_e4m3fn).view(colwise_out.dtype) | ||
| ) | ||
|
|
||
|
|
||
| def run_one_case(n, h, w, start_offset): | ||
| inp = torch.randn(n, dtype=torch.bfloat16, device="cuda") | ||
|
|
||
| rowwise_padding = [128, 4] | ||
| colwise_padding = [4, 128] | ||
|
|
||
| def _pad(x, padding): | ||
| return (x + padding - 1) // padding * padding | ||
|
|
||
| rowwise_shape = [_pad(h, rowwise_padding[0]), _pad(w // 32, rowwise_padding[1])] | ||
| colwise_shape = [_pad(h // 32, colwise_padding[0]), _pad(w, colwise_padding[1])] | ||
|
|
||
| # Partial amax cuda kernel | ||
| amax_rowwise = torch.zeros(*rowwise_shape, dtype=inp.dtype, device=inp.device) | ||
| amax_colwise = torch.zeros(*colwise_shape, dtype=inp.dtype, device=inp.device) | ||
| tex.mxfp8_scaling_compute_partial_amax(inp, amax_rowwise, amax_colwise, h, w, start_offset) | ||
|
|
||
| # Partial amax pytorch reference | ||
| amax_rowwise_ref = torch.zeros(*rowwise_shape, dtype=inp.dtype, device=inp.device) | ||
| amax_colwise_ref = torch.zeros(*colwise_shape, dtype=inp.dtype, device=inp.device) | ||
| compute_partial_amax_reference(inp, amax_rowwise_ref, amax_colwise_ref, h, w, start_offset) | ||
|
|
||
| # Check partial amax | ||
| torch.testing.assert_close(amax_rowwise, amax_rowwise_ref, atol=0, rtol=0) | ||
| torch.testing.assert_close(amax_colwise, amax_colwise_ref, atol=0, rtol=0) | ||
|
|
||
| # Calculate scales and scale_invs | ||
| dummy_overflow_buf = torch.empty(1, dtype=torch.int32, device=inp.device) | ||
| scale_inv_rowwise = torch.empty_like(amax_rowwise).to(torch.uint8) | ||
| scale_inv_colwise = torch.empty_like(amax_colwise).to(torch.uint8) | ||
| multi_tensor_applier( | ||
| multi_tensor_compute_scale_inv_e8m0, | ||
| dummy_overflow_buf, | ||
| [ | ||
| [amax_rowwise, amax_colwise], | ||
| [scale_inv_rowwise, scale_inv_colwise], | ||
| ], | ||
| ) | ||
|
|
||
| # Partial cast cuda kernel | ||
| output_rowwise = torch.empty_like(inp).to(torch.uint8) | ||
| output_colwise = torch.empty_like(inp).to(torch.uint8) | ||
| tex.mxfp8_scaling_partial_cast( | ||
| inp, | ||
| output_rowwise, | ||
| output_colwise, | ||
| scale_inv_rowwise, | ||
| scale_inv_colwise, | ||
| h, | ||
| w, | ||
| start_offset, | ||
| ) | ||
|
|
||
| # Partial cast pytorch reference | ||
| output_rowwise_ref = torch.empty_like(inp).to(torch.uint8) | ||
| output_colwise_ref = torch.empty_like(inp).to(torch.uint8) | ||
| partial_cast_reference( | ||
| inp, | ||
| output_rowwise_ref, | ||
| output_colwise_ref, | ||
| scale_inv_rowwise, | ||
| scale_inv_colwise, | ||
| h, | ||
| w, | ||
| start_offset, | ||
| ) | ||
|
|
||
| # Check partial cast results | ||
| torch.testing.assert_close(output_rowwise, output_rowwise_ref, atol=0, rtol=0) | ||
| torch.testing.assert_close(output_colwise, output_colwise_ref, atol=0, rtol=0) | ||
|
|
||
|
|
||
| def test_mxfp8_scaling_partial_cast(): | ||
| run_one_case(3, 32, 64, 31) | ||
| run_one_case(64 * 64 - 2, 64, 64, 1) | ||
| run_one_case(16384 * 6144, 16384, 6144, 0) | ||
| run_one_case(32768, 256, 128, 0) | ||
| run_one_case(131072, 768, 256, 0) | ||
| run_one_case(65536, 768, 256, 131072) | ||
| run_one_case(98304, 128, 768, 0) | ||
|
timmoon10 marked this conversation as resolved.
|
||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
| torch.cuda.manual_seed(1234) | ||
| test_mxfp8_scaling_partial_cast() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
| #include <sstream> | ||
|
|
||
| #include "../recipe/recipe_common.cuh" | ||
| #include "../util/ptx.cuh" | ||
| #include "../utils.cuh" | ||
| #include "multi_tensor_apply.cuh" | ||
|
|
||
|
|
@@ -55,6 +56,32 @@ struct ComputeScaleAndScaleInvFunctor { | |
| } | ||
| }; | ||
|
|
||
| struct ComputeScaleInvE8M0Functor { | ||
| __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, | ||
| TensorListMetadata<2> &tl) { | ||
| // I'd like this kernel to propagate infs/nans. | ||
| // if(*noop_gmem == 1) | ||
| // return; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're not using the noop flag, then we shouldn't include it in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the current implementation of the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see that many other multi-tensor kernels include the noop flag and don't use it. They're all deceiving and should be changed. For the time being, I've just modified this PR so the new function doesn't continue this antipattern. |
||
|
|
||
| int tensor_loc = tl.block_to_tensor[blockIdx.x]; | ||
| int chunk_idx = tl.block_to_chunk[blockIdx.x]; | ||
| int n = tl.sizes[tensor_loc]; | ||
|
|
||
| bf16 *amax = reinterpret_cast<bf16 *>(tl.addresses[0][tensor_loc]); | ||
| amax += chunk_idx * chunk_size; | ||
|
|
||
| e8m0_t *scale_inv = reinterpret_cast<e8m0_t *>(tl.addresses[1][tensor_loc]); | ||
| scale_inv += chunk_idx * chunk_size; | ||
|
|
||
| n -= chunk_idx * chunk_size; | ||
|
|
||
| for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) { | ||
| scale_inv[i_start] = ptx::float_to_e8m0(static_cast<float>(amax[i_start]) * | ||
| Quantized_Limits<fp8e4m3>::max_norm_rcp); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_flag, | ||
| std::vector<std::vector<Tensor *>> tensor_lists, | ||
| float max_fp8, bool force_pow_2_scales, | ||
|
|
@@ -65,6 +92,18 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_f | |
| NVTE_CHECK_CUDA(cudaGetLastError()); | ||
| } | ||
|
|
||
| void multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, Tensor noop_flag, | ||
| std::vector<std::vector<Tensor *>> tensor_lists, | ||
| cudaStream_t stream) { | ||
| NVTE_CHECK(tensor_lists[0][0]->data.dtype == DType::kBFloat16, "amax should be bf16"); | ||
| auto scale_inv_dtype = tensor_lists[1][0]->data.dtype; | ||
| NVTE_CHECK(scale_inv_dtype == DType::kByte || scale_inv_dtype == DType::kFloat8E8M0, | ||
| "scale_inv should be e8m0/uint8"); | ||
|
Comment on lines
+94
to
+97
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: dtype check validates bf16 amax and e8m0/uint8 scale_inv but doesn't verify tensor shapes match. If amax and scale_inv have mismatched sizes, the kernel may write out of bounds or leave scale_inv partially uninitialized. |
||
| multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, | ||
| ComputeScaleInvE8M0Functor(), stream); | ||
| NVTE_CHECK_CUDA(cudaGetLastError()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Missing tensor shape validation: dtype is checked but tensor shapes are not. If
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| } | ||
|
|
||
| } // namespace multi_tensor_compute_scale | ||
| } // namespace transformer_engine | ||
|
|
||
|
|
@@ -82,3 +121,16 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETens | |
| convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8, | ||
| force_pow_2_scales, epsilon, stream); | ||
| } | ||
|
|
||
| void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor noop_flag, | ||
| NVTETensor **tensor_lists, | ||
| const size_t num_tensor_lists, | ||
| const size_t num_tensors_per_list, | ||
| cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_multi_tensor_compute_scale_inv_e8m0_cuda); | ||
| using namespace transformer_engine; | ||
|
|
||
| multi_tensor_compute_scale::multi_tensor_compute_scale_inv_e8m0_cuda( | ||
| chunk_size, *convertNVTETensorCheck(noop_flag), | ||
| convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), stream); | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.