Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 158 additions & 22 deletions tests/kernels/moe/test_flashinfer_b12x_moe.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from types import SimpleNamespace

import pytest
import torch

from vllm.platforms import current_platform

if not current_platform.is_device_capability_family(120):
pytest.skip(
reason="FlashInfer CuteDSL SM12x MoE requires SM120 "
"(RTX Pro 6000 / DGX Spark).",
reason="FlashInfer B12x MoE requires SM120 (RTX Pro 6000 / DGX Spark).",
allow_module_level=True,
)

Expand All @@ -18,8 +19,8 @@
if not has_flashinfer_b12x_moe():
pytest.skip(
reason=(
"FlashInfer cute_dsl_fused_moe_nvfp4 / convert_sf_to_mma_layout "
"not available in installed FlashInfer (needs PRs #3051 and #3066)."
"FlashInfer B12xMoEWrapper not available in installed "
"FlashInfer (needs PR #3080)."
),
allow_module_level=True,
)
Expand All @@ -40,7 +41,6 @@
from vllm.model_executor.layers.fused_moe.experts.flashinfer_b12x_moe import (
FlashInferB12xExperts,
)
from vllm.utils.flashinfer import flashinfer_convert_sf_to_mma_layout
from vllm.utils.torch_utils import set_random_seed

# Dimensions chosen to satisfy FP4 alignment requirements (k multiple of 256,
Expand All @@ -59,7 +59,7 @@ def _reorder_gate_up_to_up_gate(
) -> tuple[torch.Tensor, torch.Tensor]:
"""Swap gate and up-projection halves along dim=1 to [up, gate] order.

The SM12x kernel expects weights in [up (w3), gate (w1)] order while the
The B12x kernel expects weights in [up (w3), gate (w1)] order while the
BF16 reference uses [gate (w1), up (w3)]. This replicates the reordering
done at model-load time by ``prepare_nvfp4_moe_layer_for_fi_or_cutlass``.
"""
Expand All @@ -70,6 +70,22 @@ def _reorder_gate_up_to_up_gate(
)


