diff --git a/tests/kernels/moe/test_cutedsl_sm12x_moe.py b/tests/kernels/moe/test_cutedsl_sm12x_moe.py new file mode 100644 index 000000000000..59cc8e8a12cb --- /dev/null +++ b/tests/kernels/moe/test_cutedsl_sm12x_moe.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.platforms import current_platform + +if not current_platform.is_device_capability_family(120): + pytest.skip( + reason="FlashInfer CuteDSL SM12x MoE requires SM120 (RTX Pro 6000 / DGX Spark).", + allow_module_level=True, + ) + +from vllm.utils.flashinfer import has_flashinfer_cutedsl_sm12x_moe + +if not has_flashinfer_cutedsl_sm12x_moe(): + pytest.skip( + reason=( + "FlashInfer cute_dsl_fused_moe_nvfp4 / convert_sf_to_mma_layout " + "not available in installed FlashInfer (needs PRs #3051 and #3066)." + ), + allow_module_level=True, + ) + +# Import fp4_quantize after the skip guard — FlashInfer must be installed. +from flashinfer.fp4_quantization import fp4_quantize + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from tests.kernels.moe.utils import make_dummy_moe_config +from tests.kernels.utils import torch_moe +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) +from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config +from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_moe import ( + FlashInferCuteDSLSM12xExperts, +) +from vllm.utils.torch_utils import set_random_seed + +# Dimensions chosen to satisfy FP4 alignment requirements (k multiple of 256, +# n multiple of 128) while keeping tests fast. +MNK_FACTORS = [ + (2, 128, 256), + (2, 256, 512), + (16, 128, 256), + (64, 256, 512), +] + + +def _reorder_gate_up_to_up_gate( + w: torch.Tensor, + w_s: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Swap gate and up-projection halves along dim=1 to [up, gate] order. + + The SM12x kernel expects weights in [up (w3), gate (w1)] order while the + BF16 reference uses [gate (w1), up (w3)]. This replicates the reordering + done at model-load time by ``prepare_nvfp4_moe_layer_for_fi_or_cutlass``. + """ + n = w.shape[1] // 2 + return ( + torch.cat([w[:, n:, :], w[:, :n, :]], dim=1), + torch.cat([w_s[:, n:, :], w_s[:, :n, :]], dim=1), + ) + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", [8, 16]) +@pytest.mark.parametrize("topk", [1, 2, 4]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.inference_mode() +def test_flashinfer_cutedsl_sm12x_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + workspace_init, +): + """Test FlashInferCuteDSLSM12xExperts against a BF16 torch reference. + + The SM12x kernel takes BF16 hidden states directly and fuses token + dispatch, W1 GEMM, SwiGLU, and W2 GEMM into one call. We verify + correctness against ``torch_moe`` using generous tolerances to account + for the internal FP4 quantization of activations and weights. + + Scale convention + ---------------- + The SM12x kernel uses ``w1_alpha`` as *both* the activation-quantisation + global scale and the weight dequantisation factor. These two roles are + conflated into a single parameter in ``launch_sm120_moe``, so they must + equal the same value. We use ``global_scale = 1.0`` for + ``fp4_quantize`` so that ``w1_alpha = ones`` satisfies both roles + simultaneously. The alternative — vLLM's convention of baking a large + ``w_gs`` into block-scale values and compensating with + ``g1_alphas = 1/w_gs`` — is incompatible with this kernel. + """ + set_random_seed(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + + # Generate BF16 reference weights in [gate, up] order. + # Shape: w1=(e, 2n, k), w2=(e, k, n). + w1_bf16 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 15 + w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 15 + + # ------------------------------------------------------------------ # + # Quantise weights for the SM12x kernel using FlashInfer's convention: + # global_scale = 1.0 → block_scale = max_abs_block / fp4_max + # w1_alpha = 1.0 (no extra global factor to compensate) + # + # The scale factors returned by fp4_quantize(..., is_sf_swizzled_layout=True) + # are already in the swizzled 2D layout expected by convert_sf_to_mma_layout. + # No additional swizzle_blockscale() call is needed. + # ------------------------------------------------------------------ # + gs = torch.ones(1, device="cuda", dtype=torch.float32) + sf_vec_size = 16 + + # W1: reorder BF16 from [gate, up] → [up, gate], then quantise. + w1_reordered = torch.cat( + [w1_bf16[:, n:, :], w1_bf16[:, :n, :]], dim=1 + ) # shape (e, 2n, k), [up, gate] + w1_flat = w1_reordered.reshape(e * 2 * n, k) + w1_q_flat, w1_sf_flat = fp4_quantize( + w1_flat, + global_scale=gs, + sf_vec_size=sf_vec_size, + is_sf_swizzled_layout=True, + ) + w1_q = w1_q_flat.view(e, 2 * n, k // 2) # uint8, packed FP4 + w1_blockscale = w1_sf_flat.view(e, 2 * n, w1_sf_flat.shape[1]) # float8 + + # W2: no row reordering needed for the down-projection. + w2_flat = w2_bf16.reshape(e * k, n) + w2_q_flat, w2_sf_flat = fp4_quantize( + w2_flat, + global_scale=gs, + sf_vec_size=sf_vec_size, + is_sf_swizzled_layout=True, + ) + w2_q = w2_q_flat.view(e, k, n // 2) # uint8, packed FP4 + w2_blockscale = w2_sf_flat.view(e, k, w2_sf_flat.shape[1]) # float8 + + # All per-expert alphas are 1.0 (global_scale = 1.0, no compensation). + ones_e = torch.ones(e, device="cuda", dtype=torch.float32) + + quant_config = nvfp4_moe_quant_config( + g1_alphas=ones_e, + g2_alphas=ones_e, + a1_gscale=ones_e, + a2_gscale=ones_e, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + ) + + moe_config = make_dummy_moe_config( + num_experts=e, + experts_per_token=topk, + hidden_dim=k, + intermediate_size_per_partition=n, + in_dtype=dtype, + ) + + kernel = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), + FlashInferCuteDSLSM12xExperts( + moe_config=moe_config, + quant_config=quant_config, + ), + inplace=False, + ) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) + + sm12x_output = kernel.apply( + hidden_states=a, + w1=w1_q, + w2=w2_q, + topk_weights=topk_weights, + topk_ids=topk_ids, + global_num_experts=e, + activation=MoEActivation.SILU, + apply_router_weight_on_input=False, + expert_map=None, + ) + + # Reference: BF16 torch MoE using original [gate, up] BF16 weights. + # torch_moe's SiluAndMul expects [gate, up] order, matching w1_bf16. + torch_output = torch_moe(a, w1_bf16, w2_bf16, score, topk) + + torch.testing.assert_close(sm12x_output, torch_output, atol=2e-1, rtol=2e-1) + + +if __name__ == "__main__": + test_flashinfer_cutedsl_sm12x_moe(16, 128, 256, 8, 2, torch.bfloat16) diff --git a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py index e414ba7d2cc3..698c679a201c 100644 --- a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -13,6 +13,7 @@ from vllm.platforms import current_platform from vllm.utils.flashinfer import ( flashinfer_scaled_fp4_mm, + has_flashinfer_b12x_gemm, ) from vllm.utils.torch_utils import set_random_seed @@ -74,7 +75,7 @@ def get_ref_results( @pytest.mark.parametrize("shape", SHAPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("backend", ["cutlass", "cudnn", "trtllm"]) +@pytest.mark.parametrize("backend", ["cutlass", "cudnn", "trtllm", "b12x"]) @pytest.mark.parametrize("autotune", [False, True]) @torch.inference_mode() def test_flashinfer_nvfp4_gemm( @@ -87,6 +88,10 @@ def test_flashinfer_nvfp4_gemm( ) -> None: if "trtllm" in backend and dtype == torch.float16: pytest.skip("Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations") + if backend == "b12x" and not current_platform.has_device_capability(120): + pytest.skip("b12x FP4 GEMM requires SM120+ (CC 12.0+)") + if backend == "b12x" and not has_flashinfer_b12x_gemm(): + pytest.skip("b12x FP4 GEMM backend not available in installed FlashInfer") set_random_seed(seed) m, n, packed_k = shape @@ -105,8 +110,7 @@ def test_flashinfer_nvfp4_gemm( # ops.scaled_fp4_quant returns swizzled scales, while weights # from checkpoints are in linear scales. - # So instead of needing to swizzle for cutlass as in modelopt.py, - # we need to unswizzle for trtllm here. + # cutlass and b12x use swizzled scales directly; trtllm needs them unswizzled. a_fp4, a_scale_interleaved = ops.scaled_fp4_quant( a_dtype, a_global_scale, is_sf_swizzled_layout=True, backend=backend ) diff --git a/vllm/config/kernel.py b/vllm/config/kernel.py index f3ffbe4e8b19..9e74515fe7e1 100644 --- a/vllm/config/kernel.py +++ b/vllm/config/kernel.py @@ -113,6 +113,7 @@ def with_default( "flashinfer_trtllm", "flashinfer_cutlass", "flashinfer_cutedsl", + "flashinfer_cutedsl_sm12x", "marlin", "aiter", ] @@ -141,6 +142,7 @@ class KernelConfig: - "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels - "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels - "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only) + - "flashinfer_cutedsl_sm12x": Use FlashInfer CuteDSL fused MoE for SM12x (RTX Pro 6000 / DGX Spark) - "marlin": Use Marlin kernels (weight-only quantization) - "aiter": Use AMD AITer kernels (ROCm only)""" diff --git a/vllm/envs.py b/vllm/envs.py index 8ed1d33434cb..0a85ac9e52dd 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1466,6 +1466,7 @@ def _get_or_set_default() -> str: "VLLM_NVFP4_GEMM_BACKEND", None, [ + "flashinfer-b12x", "flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass", diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py index 6cbb65e26d66..6f56e22bee1c 100644 --- a/vllm/model_executor/kernels/linear/__init__.py +++ b/vllm/model_executor/kernels/linear/__init__.py @@ -88,6 +88,7 @@ FbgemmNvFp4LinearKernel, ) from vllm.model_executor.kernels.linear.nvfp4.flashinfer import ( + FlashInferB12xNvFp4LinearKernel, FlashInferCudnnNvFp4LinearKernel, FlashInferCutlassNvFp4LinearKernel, FlashInferTrtllmNvFp4LinearKernel, @@ -258,6 +259,7 @@ _POSSIBLE_NVFP4_KERNELS: dict[PlatformEnum, list[type[NvFp4LinearKernel]]] = { PlatformEnum.CUDA: [ + FlashInferB12xNvFp4LinearKernel, FlashInferCutlassNvFp4LinearKernel, CutlassNvFp4LinearKernel, MarlinNvFp4LinearKernel, @@ -589,6 +591,7 @@ def init_wfp8_a16_linear_kernel( # Maps VLLM_NVFP4_GEMM_BACKEND env var values to kernel classes. _NVFP4_BACKEND_TO_KERNEL: dict[str, type[NvFp4LinearKernel]] = { + "flashinfer-b12x": FlashInferB12xNvFp4LinearKernel, "flashinfer-cutlass": FlashInferCutlassNvFp4LinearKernel, "cutlass": CutlassNvFp4LinearKernel, "marlin": MarlinNvFp4LinearKernel, @@ -766,6 +769,7 @@ def register_linear_kernel( "CutlassNvFp4LinearKernel", "EmulationNvFp4LinearKernel", "FbgemmNvFp4LinearKernel", + "FlashInferB12xNvFp4LinearKernel", "FlashInferCutlassNvFp4LinearKernel", "FlashInferTrtllmNvFp4LinearKernel", "FlashInferCudnnNvFp4LinearKernel", diff --git a/vllm/model_executor/kernels/linear/nvfp4/flashinfer.py b/vllm/model_executor/kernels/linear/nvfp4/flashinfer.py index 399bc3dd2785..25e1c7c91f96 100644 --- a/vllm/model_executor/kernels/linear/nvfp4/flashinfer.py +++ b/vllm/model_executor/kernels/linear/nvfp4/flashinfer.py @@ -11,7 +11,11 @@ swizzle_blockscale, ) from vllm.platforms import current_platform -from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer +from vllm.utils.flashinfer import ( + flashinfer_scaled_fp4_mm, + has_flashinfer, + has_flashinfer_b12x_gemm, +) from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig @@ -216,3 +220,66 @@ def apply_weights( if bias is not None: out = out + bias return out.view(*output_shape) + + +class FlashInferB12xNvFp4LinearKernel(NvFp4LinearKernel): + """NVFP4 GEMM via FlashInfer's b12x CuTe DSL warp-level MMA kernel (SM120+).""" + + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + if current_platform.has_device_capability(120) and has_flashinfer_b12x_gemm(): + return True, None + return False, "FlashInfer b12x requires SM120+ and FlashInfer with Sm120BlockScaledDenseGemmKernel" + + @classmethod + def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]: + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight_scale = torch.nn.Parameter( + swizzle_blockscale(layer.weight_scale.data), requires_grad=False + ) + padded_weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass( + layer.weight.data + ) + layer.weight = torch.nn.Parameter(padded_weight, requires_grad=False) + layer.weights_padding_cols = weights_padding_cols + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + output_size = layer.output_size_per_partition + output_dtype = x.dtype + output_shape = [*x.shape[:-1], output_size] + + x_fp4, x_blockscale = scaled_fp4_quant( + x, + layer.input_global_scale_inv, + is_sf_swizzled_layout=True, + backend="b12x", + ) + + x_fp4 = pad_nvfp4_activation_for_cutlass( + x_fp4, getattr(layer, "weights_padding_cols", 0) + ) + + out = flashinfer_scaled_fp4_mm( + x_fp4, + layer.weight, + x_blockscale, + layer.weight_scale, + layer.alpha, + output_dtype, + backend="b12x", + ) + + out = slice_nvfp4_output(out, output_size) + + if bias is not None: + out = out + bias + return out.view(*output_shape) diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py index 5ce58220b073..c86cf0aedec0 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py @@ -11,6 +11,7 @@ FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -20,8 +21,14 @@ ) from vllm.platforms import current_platform from vllm.utils.flashinfer import ( + flashinfer_convert_sf_to_mma_layout, flashinfer_cute_dsl_fused_moe_nvfp4, + flashinfer_cutedsl_grouped_gemm_nt_masked, + has_flashinfer_cutedsl_grouped_gemm_nt_masked, has_flashinfer_cutedsl_moe_nvfp4, + has_flashinfer_cutedsl_sm12x_moe, + scaled_fp4_grouped_quantize, + silu_and_mul_scaled_nvfp4_experts_quantize, ) @@ -170,3 +177,384 @@ def apply( local_expert_offset=self.local_expert_offset, moe_output=output, ) + + +def get_cute_dtype(input: torch.Tensor) -> str: + if input.dtype == torch.bfloat16: + return "bfloat16" + elif input.dtype == torch.float16: + return "float16" + elif input.dtype == torch.float32: + return "float32" + else: + raise ValueError(f"Unsupported cute dtype {input.dtype}") + + +def flashinfer_cutedsl_moe_masked( + hidden_states: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + input_global_scale: torch.Tensor, + w1: torch.Tensor, + w1_blockscale: torch.Tensor, + w1_alpha, + w2: torch.Tensor, + a2_global_scale: torch.Tensor, + w2_blockscale: torch.Tensor, + w2_alpha, + masked_m: torch.Tensor, + workspace: torch.Tensor, + out: torch.Tensor, +): + """ + Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL + kernels. + + Args: + hidden_states: Either of the following case + * torch.Tensor: [num_experts, m, k], bf16 + * tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], + uint8, [num_experts, m, k // 16], float8_e4m3fn + input_global_scale (torch.Tensor): (l,) + w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8 + w1_blockscale (torch.Tensor): blockscale factors, e4m3, + w1_alpha (torch.Tensor): (l,) + w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8 + a2_global_scale (torch.Tensor): (l,) + w2_blockscale (torch.Tensor): blockscale factors, e4m3, + w2_alpha (torch.Tensor): (l,) + masked_m (torch.Tensor): Masked dimension indices + workspace (torch.Tensor): For gateup_output + + Notes: + - Assumes max(masked_m) <= m. + """ + + # === Assertions on dtypes === + assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}" + assert w1_blockscale.dtype == torch.float8_e4m3fn, ( + f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}" + ) + assert w1_alpha.dtype == torch.float32, ( + f"w1_alpha must be float32, got {w1_alpha.dtype}" + ) + assert w2.dtype == torch.uint8, f"w2 must be uint8, got {w2.dtype}" + assert a2_global_scale.dtype == torch.float32, ( + f"a2_global_scale must be float32, got {a2_global_scale.dtype}" + ) + assert w2_blockscale.dtype == torch.float8_e4m3fn, ( + f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}" + ) + assert w2_alpha.dtype == torch.float32, ( + f"w2_alpha must be float32, got {w2_alpha.dtype}" + ) + + # === Assertions on shapes === + n = w2.shape[-1] * 2 # intermediate dimension + if isinstance(hidden_states, tuple): + assert input_global_scale is None, ( + "input_global_scale is needed when input needs quant" + ) + + aq = hidden_states[0].view(torch.uint8) + aq_sf = hidden_states[1].view(torch.float8_e4m3fn) + # m, k_by_2, num_experts = aq.shape + num_experts, m, k_by_2 = aq.shape + k = k_by_2 * 2 + aq = aq.permute(1, 2, 0) + else: + num_experts, m, k = hidden_states.shape + + assert input_global_scale.dtype == torch.float32, ( + f"input_global_scale must be float32, got {input_global_scale.dtype}" + ) + assert input_global_scale.shape == (num_experts,), ( + f"input_global_scale must be (l,), got {input_global_scale.shape}" + ) + + aq, aq_sf = scaled_fp4_grouped_quantize( + hidden_states, + masked_m, + input_global_scale, + ) + + assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}" + assert w1.shape[-1] * 2 == k, ( + f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}" + ) + assert w2.shape[-2:] == ( + k, + n // 2, + ), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}" + + assert w1_alpha.shape == (num_experts,), ( + f"w1_alpha must be (l,), got {w1_alpha.shape}" + ) + assert a2_global_scale.shape == (num_experts,), ( + f"a2_global_scale must be (l,), got {a2_global_scale.shape}" + ) + assert w2_alpha.shape == (num_experts,), ( + f"w2_alpha must be (l,), got {w2_alpha.shape}" + ) + + workspace = workspace.permute(1, 2, 0) # requirement of kernel + sf_vec_size = 16 + assert aq_sf.dtype == torch.float8_e4m3fn + assert aq.dtype == torch.uint8 + ab_dtype = "float4_e2m1fn" + sf_dtype = "float8_e4m3fn" + + if isinstance(hidden_states, tuple): + c_dtype = "bfloat16" + else: + c_dtype = get_cute_dtype(hidden_states) + + # Gemm1 + flashinfer_cutedsl_grouped_gemm_nt_masked( + (aq, aq_sf), + (w1.permute(1, 2, 0), w1_blockscale), + workspace, + masked_m, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + alpha=w1_alpha.view(1, 1, num_experts), + alpha_dtype=get_cute_dtype(w1_alpha), + ) # in logical [m, n, l] + + # SILU and quantization + diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize( + workspace.permute(2, 0, 1), + masked_m, + a2_global_scale, + ) + + # Gemm2 + out = out.permute(1, 2, 0) # requirement of kernel + flashinfer_cutedsl_grouped_gemm_nt_masked( + (diq, diq_sf), + (w2.permute(1, 2, 0), w2_blockscale), + out, + masked_m, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + alpha=w2_alpha.view(1, 1, num_experts), + alpha_dtype=get_cute_dtype(w2_alpha), + ) # in logical [m, k, l] + out = out.permute(2, 0, 1) + + +class FlashInferCuteDSLSM12xExperts(mk.FusedMoEExpertsModular): + """FlashInfer CuteDSL fused MoE expert for SM12x (SM120/SM121, RTX Pro 6000 / DGX Spark). + + Uses ``cute_dsl_fused_moe_nvfp4`` from FlashInfer PR #3066 which fuses + token dispatch, two GEMMs, SwiGLU activation, and topk-weight reduction + into a single kernel call. Input quantization (BF16→FP4) is performed + inside the kernel so BF16 hidden states are passed directly. + + Weight scale factors are converted to the MMA layout produced by + ``convert_sf_to_mma_layout`` once during ``process_weights_after_loading`` + and cached as ``w1_sf_mma`` / ``w2_sf_mma``. + + Only NVFP4 (kNvfp4Static/kNvfp4Dynamic) quantization is supported. + """ + + def __init__( + self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + ): + super().__init__(moe_config=moe_config, quant_config=quant_config) + assert quant_config.quant_dtype == "nvfp4", ( + "FlashInferCuteDSLSM12xExperts only supports nvfp4 quantization." + ) + self.out_dtype = moe_config.in_dtype + self.num_local_experts = moe_config.num_local_experts + self.ep_rank = moe_config.moe_parallel_config.ep_rank + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Normalise block scales to absorb the per-expert weight global scale + # (w_gs). vLLM's NVFP4 convention stores: + # block_scale = max_abs * w_gs / fp4_max, g1_alphas = 1/w_gs + # The SM12x kernel treats w1_alpha (= g1_alphas) as a per-expert weight + # dequant multiplier separate from input_gs (activation scale). We bake + # w_gs into the block scales so that w1_alpha = 1.0 and the kernel sees + # the simpler form: + # block_scale = max_abs / fp4_max, w1_alpha = 1.0 + # The FP4-packed values and dequantised results are identical in both + # representations. We set scale_2 = 1.0 to signal that the bake-in is + # already done. + layer.w13_weight_scale.data = ( + layer.w13_weight_scale.float() + * layer.w13_weight_scale_2.view(-1, 1, 1) + ).to(layer.w13_weight_scale.dtype) + layer.w13_weight_scale_2.data.fill_(1.0) + + layer.w2_weight_scale.data = ( + layer.w2_weight_scale.float() + * layer.w2_weight_scale_2.view(-1, 1, 1) + ).to(layer.w2_weight_scale.dtype) + layer.w2_weight_scale_2.data.fill_(1.0) + + # The SM12x kernel uses dynamic per-block quantization for FC2 input + # activations (the SwiGLU output before the down projection). The + # calibrated a2_gscale from the modelopt checkpoint (~tens to hundreds) + # is intended for static-quantisation backends (TRTLLM/CUTLASS) and + # causes every intermediate activation to saturate at max FP4 when + # multiplied by values that large. Keep the original a2_gscale intact + # and store a separate ones tensor to pass to the SM12x kernel so it + # uses its own per-block dynamic scale. + if self.a2_gscale is not None: + self.a2_gscale_ones = torch.ones_like(self.a2_gscale) + else: + self.a2_gscale_ones = None + + # Precompute MMA-layout views of the weight scale factors once here + # rather than recomputing on every forward pass. + num_experts_w1, m1, k1_sf = self.w1_scale.shape + k1 = k1_sf * 16 + self.w1_sf_mma = flashinfer_convert_sf_to_mma_layout( + self.w1_scale.reshape(num_experts_w1 * m1, k1_sf), + m=m1, + k=k1, + num_groups=num_experts_w1, + ) + + num_experts_w2, m2, k2_sf = self.w2_scale.shape + k2 = k2_sf * 16 + self.w2_sf_mma = flashinfer_convert_sf_to_mma_layout( + self.w2_scale.reshape(num_experts_w2 * m2, k2_sf), + m=m2, + k=k2, + num_groups=num_experts_w2, + ) + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + p = current_platform + return ( + p.is_cuda() + and p.is_device_capability_family(120) + and has_flashinfer_cutedsl_sm12x_moe() + ) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic) + + @staticmethod + def _supports_activation(activation: MoEActivation) -> bool: + return activation == MoEActivation.SILU + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True + + def supports_expert_map(self) -> bool: + return False + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # cute_dsl_fused_moe_nvfp4 applies topk weights internally. + return TopKWeightAndReduceNoOP() + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: MoEActivation, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + # cute_dsl_fused_moe_nvfp4 manages its own internal workspace. + workspace1 = (1,) + workspace2 = (0,) + output_shape = (M, K) + return (workspace1, workspace2, output_shape) + + @property + def expects_unquantized_inputs(self) -> bool: + # cute_dsl_fused_moe_nvfp4 expects BF16 hidden states and performs + # its own FP4 quantization internally. Returning True prevents the + # modular kernel from pre-quantizing activations, which would produce + # an FP4-packed tensor with size(-1)=k//2 and break the scale-factor + # conversion that expects size(-1)=k. + return True + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor | None, + workspace2: torch.Tensor | None, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool | None, + ): + assert self.w1_scale is not None and self.w2_scale is not None, ( + "w1_scale and w2_scale must not be None for FlashInferCuteDSLSM12xExperts" + ) + assert self.g1_alphas is not None and self.g2_alphas is not None, ( + "g1_alphas and g2_alphas must not be None for FlashInferCuteDSLSM12xExperts" + ) + assert self.a2_gscale is not None, ( + "a2_gscale must not be None for FlashInferCuteDSLSM12xExperts" + ) + + top_k = topk_ids.shape[1] + local_expert_offset = self.ep_rank * self.num_local_experts + + # x_sf is ignored by the SM12x kernel (quantization is fused + # internally), but the API requires a tensor argument. + x_sf_placeholder = ( + a1q_scale + if a1q_scale is not None + else hidden_states.new_zeros(1, dtype=torch.float8_e4m3fn) + ) + + # TODO: Use the plan/run() API from FlashInfer PR #3066 instead of + # calling cute_dsl_fused_moe_nvfp4 directly. The plan object can be + # created once in __init__ (shapes are fixed for MoE layers) and + # plan.run() called here, avoiding workspace allocation and kernel + # parameter setup overhead on every forward pass. + flashinfer_cute_dsl_fused_moe_nvfp4( + x=hidden_states, + x_sf=x_sf_placeholder, + token_selected_experts=topk_ids.to(torch.int32), + token_final_scales=topk_weights.float(), + w1_weight=w1, + w1_weight_sf=self.w1_sf_mma, + w1_alpha=self.g1_alphas, + fc2_input_scale=self.a2_gscale_ones, + w2_weight=w2, + w2_weight_sf=self.w2_sf_mma, + w2_alpha=self.g2_alphas, + num_experts=global_num_experts, + top_k=top_k, + num_local_experts=self.num_local_experts, + local_expert_offset=local_expert_offset, + output_dtype=self.out_dtype, + moe_output=output, + ) diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index 597d784d3b63..2d6accb55f4e 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -43,6 +43,7 @@ class NvFp4MoeBackend(Enum): FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS" FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL" FLASHINFER_CUTEDSL_BATCHED = "FLASHINFER_CUTEDSL_BATCHED" + FLASHINFER_CUTEDSL_SM12X = "FLASHINFER_CUTEDSL_SM12X" VLLM_CUTLASS = "VLLM_CUTLASS" MARLIN = "MARLIN" @@ -52,6 +53,7 @@ class NvFp4MoeBackend(Enum): NvFp4MoeBackend.FLASHINFER_CUTLASS, NvFp4MoeBackend.FLASHINFER_CUTEDSL, NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED, + NvFp4MoeBackend.FLASHINFER_CUTEDSL_SM12X, ] fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = { @@ -105,6 +107,13 @@ def backend_to_kernel_cls( return [FlashInferCuteDSLBatchedExperts] + elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL_SM12X: + from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_moe import ( # noqa: E501 + FlashInferCuteDSLSM12xExperts, + ) + + return [FlashInferCuteDSLSM12xExperts] + elif backend == NvFp4MoeBackend.VLLM_CUTLASS: from vllm.model_executor.layers.fused_moe.cutlass_moe import ( CutlassExpertsFp4, @@ -129,6 +138,7 @@ def map_nvfp4_backend(runner_backend: MoEBackend) -> NvFp4MoeBackend: "flashinfer_trtllm": NvFp4MoeBackend.FLASHINFER_TRTLLM, "flashinfer_cutlass": NvFp4MoeBackend.FLASHINFER_CUTLASS, "flashinfer_cutedsl": NvFp4MoeBackend.FLASHINFER_CUTEDSL, + "flashinfer_cutedsl_sm12x": NvFp4MoeBackend.FLASHINFER_CUTEDSL_SM12X, "marlin": NvFp4MoeBackend.MARLIN, } if backend := mapping.get(runner_backend): @@ -154,6 +164,7 @@ def select_nvfp4_moe_backend( NvFp4MoeBackend.FLASHINFER_TRTLLM, NvFp4MoeBackend.FLASHINFER_CUTEDSL, NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED, + NvFp4MoeBackend.FLASHINFER_CUTEDSL_SM12X, NvFp4MoeBackend.FLASHINFER_CUTLASS, NvFp4MoeBackend.VLLM_CUTLASS, NvFp4MoeBackend.MARLIN, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index ef0bf2bf7aca..6a8791e11e98 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -316,7 +316,9 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( NvFp4MoeBackend.VLLM_CUTLASS, NvFp4MoeBackend.FLASHINFER_CUTLASS, NvFp4MoeBackend.FLASHINFER_TRTLLM, + NvFp4MoeBackend.FLASHINFER_CUTEDSL, NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED, + NvFp4MoeBackend.FLASHINFER_CUTEDSL_SM12X, ] # Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels. @@ -328,6 +330,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( in [ NvFp4MoeBackend.FLASHINFER_CUTLASS, NvFp4MoeBackend.FLASHINFER_TRTLLM, + NvFp4MoeBackend.FLASHINFER_CUTEDSL_SM12X, ] ): w13, w13_scale = reorder_w1w3_to_w3w1(w13, w13_scale) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index ed171db96e73..6833bf0ed252 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -117,6 +117,12 @@ def wrapper(*args, **kwargs): flashinfer_cutedsl_grouped_gemm_nt_masked = _lazy_import_wrapper( "flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked" ) +flashinfer_cute_dsl_fused_moe_nvfp4 = _lazy_import_wrapper( + "flashinfer.fused_moe", "cute_dsl_fused_moe_nvfp4" +) +flashinfer_convert_sf_to_mma_layout = _lazy_import_wrapper( + "flashinfer.cute_dsl.utils", "convert_sf_to_mma_layout" +) flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") nvfp4_batched_quantize = _lazy_import_wrapper("flashinfer", "nvfp4_batched_quantize") silu_and_mul_scaled_nvfp4_experts_quantize = _lazy_import_wrapper( @@ -128,12 +134,6 @@ def wrapper(*args, **kwargs): nvfp4_block_scale_interleave = _lazy_import_wrapper( "flashinfer.fp4_quantization", "block_scale_interleave" ) -flashinfer_cute_dsl_fused_moe_nvfp4 = _lazy_import_wrapper( - "flashinfer", "cute_dsl_fused_moe_nvfp4" -) -flashinfer_convert_sf_to_mma_layout = _lazy_import_wrapper( - "flashinfer.cute_dsl.utils", "convert_sf_to_mma_layout" -) trtllm_fp4_block_scale_moe = _lazy_import_wrapper( "flashinfer", "trtllm_fp4_block_scale_moe" ) @@ -267,6 +267,33 @@ def has_flashinfer_cutedsl_moe_nvfp4() -> bool: return mod is not None and hasattr(mod, "cute_dsl_fused_moe_nvfp4") +@functools.cache +def has_flashinfer_b12x_gemm() -> bool: + """Return True if FlashInfer b12x FP4 GEMM backend is available (SM120+).""" + if not has_flashinfer_cutedsl(): + return False + mod = _get_submodule("flashinfer.gemm") + return mod is not None and hasattr(mod, "Sm120BlockScaledDenseGemmKernel") + + +@functools.cache +def has_flashinfer_cutedsl_sm12x_moe() -> bool: + """Return ``True`` if FlashInfer CuteDSL SM12x fused MoE is available.""" + if not has_flashinfer_moe(): + return False + + required_functions = [ + ("flashinfer.fused_moe", "cute_dsl_fused_moe_nvfp4"), + ("flashinfer.cute_dsl.utils", "convert_sf_to_mma_layout"), + ] + + for module_name, attr_name in required_functions: + mod = _get_submodule(module_name) + if not mod or not hasattr(mod, attr_name): + return False + return True + + @functools.cache def has_nvidia_artifactory() -> bool: """Return `True` if NVIDIA's artifactory is accessible. @@ -792,6 +819,8 @@ def should_use_flashinfer_for_blockscale_fp8_gemm( "has_flashinfer_cutlass_fused_moe", "has_flashinfer_cutedsl_grouped_gemm_nt_masked", "has_flashinfer_cutedsl_moe_nvfp4", + "has_flashinfer_cutedsl_sm12x_moe", + "has_flashinfer_b12x_gemm", "has_flashinfer_fp8_blockscale_gemm", "has_nvidia_artifactory", "supports_trtllm_attention",