Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 115 additions & 53 deletions tests/pytorch/distributed/test_cast_master_weights_to_fp8.py

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions tests/pytorch/test_multi_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,35 @@ def test_multi_tensor_compute_scale_and_scale_inv(
)
torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0)
torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0)


@pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)])
@pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55])
def test_multi_tensor_compute_scale_inv_e8m0(input_size_pair, applier, repeat):
sizea, sizeb = input_size_pair
device = torch.device("cuda")
overflow_buf = torch.zeros(1, dtype=torch.int32, device=device)
a = torch.randn([sizea], dtype=torch.bfloat16, device=device).abs()
b = torch.randn([sizeb], dtype=torch.bfloat16, device=device).abs()

amax_list = []
for _ in range(repeat):
amax_list += [a.clone(), b.clone()]
scale_inv_list = [torch.empty_like(x).to(torch.uint8) for x in amax_list]

applier(
tex.multi_tensor_compute_scale_inv_e8m0,
overflow_buf,
[amax_list, scale_inv_list],
)

max_fp8 = torch.finfo(torch.float8_e4m3fn).max
for amax, scale_inv in zip(amax_list, scale_inv_list):
scale_inv_u32 = (amax.float() / max_fp8).view(torch.int)
exponent = scale_inv_u32 // 2**23
mantissa = scale_inv_u32 & 0x7FFFFF
exponent += (
((mantissa > 0) & (exponent != 0xFE)) & ~((exponent == 0) & (mantissa <= 0x400000))
).to(torch.int)
torch.testing.assert_close(exponent.to(torch.uint8), scale_inv)
136 changes: 136 additions & 0 deletions tests/pytorch/test_partial_cast.py
Comment thread
timmoon10 marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import torch

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_compute_scale_inv_e8m0
from transformer_engine.pytorch.optimizers.multi_tensor_apply import multi_tensor_applier


