diff --git a/tests/test_fp4_quantize.py b/tests/test_fp4_quantize.py index 1d1769bb61..a02ee3715c 100644 --- a/tests/test_fp4_quantize.py +++ b/tests/test_fp4_quantize.py @@ -2,7 +2,7 @@ import pytest import torch -from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_nvfp4_quant +from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_fp4_quant from flashinfer import ( e2m1_and_ufp8sf_scale_to_float, @@ -88,30 +88,47 @@ def unswizzle_sf( @pytest.mark.parametrize("shape", SHAPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("sf_use_ue8m0", [False, True]) +@pytest.mark.parametrize("is_swizzled", [False, True]) @torch.inference_mode() def test_fp4_quantization( dtype: torch.dtype, shape: tuple[int, int], seed: int, device: str, + sf_use_ue8m0: bool, + is_swizzled: bool, ) -> None: if not is_sm100a_supported(torch.device(device)): pytest.skip("Nvfp4 Requires compute capability of 10 or above") torch.set_default_device(device) torch.manual_seed(seed) m, n = shape + sf_vec_size = 32 if sf_use_ue8m0 else 16 x = torch.randn((m, n), dtype=dtype) tensor_amax = torch.abs(x).max().to(torch.float32) - global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax - out_ref, scale_ref = ref_nvfp4_quant(x, global_scale, BLOCK_SIZE) - out, out_scale = fp4_quantize(x, global_scale, BLOCK_SIZE, False) - assert n % BLOCK_SIZE == 0, f"cols needs to be {BLOCK_SIZE} divisible" - scale_ans = recover_swizzled_scales( - out_scale.reshape(-1, n // BLOCK_SIZE).view(torch.float8_e4m3fn), - m, - n, - BLOCK_SIZE, + if sf_use_ue8m0: + global_scale = torch.tensor(1.0, dtype=torch.float32) + else: + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + out_ref, scale_ref = ref_fp4_quant(x, global_scale, sf_vec_size, sf_use_ue8m0) + out, out_scale = fp4_quantize( + x, global_scale, sf_vec_size, sf_use_ue8m0, is_swizzled ) + assert n % sf_vec_size == 0, f"cols needs to be {sf_vec_size} divisible" + if sf_use_ue8m0: + out_scale = (out_scale.to(torch.int32) << 23).view(torch.float32) + else: + out_scale = out_scale.view(torch.float8_e4m3fn).to(torch.float32) + if is_swizzled: + scale_ans = recover_swizzled_scales( + out_scale.reshape(-1, n // sf_vec_size), + m, + n, + sf_vec_size, + ) + else: + scale_ans = out_scale out_ans = cast_from_fp4(out).reshape(m, n) torch.testing.assert_close(out_ans, out_ref, rtol=1e0, atol=1e-1) torch.testing.assert_close(scale_ans, scale_ref, rtol=1e-1, atol=1e-1) diff --git a/tests/test_trtllm_gen_context.py b/tests/test_trtllm_gen_context.py index 0040b26a0f..5c8107e3dc 100644 --- a/tests/test_trtllm_gen_context.py +++ b/tests/test_trtllm_gen_context.py @@ -2,7 +2,7 @@ import pytest import torch -from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_nvfp4_quant +from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_fp4_quant import flashinfer from flashinfer.utils import FP4Tensor @@ -437,7 +437,7 @@ def test_trtllm_batch_prefill( if o_dtype == "nvfp4": output = cast_from_fp4(output) - output_ref, out_scale_factor_ref = ref_nvfp4_quant(output_ref, o_sf_scale, 16) + output_ref, out_scale_factor_ref = ref_fp4_quant(output_ref, o_sf_scale, 16) out_scale_factor = recover_swizzled_scales( out_scale_factor, output.shape[0], diff --git a/tests/test_trtllm_gen_decode.py b/tests/test_trtllm_gen_decode.py index 859f4dc0fd..0cce928059 100644 --- a/tests/test_trtllm_gen_decode.py +++ b/tests/test_trtllm_gen_decode.py @@ -3,7 +3,7 @@ import pytest import torch import torch.nn.functional as F -from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_nvfp4_quant +from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_fp4_quant import flashinfer from flashinfer.utils import FP4Tensor @@ -328,7 +328,7 @@ def test_trtllm_batch_decode_fmha( if o_dtype == "nvfp4": output = cast_from_fp4(output) - output_ref, out_scale_factor_ref = ref_nvfp4_quant(output_ref, o_sf_scale, 16) + output_ref, out_scale_factor_ref = ref_fp4_quant(output_ref, o_sf_scale, 16) out_scale_factor = recover_swizzled_scales( out_scale_factor, output.shape[0], diff --git a/tests/test_trtllm_gen_fused_moe.py b/tests/test_trtllm_gen_fused_moe.py index f82b61ccef..877180f723 100644 --- a/tests/test_trtllm_gen_fused_moe.py +++ b/tests/test_trtllm_gen_fused_moe.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from enum import IntEnum -from typing import Dict, Literal +from typing import Dict import pytest import torch @@ -27,6 +27,8 @@ RoutingMethodType, e2m1_and_ufp8sf_scale_to_float, fp4_quantize, + mxfp8_dequantize_host, + mxfp8_quantize, next_positive_power_of_2, reorder_rows_for_gated_act_gemm, shuffle_matrix_a, @@ -168,23 +170,17 @@ def cleanup(self): def _run_moe_computation(self, runtime_args): """Run the MoE computation.""" - # Quantize hidden states to FP4 - hidden_states_fp4_bytes, hidden_states_scale_fp4_bytes, _ = quant_fp4( - self.input_tensor, self.config["hidden_states_scale_global"], False, False + input_quantized = self.moe_impl.quantize_inputs( + self.input_tensor, + self.config["hidden_states_scale_global"], + is_swizzling=False, ) - hidden_states_fp4 = hidden_states_fp4_bytes.reshape( - self.input_tensor.shape[0], self.input_tensor.shape[1] // 2 - ) - hidden_states_scale_linear_fp4 = hidden_states_scale_fp4_bytes.view( - torch.float8_e4m3fn - ).reshape(-1) - # Call MoE kernel and return output tensor output = trtllm_fp4_block_scale_moe( routing_logits=runtime_args["expert_logits"], routing_bias=runtime_args["routing_bias"], - hidden_states=hidden_states_fp4, - hidden_states_scale=hidden_states_scale_linear_fp4, + hidden_states=input_quantized["hidden_states"], + hidden_states_scale=input_quantized["hidden_states_scale"], gemm1_weights=self.static_data["gemm1_weights_fp4_shuffled"], gemm1_weights_scale=self.static_data["gemm1_scales_fp4_shuffled"], gemm1_bias=None, @@ -212,12 +208,14 @@ def _run_moe_computation(self, runtime_args): return output # Extract tensor from tuple -class QuantizationMode(IntEnum): +class QuantMode(IntEnum): """Supported quantization modes for MoE testing.""" - FP4_NVFP4 = 1 - FP8_BLOCK_SCALE = 2 - FP8_PER_TENSOR = 3 + FP4_NVFP4_NVFP4 = 1 + FP4_MXFP4_MXFP8 = 2 + FP4_MXFP4_Bf16 = 3 + FP8_BLOCK_SCALE = 4 + FP8_PER_TENSOR = 5 # ==================================================================================== @@ -300,26 +298,44 @@ def __str__(self): class FP4Moe(Moe): - """FP4 NvFP4 MoE implementation with block scaling.""" + """ + FP4 NvFP4 / MxFP4 MoE implementation with block scaling. + Args: + is_mxfp4: Whether to use MxFP4 or NvFP4 weight quantization + If True, the activation is quantized to MxFP8, else the activation is quantized to NvFP4 + """ + + def __init__(self, quant_mode: QuantMode): + super().__init__() + self.quant_mode = quant_mode + self.is_mxfp4 = ( + quant_mode == QuantMode.FP4_MXFP4_MXFP8 + or quant_mode == QuantMode.FP4_MXFP4_Bf16 + ) + self.sf_vec_size = 32 if self.is_mxfp4 else 16 def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to FP4 format and compute global scale factors.""" num_experts = gemm1_weights.shape[0] - use_ue8m0 = False - # Compute global scale factor for hidden states (offline calibration) - hidden_states_scale_global = calculate_fp4_global_scale_factor( - hidden_states_sample - ) + if self.quant_mode == QuantMode.FP4_NVFP4_NVFP4: + # nvfp4 hidden states + hidden_states_scale_global = calculate_fp4_global_scale_factor( + hidden_states_sample, + False, + ) + else: + # mxfp8 / bf16 hidden states + hidden_states_scale_global = 1.0 # Quantize the weights for FC1 gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = ( - quant_fp4_batches(gemm1_weights, num_experts, use_ue8m0, True) + quant_fp4_batches(gemm1_weights, num_experts, self.is_mxfp4, True) ) # Quantize the weights for FC2 gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = ( - quant_fp4_batches(gemm2_weights, num_experts, use_ue8m0, True) + quant_fp4_batches(gemm2_weights, num_experts, self.is_mxfp4, True) ) return { @@ -332,21 +348,50 @@ def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): "gemm2_scales_global": gemm2_scales_global, } - def quantize_inputs(self, hidden_states, hidden_states_scale_global): - """Quantize hidden states to FP4 format using pre-computed global scale.""" - use_ue8m0 = False - - # Quantize hidden states using pre-computed global scale factor - ( - hidden_states_fp4_bytes, - hidden_states_scale_fp4_bytes, - _, - ) = quant_fp4(hidden_states, hidden_states_scale_global, use_ue8m0, True) + def quantize_inputs( + self, hidden_states, hidden_states_scale_global, is_swizzling=True + ): + if self.quant_mode == QuantMode.FP4_MXFP4_MXFP8: + """Quantize hidden states to MxFP8 format.""" + hidden_states_quant, hidden_states_scale = mxfp8_quantize( + hidden_states, is_swizzling + ) + hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape( + -1 + ) + print( + f"hidden_states.shape: {hidden_states_quant.shape}, dtype: {hidden_states_quant.dtype}" + ) + print( + f"hidden_states_scale.shape: {hidden_states_scale.shape}, dtype: {hidden_states_scale.dtype}" + ) + return { + "hidden_states": hidden_states_quant, + "hidden_states_scale": hidden_states_scale.view( + torch.float8_e4m3fn + ).reshape(-1), + } + elif self.quant_mode == QuantMode.FP4_NVFP4_NVFP4: + """Quantize hidden states to NvFP4 format using pre-computed global scale.""" + ( + hidden_states_fp4_bytes, + hidden_states_scale_fp4_bytes, + _, + ) = quant_fp4( + hidden_states, hidden_states_scale_global, False, is_swizzling + ) - return { - "hidden_states": hidden_states_fp4_bytes, - "hidden_states_scale": hidden_states_scale_fp4_bytes, - } + return { + "hidden_states": hidden_states_fp4_bytes, + "hidden_states_scale": hidden_states_scale_fp4_bytes.view( + torch.float8_e4m3fn + ).reshape(-1), + } + else: # bf16 + return { + "hidden_states": hidden_states.to(torch.bfloat16), + "hidden_states_scale": None, + } def prepare_static_weights_for_kernel( self, @@ -360,7 +405,7 @@ def prepare_static_weights_for_kernel( weight_processing, ): """Prepare quantized weights for kernel (done offline with weights).""" - use_ue8m0 = False + use_ue8m0 = self.is_mxfp4 epilogue_tile_m = 128 # FIXME: this depends on the kernel internals # Quantize weights with linear layout for kernels @@ -378,7 +423,7 @@ def prepare_static_weights_for_kernel( gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view( torch.float8_e4m3fn ).reshape( - num_experts, 2 * intermediate_size, hidden_size // 16 + num_experts, 2 * intermediate_size, hidden_size // self.sf_vec_size ) # fp8 scaling factors gemm2_weights_fp4 = args.gemm2_weights.view(torch.float8_e4m3fn).reshape( @@ -387,7 +432,7 @@ def prepare_static_weights_for_kernel( gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view( torch.float8_e4m3fn ).reshape( - num_experts, hidden_size, intermediate_size // 16 + num_experts, hidden_size, intermediate_size // self.sf_vec_size ) # fp8 scaling factors # Using cached permute index calculation can speed up weights preprocessing @@ -459,14 +504,16 @@ def prepare_static_weights_for_kernel( gemm1_scales_fp4_shuffled = ( torch.stack(gemm1_scales_fp4_shuffled) .view(torch.float8_e4m3fn) - .reshape(num_experts, 2 * intermediate_size, hidden_size // 16) + .reshape( + num_experts, 2 * intermediate_size, hidden_size // self.sf_vec_size + ) ) gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled) gemm2_scales_fp4_shuffled = ( torch.stack(gemm2_scales_fp4_shuffled) .view(torch.float8_e4m3fn) - .reshape(num_experts, hidden_size, intermediate_size // 16) + .reshape(num_experts, hidden_size, intermediate_size // self.sf_vec_size) ) # Calculate scaling factors that depend on weights @@ -536,8 +583,7 @@ def call_moe( cuda_graph.cleanup() def compute_reference(self, args): - """FP4 reference implementation.""" - return run_moe_reference_fp4(args) + return run_moe_reference_fp4(args, self.quant_mode) def get_tolerances(self): """Get FP4-specific accuracy tolerances.""" @@ -885,16 +931,14 @@ def get_tolerances(self): # ==================================================================================== -def get_moe_impl(quant_mode): +def get_moe_impl(quant_mode: QuantMode): """Factory function to get the appropriate MoE implementation.""" - if quant_mode == QuantizationMode.FP4_NVFP4: - return FP4Moe() - elif quant_mode == QuantizationMode.FP8_BLOCK_SCALE: + if quant_mode == QuantMode.FP8_BLOCK_SCALE: return FP8BlockScaleMoe() - elif quant_mode == QuantizationMode.FP8_PER_TENSOR: + elif quant_mode == QuantMode.FP8_PER_TENSOR: return FP8PerTensorMoe() else: - raise NotImplementedError(f"Quantization mode {quant_mode} not implemented") + return FP4Moe(quant_mode) class moe_args: @@ -1165,7 +1209,7 @@ def check_accuracy(a, b, atol, rtol, percent): # ==================================================================================== -def calculate_fp4_global_scale_factor(tensor): +def calculate_fp4_global_scale_factor(tensor, use_ue8m0=False): """ Calculate FP4 global scale factor for a tensor. @@ -1176,7 +1220,10 @@ def calculate_fp4_global_scale_factor(tensor): This function is used here for testing/reference purposes. Formula: (448 * 6) represents max representable value in FP4 format. """ - return (448 * 6) / tensor.float().abs().nan_to_num().max() + if use_ue8m0: + return torch.tensor(1.0, dtype=torch.float32) + else: + return (448 * 6) / tensor.float().abs().nan_to_num().max() def e2m1_and_ufp8_scale_batches( @@ -1216,7 +1263,7 @@ def quant_fp4(a, a_global_sf, use_ue8m0=False, is_sf_swizzled_layout=True): Pure function - same inputs always produce same outputs. """ - sf_vec_size = 16 + sf_vec_size = 32 if use_ue8m0 else 16 a_fp4, a_sf = fp4_quantize( a.cuda(), a_global_sf.cuda(), sf_vec_size, use_ue8m0, is_sf_swizzled_layout @@ -1232,7 +1279,7 @@ def quant_fp4_batches(a, num_experts, use_ue8m0=False, is_sf_swizzled_layout=Tru global_sfs = [] for i in range(num_experts): # Use centralized global scale factor calculation - a_global_sf = calculate_fp4_global_scale_factor(a[i]) + a_global_sf = calculate_fp4_global_scale_factor(a[i], use_ue8m0) a_fp4, a_sf, _ = quant_fp4(a[i], a_global_sf, use_ue8m0, is_sf_swizzled_layout) quant_a.append(a_fp4) sfs.append(a_sf) @@ -1248,8 +1295,8 @@ def quant_fp4_batches(a, num_experts, use_ue8m0=False, is_sf_swizzled_layout=Tru def quant_dequant_fp4(a, use_ue8m0=False, is_sf_swizzled_layout=True): """FP4 quantize-dequantize roundtrip function with centralized global scale factor calculation.""" # Use centralized global scale factor calculation - a_global_sf = calculate_fp4_global_scale_factor(a) - sf_vec_size = 16 + a_global_sf = calculate_fp4_global_scale_factor(a, use_ue8m0) + sf_vec_size = 32 if use_ue8m0 else 16 a_fp4, a_sf = fp4_quantize( a.cuda(), a_global_sf.cuda(), sf_vec_size, use_ue8m0, is_sf_swizzled_layout @@ -1359,7 +1406,7 @@ def dequant_reference_dsfp8(input, scale, transpose_scale, block_m, block_n): # ==================================================================================== -def run_moe_dequant(args, quant_mode: Literal["fp4", "dsFp8", "perTensorFp8"]): +def run_moe_dequant(args, quant_mode: QuantMode): """Common dequantized MoE reference implementation.""" # Permute total_num_padded_tokens = args.permute_info["permutedBufferSize"] @@ -1424,19 +1471,35 @@ def run_moe_dequant(args, quant_mode: Literal["fp4", "dsFp8", "perTensorFp8"]): i += my_num_tokens i = (i + args.padding - 1) // args.padding * args.padding - if quant_mode == "fp4": + if quant_mode == QuantMode.FP4_NVFP4_NVFP4: # Use centralized function for activation quantization activation_output, c_global_sf = quant_dequant_fp4( activation_output.to(torch.bfloat16), False, True ) activation_output = activation_output.to(torch.float) args.c_global_sf = c_global_sf - elif quant_mode == "perTensorFp8": + elif quant_mode == QuantMode.FP8_PER_TENSOR: activation_output, c_global_sf = quant_dequant_per_tensor_fp8( activation_output.to(torch.bfloat16) ) activation_output = activation_output.to(torch.float) args.c_global_sf = c_global_sf + elif quant_mode == QuantMode.FP4_MXFP4_MXFP8: + activation_output, scale_bytes = mxfp8_quantize( + activation_output.to(torch.bfloat16), True + ) + scale_bytes = scale_bytes.view(torch.uint8).reshape(-1).cpu() + activation_output = ( + mxfp8_dequantize_host( + activation_output.cpu().view(torch.uint8), scale_bytes + ) + .cuda() + .to(torch.float) + ) + args.c_global_sf = 1.0 + else: # mxfp4Bf16 + activation_output = activation_output.to(torch.bfloat16).to(torch.float) + args.c_global_sf = 1.0 # Gemm2 gemm2_output = torch.full( @@ -1480,25 +1543,42 @@ def run_moe_dequant(args, quant_mode: Literal["fp4", "dsFp8", "perTensorFp8"]): # ==================================================================================== -def run_moe_reference_fp4(args): - """FP4 reference implementation.""" - sf_vec_size = 16 +def run_moe_reference_fp4(args, quant_mode: QuantMode): + sf_vec_size = 16 if quant_mode == QuantMode.FP4_NVFP4_NVFP4 else 32 + ufp8_type_weights = 1 if quant_mode == QuantMode.FP4_NVFP4_NVFP4 else 0 - hidden_states_dequant = e2m1_and_ufp8sf_scale_to_float( - args.hidden_states.cpu(), - args.hidden_states_scale.cpu().reshape(-1), - (1 / args.hidden_states_scale_global).cpu(), - sf_vec_size, - 1, # ufp8_type - True, # is_sf_swizzled_layout - ).cuda() + if quant_mode == QuantMode.FP4_NVFP4_NVFP4: + hidden_states_dequant = e2m1_and_ufp8sf_scale_to_float( + args.hidden_states.cpu(), + args.hidden_states_scale.cpu().view(torch.uint8).reshape(-1), + (1 / args.hidden_states_scale_global).cpu(), + sf_vec_size, + ufp8_type_weights, + True, # is_sf_swizzled_layout + ).cuda() + elif quant_mode == QuantMode.FP4_MXFP4_MXFP8: + hidden_states_dequant = mxfp8_dequantize_host( + args.hidden_states.cpu().view(torch.uint8), + args.hidden_states_scale.cpu().view(torch.uint8).reshape(-1), + True, # is_sf_swizzled_layout + ).cuda() + else: + hidden_states_dequant = args.hidden_states.to(torch.bfloat16).to(torch.float) gemm1_weights_dequant = e2m1_and_ufp8_scale_batches( - args.gemm1_weights, args.gemm1_scales, 1 / args.gemm1_scales_global, sf_vec_size + args.gemm1_weights, + args.gemm1_scales, + 1 / args.gemm1_scales_global, + sf_vec_size, + ufp8_type_weights, ).cuda() gemm2_weights_dequant = e2m1_and_ufp8_scale_batches( - args.gemm2_weights, args.gemm2_scales, 1 / args.gemm2_scales_global, sf_vec_size + args.gemm2_weights, + args.gemm2_scales, + 1 / args.gemm2_scales_global, + sf_vec_size, + ufp8_type_weights, ).cuda() args_dequant = moe_args_dequant( @@ -1516,7 +1596,7 @@ def run_moe_reference_fp4(args): args.use_routing_scales_on_input, ) - return run_moe_dequant(args_dequant, "fp4"), args_dequant + return run_moe_dequant(args_dequant, quant_mode), args_dequant def run_moe_reference_dsfp8(args): @@ -1558,7 +1638,7 @@ def run_moe_reference_dsfp8(args): args.use_routing_scales_on_input, ) - return run_moe_dequant(args_dequant, "dsFp8"), args_dequant + return run_moe_dequant(args_dequant, QuantMode.FP8_BLOCK_SCALE), args_dequant def run_moe_reference_per_tensor_scale_fp8(args): @@ -1594,7 +1674,7 @@ def run_moe_reference_per_tensor_scale_fp8(args): args.use_routing_scales_on_input, ) - return run_moe_dequant(args_dequant, "perTensorFp8"), args_dequant + return run_moe_dequant(args_dequant, QuantMode.FP8_PER_TENSOR), args_dequant def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): @@ -1660,7 +1740,9 @@ def cache_permute_indices(): @pytest.mark.parametrize( "moe_impl", [ - pytest.param(FP4Moe(), id="FP4"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4 x NvFP4"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4 x MxFP8"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4 x Bf16"), pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"), ], diff --git a/tests/utils_fp4.py b/tests/utils_fp4.py index 2f6b384549..f8fe65bffc 100644 --- a/tests/utils_fp4.py +++ b/tests/utils_fp4.py @@ -69,14 +69,18 @@ def get_reciprocal(x): raise TypeError("Input must be a float, int, or a torch.Tensor.") -def ref_nvfp4_quant(x, global_scale, block_size): +def ref_fp4_quant(x, global_scale, block_size, sf_use_ue8m0=False): assert isinstance(global_scale, (float, int)) or global_scale.dtype == torch.float32 sliced_shape = x.shape[:-1] + (x.shape[-1] // block_size, block_size) sliced_x = torch.reshape(x, sliced_shape) vec_max = torch.max(torch.abs(sliced_x), dim=-1, keepdim=True)[0].to(torch.float32) scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) - scale = scale.to(torch.float8_e4m3fn).to(torch.float32) + if sf_use_ue8m0: + scale = (scale.view(torch.int32) + 0x007FFFFF) & 0x7F800000 + scale = scale.view(torch.float32) + else: + scale = scale.to(torch.float8_e4m3fn).to(torch.float32) output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) scaled_x = sliced_x.to(torch.float32) * output_scale