Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 115 additions & 53 deletions tests/pytorch/distributed/test_cast_master_weights_to_fp8.py

Large diffs are not rendered by default.

31 changes: 31 additions & 0 deletions tests/pytorch/test_multi_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,34 @@ def test_multi_tensor_compute_scale_and_scale_inv(
)
torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0)
torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0)


@pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)])
@pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55])
def test_multi_tensor_compute_scale_inv_e8m0(input_size_pair, applier, repeat):
sizea, sizeb = input_size_pair
device = torch.device("cuda")
a = torch.randn([sizea], dtype=torch.bfloat16, device=device).abs()
b = torch.randn([sizeb], dtype=torch.bfloat16, device=device).abs()

amax_list = []
for _ in range(repeat):
amax_list += [a.clone(), b.clone()]
scale_inv_list = [torch.empty_like(x).to(torch.uint8) for x in amax_list]

applier(
tex.multi_tensor_compute_scale_inv_e8m0,
None, # overflow_buf
[amax_list, scale_inv_list],
)

max_fp8 = torch.finfo(torch.float8_e4m3fn).max
for amax, scale_inv in zip(amax_list, scale_inv_list):
scale_inv_u32 = (amax.float() / max_fp8).view(torch.int)
exponent = scale_inv_u32 // 2**23
mantissa = scale_inv_u32 & 0x7FFFFF
exponent += (
((mantissa > 0) & (exponent != 0xFE)) & ~((exponent == 0) & (mantissa <= 0x400000))
).to(torch.int)
torch.testing.assert_close(exponent.to(torch.uint8), scale_inv)
135 changes: 135 additions & 0 deletions tests/pytorch/test_partial_cast.py
Comment thread
timmoon10 marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import torch

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_compute_scale_inv_e8m0
from transformer_engine.pytorch.optimizers.multi_tensor_apply import multi_tensor_applier


