Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 115 additions & 53 deletions tests/pytorch/distributed/test_cast_master_weights_to_fp8.py

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions tests/pytorch/test_multi_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,35 @@ def test_multi_tensor_compute_scale_and_scale_inv(
)
torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0)
torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0)


@pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)])
@pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55])
def test_multi_tensor_compute_scale_inv_e8m0(input_size_pair, applier, repeat):
sizea, sizeb = input_size_pair
device = torch.device("cuda")
overflow_buf = torch.zeros(1, dtype=torch.int32, device=device)
a = torch.randn([sizea], dtype=torch.bfloat16, device=device).abs()
b = torch.randn([sizeb], dtype=torch.bfloat16, device=device).abs()

amax_list = []
for _ in range(repeat):
amax_list += [a.clone(), b.clone()]
scale_inv_list = [torch.empty_like(x).to(torch.uint8) for x in amax_list]

applier(
tex.multi_tensor_compute_scale_inv_e8m0,
overflow_buf,
[amax_list, scale_inv_list],
)

max_fp8 = torch.finfo(torch.float8_e4m3fn).max
for amax, scale_inv in zip(amax_list, scale_inv_list):
scale_inv_u32 = (amax.float() / max_fp8).view(torch.int)
exponent = scale_inv_u32 // 2**23
mantissa = scale_inv_u32 & 0x7FFFFF
exponent += (
((mantissa > 0) & (exponent != 0xFE)) & ~((exponent == 0) & (mantissa <= 0x400000))
).to(torch.int)
torch.testing.assert_close(exponent.to(torch.uint8), scale_inv)
136 changes: 136 additions & 0 deletions tests/pytorch/test_partial_cast.py
Comment thread
timmoon10 marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import torch

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_compute_scale_inv_e8m0
from transformer_engine.pytorch.optimizers.multi_tensor_apply import multi_tensor_applier


