From e0ca86575ea26d37d3b3186bd167ffeef2d4619d Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Wed, 17 Dec 2025 03:27:52 -0800 Subject: [PATCH 01/11] adding new nvfp4 backend Signed-off-by: LopezCastroRoberto --- tests/kernels/quantization/nvfp4_utils.py | 18 +++++- .../test_flashinfer_nvfp4_scaled_mm.py | 36 ++++++++---- tests/models/quantization/test_nvfp4.py | 16 ++++++ vllm/envs.py | 8 ++- .../schemes/compressed_tensors_w4a4_nvfp4.py | 19 +++++-- .../layers/quantization/modelopt.py | 13 ++++- vllm/utils/flashinfer.py | 56 ++++++++++++++++++- 7 files changed, 144 insertions(+), 22 deletions(-) diff --git a/tests/kernels/quantization/nvfp4_utils.py b/tests/kernels/quantization/nvfp4_utils.py index 5e6d54c42e89..b199d568c7b7 100644 --- a/tests/kernels/quantization/nvfp4_utils.py +++ b/tests/kernels/quantization/nvfp4_utils.py @@ -23,8 +23,18 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): return out[0:m, 0:k] +def convert_swizzled_8x4_layout_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + m_tiles = (m + 8 - 1) // 8 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 8, 4)) + tmp = torch.permute(tmp, (0, 1, 3, 2, 4)) + out = tmp.reshape(m_tiles * 8, k_tiles * f // block_size) + return out[0:m, 0:k] + + def dequantize_nvfp4_to_dtype( - tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16 + tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16, is_sf_128x4_layout=True ): """Dequantize the fp4 tensor back to high precision.""" # Two fp4 values are packed into one uint8. @@ -34,7 +44,11 @@ def dequantize_nvfp4_to_dtype( tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) tensor_sf = tensor_sf.view(torch.float8_e4m3fn) - tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + if is_sf_128x4_layout: + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + else: + tensor_sf = convert_swizzled_8x4_layout_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale # scale the tensor diff --git a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py index 1e5c7dafb0f5..d3191822bb85 100644 --- a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -11,7 +11,7 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm +from vllm.utils.flashinfer import flashinfer_quant_nvfp4_8x4_sf_layout, flashinfer_scaled_fp4_mm if not current_platform.has_device_capability(100): pytest.skip( @@ -41,12 +41,19 @@ def get_ref_results( dtype, block_size, device, + is_sf_128x4_layout, ): _, m_k = a_fp4.shape _, n_k = b_fp4.shape assert m_k == n_k a_in_dtype = dequantize_nvfp4_to_dtype( - a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size + a_fp4, + a_sf, + a_global_scale, + dtype=dtype, + device=device, + block_size=block_size, + is_sf_128x4_layout=is_sf_128x4_layout, ) b_in_dtype = dequantize_nvfp4_to_dtype( b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size @@ -58,7 +65,7 @@ def get_ref_results( @pytest.mark.parametrize("shape", SHAPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("backend", ["cutlass", "trtllm"]) +@pytest.mark.parametrize("backend", ["cutlass", "trtllm", "trtllm_8x4_sf_layout"]) @pytest.mark.parametrize("autotune", [False, True]) @torch.inference_mode() def test_flashinfer_nvfp4_gemm( @@ -69,7 +76,7 @@ def test_flashinfer_nvfp4_gemm( backend: str, autotune: bool, ) -> None: - if backend == "trtllm" and dtype == torch.float16: + if "trtllm" in backend and dtype == torch.float16: pytest.skip("Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations") current_platform.seed_everything(seed) @@ -86,11 +93,18 @@ def test_flashinfer_nvfp4_gemm( (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1) ).to(torch.float32) alpha = 1.0 / (a_global_scale * b_global_scale) - # ops.scaled_fp4_quant returns swizzled scales, while weights - # from checkpoints are in linear scales. - # So instead of needing to swizzle for cutlass as in modelopt.py, - # we need to unswizzle for trtllm here. - a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale) + + if backend == "trtllm_8x4_sf_layout": + a_fp4, a_scale_interleaved = flashinfer_quant_nvfp4_8x4_sf_layout(a_dtype, a_global_scale) + is_sf_128x4_layout = False + else: + # ops.scaled_fp4_quant returns swizzled scales, while weights + # from checkpoints are in linear scales. + # So instead of needing to swizzle for cutlass as in modelopt.py, + # we need to unswizzle for trtllm here. + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale) + is_sf_128x4_layout = True + b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) # get_ref_results unswizzles the scales internally. @@ -106,14 +120,14 @@ def test_flashinfer_nvfp4_gemm( dtype, block_size, device, + is_sf_128x4_layout, ) import flashinfer - if backend == "trtllm": + if "trtllm" in backend: epilogue_tile_m = 128 b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), epilogue_tile_m) - b_scale_interleaved = convert_swizzled_to_linear( b_scale_interleaved, n, k, block_size ) diff --git a/tests/models/quantization/test_nvfp4.py b/tests/models/quantization/test_nvfp4.py index 9f45f142d68b..0ac32b498023 100644 --- a/tests/models/quantization/test_nvfp4.py +++ b/tests/models/quantization/test_nvfp4.py @@ -83,3 +83,19 @@ def test_models(example_prompts, model_name) -> None: assert expected_str == generated_str, ( f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}" ) + + +EAGER = [True, False] + +@pytest.mark.skipif( + not is_quant_method_supported("modelopt_fp4"), + reason="modelopt_fp4 is not supported on this GPU type.", +) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("eager", EAGER) +@pytest.mark.parametrize("backend", ["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass", "flashinfer-trtllm_8x4_sf_layout"]) +def test_nvfp4(vllm_runner, model, eager, backend, monkeypatch): + monkeypatch.setenv("VLLM_NVFP4_GEMM_BACKEND", backend) + with vllm_runner(model, enforce_eager=eager) as llm: + output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2) + assert output[0][1] == "1 2 3 4 5 6" diff --git a/vllm/envs.py b/vllm/envs.py index d0f279809626..a39f34c5d33b 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1408,7 +1408,13 @@ def get_vllm_port() -> int | None: "VLLM_NVFP4_GEMM_BACKEND": env_with_choices( "VLLM_NVFP4_GEMM_BACKEND", None, - ["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass", "cutlass"], + [ + "flashinfer-cudnn", + "flashinfer-trtllm", + "flashinfer-cutlass", + "flashinfer-trtllm_8x4_sf_layout", + "cutlass", + ], ), # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index c0b1e3ceeba3..e30a3f64668c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -23,7 +23,11 @@ ModelWeightParameter, PerTensorScaleParameter, ) -from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer +from vllm.utils.flashinfer import ( + flashinfer_quant_nvfp4_8x4_sf_layout, + flashinfer_scaled_fp4_mm, + has_flashinfer, +) logger = init_logger(__name__) @@ -131,7 +135,10 @@ def process_weights_after_loading(self, layer) -> None: layer.weight_global_scale.max().to(torch.float32), requires_grad=False ) - if self.backend == "flashinfer-trtllm": + if ( + self.backend == "flashinfer-trtllm" + or self.backend == "flashinfer-trtllm_8x4_sf_layout" + ): # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. # FlashInfer provides nvfp4_quantize to quantize + shuffle the # layout but we use our own quantization so we have to call @@ -186,8 +193,12 @@ def apply_weights( output_dtype = x.dtype output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]] - # quantize BF16 or FP16 to (FP4 and interleaved block scale) - x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) + if self.backend == "flashinfer-trtllm_8x4_sf_layout": + x_fp4, x_blockscale = flashinfer_quant_nvfp4_8x4_sf_layout(x, layer.input_scale_inv) + x_blockscale = x_blockscale.view(torch.float8_e4m3fn) + else: + # quantize BF16 or FP16 to (FP4 and interleaved block scale) + x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) mm_args = ( x_fp4, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 030d85080a34..b641cc155a19 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -80,6 +80,7 @@ flashinfer_scaled_fp4_mm, has_flashinfer, has_flashinfer_moe, + flashinfer_quant_nvfp4_8x4_sf_layout ) from vllm.utils.math_utils import round_up @@ -1083,7 +1084,7 @@ def process_weights_after_loading(self, layer: Module) -> None: prepare_fp4_layer_for_marlin(layer) del layer.alpha del layer.input_scale - elif self.backend == "flashinfer-trtllm": + elif self.backend == "flashinfer-trtllm" or self.backend == "flashinfer-trtllm_8x4_sf_layout": # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. # FlashInfer provides nvfp4_quantize to quantize + shuffle the # layout but we use our own quantization so we have to call @@ -1130,8 +1131,14 @@ def apply( output_dtype = x.dtype output_shape = [x.shape[0], layer.weight.shape[0]] - # quantize BF16 or FP16 to (FP4 and interleaved block scale) - x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) + if self.backend == "flashinfer-trtllm_8x4_sf_layout": + x_fp4, x_blockscale = flashinfer_quant_nvfp4_8x4_sf_layout( + x, layer.input_scale_inv + ) + x_blockscale = x_blockscale.view(torch.float8_e4m3fn) + else: + # quantize BF16 or FP16 to (FP4 and interleaved block scale) + x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) # validate dtypes of quantized input, input block scale, # weight and weight_blockscale diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 5019b771f4a1..254880789c28 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -389,12 +389,21 @@ def flashinfer_mm_fp4( B_scale: torch.Tensor, g_scale: torch.Tensor, dtype: torch.dtype, + use_8x4_sf_layout: bool, backend: str, ) -> torch.Tensor: from flashinfer import mm_fp4 as flashinfer_mm_fp4_ return flashinfer_mm_fp4_( - A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend + A, + B, + A_scale, + B_scale, + g_scale, + dtype, + block_size=16, + use_8x4_sf_layout=use_8x4_sf_layout, + backend=backend, ) @torch.library.register_fake( @@ -407,6 +416,7 @@ def flashinfer_mm_fp4_fake( B_scale: torch.Tensor, g_scale: torch.Tensor, dtype: torch.dtype, + use_8x4_sf_layout: bool, backend: str, ) -> torch.Tensor: return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device) @@ -443,6 +453,36 @@ def bmm_fp8_fake( A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device ) + @torch.library.custom_op( + "vllm::flashinfer_nvfp4_quantize", + mutates_args=[], + device_types="cuda", + ) + def flashinfer_nvfp4_quantize( + a: torch.Tensor, + a_global_sf: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + from flashinfer import SfLayout + from flashinfer import nvfp4_quantize as nvfp4_quantize_ + + return nvfp4_quantize_( + a, a_global_sf, sfLayout=SfLayout.layout_8x4, do_shuffle=False + ) + + @torch.library.register_fake( + "vllm::flashinfer_nvfp4_quantize", + ) + def flashinfer_nvfp4_quantize_fake( + a: torch.Tensor, + a_global_sf: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + padded_rows = ((a.shape[0] + 7) // 8) * 8 + return torch.empty( + a.shape[0], a.shape[1] // 2, dtype=torch.uint8, device=a.device + ), torch.empty( + padded_rows, a.shape[1] // 16, dtype=torch.uint8, device=a.device + ) + def flashinfer_scaled_fp4_mm( a: torch.Tensor, @@ -462,6 +502,12 @@ def flashinfer_scaled_fp4_mm( block_scale_a = block_scale_a.view(torch.uint8) block_scale_b = block_scale_b.view(torch.uint8) + if backend == "trtllm_8x4_sf_layout": + use_8x4_sf_layout = True + backend = "trtllm" + else: + use_8x4_sf_layout = False + return flashinfer_mm_fp4( a, b.t(), @@ -469,6 +515,7 @@ def flashinfer_scaled_fp4_mm( block_scale_b.t(), alpha, out_dtype, + use_8x4_sf_layout=use_8x4_sf_layout, backend=backend, ) @@ -503,6 +550,12 @@ def flashinfer_scaled_fp8_mm( return output +def flashinfer_quant_nvfp4_8x4_sf_layout( + a: torch.Tensor, a_global_sf: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + return flashinfer_nvfp4_quantize(a, a_global_sf) + + __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", @@ -525,4 +578,5 @@ def flashinfer_scaled_fp8_mm( "use_trtllm_attention", "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp8_mm", + "nvfp4_quantize", ] From 45c8d87ad64983af7cbc143ff621ed032987f65b Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Wed, 17 Dec 2025 08:40:20 -0800 Subject: [PATCH 02/11] fixing typos Signed-off-by: LopezCastroRoberto --- .../schemes/compressed_tensors_w4a4_nvfp4.py | 4 ++-- vllm/utils/flashinfer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index e30a3f64668c..bc97a698a53d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -194,11 +194,11 @@ def apply_weights( output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]] if self.backend == "flashinfer-trtllm_8x4_sf_layout": - x_fp4, x_blockscale = flashinfer_quant_nvfp4_8x4_sf_layout(x, layer.input_scale_inv) + x_fp4, x_blockscale = flashinfer_quant_nvfp4_8x4_sf_layout(x, layer.input_global_scale) x_blockscale = x_blockscale.view(torch.float8_e4m3fn) else: # quantize BF16 or FP16 to (FP4 and interleaved block scale) - x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) + x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) mm_args = ( x_fp4, diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 254880789c28..21c4a4e6f5f6 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -578,5 +578,5 @@ def flashinfer_quant_nvfp4_8x4_sf_layout( "use_trtllm_attention", "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp8_mm", - "nvfp4_quantize", + "flashinfer_quant_nvfp4_8x4_sf_layout", ] From d3bfcec9e49d6738438764aa487f8f417f2a795f Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Wed, 17 Dec 2025 13:46:33 -0800 Subject: [PATCH 03/11] fix: apply pre-commit formatting Signed-off-by: LopezCastroRoberto --- tests/kernels/quantization/nvfp4_utils.py | 12 ++++++++++-- .../quantization/test_flashinfer_nvfp4_scaled_mm.py | 9 +++++++-- tests/models/quantization/test_nvfp4.py | 11 ++++++++++- .../schemes/compressed_tensors_w4a4_nvfp4.py | 4 +++- vllm/model_executor/layers/quantization/modelopt.py | 11 +++++++---- vllm/utils/flashinfer.py | 6 ++---- 6 files changed, 39 insertions(+), 14 deletions(-) diff --git a/tests/kernels/quantization/nvfp4_utils.py b/tests/kernels/quantization/nvfp4_utils.py index b199d568c7b7..778895271432 100644 --- a/tests/kernels/quantization/nvfp4_utils.py +++ b/tests/kernels/quantization/nvfp4_utils.py @@ -23,7 +23,9 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): return out[0:m, 0:k] -def convert_swizzled_8x4_layout_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): +def convert_swizzled_8x4_layout_to_linear( + a_sf_swizzled: torch.Tensor, m, k, block_size +): m_tiles = (m + 8 - 1) // 8 f = block_size * 4 k_tiles = (k + f - 1) // f @@ -34,7 +36,13 @@ def convert_swizzled_8x4_layout_to_linear(a_sf_swizzled: torch.Tensor, m, k, blo def dequantize_nvfp4_to_dtype( - tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16, is_sf_128x4_layout=True + tensor_fp4, + tensor_sf, + global_scale, + dtype, + device, + block_size=16, + is_sf_128x4_layout=True, ): """Dequantize the fp4 tensor back to high precision.""" # Two fp4 values are packed into one uint8. diff --git a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py index d3191822bb85..e18915aefb86 100644 --- a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -11,7 +11,10 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils.flashinfer import flashinfer_quant_nvfp4_8x4_sf_layout, flashinfer_scaled_fp4_mm +from vllm.utils.flashinfer import ( + flashinfer_quant_nvfp4_8x4_sf_layout, + flashinfer_scaled_fp4_mm, +) if not current_platform.has_device_capability(100): pytest.skip( @@ -95,7 +98,9 @@ def test_flashinfer_nvfp4_gemm( alpha = 1.0 / (a_global_scale * b_global_scale) if backend == "trtllm_8x4_sf_layout": - a_fp4, a_scale_interleaved = flashinfer_quant_nvfp4_8x4_sf_layout(a_dtype, a_global_scale) + a_fp4, a_scale_interleaved = flashinfer_quant_nvfp4_8x4_sf_layout( + a_dtype, a_global_scale + ) is_sf_128x4_layout = False else: # ops.scaled_fp4_quant returns swizzled scales, while weights diff --git a/tests/models/quantization/test_nvfp4.py b/tests/models/quantization/test_nvfp4.py index 0ac32b498023..fbcec53f58cf 100644 --- a/tests/models/quantization/test_nvfp4.py +++ b/tests/models/quantization/test_nvfp4.py @@ -87,13 +87,22 @@ def test_models(example_prompts, model_name) -> None: EAGER = [True, False] + @pytest.mark.skipif( not is_quant_method_supported("modelopt_fp4"), reason="modelopt_fp4 is not supported on this GPU type.", ) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("eager", EAGER) -@pytest.mark.parametrize("backend", ["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass", "flashinfer-trtllm_8x4_sf_layout"]) +@pytest.mark.parametrize( + "backend", + [ + "flashinfer-cudnn", + "flashinfer-trtllm", + "flashinfer-cutlass", + "flashinfer-trtllm_8x4_sf_layout", + ], +) def test_nvfp4(vllm_runner, model, eager, backend, monkeypatch): monkeypatch.setenv("VLLM_NVFP4_GEMM_BACKEND", backend) with vllm_runner(model, enforce_eager=eager) as llm: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index bc97a698a53d..c8c8797adf3e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -194,7 +194,9 @@ def apply_weights( output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]] if self.backend == "flashinfer-trtllm_8x4_sf_layout": - x_fp4, x_blockscale = flashinfer_quant_nvfp4_8x4_sf_layout(x, layer.input_global_scale) + x_fp4, x_blockscale = flashinfer_quant_nvfp4_8x4_sf_layout( + x, layer.input_global_scale + ) x_blockscale = x_blockscale.view(torch.float8_e4m3fn) else: # quantize BF16 or FP16 to (FP4 and interleaved block scale) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index b641cc155a19..9cec97264783 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -77,10 +77,10 @@ from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter from vllm.scalar_type import scalar_types from vllm.utils.flashinfer import ( + flashinfer_quant_nvfp4_8x4_sf_layout, flashinfer_scaled_fp4_mm, has_flashinfer, has_flashinfer_moe, - flashinfer_quant_nvfp4_8x4_sf_layout ) from vllm.utils.math_utils import round_up @@ -1084,7 +1084,10 @@ def process_weights_after_loading(self, layer: Module) -> None: prepare_fp4_layer_for_marlin(layer) del layer.alpha del layer.input_scale - elif self.backend == "flashinfer-trtllm" or self.backend == "flashinfer-trtllm_8x4_sf_layout": + elif ( + self.backend == "flashinfer-trtllm" + or self.backend == "flashinfer-trtllm_8x4_sf_layout" + ): # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. # FlashInfer provides nvfp4_quantize to quantize + shuffle the # layout but we use our own quantization so we have to call @@ -1131,12 +1134,12 @@ def apply( output_dtype = x.dtype output_shape = [x.shape[0], layer.weight.shape[0]] - if self.backend == "flashinfer-trtllm_8x4_sf_layout": + if self.backend == "flashinfer-trtllm_8x4_sf_layout": x_fp4, x_blockscale = flashinfer_quant_nvfp4_8x4_sf_layout( x, layer.input_scale_inv ) x_blockscale = x_blockscale.view(torch.float8_e4m3fn) - else: + else: # quantize BF16 or FP16 to (FP4 and interleaved block scale) x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 21c4a4e6f5f6..baa2cae0c4a8 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -459,8 +459,7 @@ def bmm_fp8_fake( device_types="cuda", ) def flashinfer_nvfp4_quantize( - a: torch.Tensor, - a_global_sf: torch.Tensor + a: torch.Tensor, a_global_sf: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: from flashinfer import SfLayout from flashinfer import nvfp4_quantize as nvfp4_quantize_ @@ -473,8 +472,7 @@ def flashinfer_nvfp4_quantize( "vllm::flashinfer_nvfp4_quantize", ) def flashinfer_nvfp4_quantize_fake( - a: torch.Tensor, - a_global_sf: torch.Tensor + a: torch.Tensor, a_global_sf: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: padded_rows = ((a.shape[0] + 7) // 8) * 8 return torch.empty( From c852f0408f3c9ec893f72bfb05fb93d2586fe654 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Fri, 2 Jan 2026 11:28:24 +0100 Subject: [PATCH 04/11] make 8x4 layout default for bs<=32 Signed-off-by: LopezCastroRoberto --- .../quantization/test_flashinfer_nvfp4_scaled_mm.py | 12 +++++++++--- tests/models/quantization/test_nvfp4.py | 1 - vllm/envs.py | 1 - .../schemes/compressed_tensors_w4a4_nvfp4.py | 8 +++----- vllm/model_executor/layers/quantization/modelopt.py | 7 ++----- vllm/utils/flashinfer.py | 3 +-- 6 files changed, 15 insertions(+), 17 deletions(-) diff --git a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py index e18915aefb86..96545d822ee1 100644 --- a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -24,7 +24,13 @@ DTYPES = [torch.float16, torch.bfloat16] # m, n, k -SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] +SHAPES = [ + (128, 128, 64), + (128, 128, 128), + (256, 128, 64), + (128, 256, 128), + (1, 128, 128), +] PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] SHAPES.extend(PAD_SHAPES) @@ -68,7 +74,7 @@ def get_ref_results( @pytest.mark.parametrize("shape", SHAPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("backend", ["cutlass", "trtllm", "trtllm_8x4_sf_layout"]) +@pytest.mark.parametrize("backend", ["cutlass", "trtllm"]) @pytest.mark.parametrize("autotune", [False, True]) @torch.inference_mode() def test_flashinfer_nvfp4_gemm( @@ -97,7 +103,7 @@ def test_flashinfer_nvfp4_gemm( ).to(torch.float32) alpha = 1.0 / (a_global_scale * b_global_scale) - if backend == "trtllm_8x4_sf_layout": + if backend == "trtllm" and m <= 32: a_fp4, a_scale_interleaved = flashinfer_quant_nvfp4_8x4_sf_layout( a_dtype, a_global_scale ) diff --git a/tests/models/quantization/test_nvfp4.py b/tests/models/quantization/test_nvfp4.py index fbcec53f58cf..5ea1acd7ca74 100644 --- a/tests/models/quantization/test_nvfp4.py +++ b/tests/models/quantization/test_nvfp4.py @@ -100,7 +100,6 @@ def test_models(example_prompts, model_name) -> None: "flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass", - "flashinfer-trtllm_8x4_sf_layout", ], ) def test_nvfp4(vllm_runner, model, eager, backend, monkeypatch): diff --git a/vllm/envs.py b/vllm/envs.py index 7603c8d416e7..33c2276230d5 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1414,7 +1414,6 @@ def get_vllm_port() -> int | None: "flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass", - "flashinfer-trtllm_8x4_sf_layout", "cutlass", ], ), diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index c8c8797adf3e..5a3d91d587a7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math from collections.abc import Callable import torch @@ -135,10 +136,7 @@ def process_weights_after_loading(self, layer) -> None: layer.weight_global_scale.max().to(torch.float32), requires_grad=False ) - if ( - self.backend == "flashinfer-trtllm" - or self.backend == "flashinfer-trtllm_8x4_sf_layout" - ): + if self.backend == "flashinfer-trtllm": # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. # FlashInfer provides nvfp4_quantize to quantize + shuffle the # layout but we use our own quantization so we have to call @@ -193,7 +191,7 @@ def apply_weights( output_dtype = x.dtype output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]] - if self.backend == "flashinfer-trtllm_8x4_sf_layout": + if self.backend == "flashinfer-trtllm" and math.prod(x.shape[:-1]) <= 32: x_fp4, x_blockscale = flashinfer_quant_nvfp4_8x4_sf_layout( x, layer.input_global_scale ) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index a2e1f4e8b615..e358b1fd5a0b 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1084,10 +1084,7 @@ def process_weights_after_loading(self, layer: Module) -> None: prepare_fp4_layer_for_marlin(layer) del layer.alpha del layer.input_scale - elif ( - self.backend == "flashinfer-trtllm" - or self.backend == "flashinfer-trtllm_8x4_sf_layout" - ): + elif self.backend == "flashinfer-trtllm": # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. # FlashInfer provides nvfp4_quantize to quantize + shuffle the # layout but we use our own quantization so we have to call @@ -1134,7 +1131,7 @@ def apply( output_dtype = x.dtype output_shape = [x.shape[0], layer.weight.shape[0]] - if self.backend == "flashinfer-trtllm_8x4_sf_layout": + if self.backend == "flashinfer-trtllm" and x.shape[0] <= 32: x_fp4, x_blockscale = flashinfer_quant_nvfp4_8x4_sf_layout( x, layer.input_scale_inv ) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index e413a1cf9213..b4bbf2f7a8e5 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -537,9 +537,8 @@ def flashinfer_scaled_fp4_mm( block_scale_a = block_scale_a.view(torch.uint8) block_scale_b = block_scale_b.view(torch.uint8) - if backend == "trtllm_8x4_sf_layout": + if backend == "trtllm" and a.shape[0] <= 32: use_8x4_sf_layout = True - backend = "trtllm" else: use_8x4_sf_layout = False From f52809434972f03345dd6d70b81592e7b76427a0 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Fri, 2 Jan 2026 11:40:12 +0100 Subject: [PATCH 05/11] fix: apply pre-commit formatting Signed-off-by: LopezCastroRoberto --- vllm/utils/flashinfer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index b4bbf2f7a8e5..164b1d9bc5d8 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -537,10 +537,7 @@ def flashinfer_scaled_fp4_mm( block_scale_a = block_scale_a.view(torch.uint8) block_scale_b = block_scale_b.view(torch.uint8) - if backend == "trtllm" and a.shape[0] <= 32: - use_8x4_sf_layout = True - else: - use_8x4_sf_layout = False + use_8x4_sf_layout = bool(backend == "trtllm" and a.shape[0] <= 32) return flashinfer_mm_fp4( a, From fee91f032002976034d74ce20c069d5b81617bd5 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Thu, 8 Jan 2026 13:16:40 +0100 Subject: [PATCH 06/11] fix: apply pre-commit formatting Signed-off-by: LopezCastroRoberto --- vllm/utils/flashinfer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 35289d52ecd4..3cf94521c81b 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -586,7 +586,7 @@ def flashinfer_quant_nvfp4_8x4_sf_layout( ) -> tuple[torch.Tensor, torch.Tensor]: return flashinfer_nvfp4_quantize(a, a_global_sf) - + flashinfer_fp8_blockscale_gemm = _lazy_import_wrapper( "flashinfer.gemm", "fp8_blockscale_gemm_sm90" ) From 6d9b39b6be1e8701cdcdeeb532378170af201330 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Fri, 9 Jan 2026 03:08:43 -0800 Subject: [PATCH 07/11] reverting nvfp4 test Signed-off-by: LopezCastroRoberto --- tests/models/quantization/test_nvfp4.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/tests/models/quantization/test_nvfp4.py b/tests/models/quantization/test_nvfp4.py index 5ea1acd7ca74..9f45f142d68b 100644 --- a/tests/models/quantization/test_nvfp4.py +++ b/tests/models/quantization/test_nvfp4.py @@ -83,27 +83,3 @@ def test_models(example_prompts, model_name) -> None: assert expected_str == generated_str, ( f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}" ) - - -EAGER = [True, False] - - -@pytest.mark.skipif( - not is_quant_method_supported("modelopt_fp4"), - reason="modelopt_fp4 is not supported on this GPU type.", -) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("eager", EAGER) -@pytest.mark.parametrize( - "backend", - [ - "flashinfer-cudnn", - "flashinfer-trtllm", - "flashinfer-cutlass", - ], -) -def test_nvfp4(vllm_runner, model, eager, backend, monkeypatch): - monkeypatch.setenv("VLLM_NVFP4_GEMM_BACKEND", backend) - with vllm_runner(model, enforce_eager=eager) as llm: - output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2) - assert output[0][1] == "1 2 3 4 5 6" From 80958f34425df4fa15e45677f4777f6fd78d3c38 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Fri, 9 Jan 2026 03:47:32 -0800 Subject: [PATCH 08/11] change nvfp4 model testing to smaller model to avoid OOM issues Signed-off-by: LopezCastroRoberto --- tests/models/quantization/test_nvfp4.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/models/quantization/test_nvfp4.py b/tests/models/quantization/test_nvfp4.py index 9f45f142d68b..887e681123dc 100644 --- a/tests/models/quantization/test_nvfp4.py +++ b/tests/models/quantization/test_nvfp4.py @@ -83,3 +83,27 @@ def test_models(example_prompts, model_name) -> None: assert expected_str == generated_str, ( f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}" ) + + +EAGER = [True, False] + + +@pytest.mark.skipif( + not is_quant_method_supported("modelopt_fp4"), + reason="modelopt_fp4 is not supported on this GPU type.", +) +@pytest.mark.parametrize("model", ["nvidia/Llama-3.1-8B-Instruct-NVFP4"]) +@pytest.mark.parametrize("eager", EAGER) +@pytest.mark.parametrize( + "backend", + [ + "flashinfer-cudnn", + "flashinfer-trtllm", + "flashinfer-cutlass", + ], +) +def test_nvfp4(vllm_runner, model, eager, backend, monkeypatch): + monkeypatch.setenv("VLLM_NVFP4_GEMM_BACKEND", backend) + with vllm_runner(model, enforce_eager=eager) as llm: + output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2) + assert output[0][1] == "1 2 3 4 5 6" From 5e51feeac8222c10373bc865ae37d4f815965616 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Fri, 9 Jan 2026 06:38:45 -0800 Subject: [PATCH 09/11] revert pre-commit change causing torch.compile to break Signed-off-by: LopezCastroRoberto --- .buildkite/test-pipeline.yaml | 4 +++- tests/models/quantization/test_nvfp4.py | 6 ++++-- vllm/utils/flashinfer.py | 7 +++++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index fceae96854a8..c6cf774fa889 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -939,7 +939,7 @@ steps: # Whisper needs spawn method to avoid deadlock - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper -- label: Blackwell Test # 21 min +- label: Blackwell Test # 23 min timeout_in_minutes: 30 working_dir: "/vllm-workspace/" gpu: b200 @@ -979,6 +979,8 @@ steps: - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - pytest -v -s tests/kernels/moe/test_flashinfer.py - pytest -v -s tests/kernels/moe/test_cutedsl_moe.py + # e2e + - pytest -v -s tests/models/quantization/test_nvfp4.py - label: Blackwell Fusion and Compile Tests # 30 min timeout_in_minutes: 40 diff --git a/tests/models/quantization/test_nvfp4.py b/tests/models/quantization/test_nvfp4.py index 887e681123dc..a9124fbba7e4 100644 --- a/tests/models/quantization/test_nvfp4.py +++ b/tests/models/quantization/test_nvfp4.py @@ -14,6 +14,8 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams +from vllm.platforms import current_platform + os.environ["TOKENIZERS_PARALLELISM"] = "true" MAX_MODEL_LEN = 1024 @@ -89,7 +91,7 @@ def test_models(example_prompts, model_name) -> None: @pytest.mark.skipif( - not is_quant_method_supported("modelopt_fp4"), + not current_platform.has_device_capability(100), reason="modelopt_fp4 is not supported on this GPU type.", ) @pytest.mark.parametrize("model", ["nvidia/Llama-3.1-8B-Instruct-NVFP4"]) @@ -98,7 +100,7 @@ def test_models(example_prompts, model_name) -> None: "backend", [ "flashinfer-cudnn", - "flashinfer-trtllm", + "flashinfer-trtllm", # the small seq_len ensures trtllm_8x4_layout backend is used "flashinfer-cutlass", ], ) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 3cf94521c81b..ab8f485c0fdb 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -537,8 +537,11 @@ def flashinfer_scaled_fp4_mm( block_scale_a = block_scale_a.view(torch.uint8) block_scale_b = block_scale_b.view(torch.uint8) - use_8x4_sf_layout = bool(backend == "trtllm" and a.shape[0] <= 32) - + if backend == "trtllm" and a.shape[0] <= 32: + use_8x4_sf_layout = True + else: + use_8x4_sf_layout = False + return flashinfer_mm_fp4( a, b.t(), From 798ff530b2c568c868c1f27f0b8d4576dcd717f3 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Fri, 9 Jan 2026 07:06:31 -0800 Subject: [PATCH 10/11] disable SIM210 rule Signed-off-by: LopezCastroRoberto --- tests/models/quantization/test_nvfp4.py | 2 +- vllm/utils/flashinfer.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/models/quantization/test_nvfp4.py b/tests/models/quantization/test_nvfp4.py index a9124fbba7e4..b73462bfd198 100644 --- a/tests/models/quantization/test_nvfp4.py +++ b/tests/models/quantization/test_nvfp4.py @@ -100,7 +100,7 @@ def test_models(example_prompts, model_name) -> None: "backend", [ "flashinfer-cudnn", - "flashinfer-trtllm", # the small seq_len ensures trtllm_8x4_layout backend is used + "flashinfer-trtllm", # the small seq_len ensures trtllm_8x4_layout backend is used "flashinfer-cutlass", ], ) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index ab8f485c0fdb..2c80ec992461 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -537,11 +537,8 @@ def flashinfer_scaled_fp4_mm( block_scale_a = block_scale_a.view(torch.uint8) block_scale_b = block_scale_b.view(torch.uint8) - if backend == "trtllm" and a.shape[0] <= 32: - use_8x4_sf_layout = True - else: - use_8x4_sf_layout = False - + use_8x4_sf_layout = True if backend == "trtllm" and a.shape[0] <= 32 else False # noqa: SIM210 + return flashinfer_mm_fp4( a, b.t(), From 36e70cd6d343f4f6313e6b790d9fb5d5838ec5c5 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Mon, 12 Jan 2026 13:38:12 -0500 Subject: [PATCH 11/11] addressing comments Signed-off-by: LopezCastroRoberto --- .../test_flashinfer_nvfp4_scaled_mm.py | 21 +++------ vllm/_custom_ops.py | 46 ++++++++++++------- .../schemes/compressed_tensors_w4a4_nvfp4.py | 14 ++---- .../layers/quantization/modelopt.py | 11 +---- vllm/utils/flashinfer.py | 15 ++++-- 5 files changed, 53 insertions(+), 54 deletions(-) diff --git a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py index 560df5a19a28..d615bb7dc797 100644 --- a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -12,7 +12,6 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform from vllm.utils.flashinfer import ( - flashinfer_quant_nvfp4_8x4_sf_layout, flashinfer_scaled_fp4_mm, ) from vllm.utils.torch_utils import set_random_seed @@ -32,7 +31,7 @@ (128, 256, 128), (1, 128, 128), ] -PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] +PAD_SHAPES = [(150, 128, 64), (128, 128, 96), (2, 128, 64), (3, 128, 96)] SHAPES.extend(PAD_SHAPES) SEEDS = [42] @@ -104,18 +103,12 @@ def test_flashinfer_nvfp4_gemm( ).to(torch.float32) alpha = 1.0 / (a_global_scale * b_global_scale) - if backend == "trtllm" and m <= 32: - a_fp4, a_scale_interleaved = flashinfer_quant_nvfp4_8x4_sf_layout( - a_dtype, a_global_scale - ) - is_sf_128x4_layout = False - else: - # ops.scaled_fp4_quant returns swizzled scales, while weights - # from checkpoints are in linear scales. - # So instead of needing to swizzle for cutlass as in modelopt.py, - # we need to unswizzle for trtllm here. - a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale) - is_sf_128x4_layout = True + # ops.scaled_fp4_quant returns swizzled scales, while weights + # from checkpoints are in linear scales. + # So instead of needing to swizzle for cutlass as in modelopt.py, + # we need to unswizzle for trtllm here. + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale, backend) + is_sf_128x4_layout = not (backend == "trtllm" and m <= 32) b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 86d6e309b1aa..7ff23e9686de 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -9,6 +9,9 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.scalar_type import ScalarType +from vllm.utils.flashinfer import ( + flashinfer_quant_nvfp4_8x4_sf_layout, +) logger = init_logger(__name__) @@ -1563,7 +1566,9 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: # fp4 def scaled_fp4_quant( - input: torch.Tensor, input_global_scale: torch.Tensor + input: torch.Tensor, + input_global_scale: torch.Tensor, + backend: str = "none", ) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP4 and return quantized tensor and scale. @@ -1577,6 +1582,7 @@ def scaled_fp4_quant( Args: input: The input tensor to be quantized to FP4 input_global_scale: A scalar scaling factor for the entire tensor. + use_8x4_sf_layout: Whether to use the 8x4 or 128x4 layout for the scaling Returns: tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every @@ -1596,23 +1602,31 @@ def scaled_fp4_quant( f"input.dtype needs to be fp16 or bf16 but got {input.dtype}." ) - # Two fp4 values will be packed into an uint8. - output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + use_8x4_sf_layout = True if "trtllm" in backend and m <= 32 else False # noqa: SIM210 - # We use the rounded values to store the swizzled values. Due to the - # requirement of the Tensor Core, the minimum tile is 128x4 for the scales. - # So, we first pad the scales to multiples of 128 and 4. Then, the scales - # (in float8_e4m3fn) are packed into an int32 for every 4 values. More: - # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x - round_up = lambda x, y: (x + y - 1) // y * y - rounded_m = round_up(m, 128) - scale_n = n // block_size - rounded_n = round_up(scale_n, 4) - output_scale = torch.empty( - (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 - ) + if use_8x4_sf_layout: + output, output_scale = flashinfer_quant_nvfp4_8x4_sf_layout( + input, input_global_scale + ) + else: + # Two fp4 values will be packed into an uint8. + output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + + # We use the rounded values to store the swizzled values. Due to the + # requirement of the Tensor Core, the minimum tile is 128x4 for the scales. + # So, we first pad the scales to multiples of 128 and 4. Then, the scales + # (in float8_e4m3fn) are packed into an int32 for every 4 values. More: + # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(m, 128) + scale_n = n // block_size + rounded_n = round_up(scale_n, 4) + output_scale = torch.empty( + (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 + ) + + torch.ops._C.scaled_fp4_quant(output, input, output_scale, input_global_scale) - torch.ops._C.scaled_fp4_quant(output, input, output_scale, input_global_scale) output_scale = output_scale.view(torch.float8_e4m3fn) return output, output_scale diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 5a3d91d587a7..d7f34e4f5eca 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math from collections.abc import Callable import torch @@ -25,7 +24,6 @@ PerTensorScaleParameter, ) from vllm.utils.flashinfer import ( - flashinfer_quant_nvfp4_8x4_sf_layout, flashinfer_scaled_fp4_mm, has_flashinfer, ) @@ -191,14 +189,10 @@ def apply_weights( output_dtype = x.dtype output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]] - if self.backend == "flashinfer-trtllm" and math.prod(x.shape[:-1]) <= 32: - x_fp4, x_blockscale = flashinfer_quant_nvfp4_8x4_sf_layout( - x, layer.input_global_scale - ) - x_blockscale = x_blockscale.view(torch.float8_e4m3fn) - else: - # quantize BF16 or FP16 to (FP4 and interleaved block scale) - x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) + # quantize BF16 or FP16 to (FP4 and interleaved block scale) + x_fp4, x_blockscale = scaled_fp4_quant( + x, layer.input_global_scale, self.backend + ) mm_args = ( x_fp4, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index b11b746c7ee7..bcda7b42c2ec 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -93,7 +93,6 @@ ) from vllm.model_executor.utils import replace_parameter from vllm.utils.flashinfer import ( - flashinfer_quant_nvfp4_8x4_sf_layout, flashinfer_scaled_fp4_mm, has_flashinfer, ) @@ -1291,14 +1290,8 @@ def apply( output_dtype = x.dtype output_shape = [x.shape[0], layer.weight.shape[0]] - if self.backend == "flashinfer-trtllm" and x.shape[0] <= 32: - x_fp4, x_blockscale = flashinfer_quant_nvfp4_8x4_sf_layout( - x, layer.input_scale_inv - ) - x_blockscale = x_blockscale.view(torch.float8_e4m3fn) - else: - # quantize BF16 or FP16 to (FP4 and interleaved block scale) - x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) + # quantize BF16 or FP16 to (FP4 and interleaved block scale) + x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv, self.backend) # validate dtypes of quantized input, input block scale, # weight and weight_blockscale diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 6e3e8541d860..067c6fb3e785 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -491,11 +491,16 @@ def flashinfer_nvfp4_quantize( def flashinfer_nvfp4_quantize_fake( a: torch.Tensor, a_global_sf: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: - padded_rows = ((a.shape[0] + 7) // 8) * 8 - return torch.empty( - a.shape[0], a.shape[1] // 2, dtype=torch.uint8, device=a.device - ), torch.empty( - padded_rows, a.shape[1] // 16, dtype=torch.uint8, device=a.device + m, n = a.shape + + round_up = lambda x, y: (x + y - 1) // y * y + + rounded_m = round_up(m, 8) + scale_n = n // 16 + rounded_n = round_up(scale_n, 4) + + return torch.empty(m, n // 2, dtype=torch.uint8, device=a.device), torch.empty( + rounded_m, rounded_n, dtype=torch.uint8, device=a.device )