def compute_partial_amax_reference(inp, amax_rowwise, amax_colwise, h, w, start_offset):
n = inp.view(-1).size(0)
if n == h * w:
full = inp.view(-1)
else:
full = torch.zeros(h * w, dtype=inp.dtype, device=inp.device)
full[start_offset : start_offset + n].copy_(inp)
full = torch.abs(full)
_amax_rowwise, _ = torch.max(full.view(h, w // 32, 32), dim=2)
amax_rowwise[:h, : (w // 32)].copy_(_amax_rowwise)
_amax_colwise, _ = torch.max(full.view(h // 32, 32, w), dim=1)
amax_colwise[: (h // 32), :w].copy_(_amax_colwise)


def partial_cast_reference(
inp, rowwise_out, colwise_out, rowwise_inv_scale, colwise_inv_scale, h, w, start_offset
):
rowwise_scale = ((254 - rowwise_inv_scale.int()) * 2**23).view(torch.float32)
colwise_scale = ((254 - colwise_inv_scale.int()) * 2**23).view(torch.float32)
n = inp.view(-1).size(0)
if n == h * w:
full = inp
else:
full = torch.empty(h * w, dtype=inp.dtype, device=inp.device)
full[start_offset : start_offset + n].copy_(inp)
full = full.float()
rowwise_scale = rowwise_scale[:h, : (w // 32)].contiguous().float()
colwise_scale = colwise_scale[: (h // 32), :w].contiguous().float()
scaled = (full.view(-1, 32) * rowwise_scale.view(-1, 1)).view(-1)
rowwise_out.copy_(
scaled[start_offset : start_offset + n].to(torch.float8_e4m3fn).view(rowwise_out.dtype)
)
scaled = (full.view(h // 32, 32, w) * colwise_scale.view(h // 32, 1, w)).view(-1)
colwise_out.copy_(
scaled[start_offset : start_offset + n].to(torch.float8_e4m3fn).view(colwise_out.dtype)
)


def run_one_case(n, h, w, start_offset):
inp = torch.randn(n, dtype=torch.bfloat16, device="cuda")

rowwise_padding = [128, 4]
colwise_padding = [4, 128]

def _pad(x, padding):
return (x + padding - 1) // padding * padding

rowwise_shape = [_pad(h, rowwise_padding[0]), _pad(w // 32, rowwise_padding[1])]
colwise_shape = [_pad(h // 32, colwise_padding[0]), _pad(w, colwise_padding[1])]

# Partial amax cuda kernel
amax_rowwise = torch.zeros(*rowwise_shape, dtype=inp.dtype, device=inp.device)
amax_colwise = torch.zeros(*colwise_shape, dtype=inp.dtype, device=inp.device)
tex.mxfp8_scaling_compute_partial_amax(inp, amax_rowwise, amax_colwise, h, w, start_offset)

# Partial amax pytorch reference
amax_rowwise_ref = torch.zeros(*rowwise_shape, dtype=inp.dtype, device=inp.device)
amax_colwise_ref = torch.zeros(*colwise_shape, dtype=inp.dtype, device=inp.device)
compute_partial_amax_reference(inp, amax_rowwise_ref, amax_colwise_ref, h, w, start_offset)

# Check partial amax
torch.testing.assert_close(amax_rowwise, amax_rowwise_ref, atol=0, rtol=0)
torch.testing.assert_close(amax_colwise, amax_colwise_ref, atol=0, rtol=0)

# Calculate scales and scale_invs
scale_inv_rowwise = torch.empty_like(amax_rowwise).to(torch.uint8)
scale_inv_colwise = torch.empty_like(amax_colwise).to(torch.uint8)
multi_tensor_applier(
multi_tensor_compute_scale_inv_e8m0,
None,
[
[amax_rowwise, amax_colwise],
[scale_inv_rowwise, scale_inv_colwise],
],
)

# Partial cast cuda kernel
output_rowwise = torch.empty_like(inp).to(torch.uint8)
output_colwise = torch.empty_like(inp).to(torch.uint8)
tex.mxfp8_scaling_partial_cast(
inp,
output_rowwise,
output_colwise,
scale_inv_rowwise,
scale_inv_colwise,
h,
w,
start_offset,
)

# Partial cast pytorch reference
output_rowwise_ref = torch.empty_like(inp).to(torch.uint8)
output_colwise_ref = torch.empty_like(inp).to(torch.uint8)
partial_cast_reference(
inp,
output_rowwise_ref,
output_colwise_ref,
scale_inv_rowwise,
scale_inv_colwise,
h,
w,
start_offset,
)

# Check partial cast results
torch.testing.assert_close(output_rowwise, output_rowwise_ref, atol=0, rtol=0)
torch.testing.assert_close(output_colwise, output_colwise_ref, atol=0, rtol=0)


def test_mxfp8_scaling_partial_cast():
run_one_case(3, 32, 64, 31)
run_one_case(64 * 64 - 2, 64, 64, 1)
run_one_case(16384 * 6144, 16384, 6144, 0)
run_one_case(32768, 256, 128, 0)
run_one_case(131072, 768, 256, 0)
run_one_case(65536, 768, 256, 131072)
run_one_case(98304, 128, 768, 0)
Comment thread
timmoon10 marked this conversation as resolved.


if __name__ == "__main__":

torch.cuda.manual_seed(1234)
test_mxfp8_scaling_partial_cast()
15 changes: 8 additions & 7 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ list(APPEND transformer_engine_cpp_sources
list(APPEND transformer_engine_cuda_sources
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
Expand Down Expand Up @@ -167,16 +166,18 @@ list(APPEND transformer_engine_cuda_sources
comm_gemm_overlap/userbuffers/userbuffers.cu)

list(APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
cast/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu
cast/cast.cu
gemm/cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu)
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu)

# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,21 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETens
float max_fp8, int force_pow_2_scales,
float epsilon, cudaStream_t stream);

/*! \brief Compute E8M0 scale_inv for a list of tensors.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor **tensor_lists,
const size_t num_tensor_lists,
const size_t num_tensors_per_list,
cudaStream_t stream);

/*! \brief Split a tensor along dimension 0 and compute the amax for each split.
*
* This function is experimental and the API is not stable.
Expand Down
Loading
Loading