diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index e1ce680094..f1a48e421b 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -49,6 +49,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index 0ff98e6cb7..51b920eab5 100644 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -18,6 +18,7 @@ DelayedScaling, Float8CurrentScaling, Float8BlockScaling, + MXFP8BlockScaling, Format, Recipe, ) @@ -25,9 +26,11 @@ from transformer_engine.pytorch import ( is_fp8_available, is_fp8_block_scaling_available, + is_mxfp8_available, QuantizedTensor, Float8Tensor, Float8BlockwiseQTensor, + MXFP8Tensor, ) from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8 from transformer_engine.pytorch.tensor.utils import post_all_gather_processing, replace_raw_data @@ -42,17 +45,21 @@ def _get_quantization_recipe(quantization) -> Recipe: return Float8CurrentScaling(fp8_format=fp8_format) elif quantization == "fp8_block": return Float8BlockScaling(fp8_format=fp8_format) + elif quantization == "mxfp8": + return MXFP8BlockScaling() else: raise ValueError(f"Unsupported quantization: {quantization}") -def _get_raw_data(quantized_tensor): +def _get_raw_data(quantized_tensor, colwise=False): """Get the underlying data of a quantized tensor, used in zero-1 optimizer""" if isinstance(quantized_tensor, Float8Tensor): + assert not colwise, "Float8Tensor does not support get colwise data" assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute" assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8" return quantized_tensor._data elif isinstance(quantized_tensor, Float8BlockwiseQTensor): + assert not colwise, "Float8BlockwiseQTensor does not support get colwise data" assert hasattr( quantized_tensor, "_rowwise_data" ), "Float8BlockwiseQTensor does not have _rowwise_data attribute" @@ -60,6 +67,23 @@ def _get_raw_data(quantized_tensor): quantized_tensor._rowwise_data.dtype == torch.uint8 ), "Float8BlockwiseQTensor _rowwise_data must be uint8" return quantized_tensor._rowwise_data + elif isinstance(quantized_tensor, MXFP8Tensor): + if colwise: + assert hasattr( + quantized_tensor, "_columnwise_data" + ), "MXFP8Tensor does not have columnwise_data attribute" + assert ( + quantized_tensor._columnwise_data.dtype == torch.uint8 + ), "MXFP8Tensor columnwise_data must be uint8" + return quantized_tensor._columnwise_data + else: + assert hasattr( + quantized_tensor, "_rowwise_data" + ), "MXFP8Tensor does not have rowwise_data attribute" + assert ( + quantized_tensor._rowwise_data.dtype == torch.uint8 + ), "MXFP8Tensor rowwise_data must be uint8" + return quantized_tensor._rowwise_data else: raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}") @@ -229,38 +253,43 @@ def step(self): end = start_offset + master_weight.numel() weight.data.view(-1)[start:end].copy_(master_weight) - # ----------------------------------------------------------------------------------------- - # Step 5: Copy the updated weights (not all weights) to the weight buffer - # ----------------------------------------------------------------------------------------- - for i in range(len(self.weights)): - master_weight = self.master_weights[i] - if master_weight is None: - continue - start_offset = self.start_offsets[i] - if isinstance(self.weights[i], QuantizedTensor): - weight = _get_raw_data(self.weights[i]) - else: - weight = self.weights[i] - weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()] - overlapping_start, overlapping_end = self.overlapping_areas[i] - self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice) + colwise_list = [False] + if isinstance(self.weights[0], MXFP8Tensor): + colwise_list.append(True) - # ----------------------------------------------------------------------------------------- - # Step 6: Weight all-gather (FP8 or BF16) - # ----------------------------------------------------------------------------------------- - dist.all_gather_into_tensor( - self.weight_buffer, self.weight_buffer_slice, group=self.dp_group - ) + for colwise in colwise_list: + # ------------------------------------------------------------------------------------- + # Step 5: Copy the updated weights (not all weights) to the weight buffer + # ------------------------------------------------------------------------------------- + for i in range(len(self.weights)): + master_weight = self.master_weights[i] + if master_weight is None: + continue + start_offset = self.start_offsets[i] + if isinstance(self.weights[i], QuantizedTensor): + weight = _get_raw_data(self.weights[i], colwise) + else: + weight = self.weights[i] + weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()] + overlapping_start, overlapping_end = self.overlapping_areas[i] + self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice) + + # ------------------------------------------------------------------------------------- + # Step 6: Weight all-gather (FP8 or BF16) + # ------------------------------------------------------------------------------------- + dist.all_gather_into_tensor( + self.weight_buffer, self.weight_buffer_slice, group=self.dp_group + ) - # ----------------------------------------------------------------------------------------- - # Step 7: Copy the gathered weights from weight buffer to the actual weights - # ----------------------------------------------------------------------------------------- - for weight, offset in zip(self.weights, self.offsets[:-1]): - start = offset - end = offset + weight.numel() - if isinstance(weight, QuantizedTensor): - weight = _get_raw_data(weight) - weight.view(-1).data.copy_(self.weight_buffer[start:end]) + # ------------------------------------------------------------------------------------- + # Step 7: Copy the gathered weights from weight buffer to the actual weights + # ------------------------------------------------------------------------------------- + for weight, offset in zip(self.weights, self.offsets[:-1]): + start = offset + end = offset + weight.numel() + if isinstance(weight, QuantizedTensor): + weight = _get_raw_data(weight, colwise) + weight.view(-1).data.copy_(self.weight_buffer[start:end]) if self.manual_post_all_gather_processing: quantized_weights = [ @@ -285,9 +314,15 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals else: raw_data_list = [w.view(-1) for w in weights] self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list) + if isinstance(weights[0], MXFP8Tensor): + self.flatten_columnwise = self.flatten_weight.clone() + else: + self.flatten_columnwise = None # Split flattened weights into shards self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank] + if self.flatten_columnwise is not None: + self.local_columnwise_shard = torch.chunk(self.flatten_columnwise, world_size)[rank] self.local_main_grad_shard = torch.zeros_like( self.local_weight_shard, dtype=torch.float32, device="cuda" ) @@ -319,14 +354,25 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals self.shard_indices.append((None, None)) if isinstance(weights[idx], QuantizedTensor): - replace_raw_data( - weights[idx], self.flatten_weight[start:end].view(weights[idx].shape) - ) + if self.flatten_columnwise is not None: + new_rowwise_data = self.flatten_weight[start:end].view(weights[idx].shape) + new_rowwise_data.copy_(weights[idx]._rowwise_data) + weights[idx]._rowwise_data = new_rowwise_data + new_columnwise_data = self.flatten_columnwise[start:end].view( + weights[idx].shape + ) + new_columnwise_data.copy_(weights[idx]._columnwise_data) + weights[idx]._columnwise_data = new_columnwise_data + else: + replace_raw_data( + weights[idx], self.flatten_weight[start:end].view(weights[idx].shape) + ) else: weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape) # Initialize local model weights and high-precision master weights self.local_weights = [] + self.local_columnwise = [] self.master_weights = [] for i, weight in enumerate(self.weights): weight_start, weight_end = self.weight_indices[i] @@ -334,6 +380,11 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals if shard_start is not None and shard_end is not None: local_weight_shard = self.local_weight_shard[shard_start:shard_end] self.local_weights.append(local_weight_shard) + if self.flatten_columnwise is not None: + local_columnwise_shard = self.local_columnwise_shard[shard_start:shard_end] + else: + local_columnwise_shard = None + self.local_columnwise.append(local_columnwise_shard) if isinstance(weight, QuantizedTensor): high_precision_init_val = weight.get_high_precision_init_val().view(-1) @@ -345,6 +396,7 @@ def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=Fals self.master_weights.append(master_weight_shard) else: self.local_weights.append(None) + self.local_columnwise.append(None) self.master_weights.append(None) setattr( weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda") @@ -415,12 +467,12 @@ def step(self): # Step 3: Cast master weights to FP8 or BF16 precision if isinstance(self.weights[0], QuantizedTensor): local_weights = [] - for local_weight in self.local_weights: - if local_weight is None: - local_weights.append(None) - continue - - local_weights.append(local_weight) + for i, local_weight in enumerate(self.local_weights): + if self.flatten_columnwise is not None: + local_columnwise = self.local_columnwise[i] + local_weights.append((local_weight, local_columnwise)) + else: + local_weights.append(local_weight) cast_master_weights_to_fp8( self.weights, @@ -442,6 +494,10 @@ def step(self): dist.all_gather_into_tensor( self.flatten_weight, self.local_weight_shard, group=self.dp_group ) + if self.flatten_columnwise is not None: + dist.all_gather_into_tensor( + self.flatten_columnwise, self.local_columnwise_shard, group=self.dp_group + ) if self.manual_post_all_gather_processing: quantized_weights = [ @@ -513,15 +569,15 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat preserve_high_precision_init_val=True, ): model_fp8 = nn.Sequential( - te.Linear(128, 256 + 16, **linear_kwargs), - te.Linear(256 + 16, 256 * 3, **linear_kwargs), + te.Linear(128, 256 + 32, **linear_kwargs), + te.Linear(256 + 32, 256 * 3, **linear_kwargs), te.Linear(256 * 3, 128, **linear_kwargs), ) # Create model with BF16 weights model = nn.Sequential( - te.Linear(128, 256 + 16, **linear_kwargs), - te.Linear(256 + 16, 256 * 3, **linear_kwargs), + te.Linear(128, 256 + 32, **linear_kwargs), + te.Linear(256 + 32, 256 * 3, **linear_kwargs), te.Linear(256 * 3, 128, **linear_kwargs), ) @@ -546,7 +602,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat w.main_grad.zero_() inputs = [ - torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + torch.randn(32, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) ] # Choose based on rank to make sure the inputs of different ranks are different. x = inputs[rank] @@ -577,7 +633,9 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat optimizer_fp8.step() optimizer.step() - torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) + assert torch.allclose( + loss_fp8, loss, atol=0, rtol=0 + ), f"Loss mismatch at rank {rank}, step {i} for {quantization}" def _test_fsdp_cast_master_weights_to_fp8( @@ -609,15 +667,15 @@ def _test_fsdp_cast_master_weights_to_fp8( preserve_high_precision_init_val=True, ): model_fp8 = nn.Sequential( - te.Linear(128, 256 + 16, **linear_kwargs), - te.Linear(256 + 16, 256 * 3, **linear_kwargs), + te.Linear(128, 256 + 32, **linear_kwargs), + te.Linear(256 + 32, 256 * 3, **linear_kwargs), te.Linear(256 * 3, 128, **linear_kwargs), ) # Create model with BF16 weights model = nn.Sequential( - te.Linear(128, 256 + 16, **linear_kwargs), - te.Linear(256 + 16, 256 * 3, **linear_kwargs), + te.Linear(128, 256 + 32, **linear_kwargs), + te.Linear(256 + 32, 256 * 3, **linear_kwargs), te.Linear(256 * 3, 128, **linear_kwargs), ) @@ -631,12 +689,12 @@ def _test_fsdp_cast_master_weights_to_fp8( ) optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group) - for _ in range(100): + for i in range(100): optimizer_fp8.zero_grad() optimizer.zero_grad() inputs = [ - torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + torch.randn(32, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) ] # Choose based on rank to make sure the inputs of different ranks are different. x = inputs[rank] @@ -667,7 +725,9 @@ def _test_fsdp_cast_master_weights_to_fp8( optimizer_fp8.step() optimizer.step() - torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) + assert torch.allclose( + loss_fp8, loss, atol=0, rtol=0 + ), f"Loss mismatch at rank {rank}, step {i} for {quantization} (FSDP)" def run_parallel_tests() -> None: @@ -698,6 +758,8 @@ def run_parallel_tests() -> None: quantizations.extend(["fp8", "fp8_cs"]) if is_fp8_block_scaling_available(): quantizations.append("fp8_block") + if is_mxfp8_available(): + quantizations.append("mxfp8") manual_post_all_gather_processings = [False, True] diff --git a/tests/pytorch/test_multi_tensor.py b/tests/pytorch/test_multi_tensor.py index 46ba821879..94012354db 100644 --- a/tests/pytorch/test_multi_tensor.py +++ b/tests/pytorch/test_multi_tensor.py @@ -7,6 +7,7 @@ import transformer_engine.pytorch import transformer_engine_torch as tex +from transformer_engine.pytorch import is_mxfp8_available from transformer_engine.pytorch.optimizers import MultiTensorApply from references.quantize_scale_calc import scale_from_amax_tensor @@ -23,6 +24,7 @@ (555, 33333), ] appliers = [MultiTensorApply(2048 * 32), MultiTensorApply(333), MultiTensorApply(33333)] +mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) @pytest.mark.parametrize("input_size_pair", input_size_pairs) @@ -259,3 +261,35 @@ def test_multi_tensor_compute_scale_and_scale_inv( ) torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0) torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0) + + +@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) +@pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)]) +@pytest.mark.parametrize("applier", appliers) +@pytest.mark.parametrize("repeat", [1, 55]) +def test_multi_tensor_compute_scale_inv_e8m0(input_size_pair, applier, repeat): + sizea, sizeb = input_size_pair + device = torch.device("cuda") + a = torch.randn([sizea], dtype=torch.bfloat16, device=device).abs() + b = torch.randn([sizeb], dtype=torch.bfloat16, device=device).abs() + + amax_list = [] + for _ in range(repeat): + amax_list += [a.clone(), b.clone()] + scale_inv_list = [torch.empty_like(x).to(torch.uint8) for x in amax_list] + + applier( + tex.multi_tensor_compute_scale_inv_e8m0, + None, # overflow_buf + [amax_list, scale_inv_list], + ) + + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + for amax, scale_inv in zip(amax_list, scale_inv_list): + scale_inv_u32 = (amax.float() / max_fp8).view(torch.int) + exponent = scale_inv_u32 // 2**23 + mantissa = scale_inv_u32 & 0x7FFFFF + exponent += ( + ((mantissa > 0) & (exponent != 0xFE)) & ~((exponent == 0) & (mantissa <= 0x400000)) + ).to(torch.int) + torch.testing.assert_close(exponent.to(torch.uint8), scale_inv) diff --git a/tests/pytorch/test_partial_cast.py b/tests/pytorch/test_partial_cast.py new file mode 100644 index 0000000000..cb0c4d75bd --- /dev/null +++ b/tests/pytorch/test_partial_cast.py @@ -0,0 +1,137 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine_torch import multi_tensor_compute_scale_inv_e8m0 +from transformer_engine.pytorch import is_mxfp8_available +from transformer_engine.pytorch.optimizers.multi_tensor_apply import multi_tensor_applier + + +mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) + + +def compute_partial_amax_reference(inp, amax_rowwise, amax_colwise, h, w, start_offset): + n = inp.view(-1).size(0) + if n == h * w: + full = inp.view(-1) + else: + full = torch.zeros(h * w, dtype=inp.dtype, device=inp.device) + full[start_offset : start_offset + n].copy_(inp) + full = torch.abs(full) + _amax_rowwise, _ = torch.max(full.view(h, w // 32, 32), dim=2) + amax_rowwise[:h, : (w // 32)].copy_(_amax_rowwise) + _amax_colwise, _ = torch.max(full.view(h // 32, 32, w), dim=1) + amax_colwise[: (h // 32), :w].copy_(_amax_colwise) + + +def partial_cast_reference( + inp, rowwise_out, colwise_out, rowwise_inv_scale, colwise_inv_scale, h, w, start_offset +): + rowwise_scale = ((254 - rowwise_inv_scale.int()) * 2**23).view(torch.float32) + colwise_scale = ((254 - colwise_inv_scale.int()) * 2**23).view(torch.float32) + n = inp.view(-1).size(0) + if n == h * w: + full = inp + else: + full = torch.empty(h * w, dtype=inp.dtype, device=inp.device) + full[start_offset : start_offset + n].copy_(inp) + full = full.float() + rowwise_scale = rowwise_scale[:h, : (w // 32)].contiguous().float() + colwise_scale = colwise_scale[: (h // 32), :w].contiguous().float() + scaled = (full.view(-1, 32) * rowwise_scale.view(-1, 1)).view(-1) + rowwise_out.copy_( + scaled[start_offset : start_offset + n].to(torch.float8_e4m3fn).view(rowwise_out.dtype) + ) + scaled = (full.view(h // 32, 32, w) * colwise_scale.view(h // 32, 1, w)).view(-1) + colwise_out.copy_( + scaled[start_offset : start_offset + n].to(torch.float8_e4m3fn).view(colwise_out.dtype) + ) + + +def run_one_case(n, h, w, start_offset): + inp = torch.randn(n, dtype=torch.bfloat16, device="cuda") + + rowwise_padding = [128, 4] + colwise_padding = [4, 128] + + def _pad(x, padding): + return (x + padding - 1) // padding * padding + + rowwise_shape = [_pad(h, rowwise_padding[0]), _pad(w // 32, rowwise_padding[1])] + colwise_shape = [_pad(h // 32, colwise_padding[0]), _pad(w, colwise_padding[1])] + + # Partial amax cuda kernel + amax_rowwise = torch.zeros(*rowwise_shape, dtype=inp.dtype, device=inp.device) + amax_colwise = torch.zeros(*colwise_shape, dtype=inp.dtype, device=inp.device) + tex.mxfp8_scaling_compute_partial_amax(inp, amax_rowwise, amax_colwise, h, w, start_offset) + + # Partial amax pytorch reference + amax_rowwise_ref = torch.zeros(*rowwise_shape, dtype=inp.dtype, device=inp.device) + amax_colwise_ref = torch.zeros(*colwise_shape, dtype=inp.dtype, device=inp.device) + compute_partial_amax_reference(inp, amax_rowwise_ref, amax_colwise_ref, h, w, start_offset) + + # Check partial amax + torch.testing.assert_close(amax_rowwise, amax_rowwise_ref, atol=0, rtol=0) + torch.testing.assert_close(amax_colwise, amax_colwise_ref, atol=0, rtol=0) + + # Calculate scales and scale_invs + scale_inv_rowwise = torch.empty_like(amax_rowwise).to(torch.uint8) + scale_inv_colwise = torch.empty_like(amax_colwise).to(torch.uint8) + multi_tensor_applier( + multi_tensor_compute_scale_inv_e8m0, + None, + [ + [amax_rowwise, amax_colwise], + [scale_inv_rowwise, scale_inv_colwise], + ], + ) + + # Partial cast cuda kernel + output_rowwise = torch.empty_like(inp).to(torch.uint8) + output_colwise = torch.empty_like(inp).to(torch.uint8) + tex.mxfp8_scaling_partial_cast( + inp, + output_rowwise, + output_colwise, + scale_inv_rowwise, + scale_inv_colwise, + h, + w, + start_offset, + ) + + # Partial cast pytorch reference + output_rowwise_ref = torch.empty_like(inp).to(torch.uint8) + output_colwise_ref = torch.empty_like(inp).to(torch.uint8) + partial_cast_reference( + inp, + output_rowwise_ref, + output_colwise_ref, + scale_inv_rowwise, + scale_inv_colwise, + h, + w, + start_offset, + ) + + # Check partial cast results + torch.testing.assert_close(output_rowwise, output_rowwise_ref, atol=0, rtol=0) + torch.testing.assert_close(output_colwise, output_colwise_ref, atol=0, rtol=0) + + +@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) +def test_mxfp8_scaling_partial_cast(): + torch.cuda.manual_seed(1234) + + run_one_case(3, 32, 64, 31) + run_one_case(64 * 64 - 2, 64, 64, 1) + run_one_case(16384 * 6144, 16384, 6144, 0) + run_one_case(32768, 256, 128, 0) + run_one_case(131072, 768, 256, 0) + run_one_case(65536, 768, 256, 131072) + run_one_case(98304, 128, 768, 0) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index d3532b8c45..264f7f9a78 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -125,7 +125,6 @@ list(APPEND transformer_engine_cpp_sources list(APPEND transformer_engine_cuda_sources common.cu multi_tensor/adam.cu - multi_tensor/compute_scale.cu multi_tensor/l2norm.cu multi_tensor/scale.cu multi_tensor/sgd.cu @@ -167,16 +166,18 @@ list(APPEND transformer_engine_cuda_sources comm_gemm_overlap/userbuffers/userbuffers.cu) list(APPEND transformer_engine_cuda_arch_specific_sources - gemm/cutlass_grouped_gemm.cu - cast/cast.cu activation/gelu.cu activation/relu.cu activation/swiglu.cu - transpose/quantize_transpose_square_blockwise.cu - transpose/quantize_transpose_vector_blockwise_fp4.cu - hadamard_transform/hadamard_transform.cu + cast/cast.cu + gemm/cutlass_grouped_gemm.cu hadamard_transform/group_hadamard_transform.cu - hadamard_transform/hadamard_transform_cast_fusion.cu) + hadamard_transform/hadamard_transform.cu + hadamard_transform/hadamard_transform_cast_fusion.cu + multi_tensor/compute_scale.cu + recipe/mxfp8_scaling.cu + transpose/quantize_transpose_square_blockwise.cu + transpose/quantize_transpose_vector_blockwise_fp4.cu) # Compiling the files with the worst compilation time first to hopefully overlap # better with the faster-compiling cpp files diff --git a/transformer_engine/common/include/transformer_engine/multi_tensor.h b/transformer_engine/common/include/transformer_engine/multi_tensor.h index af3f51d46f..03d35dc2ed 100644 --- a/transformer_engine/common/include/transformer_engine/multi_tensor.h +++ b/transformer_engine/common/include/transformer_engine/multi_tensor.h @@ -265,6 +265,21 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETens float max_fp8, int force_pow_2_scales, float epsilon, cudaStream_t stream); +/*! \brief Compute E8M0 scale_inv for a list of tensors. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor **tensor_lists, + const size_t num_tensor_lists, + const size_t num_tensors_per_list, + cudaStream_t stream); + /*! \brief Split a tensor along dimension 0 and compute the amax for each split. * * This function is experimental and the API is not stable. diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 6e1e9dd7ac..b1773a8db3 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -111,17 +111,200 @@ void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output, void nvte_compute_scale_from_amax(NVTETensor output, const NVTEQuantizationConfig config, cudaStream_t stream); +/*! \brief Compute partial amax for FP8 blockwise scaling. + * + * This function computes the maximum absolute values for each block of the original tensor. + * `inp` contains a continuous segment from the flattened original tensor. For each block, + * if it overlaps with the range [start_offset, start_offset+inp.length), the amax is + * computed from inp; otherwise, the amax is set to 0. + * + * Example: Original tensor (logically 512x512) divided into 16 blocks of size 128x128. + * `inp` contains continuous elements starting from position start_offset + * in the flattened original tensor. + * + * Logical view - Original Tensor (e.g., 512x512) divided into 16 blocks of size 128x128: + * ┌─────────┬─────────┬─────────┬─────────┐ + * │ Block0 │ Block1 │ Block2 │ Block3 │ Each block: 128x128 + * │ 128x128 │ 128x128 │ 128x128 │ 128x128 │ + * ├─────────┼─────────┼─────────┼─────────┤ + * │ Block4 │ Block5 │ Block6 │ Block7 │ + * ├─────────┼─────────┼─────────┼─────────┤ + * │ Block8 │ Block9 │ Block10 │ Block11 │ + * ├─────────┼─────────┼─────────┼─────────┤ + * │ Block12 │ Block13 │ Block14 │ Block15 │ + * └─────────┴─────────┴─────────┴─────────┘ + * + * Physical view - Flattened in row-major order: + * ┌────────────────────────────────────────────────────────────────┐ + * │[0...128][128...256][256...384][384...512]...[261632...262143] │ + * └────────────────────────────────────────────────────────────────┘ + * ^ ^ + * start_offset start_offset + inp.length + * + * For each 128x128 block, compute amax: + * - If the block overlaps with [start_offset, start_offset+inp.length), compute amax + * - If the block is completely outside this range, set amax = 0 + * + * amax output (one value per 128x128 block), block 1 and block 2 are non-zero because they + * overlap with the [start_offset, start_offset+inp.length) range: + * ┌───────┬───────┬───────┬───────┐ + * │ 0 │ amax │ amax │ 0 │ Block0-3 + * ├───────┼───────┼───────┼───────┤ + * │ 0 │ 0 │ 0 │ 0 │ Block4-7 + * ├───────┼───────┼───────┼───────┤ + * │ 0 │ 0 │ 0 │ 0 │ Block8-11 + * ├───────┼───────┼───────┼───────┤ + * │ 0 │ 0 │ 0 │ 0 │ Block12-15 + * └───────┴───────┴───────┴───────┘ + * + * \param[in] inp Input tensor (continuous slice of flattened original tensor). + * \param[in,out] amax Output tensor for maximum absolute values per block. + * \param[in] h Height dimension of the logical tensor. + * \param[in] w Width dimension of the logical tensor. + * \param[in] amax_stride_h Stride in height dimension for amax tensor. + * \param[in] amax_stride_w Stride in width dimension for amax tensor. + * \param[in] start_offset Starting offset in the flattened tensor. + * \param[in] block_len Length of a quantization block to process. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h, size_t w, size_t amax_stride_h, size_t amax_stride_w, size_t start_offset, size_t block_len, cudaStream_t stream); +/*! \brief Perform partial FP8 casting with blockwise scaling. + * + * This function casts the input tensor to FP8 format using blockwise scaling factors. + * `inp` contains a continuous segment from the flattened original tensor. + * + * \param[in] inp Input tensor. + * \param[out] out Output tensor in FP8 format. + * \param[in] scale Scaling factors per block. + * \param[in] h Height dimension of the tensor. + * \param[in] w Width dimension of the tensor. + * \param[in] scale_stride_h Stride in height dimension for scale tensor. + * \param[in] scale_stride_w Stride in width dimension for scale tensor. + * \param[in] start_offset Starting offset for partial computation. + * \param[in] block_len Length of the block to process. + * \param[in] out_dtype Output FP8 datatype. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out, const NVTETensor scale, size_t h, size_t w, size_t scale_stride_h, size_t scale_stride_w, size_t start_offset, size_t block_len, const NVTEDType out_dtype, cudaStream_t stream); +/*! \brief Compute partial amax for MXFP8 scaling. + * + * This function computes the maximum absolute values along both row and column dimensions. + * input contains a continuous segment from the flattened original tensor. For each row/column + * block, if it overlaps with the range starting from start_offset, the amax is computed from + * `input`; otherwise, the amax is set to 0. + * + * Example: Original tensor (64 rows x 64 cols). + * Rowwise amax granularity: 1x32 (each row divided into 2 blocks) + * Columnwise amax granularity: 32x1 (each column divided into 2 blocks) + * input contains a continuous segment starting from start_offset. + * + * Logical view - Original Tensor (64x64) with 1x32 and 32x1 blocks: + * + * Rowwise blocks (1x32): Each row has 2 blocks + * ┌──────────────┬──────────────┐ + * row0 │ Block_r0_0 │ Block_r0_1 │ (cols 0-31, 32-63) + * ├──────────────┼──────────────┤ + * row1 │ Block_r1_0 │ Block_r1_1 │ + * ├──────────────┼──────────────┤ + * ... │ ... │ ... │ + * ├──────────────┼──────────────┤ + * row63│ Block_r63_0 │ Block_r63_1 │ + * └──────────────┴──────────────┘ + * + * Columnwise blocks (32x1): Each column has 2 blocks + * ┌───┬───┬─────┬───┬───┐ + * │c0 │c1 │ ... │c62│c63│ + * ┌────┼───┼───┼─────┼───┼───┤ + * │Blk0│ │ │ │ │ │ rows 0-31 + * ├────┼───┼───┼─────┼───┼───┤ + * │Blk1│ │ │ │ │ │ rows 32-63 + * └────┴───┴───┴─────┴───┴───┘ + * + * Physical view - Flattened in row-major order: + * Total elements: 64*64 = 4096 + * ┌──────────────────────────────────────────────────────┐ + * │[0...63][64...127][128...191]...[4032...4095] │ + * └──────────────────────────────────────────────────────┘ + * ^ ^ + * start_offset=60 start_offset + input.length=130 + * + * Row-wise amax output (one value per 1x32 block): + * ┌────────┬────────┐ + * │ amax │ amax │ row0 (block0 and block1 partially covered) + * ├────────┼────────┤ + * │ 0 │ 0 │ row1 (not covered) + * ├────────┼────────┤ + * │ ... │ ... │ + * ├────────┼────────┤ + * │ 0 │ 0 │ row63 (not covered) + * └────────┴────────┘ + * + * Column-wise amax output (one value per 32x1 block): + * ┌────────┬────────┬────────┬────────┬────────┬────────┬────────┐ + * │ amax │ amax │ amax │ amax │ amax │ amax │ amax │ ... row 0-31 + * ├────────┼────────┼────────┼────────┼────────┼────────┼────────┤ + * │ amax=0 │ amax=0 │ amax=0 │ amax=0 │ amax=0 │ amax=0 │ amax=0 │ ... row 32-62 + * └────────┴────────┴────────┴────────┴────────┴────────┴────────┘ + * col0 col1 col2 col3 col4 col5 col6 + * + * For each 1x32 or 32x1 block, if it overlaps with [start_offset, start_offset+input.length), + * compute amax; otherwise set to 0. + * + * \param[in] input Input tensor (continuous segment of flattened original tensor). + * \param[in,out] amax_rowwise Output tensor for row-wise maximum absolute values. + * \param[in,out] amax_colwise Output tensor for column-wise maximum absolute values. + * \param[in] rows Number of rows in the logical tensor. + * \param[in] cols Number of columns in the logical tensor. + * \param[in] start_offset Starting offset in the flattened tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_mxfp8_scaling_compute_partial_amax(const NVTETensor input, NVTETensor amax_rowwise, + NVTETensor amax_colwise, int rows, int cols, + size_t start_offset, cudaStream_t stream); + +/*! \brief Perform partial MXFP8 casting. + * + * This function casts the input tensor to MXFP8 format, producing both row-wise and + * column-wise scaled outputs. input contains a continuous segment from the flattened + * original tensor. + * + * \param[in] input Input (continuous segment of flattened original tensor). + * \param[out] output_rowwise Output tensor with row-wise scaling (MXFP8 format). + * \param[out] output_colwise Output tensor with column-wise scaling (MXFP8 format). + * \param[in] scale_inv_rowwise Inverse scaling factors for row-wise scaling. + * \param[in] scale_inv_colwise Inverse scaling factors for column-wise scaling. + * \param[in] rows Number of rows in the logical tensor. + * \param[in] cols Number of columns in the logical tensor. + * \param[in] start_offset Starting offset in the flattened tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_mxfp8_scaling_partial_cast(const NVTETensor input, NVTETensor output_rowwise, + NVTETensor output_colwise, const NVTETensor scale_inv_rowwise, + const NVTETensor scale_inv_colwise, int rows, int cols, + size_t start_offset, cudaStream_t stream); + +/*! \brief Compute per-tensor scaling factor for NVFP4 format. + * + * This function computes the scaling factor (alpha) for NVFP4 quantization based + * on the input tensors A and B, with options for using row-wise amax values. + * + * \param[in] inpA Input tensor A. + * \param[in] use_rowwise_amax_A Whether to use row-wise amax for tensor A. + * \param[in] inpB Input tensor B. + * \param[in] use_rowwise_amax_B Whether to use row-wise amax for tensor B. + * \param[in] alpha_in Input scaling factor. + * \param[out] alpha_out Output scaling factor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, const NVTETensor inpB, const bool use_rowwise_amax_B, float alpha_in, NVTETensor alpha_out, cudaStream_t stream); diff --git a/transformer_engine/common/multi_tensor/compute_scale.cu b/transformer_engine/common/multi_tensor/compute_scale.cu index dc4eb87145..0ac9ab7371 100644 --- a/transformer_engine/common/multi_tensor/compute_scale.cu +++ b/transformer_engine/common/multi_tensor/compute_scale.cu @@ -14,6 +14,7 @@ #include #include "../recipe/recipe_common.cuh" +#include "../util/ptx.cuh" #include "../utils.cuh" #include "multi_tensor_apply.cuh" @@ -55,6 +56,28 @@ struct ComputeScaleAndScaleInvFunctor { } }; +struct ComputeScaleInvE8M0Functor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int *unused, + TensorListMetadata<2> &tl) { + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + bf16 *amax = reinterpret_cast(tl.addresses[0][tensor_loc]); + amax += chunk_idx * chunk_size; + + e8m0_t *scale_inv = reinterpret_cast(tl.addresses[1][tensor_loc]); + scale_inv += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) { + scale_inv[i_start] = ptx::float_to_e8m0(static_cast(amax[i_start]) * + Quantized_Limits::max_norm_rcp); + } + } +}; + void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, float max_fp8, bool force_pow_2_scales, @@ -65,6 +88,19 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_f NVTE_CHECK_CUDA(cudaGetLastError()); } +void multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, + std::vector> tensor_lists, + cudaStream_t stream) { + NVTE_CHECK(tensor_lists[0][0]->data.dtype == DType::kBFloat16, "amax should be bf16"); + auto scale_inv_dtype = tensor_lists[1][0]->data.dtype; + NVTE_CHECK(scale_inv_dtype == DType::kByte || scale_inv_dtype == DType::kFloat8E8M0, + "scale_inv should be e8m0/uint8"); + Tensor dummy; + multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, dummy, tensor_lists, ComputeScaleInvE8M0Functor(), + stream); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace multi_tensor_compute_scale } // namespace transformer_engine @@ -82,3 +118,15 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETens convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8, force_pow_2_scales, epsilon, stream); } + +void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor **tensor_lists, + const size_t num_tensor_lists, + const size_t num_tensors_per_list, + cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_compute_scale_inv_e8m0_cuda); + using namespace transformer_engine; + + multi_tensor_compute_scale::multi_tensor_compute_scale_inv_e8m0_cuda( + chunk_size, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), + stream); +} diff --git a/transformer_engine/common/recipe/mxfp8_scaling.cu b/transformer_engine/common/recipe/mxfp8_scaling.cu new file mode 100644 index 0000000000..8a7ecc6b01 --- /dev/null +++ b/transformer_engine/common/recipe/mxfp8_scaling.cu @@ -0,0 +1,253 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "../common.h" +#include "../util/ptx.cuh" +#include "../utils.cuh" + +namespace transformer_engine { +namespace mxfp8_scaling_recipe { + +constexpr int rowwise_row_padding = 128; // Row padding of rowwise_scale and rowwise_amax +constexpr int rowwise_col_padding = 4; // Column padding of rowwise_scale and rowwise_amax +constexpr int colwise_row_padding = 4; // Row padding of colwise_scale and colwise_amax +constexpr int colwise_col_padding = 128; // Column padding of colwise_scale and colwise_amax + +constexpr int kRowsPerTile = 32; // Rows each block processes +constexpr int kColsPerTile = 128; // Columns each block processes + +constexpr int kThreadsPerBlock = 128; + +template +__global__ void __launch_bounds__(kThreadsPerBlock) + mxfp8_scaling_compute_partial_amax_kernel(const IType *input, IType *amax_rowwise, + IType *amax_colwise, int amax_rowwise_stride, + int amax_colwise_stride, int rows, int cols, + size_t start_offset, size_t len) { + __shared__ float smem_amax_rowwise[kRowsPerTile][kColsPerTile / 32]; + + size_t end_offset = start_offset + len; + const IType *input_minus_offset = input - start_offset; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + int c = blockIdx.x * kColsPerTile + threadIdx.x; + int r = blockIdx.y * kRowsPerTile; + + float col_amax = 0.0f; +#pragma unroll + for (int i = 0; i < kRowsPerTile; i++) { + size_t idx = r * cols + c; + float row_amax = 0.0f; + + if (r < rows && c < cols && idx >= start_offset && idx < end_offset) { + float abs_input = fabs(static_cast(input_minus_offset[idx])); + row_amax = fmaxf(row_amax, abs_input); + col_amax = fmaxf(col_amax, abs_input); + } + +#pragma unroll + for (int delta = 16; delta > 0; delta /= 2) { + float other_row_amax = __shfl_down_sync(0xFFFFFFFF, row_amax, delta); + row_amax = fmaxf(row_amax, other_row_amax); + } + + if (lane_idx == 0) { + smem_amax_rowwise[i][warp_idx] = row_amax; + } + + r++; + } + + amax_colwise[blockIdx.y * amax_colwise_stride + c] = static_cast(col_amax); + + __syncthreads(); + + int r_ = threadIdx.x / (kColsPerTile / 32); // rows in shared memory + int c_ = threadIdx.x % (kColsPerTile / 32); // cols in shared memory + r = blockIdx.y * kRowsPerTile + r_; + c = blockIdx.x * kColsPerTile / 32 + c_; + amax_rowwise[r * amax_rowwise_stride + c] = static_cast(smem_amax_rowwise[r_][c_]); +} + +template +__global__ void __launch_bounds__(kThreadsPerBlock) + mxfp8_scaling_partial_cast_kernel(const IType *input, OType *output_rowwise, + OType *output_colwise, const e8m0_t *scale_inv_rowwise, + const e8m0_t *scale_inv_colwise, int scale_inv_rowwise_stride, + int scale_inv_colwise_stride, int rows, int cols, + size_t start_offset, size_t len) { + __shared__ float smem_scales_rowwise[kRowsPerTile][kColsPerTile / 32]; + __shared__ float smem_scales_colwise[kColsPerTile]; + + // Load scales_rowwise + { + int r_ = threadIdx.x / (kColsPerTile / 32); // rows in shared memory + int c_ = threadIdx.x % (kColsPerTile / 32); // cols in shared memory + int r = blockIdx.y * kRowsPerTile + r_; + int c = blockIdx.x * kColsPerTile / 32 + c_; + size_t idx = r * scale_inv_rowwise_stride + c; + smem_scales_rowwise[r_][c_] = ptx::exp2f_rcp(scale_inv_rowwise[idx]); + } + + // Load scales_colwise + { + int c_ = threadIdx.x; + int r = blockIdx.y * kRowsPerTile / 32; + int c = blockIdx.x * kColsPerTile + c_; + size_t idx = r * scale_inv_colwise_stride + c; + smem_scales_colwise[c_] = ptx::exp2f_rcp(scale_inv_colwise[idx]); + } + + __syncthreads(); + + size_t end_offset = start_offset + len; + const IType *input_minus_offset = input - start_offset; + OType *output_rowwise_minus_offset = output_rowwise - start_offset; + OType *output_colwise_minus_offset = output_colwise - start_offset; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + int c = blockIdx.x * kColsPerTile + threadIdx.x; + int r = blockIdx.y * kRowsPerTile; + +#pragma unroll + for (int i = 0; i < kRowsPerTile; i++) { + size_t idx = r * cols + c; + + if (r < rows && c < cols && idx >= start_offset && idx < end_offset) { + float inp = static_cast(input_minus_offset[idx]); + OType out_rowwise = static_cast(inp * smem_scales_rowwise[i][warp_idx]); + OType out_colwise = static_cast(inp * smem_scales_colwise[threadIdx.x]); + output_rowwise_minus_offset[idx] = out_rowwise; + output_colwise_minus_offset[idx] = out_colwise; + } + + r++; + } +} + +void mxfp8_scaling_compute_partial_amax(const Tensor input, Tensor amax_rowwise, + Tensor amax_colwise, int rows, int cols, + size_t start_offset, cudaStream_t stream) { + NVTE_CHECK(rows % 32 == 0, "rows must be divisible by 32"); + NVTE_CHECK(cols % 32 == 0, "cols must be divisible by 32"); + + NVTE_CHECK(input.data.shape.size() == 1, "input must be a 1D tensor"); + NVTE_CHECK(start_offset + input.data.shape[0] <= static_cast(rows) * cols, + "Invalid start_offset"); + + NVTE_CHECK(amax_rowwise.data.shape.size() == 2, "amax_rowwise must be a 2D tensor"); + NVTE_CHECK(amax_rowwise.data.shape[0] % rowwise_row_padding == 0, + "Wrong padding of amax_rowwise's rows"); + NVTE_CHECK(amax_rowwise.data.shape[0] >= rows, "Invalid rows"); + NVTE_CHECK(amax_rowwise.data.shape[1] % rowwise_col_padding == 0, + "Wrong padding of amax_rowwise's cols"); + NVTE_CHECK(amax_rowwise.data.shape[1] >= cols / 32, "Invalid cols"); + NVTE_CHECK(amax_rowwise.dtype() == input.dtype(), "Wrong dtype of amax_rowwise"); + + NVTE_CHECK(amax_colwise.data.shape.size() == 2, "amax_colwise must be a 2D tensor"); + NVTE_CHECK(amax_colwise.data.shape[0] % colwise_row_padding == 0, + "Wrong padding of amax_colwise's rows"); + NVTE_CHECK(amax_colwise.data.shape[0] >= rows / 32, "Invalid rows"); + NVTE_CHECK(amax_colwise.data.shape[1] % colwise_col_padding == 0, + "Wrong padding of amax_colwise's cols"); + NVTE_CHECK(amax_colwise.data.shape[1] >= cols, "Invalid cols"); + NVTE_CHECK(amax_colwise.dtype() == input.dtype(), "Wrong dtype of amax_colwise"); + + int blocks_x = (cols + kColsPerTile - 1) / kColsPerTile; + int blocks_y = (rows + kRowsPerTile - 1) / kRowsPerTile; + dim3 grid(blocks_x, blocks_y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + mxfp8_scaling_compute_partial_amax_kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(amax_rowwise.data.dptr), + reinterpret_cast(amax_colwise.data.dptr), amax_rowwise.data.shape[1], + amax_colwise.data.shape[1], rows, cols, start_offset, input.data.shape[0]);) +} + +void mxfp8_scaling_partial_cast(const Tensor input, Tensor output_rowwise, Tensor output_colwise, + const Tensor scale_inv_rowwise, const Tensor scale_inv_colwise, + int rows, int cols, size_t start_offset, cudaStream_t stream) { + NVTE_CHECK(rows % 32 == 0, "rows must be divisible by 32"); + NVTE_CHECK(cols % 32 == 0, "cols must be divisible by 32"); + + NVTE_CHECK(input.data.shape.size() == 1, "input must be a 1D tensor"); + NVTE_CHECK(start_offset + input.data.shape[0] <= static_cast(rows) * cols, + "Invalid start_offset"); + + NVTE_CHECK(output_rowwise.data.shape.size() == 1, "output_rowwise must be a 1D tensor"); + NVTE_CHECK(output_colwise.data.shape.size() == 1, "output_colwise must be a 1D tensor"); + NVTE_CHECK(output_rowwise.data.shape[0] == input.data.shape[0], + "Size of input and output_rowwise mismatch"); + NVTE_CHECK(output_colwise.data.shape[0] == input.data.shape[0], + "Size of input and output_colwise mismatch"); + + NVTE_CHECK(output_rowwise.dtype() == DType::kFloat8E4M3 || output_rowwise.dtype() == DType::kByte, + "output_rowwise should be e4m3 or uint8"); + NVTE_CHECK(output_colwise.dtype() == DType::kFloat8E4M3 || output_colwise.dtype() == DType::kByte, + "output_colwise should be e4m3 or uint8"); + + NVTE_CHECK(scale_inv_rowwise.data.shape.size() == 2, "scale_inv_rowwise must be a 2D tensor"); + NVTE_CHECK(scale_inv_rowwise.data.shape[0] % rowwise_row_padding == 0, + "Wrong padding of scale_inv_rowwise's rows"); + NVTE_CHECK(scale_inv_rowwise.data.shape[0] >= rows, "Invalid rows"); + NVTE_CHECK(scale_inv_rowwise.data.shape[1] % rowwise_col_padding == 0, + "Wrong padding of scale_inv_rowwise's cols"); + NVTE_CHECK(scale_inv_rowwise.data.shape[1] >= cols / 32, "Invalid cols"); + NVTE_CHECK(scale_inv_rowwise.dtype() == DType::kByte, "Wrong dtype of scale_inv_rowwise"); + + NVTE_CHECK(scale_inv_colwise.data.shape.size() == 2, "scale_inv_colwise must be a 2D tensor"); + NVTE_CHECK(scale_inv_colwise.data.shape[0] % colwise_row_padding == 0, + "Wrong padding of scale_inv_colwise's rows"); + NVTE_CHECK(scale_inv_colwise.data.shape[0] >= rows / 32, "Invalid rows"); + NVTE_CHECK(scale_inv_colwise.data.shape[1] % colwise_col_padding == 0, + "Wrong padding of scale_inv_colwise's cols"); + NVTE_CHECK(scale_inv_colwise.data.shape[1] >= cols, "Invalid cols"); + NVTE_CHECK(scale_inv_colwise.dtype() == DType::kByte, "Wrong dtype of scale_inv_colwise"); + + int blocks_x = (cols + kColsPerTile - 1) / kColsPerTile; + int blocks_y = (rows + kRowsPerTile - 1) / kRowsPerTile; + dim3 grid(blocks_x, blocks_y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + mxfp8_scaling_partial_cast_kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output_rowwise.data.dptr), + reinterpret_cast(output_colwise.data.dptr), + reinterpret_cast(scale_inv_rowwise.data.dptr), + reinterpret_cast(scale_inv_colwise.data.dptr), + scale_inv_rowwise.data.shape[1], scale_inv_colwise.data.shape[1], rows, cols, + start_offset, input.data.shape[0]);) +} + +} // namespace mxfp8_scaling_recipe +} // namespace transformer_engine + +void nvte_mxfp8_scaling_compute_partial_amax(const NVTETensor input, NVTETensor amax_rowwise, + NVTETensor amax_colwise, int rows, int cols, + size_t start_offset, cudaStream_t stream) { + NVTE_API_CALL(nvte_mxfp8_scaling_compute_partial_amax); + using namespace transformer_engine; + mxfp8_scaling_recipe::mxfp8_scaling_compute_partial_amax( + *convertNVTETensorCheck(input), *convertNVTETensorCheck(amax_rowwise), + *convertNVTETensorCheck(amax_colwise), rows, cols, start_offset, stream); +} + +void nvte_mxfp8_scaling_partial_cast(const NVTETensor input, NVTETensor output_rowwise, + NVTETensor output_colwise, const NVTETensor scale_inv_rowwise, + const NVTETensor scale_inv_colwise, int rows, int cols, + size_t start_offset, cudaStream_t stream) { + NVTE_API_CALL(nvte_mxfp8_scaling_partial_cast); + using namespace transformer_engine; + mxfp8_scaling_recipe::mxfp8_scaling_partial_cast( + *convertNVTETensorCheck(input), *convertNVTETensorCheck(output_rowwise), + *convertNVTETensorCheck(output_colwise), *convertNVTETensorCheck(scale_inv_rowwise), + *convertNVTETensorCheck(scale_inv_colwise), rows, cols, start_offset, stream); +} diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 44c49b20bc..80479dccf4 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -335,6 +335,15 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const size_t h, size_t w, size_t start_offset, size_t block_len, const DType out_dtype); +void mxfp8_scaling_compute_partial_amax(const at::Tensor &input, at::Tensor amax_rowwise, + at::Tensor amax_colwise, int rows, int cols, + size_t start_offset); + +void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwise, + at::Tensor output_colwise, const at::Tensor &scale_inv_rowwise, + const at::Tensor &scale_inv_colwise, int rows, int cols, + size_t start_offset); + /*************************************************************************************************** * Rotary positional embedding **************************************************************************************************/ @@ -451,6 +460,9 @@ void multi_tensor_compute_scale_and_scale_inv_cuda( int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, float max_fp8, bool force_pow_2_scales, float epsilon); +void multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, const py::object &dummy, + std::vector> tensor_lists); + /*************************************************************************************************** * padding **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp b/transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp similarity index 53% rename from transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp rename to transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp index bea6f8c907..3be2ca9396 100644 --- a/transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp @@ -48,4 +48,42 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const start_offset, block_len, static_cast(out_dtype), at::cuda::getCurrentCUDAStream()); } +void mxfp8_scaling_compute_partial_amax(const at::Tensor &input, at::Tensor amax_rowwise, + at::Tensor amax_colwise, int rows, int cols, + size_t start_offset) { + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(amax_rowwise.is_contiguous(), "amax_rowwise must be contiguous"); + TORCH_CHECK(amax_colwise.is_contiguous(), "amax_colwise must be contiguous"); + + const TensorWrapper input_cu = makeTransformerEngineTensor(input); + TensorWrapper amax_rowwise_cu = makeTransformerEngineTensor(amax_rowwise); + TensorWrapper amax_colwise_cu = makeTransformerEngineTensor(amax_colwise); + + nvte_mxfp8_scaling_compute_partial_amax(input_cu.data(), amax_rowwise_cu.data(), + amax_colwise_cu.data(), rows, cols, start_offset, + at::cuda::getCurrentCUDAStream()); +} + +void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwise, + at::Tensor output_colwise, const at::Tensor &scale_inv_rowwise, + const at::Tensor &scale_inv_colwise, int rows, int cols, + size_t start_offset) { + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(output_rowwise.is_contiguous(), "output_rowwise must be contiguous"); + TORCH_CHECK(output_colwise.is_contiguous(), "output_colwise must be contiguous"); + TORCH_CHECK(scale_inv_rowwise.is_contiguous(), "scale_inv_rowwise must be contiguous"); + TORCH_CHECK(scale_inv_colwise.is_contiguous(), "scale_inv_colwise must be contiguous"); + + const TensorWrapper input_cu = makeTransformerEngineTensor(input); + TensorWrapper output_rowwise_cu = makeTransformerEngineTensor(output_rowwise); + TensorWrapper output_colwise_cu = makeTransformerEngineTensor(output_colwise); + const TensorWrapper scale_inv_rowwise_cu = makeTransformerEngineTensor(scale_inv_rowwise); + const TensorWrapper scale_inv_colwise_cu = makeTransformerEngineTensor(scale_inv_colwise); + + nvte_mxfp8_scaling_partial_cast(input_cu.data(), output_rowwise_cu.data(), + output_colwise_cu.data(), scale_inv_rowwise_cu.data(), + scale_inv_colwise_cu.data(), rows, cols, start_offset, + at::cuda::getCurrentCUDAStream()); +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp index 8a1a34698b..e60b001f6f 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp @@ -20,4 +20,14 @@ void multi_tensor_compute_scale_and_scale_inv_cuda( force_pow_2_scales, epsilon, at::cuda::getCurrentCUDAStream()); } +void multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, const py::object &dummy, + std::vector> tensor_lists) { + NVTE_CHECK(dummy.is_none(), "No-op flag is not supported."); + auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = + makeTransformerEngineTensorList(tensor_lists); + + nvte_multi_tensor_compute_scale_inv_e8m0_cuda(chunk_size, tensor_lists_ptr.data(), num_lists, + num_tensors, at::cuda::getCurrentCUDAStream()); +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 3b81393dbd..d0f450bc71 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -276,6 +276,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"), py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"), py::arg("out_dtype"), py::call_guard()); + m.def("mxfp8_scaling_compute_partial_amax", + &transformer_engine::pytorch::mxfp8_scaling_compute_partial_amax, + "Compute partial amax from master weights for fp8 mxfp8 scaling", py::arg("input"), + py::arg("amax_rowwise"), py::arg("amax_colwise"), py::arg("rows"), py::arg("cols"), + py::arg("start_offset"), py::call_guard()); + m.def("mxfp8_scaling_partial_cast", &transformer_engine::pytorch::mxfp8_scaling_partial_cast, + "Partial cast from master weights for fp8 mxfp8 scaling", py::arg("input"), + py::arg("output_rowwise"), py::arg("output_colwise"), py::arg("scale_inv_rowwise"), + py::arg("scale_inv_colwise"), py::arg("rows"), py::arg("cols"), py::arg("start_offset"), + py::call_guard()); m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding, "Fused Multi-tensor padding", py::call_guard()); m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding, @@ -427,6 +437,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_compute_scale_and_scale_inv", &transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda, "Fused compute scale and scale_inv from amax", py::call_guard()); + m.def("multi_tensor_compute_scale_inv_e8m0", + &transformer_engine::pytorch::multi_tensor_compute_scale_inv_e8m0_cuda, + "Fused compute E8M0 scale_inv from amax", py::call_guard()); // Comm+GEMM Overlap m.def("bulk_overlap_ag_with_external_gemm", diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 9773e17e64..94f761f2b0 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -8,7 +8,11 @@ import torch import transformer_engine_torch as tex -from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv +from transformer_engine_torch import ( + multi_tensor_scale, + multi_tensor_compute_scale_and_scale_inv, + multi_tensor_compute_scale_inv_e8m0, +) from ..quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer @@ -85,6 +89,7 @@ def cast_master_weights_to_fp8( delayed_scaling_params = [] current_scaling_params = [] blockwise_scaling_params = [] + mxfp8_scaling_params = [] if fsdp_shard_model_weights is None: use_fsdp_shard_model_weights = False @@ -131,8 +136,8 @@ def cast_master_weights_to_fp8( (model_weight, master_weight, start_offset, fsdp_shard_model_weight) ) elif isinstance(quantizer, MXFP8Quantizer): - raise NotImplementedError( - "cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet" + mxfp8_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) ) else: raise ValueError( @@ -146,6 +151,8 @@ def cast_master_weights_to_fp8( _cast_master_weights_to_fp8_current_scaling(current_scaling_params, *extra_args) if len(blockwise_scaling_params) > 0: _cast_master_weights_to_fp8_blockwise_scaling(blockwise_scaling_params, *extra_args) + if len(mxfp8_scaling_params) > 0: + _cast_master_weights_to_fp8_mxfp8_scaling(mxfp8_scaling_params, *extra_args) def _cast_master_weights_to_fp8_delayed_scaling( @@ -467,6 +474,131 @@ def _cast_master_weights_to_fp8_blockwise_scaling( ) +def _cast_master_weights_to_fp8_mxfp8_scaling( + params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False +): # pylint: disable=unused-argument + r"""Helper function to cast master weights to FP8 primary weights for mxfp8 scaling. + + Parameters + ---------- + params : List of tuple, each tuple contains a model weight, a master weight, and an offset + indicating the starting index of the master weight in the model weight. + group : The distributed group to do amax reduction. Typically it's the data parallel + group. + use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded. + """ + + # Parameter attributes + device = params[0][0].device + for _, master_weight, _, _ in params: + if master_weight is not None: + master_weight_dtype = master_weight.dtype + break + + # Get the total number of amax elements in all the model weights. + cu_rowwise_amax_sizes = [0] + cu_colwise_amax_sizes = [0] + for model_weight, _, _, _ in params: + rowwise_shape = model_weight._rowwise_scale_inv.shape + assert len(rowwise_shape) == 2 + colwise_shape = model_weight._columnwise_scale_inv.shape + assert len(colwise_shape) == 2 + cu_rowwise_amax_sizes.append( + cu_rowwise_amax_sizes[-1] + rowwise_shape[0] * rowwise_shape[1] + ) + cu_colwise_amax_sizes.append( + cu_colwise_amax_sizes[-1] + colwise_shape[0] * colwise_shape[1] + ) + + # Create a contiguous buffer to store amaxes temporarily, so we can perform all all-reduce + # NCCL kernels at once. + packed_amaxes = torch.zeros( + cu_rowwise_amax_sizes[-1] + cu_colwise_amax_sizes[-1], + dtype=master_weight_dtype, + device=device, + ) + + # --------------------------------------------------------------------------------------------- + # Step 1: Iterate through all the none empty master weights and compute amax of them. Store the + # amaxes in a contiguous buffer. If a block of a master weight is empty, the + # corresponding amax will be set to 0. + # --------------------------------------------------------------------------------------------- + amaxes_rowwise, scale_invs_rowwise = [], [] + amaxes_colwise, scale_invs_colwise = [], [] + for i, (model_weight, master_weight, start_offset, _) in enumerate(params): + rowwise_shape = model_weight._rowwise_scale_inv.shape + colwise_shape = model_weight._columnwise_scale_inv.shape + rowwise_start = cu_rowwise_amax_sizes[i] + rowwise_end = cu_rowwise_amax_sizes[i + 1] + colwise_start = cu_rowwise_amax_sizes[-1] + cu_colwise_amax_sizes[i] + colwise_end = cu_rowwise_amax_sizes[-1] + cu_colwise_amax_sizes[i + 1] + amax_rowwise = packed_amaxes[rowwise_start:rowwise_end].reshape(rowwise_shape) + amax_colwise = packed_amaxes[colwise_start:colwise_end].reshape(colwise_shape) + amaxes_rowwise.append(amax_rowwise) + amaxes_colwise.append(amax_colwise) + scale_invs_rowwise.append(model_weight._rowwise_scale_inv) + scale_invs_colwise.append(model_weight._columnwise_scale_inv) + + # Compute amax of the master weight and store it in packed_amaxes. + if master_weight is not None: + assert len(model_weight.shape) == 2 + h, w = model_weight.shape + tex.mxfp8_scaling_compute_partial_amax( + master_weight, amax_rowwise, amax_colwise, h, w, start_offset + ) + + # --------------------------------------------------------------------------------------------- + # Step 2: Perform all-reduce on packed_amaxes to get the global amax. + # --------------------------------------------------------------------------------------------- + torch.distributed.all_reduce(packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=group) + + # --------------------------------------------------------------------------------------------- + # Step 3: Update scales and scale_invs. + # --------------------------------------------------------------------------------------------- + multi_tensor_applier( + multi_tensor_compute_scale_inv_e8m0, + None, # dummy_overflow_buf + [ + amaxes_rowwise + amaxes_colwise, + scale_invs_rowwise + scale_invs_colwise, + ], + ) + + # --------------------------------------------------------------------------------------------- + # Step 4: Cast master weights to FP8. + # --------------------------------------------------------------------------------------------- + for ( + (model_weight, master_weight, start_offset, model_weight_fragment), + scale_inv_rowwise, + scale_inv_colwise, + ) in zip(params, scale_invs_rowwise, scale_invs_colwise): + # If master weight is None, it means that the master weight of the current model weight + # is in other DP ranks. + if master_weight is None: + continue + + # Cast master weight to FP8 + end_offset = start_offset + master_weight.numel() + if use_fsdp_shard_model_weights: + rowwise_fragment = model_weight_fragment[0] + colwise_fragment = model_weight_fragment[1] + else: + rowwise_fragment = model_weight._rowwise_data.reshape(-1)[start_offset:end_offset] + colwise_fragment = model_weight._columnwise_data.reshape(-1)[start_offset:end_offset] + assert len(model_weight.shape) == 2 + h, w = model_weight.shape + tex.mxfp8_scaling_partial_cast( + master_weight, + rowwise_fragment, + colwise_fragment, + scale_inv_rowwise, + scale_inv_colwise, + h, + w, + start_offset, + ) + + def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Tensor]]): """ Post-processing after all-gather for weights in distributed optimizer. @@ -485,6 +617,9 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten elif isinstance(model_weight, Float8BlockwiseQTensor): # Blockwise scaling: create column-wise storage. model_weight._create_columnwise() + elif isinstance(model_weight, MXFP8Tensor): + # MXFP8 scaling: no need to do anything. + pass elif isinstance(model_weight, QuantizedTensor): raise ValueError(f"post_processing for {type(model_weight)} is not supported")