def _process_b12x_weights(
experts: FlashInferB12xExperts,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_scale_2: torch.Tensor,
w2_scale_2: torch.Tensor,
) -> None:
layer = SimpleNamespace(
w13_weight_scale=w1_scale,
w13_weight_scale_2=w1_scale_2,
w2_weight_scale=w2_scale,
w2_weight_scale_2=w2_scale_2,
)
experts.process_weights_after_loading(layer)


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [8, 16])
@pytest.mark.parametrize("topk", [1, 2, 4])
Expand Down Expand Up @@ -174,22 +190,12 @@ def test_flashinfer_b12x_moe(
moe_config=moe_config,
quant_config=quant_config,
)
# In production, process_weights_after_loading computes these after
# normalizing block scales. In the test the scales are already in final
# form (global_scale=1.0), so we compute the MMA layouts directly.
num_experts_w1, m1, k1_sf = w1_blockscale.shape
experts.w1_sf_mma = flashinfer_convert_sf_to_mma_layout(
w1_blockscale.reshape(num_experts_w1 * m1, k1_sf),
m=m1,
k=k1_sf * 16,
num_groups=num_experts_w1,
)
num_experts_w2, m2, k2_sf = w2_blockscale.shape
experts.w2_sf_mma = flashinfer_convert_sf_to_mma_layout(
w2_blockscale.reshape(num_experts_w2 * m2, k2_sf),
m=m2,
k=k2_sf * 16,
num_groups=num_experts_w2,
_process_b12x_weights(
experts,
w1_blockscale,
w2_blockscale,
ones_e,
ones_e,
)

kernel = mk.FusedMoEKernel(
Expand Down Expand Up @@ -225,5 +231,135 @@ def test_flashinfer_b12x_moe(
torch.testing.assert_close(sm12x_output, torch_output, atol=2e-1, rtol=2e-1)


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [8, 16])
@pytest.mark.parametrize("topk", [1, 2, 4])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.inference_mode()
def test_flashinfer_b12x_moe_relu2(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
workspace_init,
):
"""Test FlashInferB12xExperts with ReLU2 (non-gated) activation.

ReLU2 is used by Nemotron-H style models. Unlike the gated SiLU
path, w1 has shape [E, N, K] (not [E, 2N, K]) and the activation
is relu(x)^2 without a gate/up split.
"""
set_random_seed(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10

# Non-gated: w1 shape is (e, n, k), not (e, 2n, k).
w1_bf16 = torch.randn((e, n, k), device="cuda", dtype=dtype) / 15
w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 15

gs = torch.ones(1, device="cuda", dtype=torch.float32)
sf_vec_size = 16

# W1: no gate/up reordering for non-gated.
w1_flat = w1_bf16.reshape(e * n, k)
w1_q_flat, w1_sf_flat = fp4_quantize(
w1_flat,
global_scale=gs,
sf_vec_size=sf_vec_size,
is_sf_swizzled_layout=True,
)
w1_q = w1_q_flat.view(e, n, k // 2)
w1_blockscale = w1_sf_flat.view(e, n, w1_sf_flat.shape[1])

w2_flat = w2_bf16.reshape(e * k, n)
w2_q_flat, w2_sf_flat = fp4_quantize(
w2_flat,
global_scale=gs,
sf_vec_size=sf_vec_size,
is_sf_swizzled_layout=True,
)
w2_q = w2_q_flat.view(e, k, n // 2)
w2_blockscale = w2_sf_flat.view(e, k, w2_sf_flat.shape[1])

ones_e = torch.ones(e, device="cuda", dtype=torch.float32)

quant_config = nvfp4_moe_quant_config(
g1_alphas=ones_e,
g2_alphas=ones_e,
a1_gscale=ones_e,
a2_gscale=ones_e,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
)

moe_config = make_dummy_moe_config(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
in_dtype=dtype,
activation=MoEActivation.RELU2_NO_MUL,
is_act_and_mul=False,
)

experts = FlashInferB12xExperts(
moe_config=moe_config,
quant_config=quant_config,
)
_process_b12x_weights(
experts,
w1_blockscale,
w2_blockscale,
ones_e,
ones_e,
)

kernel = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
experts,
inplace=False,
)

score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)

b12x_output = kernel.apply(
hidden_states=a,
w1=w1_q,
w2=w2_q,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_num_experts=e,
activation=MoEActivation.RELU2_NO_MUL,
apply_router_weight_on_input=False,
expert_map=None,
)

torch_output = torch_moe(
a,
w1_bf16,
w2_bf16,
score,
topk,
activation=MoEActivation.RELU2_NO_MUL,
)

torch.testing.assert_close(
b12x_output,
torch_output,
atol=2e-1,
rtol=2e-1,
)


if __name__ == "__main__":
test_flashinfer_b12x_moe(16, 128, 256, 8, 2, torch.bfloat16)
5 changes: 4 additions & 1 deletion tests/kernels/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def make_dummy_moe_config(
hidden_dim: int = 1,
intermediate_size_per_partition: int = 1,
in_dtype: torch.dtype = torch.bfloat16,
activation: MoEActivation = MoEActivation.SILU,
is_act_and_mul: bool = True,
) -> FusedMoEConfig:
"""
This is a dummy config for the mk constructor interface
Expand All @@ -69,7 +71,8 @@ def make_dummy_moe_config(
num_local_experts=num_experts,
num_logical_experts=num_experts,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation=MoEActivation.SILU,
activation=activation,
is_act_and_mul=is_act_and_mul,
in_dtype=in_dtype,
device="cuda",
routing_method=RoutingMethodType.TopK,
Expand Down
8 changes: 8 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@
VLLM_FLASHINFER_AUTOTUNE_CACHE_DIR: str | None = None
VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "auto"
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD: int = 0
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
Expand Down Expand Up @@ -1499,6 +1500,13 @@ def _resolve_rust_frontend_path() -> str | None:
"VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int(
os.getenv("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", str(394 * 1024 * 1024))
),
# When >0 and num_tokens >= threshold, the B12x SM12x MoE wrapper routes
# to cutlass_fused_moe (prefill) instead of the b12x kernels (decode).
# 0 (default) keeps pure b12x dispatch. Requires a FlashInfer build that
# exposes the `cutlass_prefill_threshold` kwarg on B12xMoEWrapper.
"VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD": lambda: int(
os.getenv("VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD", "0")
),
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
# the blockscale tensor of activations NVFP4 Quantization.
Expand Down
Loading
Loading