def compute_partial_amax_reference(inp, amax_rowwise, amax_colwise, h, w, start_offset):
n = inp.view(-1).size(0)
if n == h * w:
full = inp.view(-1)
else:
full = torch.zeros(h * w, dtype=inp.dtype, device=inp.device)
full[start_offset : start_offset + n].copy_(inp)
full = torch.abs(full)
_amax_rowwise, _ = torch.max(full.view(h, w // 32, 32), dim=2)
amax_rowwise[:h, : (w // 32)].copy_(_amax_rowwise)
_amax_colwise, _ = torch.max(full.view(h // 32, 32, w), dim=1)
amax_colwise[: (h // 32), :w].copy_(_amax_colwise)


def partial_cast_reference(
inp, rowwise_out, colwise_out, rowwise_inv_scale, colwise_inv_scale, h, w, start_offset
):
rowwise_scale = ((254 - rowwise_inv_scale.int()) * 2**23).view(torch.float32)
colwise_scale = ((254 - colwise_inv_scale.int()) * 2**23).view(torch.float32)
n = inp.view(-1).size(0)
if n == h * w:
full = inp
else:
full = torch.empty(h * w, dtype=inp.dtype, device=inp.device)
full[start_offset : start_offset + n].copy_(inp)
full = full.float()
rowwise_scale = rowwise_scale[:h, : (w // 32)].contiguous().float()
colwise_scale = colwise_scale[: (h // 32), :w].contiguous().float()
scaled = (full.view(-1, 32) * rowwise_scale.view(-1, 1)).view(-1)
rowwise_out.copy_(
scaled[start_offset : start_offset + n].to(torch.float8_e4m3fn).view(rowwise_out.dtype)
)
scaled = (full.view(h // 32, 32, w) * colwise_scale.view(h // 32, 1, w)).view(-1)
colwise_out.copy_(
scaled[start_offset : start_offset + n].to(torch.float8_e4m3fn).view(colwise_out.dtype)
)


def run_one_case(n, h, w, start_offset):
inp = torch.randn(n, dtype=torch.bfloat16, device="cuda")

rowwise_padding = [128, 4]
colwise_padding = [4, 128]

def _pad(x, padding):
return (x + padding - 1) // padding * padding

rowwise_shape = [_pad(h, rowwise_padding[0]), _pad(w // 32, rowwise_padding[1])]
colwise_shape = [_pad(h // 32, colwise_padding[0]), _pad(w, colwise_padding[1])]

# Partial amax cuda kernel
amax_rowwise = torch.zeros(*rowwise_shape, dtype=inp.dtype, device=inp.device)
amax_colwise = torch.zeros(*colwise_shape, dtype=inp.dtype, device=inp.device)
tex.mxfp8_scaling_compute_partial_amax(inp, amax_rowwise, amax_colwise, h, w, start_offset)

# Partial amax pytorch reference
amax_rowwise_ref = torch.zeros(*rowwise_shape, dtype=inp.dtype, device=inp.device)
amax_colwise_ref = torch.zeros(*colwise_shape, dtype=inp.dtype, device=inp.device)
compute_partial_amax_reference(inp, amax_rowwise_ref, amax_colwise_ref, h, w, start_offset)

# Check partial amax
torch.testing.assert_close(amax_rowwise, amax_rowwise_ref, atol=0, rtol=0)
torch.testing.assert_close(amax_colwise, amax_colwise_ref, atol=0, rtol=0)

# Calculate scales and scale_invs
dummy_overflow_buf = torch.empty(1, dtype=torch.int32, device=inp.device)
scale_inv_rowwise = torch.empty_like(amax_rowwise).to(torch.uint8)
scale_inv_colwise = torch.empty_like(amax_colwise).to(torch.uint8)
multi_tensor_applier(
multi_tensor_compute_scale_inv_e8m0,
dummy_overflow_buf,
[
[amax_rowwise, amax_colwise],
[scale_inv_rowwise, scale_inv_colwise],
],
)

# Partial cast cuda kernel
output_rowwise = torch.empty_like(inp).to(torch.uint8)
output_colwise = torch.empty_like(inp).to(torch.uint8)
tex.mxfp8_scaling_partial_cast(
inp,
output_rowwise,
output_colwise,
scale_inv_rowwise,
scale_inv_colwise,
h,
w,
start_offset,
)

# Partial cast pytorch reference
output_rowwise_ref = torch.empty_like(inp).to(torch.uint8)
output_colwise_ref = torch.empty_like(inp).to(torch.uint8)
partial_cast_reference(
inp,
output_rowwise_ref,
output_colwise_ref,
scale_inv_rowwise,
scale_inv_colwise,
h,
w,
start_offset,
)

# Check partial cast results
torch.testing.assert_close(output_rowwise, output_rowwise_ref, atol=0, rtol=0)
torch.testing.assert_close(output_colwise, output_colwise_ref, atol=0, rtol=0)


def test_mxfp8_scaling_partial_cast():
run_one_case(3, 32, 64, 31)
run_one_case(64 * 64 - 2, 64, 64, 1)
run_one_case(16384 * 6144, 16384, 6144, 0)
run_one_case(32768, 256, 128, 0)
run_one_case(131072, 768, 256, 0)
run_one_case(65536, 768, 256, 131072)
run_one_case(98304, 128, 768, 0)
Comment thread
timmoon10 marked this conversation as resolved.


if __name__ == "__main__":

torch.cuda.manual_seed(1234)
test_mxfp8_scaling_partial_cast()
1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ list(APPEND transformer_engine_cuda_sources
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/mxfp8_scaling.cu
recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers.cu)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,24 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETens
float max_fp8, int force_pow_2_scales,
float epsilon, cudaStream_t stream);

/*! \brief Compute E8M0 scale_inv for a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists,
const size_t num_tensor_lists,
const size_t num_tensors_per_list,
cudaStream_t stream);

/*! \brief Split a tensor along dimension 0 and compute the amax for each split.
*
* This function is experimental and the API is not stable.
Expand Down
Loading
Loading