diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 1c7a5ca36886..db8c0348c2bc 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -951,7 +951,7 @@ steps: # Whisper needs spawn method to avoid deadlock - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper -- label: Blackwell Test # 21 min +- label: Blackwell Test # 23 min timeout_in_minutes: 30 working_dir: "/vllm-workspace/" gpu: b200 @@ -991,6 +991,8 @@ steps: - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - pytest -v -s tests/kernels/moe/test_flashinfer.py - pytest -v -s tests/kernels/moe/test_cutedsl_moe.py + # e2e + - pytest -v -s tests/models/quantization/test_nvfp4.py - label: Blackwell Fusion and Compile Tests # 30 min timeout_in_minutes: 40 diff --git a/tests/kernels/quantization/nvfp4_utils.py b/tests/kernels/quantization/nvfp4_utils.py index 5e6d54c42e89..778895271432 100644 --- a/tests/kernels/quantization/nvfp4_utils.py +++ b/tests/kernels/quantization/nvfp4_utils.py @@ -23,8 +23,26 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): return out[0:m, 0:k] +def convert_swizzled_8x4_layout_to_linear( + a_sf_swizzled: torch.Tensor, m, k, block_size +): + m_tiles = (m + 8 - 1) // 8 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 8, 4)) + tmp = torch.permute(tmp, (0, 1, 3, 2, 4)) + out = tmp.reshape(m_tiles * 8, k_tiles * f // block_size) + return out[0:m, 0:k] + + def dequantize_nvfp4_to_dtype( - tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16 + tensor_fp4, + tensor_sf, + global_scale, + dtype, + device, + block_size=16, + is_sf_128x4_layout=True, ): """Dequantize the fp4 tensor back to high precision.""" # Two fp4 values are packed into one uint8. @@ -34,7 +52,11 @@ def dequantize_nvfp4_to_dtype( tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) tensor_sf = tensor_sf.view(torch.float8_e4m3fn) - tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + if is_sf_128x4_layout: + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + else: + tensor_sf = convert_swizzled_8x4_layout_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale # scale the tensor diff --git a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py index 94fa38b5aae4..d615bb7dc797 100644 --- a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -11,7 +11,9 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm +from vllm.utils.flashinfer import ( + flashinfer_scaled_fp4_mm, +) from vllm.utils.torch_utils import set_random_seed if not current_platform.has_device_capability(100): @@ -22,8 +24,14 @@ DTYPES = [torch.float16, torch.bfloat16] # m, n, k -SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] -PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] +SHAPES = [ + (128, 128, 64), + (128, 128, 128), + (256, 128, 64), + (128, 256, 128), + (1, 128, 128), +] +PAD_SHAPES = [(150, 128, 64), (128, 128, 96), (2, 128, 64), (3, 128, 96)] SHAPES.extend(PAD_SHAPES) SEEDS = [42] @@ -42,12 +50,19 @@ def get_ref_results( dtype, block_size, device, + is_sf_128x4_layout, ): _, m_k = a_fp4.shape _, n_k = b_fp4.shape assert m_k == n_k a_in_dtype = dequantize_nvfp4_to_dtype( - a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size + a_fp4, + a_sf, + a_global_scale, + dtype=dtype, + device=device, + block_size=block_size, + is_sf_128x4_layout=is_sf_128x4_layout, ) b_in_dtype = dequantize_nvfp4_to_dtype( b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size @@ -70,7 +85,7 @@ def test_flashinfer_nvfp4_gemm( backend: str, autotune: bool, ) -> None: - if backend == "trtllm" and dtype == torch.float16: + if "trtllm" in backend and dtype == torch.float16: pytest.skip("Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations") set_random_seed(seed) @@ -87,11 +102,14 @@ def test_flashinfer_nvfp4_gemm( (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1) ).to(torch.float32) alpha = 1.0 / (a_global_scale * b_global_scale) + # ops.scaled_fp4_quant returns swizzled scales, while weights # from checkpoints are in linear scales. # So instead of needing to swizzle for cutlass as in modelopt.py, # we need to unswizzle for trtllm here. - a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale) + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale, backend) + is_sf_128x4_layout = not (backend == "trtllm" and m <= 32) + b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) # get_ref_results unswizzles the scales internally. @@ -107,14 +125,14 @@ def test_flashinfer_nvfp4_gemm( dtype, block_size, device, + is_sf_128x4_layout, ) import flashinfer - if backend == "trtllm": + if "trtllm" in backend: epilogue_tile_m = 128 b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), epilogue_tile_m) - b_scale_interleaved = convert_swizzled_to_linear( b_scale_interleaved, n, k, block_size ) diff --git a/tests/models/quantization/test_nvfp4.py b/tests/models/quantization/test_nvfp4.py index 9f45f142d68b..b73462bfd198 100644 --- a/tests/models/quantization/test_nvfp4.py +++ b/tests/models/quantization/test_nvfp4.py @@ -14,6 +14,8 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams +from vllm.platforms import current_platform + os.environ["TOKENIZERS_PARALLELISM"] = "true" MAX_MODEL_LEN = 1024 @@ -83,3 +85,27 @@ def test_models(example_prompts, model_name) -> None: assert expected_str == generated_str, ( f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}" ) + + +EAGER = [True, False] + + +@pytest.mark.skipif( + not current_platform.has_device_capability(100), + reason="modelopt_fp4 is not supported on this GPU type.", +) +@pytest.mark.parametrize("model", ["nvidia/Llama-3.1-8B-Instruct-NVFP4"]) +@pytest.mark.parametrize("eager", EAGER) +@pytest.mark.parametrize( + "backend", + [ + "flashinfer-cudnn", + "flashinfer-trtllm", # the small seq_len ensures trtllm_8x4_layout backend is used + "flashinfer-cutlass", + ], +) +def test_nvfp4(vllm_runner, model, eager, backend, monkeypatch): + monkeypatch.setenv("VLLM_NVFP4_GEMM_BACKEND", backend) + with vllm_runner(model, enforce_eager=eager) as llm: + output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2) + assert output[0][1] == "1 2 3 4 5 6" diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 86d6e309b1aa..7ff23e9686de 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -9,6 +9,9 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.scalar_type import ScalarType +from vllm.utils.flashinfer import ( + flashinfer_quant_nvfp4_8x4_sf_layout, +) logger = init_logger(__name__) @@ -1563,7 +1566,9 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: # fp4 def scaled_fp4_quant( - input: torch.Tensor, input_global_scale: torch.Tensor + input: torch.Tensor, + input_global_scale: torch.Tensor, + backend: str = "none", ) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP4 and return quantized tensor and scale. @@ -1577,6 +1582,7 @@ def scaled_fp4_quant( Args: input: The input tensor to be quantized to FP4 input_global_scale: A scalar scaling factor for the entire tensor. + use_8x4_sf_layout: Whether to use the 8x4 or 128x4 layout for the scaling Returns: tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every @@ -1596,23 +1602,31 @@ def scaled_fp4_quant( f"input.dtype needs to be fp16 or bf16 but got {input.dtype}." ) - # Two fp4 values will be packed into an uint8. - output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + use_8x4_sf_layout = True if "trtllm" in backend and m <= 32 else False # noqa: SIM210 - # We use the rounded values to store the swizzled values. Due to the - # requirement of the Tensor Core, the minimum tile is 128x4 for the scales. - # So, we first pad the scales to multiples of 128 and 4. Then, the scales - # (in float8_e4m3fn) are packed into an int32 for every 4 values. More: - # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x - round_up = lambda x, y: (x + y - 1) // y * y - rounded_m = round_up(m, 128) - scale_n = n // block_size - rounded_n = round_up(scale_n, 4) - output_scale = torch.empty( - (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 - ) + if use_8x4_sf_layout: + output, output_scale = flashinfer_quant_nvfp4_8x4_sf_layout( + input, input_global_scale + ) + else: + # Two fp4 values will be packed into an uint8. + output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + + # We use the rounded values to store the swizzled values. Due to the + # requirement of the Tensor Core, the minimum tile is 128x4 for the scales. + # So, we first pad the scales to multiples of 128 and 4. Then, the scales + # (in float8_e4m3fn) are packed into an int32 for every 4 values. More: + # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(m, 128) + scale_n = n // block_size + rounded_n = round_up(scale_n, 4) + output_scale = torch.empty( + (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 + ) + + torch.ops._C.scaled_fp4_quant(output, input, output_scale, input_global_scale) - torch.ops._C.scaled_fp4_quant(output, input, output_scale, input_global_scale) output_scale = output_scale.view(torch.float8_e4m3fn) return output, output_scale diff --git a/vllm/envs.py b/vllm/envs.py index d77c1e9d95e2..ca4bda46fe29 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1444,7 +1444,12 @@ def get_vllm_port() -> int | None: "VLLM_NVFP4_GEMM_BACKEND": env_with_choices( "VLLM_NVFP4_GEMM_BACKEND", None, - ["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass", "cutlass"], + [ + "flashinfer-cudnn", + "flashinfer-trtllm", + "flashinfer-cutlass", + "cutlass", + ], ), # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index c0b1e3ceeba3..d7f34e4f5eca 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -23,7 +23,10 @@ ModelWeightParameter, PerTensorScaleParameter, ) -from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer +from vllm.utils.flashinfer import ( + flashinfer_scaled_fp4_mm, + has_flashinfer, +) logger = init_logger(__name__) @@ -187,7 +190,9 @@ def apply_weights( output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]] # quantize BF16 or FP16 to (FP4 and interleaved block scale) - x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) + x_fp4, x_blockscale = scaled_fp4_quant( + x, layer.input_global_scale, self.backend + ) mm_args = ( x_fp4, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index a646012ddd3a..bcda7b42c2ec 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1291,7 +1291,7 @@ def apply( output_shape = [x.shape[0], layer.weight.shape[0]] # quantize BF16 or FP16 to (FP4 and interleaved block scale) - x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) + x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv, self.backend) # validate dtypes of quantized input, input block scale, # weight and weight_blockscale diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 3da8be098fbd..067c6fb3e785 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -406,12 +406,21 @@ def flashinfer_mm_fp4( B_scale: torch.Tensor, g_scale: torch.Tensor, dtype: torch.dtype, + use_8x4_sf_layout: bool, backend: str, ) -> torch.Tensor: from flashinfer import mm_fp4 as flashinfer_mm_fp4_ return flashinfer_mm_fp4_( - A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend + A, + B, + A_scale, + B_scale, + g_scale, + dtype, + block_size=16, + use_8x4_sf_layout=use_8x4_sf_layout, + backend=backend, ) @torch.library.register_fake( @@ -424,6 +433,7 @@ def flashinfer_mm_fp4_fake( B_scale: torch.Tensor, g_scale: torch.Tensor, dtype: torch.dtype, + use_8x4_sf_layout: bool, backend: str, ) -> torch.Tensor: return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device) @@ -460,6 +470,39 @@ def bmm_fp8_fake( A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device ) + @torch.library.custom_op( + "vllm::flashinfer_nvfp4_quantize", + mutates_args=[], + device_types="cuda", + ) + def flashinfer_nvfp4_quantize( + a: torch.Tensor, a_global_sf: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + from flashinfer import SfLayout + from flashinfer import nvfp4_quantize as nvfp4_quantize_ + + return nvfp4_quantize_( + a, a_global_sf, sfLayout=SfLayout.layout_8x4, do_shuffle=False + ) + + @torch.library.register_fake( + "vllm::flashinfer_nvfp4_quantize", + ) + def flashinfer_nvfp4_quantize_fake( + a: torch.Tensor, a_global_sf: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + m, n = a.shape + + round_up = lambda x, y: (x + y - 1) // y * y + + rounded_m = round_up(m, 8) + scale_n = n // 16 + rounded_n = round_up(scale_n, 4) + + return torch.empty(m, n // 2, dtype=torch.uint8, device=a.device), torch.empty( + rounded_m, rounded_n, dtype=torch.uint8, device=a.device + ) + def flashinfer_scaled_fp4_mm( a: torch.Tensor, @@ -479,6 +522,8 @@ def flashinfer_scaled_fp4_mm( block_scale_a = block_scale_a.view(torch.uint8) block_scale_b = block_scale_b.view(torch.uint8) + use_8x4_sf_layout = True if backend == "trtllm" and a.shape[0] <= 32 else False # noqa: SIM210 + return flashinfer_mm_fp4( a, b.t(), @@ -486,6 +531,7 @@ def flashinfer_scaled_fp4_mm( block_scale_b.t(), alpha, out_dtype, + use_8x4_sf_layout=use_8x4_sf_layout, backend=backend, ) @@ -520,6 +566,12 @@ def flashinfer_scaled_fp8_mm( return output +def flashinfer_quant_nvfp4_8x4_sf_layout( + a: torch.Tensor, a_global_sf: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + return flashinfer_nvfp4_quantize(a, a_global_sf) + + flashinfer_fp8_blockscale_gemm = _lazy_import_wrapper( "flashinfer.gemm", "fp8_blockscale_gemm_sm90" ) @@ -596,6 +648,7 @@ def should_use_flashinfer_for_blockscale_fp8_gemm( "use_trtllm_attention", "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp8_mm", + "flashinfer_quant_nvfp4_8x4_sf_layout", "flashinfer_fp8_blockscale_gemm", "should_use_flashinfer_for_blockscale_fp8_gemm", "is_flashinfer_fp8_blockscale_gemm_supported",