def compute_partial_amax_reference(inp, amax_rowwise, amax_colwise, h, w, start_offset):
n = inp.view(-1).size(0)
if n == h * w:
full = inp.view(-1)
else:
full = torch.zeros(h * w, dtype=inp.dtype, device=inp.device)
full[start_offset : start_offset + n].copy_(inp)
full = torch.abs(full)
_amax_rowwise, _ = torch.max(full.view(h, w // 32, 32), dim=2)
amax_rowwise[:h, : (w // 32)].copy_(_amax_rowwise)
_amax_colwise, _ = torch.max(full.view(h // 32, 32, w), dim=1)
amax_colwise[: (h // 32), :w].copy_(_amax_colwise)


def partial_cast_reference(
inp, rowwise_out, colwise_out, rowwise_inv_scale, colwise_inv_scale, h, w, start_offset
):
rowwise_scale = ((254 - rowwise_inv_scale.int()) * 2**23).view(torch.float32)
colwise_scale = ((254 - colwise_inv_scale.int()) * 2**23).view(torch.float32)
n = inp.view(-1).size(0)
if n == h * w:
full = inp
else:
full = torch.empty(h * w, dtype=inp.dtype, device=inp.device)
full[start_offset : start_offset + n].copy_(inp)
full = full.float()
rowwise_scale = rowwise_scale[:h, : (w // 32)].contiguous().float()
colwise_scale = colwise_scale[: (h // 32), :w].contiguous().float()
scaled = (full.view(-1, 32) * rowwise_scale.view(-1, 1)).view(-1)
rowwise_out.copy_(
scaled[start_offset : start_offset + n].to(torch.float8_e4m3fn).view(rowwise_out.dtype)
)
scaled = (full.view(h // 32, 32, w) * colwise_scale.view(h // 32, 1, w)).view(-1)
colwise_out.copy_(
scaled[start_offset : start_offset + n].to(torch.float8_e4m3fn).view(colwise_out.dtype)
)


def run_one_case(n, h, w, start_offset):
inp = torch.randn(n, dtype=torch.bfloat16, device="cuda")

rowwise_padding = [128, 4]
colwise_padding = [4, 128]

def _pad(x, padding):
return (x + padding - 1) // padding * padding

rowwise_shape = [_pad(h, rowwise_padding[0]), _pad(w // 32, rowwise_padding[1])]
colwise_shape = [_pad(h // 32, colwise_padding[0]), _pad(w, colwise_padding[1])]

# Partial amax cuda kernel
amax_rowwise = torch.zeros(*rowwise_shape, dtype=inp.dtype, device=inp.device)
amax_colwise = torch.zeros(*colwise_shape, dtype=inp.dtype, device=inp.device)
tex.mxfp8_scaling_compute_partial_amax(inp, amax_rowwise, amax_colwise, h, w, start_offset)

# Partial amax pytorch reference
amax_rowwise_ref = torch.zeros(*rowwise_shape, dtype=inp.dtype, device=inp.device)
amax_colwise_ref = torch.zeros(*colwise_shape, dtype=inp.dtype, device=inp.device)
compute_partial_amax_reference(inp, amax_rowwise_ref, amax_colwise_ref, h, w, start_offset)

# Check partial amax
torch.testing.assert_close(amax_rowwise, amax_rowwise_ref, atol=0, rtol=0)
torch.testing.assert_close(amax_colwise, amax_colwise_ref, atol=0, rtol=0)

# Calculate scales and scale_invs
dummy_overflow_buf = torch.empty(1, dtype=torch.int32, device=inp.device)
scale_inv_rowwise = torch.empty_like(amax_rowwise).to(torch.uint8)
scale_inv_colwise = torch.empty_like(amax_colwise).to(torch.uint8)
multi_tensor_applier(
multi_tensor_compute_scale_inv_e8m0,
dummy_overflow_buf,
[
[amax_rowwise, amax_colwise],
[scale_inv_rowwise, scale_inv_colwise],
],
)

# Partial cast cuda kernel
output_rowwise = torch.empty_like(inp).to(torch.uint8)
output_colwise = torch.empty_like(inp).to(torch.uint8)
tex.mxfp8_scaling_partial_cast(
inp,
output_rowwise,
output_colwise,
scale_inv_rowwise,
scale_inv_colwise,
h,
w,
start_offset,
)

# Partial cast pytorch reference
output_rowwise_ref = torch.empty_like(inp).to(torch.uint8)
output_colwise_ref = torch.empty_like(inp).to(torch.uint8)
partial_cast_reference(
inp,
output_rowwise_ref,
output_colwise_ref,
scale_inv_rowwise,
scale_inv_colwise,
h,
w,
start_offset,
)

# Check partial cast results
torch.testing.assert_close(output_rowwise, output_rowwise_ref, atol=0, rtol=0)
torch.testing.assert_close(output_colwise, output_colwise_ref, atol=0, rtol=0)


def test_mxfp8_scaling_partial_cast():
run_one_case(3, 32, 64, 31)
run_one_case(64 * 64 - 2, 64, 64, 1)
run_one_case(16384 * 6144, 16384, 6144, 0)
run_one_case(32768, 256, 128, 0)
run_one_case(131072, 768, 256, 0)
run_one_case(65536, 768, 256, 131072)
run_one_case(98304, 128, 768, 0)
Comment thread
timmoon10 marked this conversation as resolved.


if __name__ == "__main__":

torch.cuda.manual_seed(1234)
test_mxfp8_scaling_partial_cast()
1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ list(APPEND transformer_engine_cuda_sources
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/mxfp8_scaling.cu
recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers.cu)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,24 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETens
float max_fp8, int force_pow_2_scales,
float epsilon, cudaStream_t stream);

/*! \brief Compute E8M0 scale_inv for a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists,
const size_t num_tensor_lists,
const size_t num_tensors_per_list,
cudaStream_t stream);

#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
size_t start_offset, size_t block_len,
const NVTEDType out_dtype, cudaStream_t stream);

void nvte_mxfp8_scaling_compute_partial_amax(const NVTETensor input, NVTETensor amax_rowwise,
NVTETensor amax_colwise, int rows, int cols,
size_t start_offset, cudaStream_t stream);

void nvte_mxfp8_scaling_partial_cast(const NVTETensor input, NVTETensor output_rowwise,
NVTETensor output_colwise, const NVTETensor scale_inv_rowwise,
const NVTETensor scale_inv_colwise, int rows, int cols,
size_t start_offset, cudaStream_t stream);
Comment thread
timmoon10 marked this conversation as resolved.

void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A,
const NVTETensor inpB, const bool use_rowwise_amax_B,
float alpha_in, NVTETensor alpha_out, cudaStream_t stream);
Expand Down
52 changes: 52 additions & 0 deletions transformer_engine/common/multi_tensor/compute_scale.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <sstream>

#include "../recipe/recipe_common.cuh"
#include "../util/ptx.cuh"
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"

Expand Down Expand Up @@ -55,6 +56,32 @@ struct ComputeScaleAndScaleInvFunctor {
}
};

struct ComputeScaleInvE8M0Functor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem,
TensorListMetadata<2> &tl) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're not using the noop flag, then we shouldn't include it in nvte_multi_tensor_compute_scale_inv_e8m0_cuda.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current implementation of the MultiTensorApplier https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/optimizers/multi_tensor_apply.py#L21, it is hard-coded to always pass a noop flag in to the multi tensor functor, so we need it to accept that parameter.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that many other multi-tensor kernels include the noop flag and don't use it. They're all deceiving and should be changed. For the time being, I've just modified this PR so the new function doesn't continue this antipattern.


int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];

bf16 *amax = reinterpret_cast<bf16 *>(tl.addresses[0][tensor_loc]);
amax += chunk_idx * chunk_size;

e8m0_t *scale_inv = reinterpret_cast<e8m0_t *>(tl.addresses[1][tensor_loc]);
scale_inv += chunk_idx * chunk_size;

n -= chunk_idx * chunk_size;

for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) {
scale_inv[i_start] = ptx::float_to_e8m0(static_cast<float>(amax[i_start]) *
Quantized_Limits<fp8e4m3>::max_norm_rcp);
}
}
};

void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
float max_fp8, bool force_pow_2_scales,
Expand All @@ -65,6 +92,18 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_f
NVTE_CHECK_CUDA(cudaGetLastError());
}

void multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
cudaStream_t stream) {
NVTE_CHECK(tensor_lists[0][0]->data.dtype == DType::kBFloat16, "amax should be bf16");
auto scale_inv_dtype = tensor_lists[1][0]->data.dtype;
NVTE_CHECK(scale_inv_dtype == DType::kByte || scale_inv_dtype == DType::kFloat8E8M0,
"scale_inv should be e8m0/uint8");
Comment on lines +94 to +97

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: dtype check validates bf16 amax and e8m0/uint8 scale_inv but doesn't verify tensor shapes match. If amax and scale_inv have mismatched sizes, the kernel may write out of bounds or leave scale_inv partially uninitialized.

multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ComputeScaleInvE8M0Functor(), stream);
NVTE_CHECK_CUDA(cudaGetLastError());

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Missing tensor shape validation: dtype is checked but tensor shapes are not. If amax and scale_inv have mismatched sizes, the kernel could write out-of-bounds or leave scale_inv partially uninitialized. Add shape validation before kernel launch.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiTensorApplier handles the shape.

}

} // namespace multi_tensor_compute_scale
} // namespace transformer_engine

Expand All @@ -82,3 +121,16 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETens
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8,
force_pow_2_scales, epsilon, stream);
}

void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists,
const size_t num_tensor_lists,
const size_t num_tensors_per_list,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_compute_scale_inv_e8m0_cuda);
using namespace transformer_engine;

multi_tensor_compute_scale::multi_tensor_compute_scale_inv_e8m0_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), stream);
}
Loading
Loading