Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions tests/kernels/moe/test_cutedsl_sm12x_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from vllm.platforms import current_platform

if not current_platform.is_device_capability_family(120):
pytest.skip(
reason="FlashInfer CuteDSL SM12x MoE requires SM120 (RTX Pro 6000 / DGX Spark).",
allow_module_level=True,
)

from vllm.utils.flashinfer import has_flashinfer_cutedsl_sm12x_moe

if not has_flashinfer_cutedsl_sm12x_moe():
pytest.skip(
reason=(
"FlashInfer cute_dsl_fused_moe_nvfp4 / convert_sf_to_mma_layout "
"not available in installed FlashInfer (needs PRs #3051 and #3066)."
),
allow_module_level=True,
)

# Import fp4_quantize after the skip guard — FlashInfer must be installed.
from flashinfer.fp4_quantization import fp4_quantize

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_dummy_moe_config
from tests.kernels.utils import torch_moe
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_moe import (
FlashInferCuteDSLSM12xExperts,
)
from vllm.utils.torch_utils import set_random_seed

# Dimensions chosen to satisfy FP4 alignment requirements (k multiple of 256,
# n multiple of 128) while keeping tests fast.
MNK_FACTORS = [
(2, 128, 256),
(2, 256, 512),
(16, 128, 256),
(64, 256, 512),
]


