diff --git a/flashinfer/fp4_quantization.py b/flashinfer/fp4_quantization.py index 38948da637..c5fab606a9 100644 --- a/flashinfer/fp4_quantization.py +++ b/flashinfer/fp4_quantization.py @@ -37,17 +37,18 @@ from .jit.cpp_ext import is_cuda_version_at_least from .utils import ( device_support_pdl, + get_compute_capability, get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices, register_custom_op, register_fake_op, - get_compute_capability, + round_up, ) def _compute_swizzled_layout_sf_size(total_row, total_column, row_size=128): - padded_row = (total_row + row_size - 1) // row_size * row_size - padded_column = (total_column + 3) // 4 * 4 + padded_row = round_up(total_row, row_size) + padded_column = round_up(total_column, 4) return padded_row * padded_column @@ -66,8 +67,8 @@ def _pad_scale_factors( torch.Tensor: Padded scale factors tensor. """ factor = sf_vec_size * 4 - padded_row = ((m + 128 - 1) // 128) * 128 # Next multiple of 128 - padded_col = ((n + factor - 1) // factor) * factor # Next multiple of 64 + padded_row = round_up(m, 128) + padded_col = round_up(n, factor) # Pad the input tensor to [padded_row, padded_col // scaling_vector_size] pad_rows = padded_row - m @@ -209,9 +210,13 @@ def fp4_quantize_sm100( out_sf_size = _compute_swizzled_layout_sf_size( m, k // sf_vec_size, 8 if is_sf_8x4_layout else 128 ) + out_sf_size_padded = out_sf_size else: out_sf_size = m * k // sf_vec_size - out_sf = torch.empty((out_sf_size,), dtype=torch.uint8, device=input.device) + out_sf_size_padded = round_up(m, 16) * k // sf_vec_size + out_sf = torch.empty( + (out_sf_size_padded,), dtype=torch.uint8, device=input.device + ) module.fp4_quantize( input, global_scale, @@ -223,7 +228,7 @@ def fp4_quantize_sm100( is_sf_8x4_layout, enable_pdl, ) - return out_val, out_sf + return out_val, out_sf[:out_sf_size] @register_fake_op("flashinfer::fp4_quantize_sm100") def _fake_fp4_quantize_sm100( @@ -433,9 +438,9 @@ def silu_and_mul_scaled_nvfp4_experts_quantize_sm100( assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." scale_k = k // sf_vec_size - padded_k = (scale_k + (4 - 1)) // 4 * 4 + padded_k = round_up(scale_k, 4) padded_k_int32 = padded_k // 4 - padded_m = (m + (128 - 1)) // 128 * 128 + padded_m = round_up(m, 128) output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) output_scales = torch.empty( l, padded_m, padded_k_int32, device=device, dtype=torch.int32 @@ -469,9 +474,9 @@ def _fake_silu_and_mul_scaled_nvfp4_experts_quantize_sm100( assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." scale_k = k // sf_vec_size - padded_k = (scale_k + (4 - 1)) // 4 * 4 + padded_k = round_up(scale_k, 4) padded_k_int32 = padded_k // 4 - padded_m = (m + (128 - 1)) // 128 * 128 + padded_m = round_up(m, 128) output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) output_scales = torch.empty( l, padded_m, padded_k_int32, device=device, dtype=torch.int32 @@ -517,9 +522,9 @@ def scaled_fp4_grouped_quant_sm100( assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." scale_k = k // sf_vec_size - padded_k = (scale_k + (4 - 1)) // 4 * 4 + padded_k = round_up(scale_k, 4) padded_k_int32 = padded_k // 4 - padded_m = (m + (128 - 1)) // 128 * 128 + padded_m = round_up(m, 128) output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) output_scales = torch.empty( l, padded_m, padded_k_int32, device=device, dtype=torch.int32 @@ -557,9 +562,9 @@ def _fake_scaled_fp4_grouped_quant_sm100( assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." scale_k = k // sf_vec_size - padded_k = (scale_k + (4 - 1)) // 4 * 4 + padded_k = round_up(scale_k, 4) padded_k_int32 = padded_k // 4 - padded_m = (m + (128 - 1)) // 128 * 128 + padded_m = round_up(m, 128) output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) output_scales = torch.empty( l, padded_m, padded_k_int32, device=device, dtype=torch.int32 diff --git a/tests/utils/test_fp4_quantize_padding.py b/tests/utils/test_fp4_quantize_padding.py new file mode 100644 index 0000000000..bd60b031cf --- /dev/null +++ b/tests/utils/test_fp4_quantize_padding.py @@ -0,0 +1,82 @@ +import os + +# Disable CUDA memory caching so out-of-bounds writes surface as immediate errors +# instead of silently corrupting adjacent cached allocations. +os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1" + +import pytest +import torch +from tests.test_helpers.utils_fp4 import cast_from_fp4, ref_fp4_quant + +from flashinfer import fp4_quantize +from flashinfer.utils import ( + is_sm100a_supported, + is_sm110a_supported, + is_sm12x_supported, +) + +DTYPES = [torch.float16, torch.bfloat16] +UNALIGNED_M_SHAPES = [ + (17, 512), + (33, 1024), + (1025, 1024), + (1025, 6144), +] +SEEDS = [42] +CUDA_DEVICES = ["cuda:0"] + +FLOAT4_E2M1_MAX = 6.0 +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +BLOCK_SIZE = 16 + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", UNALIGNED_M_SHAPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_fp4_quantize_unaligned_m_non_swizzled( + dtype: torch.dtype, + shape: tuple[int, int], + seed: int, + device: str, +) -> None: + """Regression test: fp4_quantize with M not a multiple of 16 for linear SF.""" + if not ( + is_sm100a_supported(torch.device(device)) + or is_sm110a_supported(torch.device(device)) + or is_sm12x_supported(torch.device(device)) + ): + pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") + torch.set_default_device(device) + torch.manual_seed(seed) + + m, n = shape + sf_vec_size = BLOCK_SIZE + assert n % sf_vec_size == 0, f"cols needs to be {sf_vec_size} divisible" + + x = torch.randn((m, n), dtype=dtype) + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + + out_val, out_sf = fp4_quantize(x, global_scale, sf_vec_size, False, False) + + assert out_val.shape == (m, n // 2), ( + f"Expected val shape {(m, n // 2)}, got {out_val.shape}" + ) + expected_sf_size = m * n // sf_vec_size + assert out_sf.numel() == expected_sf_size, ( + f"Expected sf numel {expected_sf_size}, got {out_sf.numel()}" + ) + + out_ref, scale_ref = ref_fp4_quant(x, global_scale, sf_vec_size) + out_ans = cast_from_fp4(out_val).reshape(m, n) + out_scale = out_sf.view(torch.float8_e4m3fn).to(torch.float32) + # atol=0.5 accounts for FP4 E2M1 rounding at the 0/0.5 boundary + torch.testing.assert_close(out_ans, out_ref, rtol=1e0, atol=5e-1) + torch.testing.assert_close(out_scale, scale_ref, rtol=1e-1, atol=1e-1) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])