def _reorder_gate_up_to_up_gate(
w: torch.Tensor,
w_s: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Swap gate and up-projection halves along dim=1 to [up, gate] order.

The SM12x kernel expects weights in [up (w3), gate (w1)] order while the
BF16 reference uses [gate (w1), up (w3)]. This replicates the reordering
done at model-load time by ``prepare_nvfp4_moe_layer_for_fi_or_cutlass``.
"""
n = w.shape[1] // 2
return (
torch.cat([w[:, n:, :], w[:, :n, :]], dim=1),
torch.cat([w_s[:, n:, :], w_s[:, :n, :]], dim=1),
)


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [8, 16])
@pytest.mark.parametrize("topk", [1, 2, 4])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.inference_mode()
def test_flashinfer_cutedsl_sm12x_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
workspace_init,
):
"""Test FlashInferCuteDSLSM12xExperts against a BF16 torch reference.

The SM12x kernel takes BF16 hidden states directly and fuses token
dispatch, W1 GEMM, SwiGLU, and W2 GEMM into one call. We verify
correctness against ``torch_moe`` using generous tolerances to account
for the internal FP4 quantization of activations and weights.

Scale convention
----------------
The SM12x kernel uses ``w1_alpha`` as *both* the activation-quantisation
global scale and the weight dequantisation factor. These two roles are
conflated into a single parameter in ``launch_sm120_moe``, so they must
equal the same value. We use ``global_scale = 1.0`` for
``fp4_quantize`` so that ``w1_alpha = ones`` satisfies both roles
simultaneously. The alternative — vLLM's convention of baking a large
``w_gs`` into block-scale values and compensating with
``g1_alphas = 1/w_gs`` — is incompatible with this kernel.
"""
set_random_seed(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10

# Generate BF16 reference weights in [gate, up] order.
# Shape: w1=(e, 2n, k), w2=(e, k, n).
w1_bf16 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 15
w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 15

# ------------------------------------------------------------------ #
# Quantise weights for the SM12x kernel using FlashInfer's convention:
# global_scale = 1.0 → block_scale = max_abs_block / fp4_max
# w1_alpha = 1.0 (no extra global factor to compensate)
#
# The scale factors returned by fp4_quantize(..., is_sf_swizzled_layout=True)
# are already in the swizzled 2D layout expected by convert_sf_to_mma_layout.
# No additional swizzle_blockscale() call is needed.
# ------------------------------------------------------------------ #
gs = torch.ones(1, device="cuda", dtype=torch.float32)
sf_vec_size = 16

# W1: reorder BF16 from [gate, up] → [up, gate], then quantise.
w1_reordered = torch.cat(
[w1_bf16[:, n:, :], w1_bf16[:, :n, :]], dim=1
) # shape (e, 2n, k), [up, gate]
w1_flat = w1_reordered.reshape(e * 2 * n, k)
w1_q_flat, w1_sf_flat = fp4_quantize(
w1_flat,
global_scale=gs,
sf_vec_size=sf_vec_size,
is_sf_swizzled_layout=True,
)
w1_q = w1_q_flat.view(e, 2 * n, k // 2) # uint8, packed FP4
w1_blockscale = w1_sf_flat.view(e, 2 * n, w1_sf_flat.shape[1]) # float8

# W2: no row reordering needed for the down-projection.
w2_flat = w2_bf16.reshape(e * k, n)
w2_q_flat, w2_sf_flat = fp4_quantize(
w2_flat,
global_scale=gs,
sf_vec_size=sf_vec_size,
is_sf_swizzled_layout=True,
)
w2_q = w2_q_flat.view(e, k, n // 2) # uint8, packed FP4
w2_blockscale = w2_sf_flat.view(e, k, w2_sf_flat.shape[1]) # float8

# All per-expert alphas are 1.0 (global_scale = 1.0, no compensation).
ones_e = torch.ones(e, device="cuda", dtype=torch.float32)

quant_config = nvfp4_moe_quant_config(
g1_alphas=ones_e,
g2_alphas=ones_e,
a1_gscale=ones_e,
a2_gscale=ones_e,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
)

moe_config = make_dummy_moe_config(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
in_dtype=dtype,
)

kernel = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
FlashInferCuteDSLSM12xExperts(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=False,
)

score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)

sm12x_output = kernel.apply(
hidden_states=a,
w1=w1_q,
w2=w2_q,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_num_experts=e,
activation=MoEActivation.SILU,
apply_router_weight_on_input=False,
expert_map=None,
)

# Reference: BF16 torch MoE using original [gate, up] BF16 weights.
# torch_moe's SiluAndMul expects [gate, up] order, matching w1_bf16.
torch_output = torch_moe(a, w1_bf16, w2_bf16, score, topk)

torch.testing.assert_close(sm12x_output, torch_output, atol=2e-1, rtol=2e-1)


if __name__ == "__main__":
test_flashinfer_cutedsl_sm12x_moe(16, 128, 256, 8, 2, torch.bfloat16)
10 changes: 7 additions & 3 deletions tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
flashinfer_scaled_fp4_mm,
has_flashinfer_b12x_gemm,
)
from vllm.utils.torch_utils import set_random_seed

Expand Down Expand Up @@ -74,7 +75,7 @@ def get_ref_results(
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("backend", ["cutlass", "cudnn", "trtllm"])
@pytest.mark.parametrize("backend", ["cutlass", "cudnn", "trtllm", "b12x"])
@pytest.mark.parametrize("autotune", [False, True])
@torch.inference_mode()
def test_flashinfer_nvfp4_gemm(
Expand All @@ -87,6 +88,10 @@ def test_flashinfer_nvfp4_gemm(
) -> None:
if "trtllm" in backend and dtype == torch.float16:
pytest.skip("Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")
if backend == "b12x" and not current_platform.has_device_capability(120):
pytest.skip("b12x FP4 GEMM requires SM120+ (CC 12.0+)")
if backend == "b12x" and not has_flashinfer_b12x_gemm():
pytest.skip("b12x FP4 GEMM backend not available in installed FlashInfer")

set_random_seed(seed)
m, n, packed_k = shape
Expand All @@ -105,8 +110,7 @@ def test_flashinfer_nvfp4_gemm(

# ops.scaled_fp4_quant returns swizzled scales, while weights
# from checkpoints are in linear scales.
# So instead of needing to swizzle for cutlass as in modelopt.py,
# we need to unswizzle for trtllm here.
# cutlass and b12x use swizzled scales directly; trtllm needs them unswizzled.
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
a_dtype, a_global_scale, is_sf_swizzled_layout=True, backend=backend
)
Expand Down
2 changes: 2 additions & 0 deletions vllm/config/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def with_default(
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_cutedsl",
"flashinfer_cutedsl_sm12x",
"marlin",
"aiter",
]
Expand Down Expand Up @@ -141,6 +142,7 @@ class KernelConfig:
- "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels
- "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)
- "flashinfer_cutedsl_sm12x": Use FlashInfer CuteDSL fused MoE for SM12x (RTX Pro 6000 / DGX Spark)
- "marlin": Use Marlin kernels (weight-only quantization)
- "aiter": Use AMD AITer kernels (ROCm only)"""

Expand Down
1 change: 1 addition & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,6 +1466,7 @@ def _get_or_set_default() -> str:
"VLLM_NVFP4_GEMM_BACKEND",
None,
[
"flashinfer-b12x",
"flashinfer-cudnn",
"flashinfer-trtllm",
"flashinfer-cutlass",
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/kernels/linear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
FbgemmNvFp4LinearKernel,
)
from vllm.model_executor.kernels.linear.nvfp4.flashinfer import (
FlashInferB12xNvFp4LinearKernel,
FlashInferCudnnNvFp4LinearKernel,
FlashInferCutlassNvFp4LinearKernel,
FlashInferTrtllmNvFp4LinearKernel,
Expand Down Expand Up @@ -258,6 +259,7 @@

_POSSIBLE_NVFP4_KERNELS: dict[PlatformEnum, list[type[NvFp4LinearKernel]]] = {
PlatformEnum.CUDA: [
FlashInferB12xNvFp4LinearKernel,
FlashInferCutlassNvFp4LinearKernel,
CutlassNvFp4LinearKernel,
MarlinNvFp4LinearKernel,
Expand Down Expand Up @@ -589,6 +591,7 @@ def init_wfp8_a16_linear_kernel(

# Maps VLLM_NVFP4_GEMM_BACKEND env var values to kernel classes.
_NVFP4_BACKEND_TO_KERNEL: dict[str, type[NvFp4LinearKernel]] = {
"flashinfer-b12x": FlashInferB12xNvFp4LinearKernel,
"flashinfer-cutlass": FlashInferCutlassNvFp4LinearKernel,
"cutlass": CutlassNvFp4LinearKernel,
"marlin": MarlinNvFp4LinearKernel,
Expand Down Expand Up @@ -766,6 +769,7 @@ def register_linear_kernel(
"CutlassNvFp4LinearKernel",
"EmulationNvFp4LinearKernel",
"FbgemmNvFp4LinearKernel",
"FlashInferB12xNvFp4LinearKernel",
"FlashInferCutlassNvFp4LinearKernel",
"FlashInferTrtllmNvFp4LinearKernel",
"FlashInferCudnnNvFp4LinearKernel",
Expand Down
69 changes: 68 additions & 1 deletion vllm/model_executor/kernels/linear/nvfp4/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
swizzle_blockscale,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer
from vllm.utils.flashinfer import (
flashinfer_scaled_fp4_mm,
has_flashinfer,
has_flashinfer_b12x_gemm,
)

from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig

Expand Down Expand Up @@ -216,3 +220,66 @@ def apply_weights(
if bias is not None:
out = out + bias
return out.view(*output_shape)


class FlashInferB12xNvFp4LinearKernel(NvFp4LinearKernel):
"""NVFP4 GEMM via FlashInfer's b12x CuTe DSL warp-level MMA kernel (SM120+)."""

@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if current_platform.has_device_capability(120) and has_flashinfer_b12x_gemm():
return True, None
return False, "FlashInfer b12x requires SM120+ and FlashInfer with Sm120BlockScaledDenseGemmKernel"

@classmethod
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.weight_scale.data), requires_grad=False
)
padded_weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight.data
)
layer.weight = torch.nn.Parameter(padded_weight, requires_grad=False)
layer.weights_padding_cols = weights_padding_cols

def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
output_size = layer.output_size_per_partition
output_dtype = x.dtype
output_shape = [*x.shape[:-1], output_size]

x_fp4, x_blockscale = scaled_fp4_quant(
x,
layer.input_global_scale_inv,
is_sf_swizzled_layout=True,
backend="b12x",
)

x_fp4 = pad_nvfp4_activation_for_cutlass(
x_fp4, getattr(layer, "weights_padding_cols", 0)
)

out = flashinfer_scaled_fp4_mm(
x_fp4,
layer.weight,
x_blockscale,
layer.weight_scale,
layer.alpha,
output_dtype,
backend="b12x",
)

out = slice_nvfp4_output(out, output_size)

if bias is not None:
out = out + bias
return out.view(*output_shape)
Loading
Loading