Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 158 additions & 22 deletions tests/kernels/moe/test_flashinfer_b12x_moe.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from types import SimpleNamespace

import pytest
import torch

from vllm.platforms import current_platform

if not current_platform.is_device_capability_family(120):
pytest.skip(
reason="FlashInfer CuteDSL SM12x MoE requires SM120 "
"(RTX Pro 6000 / DGX Spark).",
reason="FlashInfer B12x MoE requires SM120 (RTX Pro 6000 / DGX Spark).",
allow_module_level=True,
)

Expand All @@ -18,8 +19,8 @@
if not has_flashinfer_b12x_moe():
pytest.skip(
reason=(
"FlashInfer cute_dsl_fused_moe_nvfp4 / convert_sf_to_mma_layout "
"not available in installed FlashInfer (needs PRs #3051 and #3066)."
"FlashInfer B12xMoEWrapper not available in installed "
"FlashInfer (needs PR #3080)."
),
allow_module_level=True,
)
Expand All @@ -40,7 +41,6 @@
from vllm.model_executor.layers.fused_moe.experts.flashinfer_b12x_moe import (
FlashInferB12xExperts,
)
from vllm.utils.flashinfer import flashinfer_convert_sf_to_mma_layout
from vllm.utils.torch_utils import set_random_seed

# Dimensions chosen to satisfy FP4 alignment requirements (k multiple of 256,
Expand All @@ -59,7 +59,7 @@ def _reorder_gate_up_to_up_gate(
) -> tuple[torch.Tensor, torch.Tensor]:
"""Swap gate and up-projection halves along dim=1 to [up, gate] order.

The SM12x kernel expects weights in [up (w3), gate (w1)] order while the
The B12x kernel expects weights in [up (w3), gate (w1)] order while the
BF16 reference uses [gate (w1), up (w3)]. This replicates the reordering
done at model-load time by ``prepare_nvfp4_moe_layer_for_fi_or_cutlass``.
"""
Expand All @@ -70,6 +70,22 @@ def _reorder_gate_up_to_up_gate(
)


def _process_b12x_weights(
experts: FlashInferB12xExperts,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_scale_2: torch.Tensor,
w2_scale_2: torch.Tensor,
) -> None:
layer = SimpleNamespace(
w13_weight_scale=w1_scale,
w13_weight_scale_2=w1_scale_2,
w2_weight_scale=w2_scale,
w2_weight_scale_2=w2_scale_2,
)
experts.process_weights_after_loading(layer)


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [8, 16])
@pytest.mark.parametrize("topk", [1, 2, 4])
Expand Down Expand Up @@ -174,22 +190,12 @@ def test_flashinfer_b12x_moe(
moe_config=moe_config,
quant_config=quant_config,
)
# In production, process_weights_after_loading computes these after
# normalizing block scales. In the test the scales are already in final
# form (global_scale=1.0), so we compute the MMA layouts directly.
num_experts_w1, m1, k1_sf = w1_blockscale.shape
experts.w1_sf_mma = flashinfer_convert_sf_to_mma_layout(
w1_blockscale.reshape(num_experts_w1 * m1, k1_sf),
m=m1,
k=k1_sf * 16,
num_groups=num_experts_w1,
)
num_experts_w2, m2, k2_sf = w2_blockscale.shape
experts.w2_sf_mma = flashinfer_convert_sf_to_mma_layout(
w2_blockscale.reshape(num_experts_w2 * m2, k2_sf),
m=m2,
k=k2_sf * 16,
num_groups=num_experts_w2,
_process_b12x_weights(
experts,
w1_blockscale,
w2_blockscale,
ones_e,
ones_e,
)

kernel = mk.FusedMoEKernel(
Expand Down Expand Up @@ -225,5 +231,135 @@ def test_flashinfer_b12x_moe(
torch.testing.assert_close(sm12x_output, torch_output, atol=2e-1, rtol=2e-1)


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [8, 16])
@pytest.mark.parametrize("topk", [1, 2, 4])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.inference_mode()
def test_flashinfer_b12x_moe_relu2(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
workspace_init,
):
"""Test FlashInferB12xExperts with ReLU2 (non-gated) activation.

ReLU2 is used by Nemotron-H style models. Unlike the gated SiLU
path, w1 has shape [E, N, K] (not [E, 2N, K]) and the activation
is relu(x)^2 without a gate/up split.
"""
set_random_seed(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10

# Non-gated: w1 shape is (e, n, k), not (e, 2n, k).
w1_bf16 = torch.randn((e, n, k), device="cuda", dtype=dtype) / 15
w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 15

gs = torch.ones(1, device="cuda", dtype=torch.float32)
sf_vec_size = 16

# W1: no gate/up reordering for non-gated.
w1_flat = w1_bf16.reshape(e * n, k)
w1_q_flat, w1_sf_flat = fp4_quantize(
w1_flat,
global_scale=gs,
sf_vec_size=sf_vec_size,
is_sf_swizzled_layout=True,
)
w1_q = w1_q_flat.view(e, n, k // 2)
w1_blockscale = w1_sf_flat.view(e, n, w1_sf_flat.shape[1])

w2_flat = w2_bf16.reshape(e * k, n)
w2_q_flat, w2_sf_flat = fp4_quantize(
w2_flat,
global_scale=gs,
sf_vec_size=sf_vec_size,
is_sf_swizzled_layout=True,
)
w2_q = w2_q_flat.view(e, k, n // 2)
w2_blockscale = w2_sf_flat.view(e, k, w2_sf_flat.shape[1])

ones_e = torch.ones(e, device="cuda", dtype=torch.float32)

quant_config = nvfp4_moe_quant_config(
g1_alphas=ones_e,
g2_alphas=ones_e,
a1_gscale=ones_e,
a2_gscale=ones_e,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
)

moe_config = make_dummy_moe_config(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
in_dtype=dtype,
activation=MoEActivation.RELU2_NO_MUL,
is_act_and_mul=False,
)

experts = FlashInferB12xExperts(
moe_config=moe_config,
quant_config=quant_config,
)
_process_b12x_weights(
experts,
w1_blockscale,
w2_blockscale,
ones_e,
ones_e,
)

kernel = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
experts,
inplace=False,
)

score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)

b12x_output = kernel.apply(
hidden_states=a,
w1=w1_q,
w2=w2_q,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_num_experts=e,
activation=MoEActivation.RELU2_NO_MUL,
apply_router_weight_on_input=False,
expert_map=None,
)

torch_output = torch_moe(
a,
w1_bf16,
w2_bf16,
score,
topk,
activation=MoEActivation.RELU2_NO_MUL,
)

torch.testing.assert_close(
b12x_output,
torch_output,
atol=2e-1,
rtol=2e-1,
)


if __name__ == "__main__":
test_flashinfer_b12x_moe(16, 128, 256, 8, 2, torch.bfloat16)
5 changes: 4 additions & 1 deletion tests/kernels/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def make_dummy_moe_config(
hidden_dim: int = 1,
intermediate_size_per_partition: int = 1,
in_dtype: torch.dtype = torch.bfloat16,
activation: MoEActivation = MoEActivation.SILU,
is_act_and_mul: bool = True,
) -> FusedMoEConfig:
"""
This is a dummy config for the mk constructor interface
Expand All @@ -69,7 +71,8 @@ def make_dummy_moe_config(
num_local_experts=num_experts,
num_logical_experts=num_experts,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation=MoEActivation.SILU,
activation=activation,
is_act_and_mul=is_act_and_mul,
in_dtype=in_dtype,
device="cuda",
routing_method=RoutingMethodType.TopK,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
)
from vllm.platforms import current_platform

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The Any type hint is used in the __init__ method (line 82), but it has not been imported from the typing module. This will cause a NameError at runtime when the class is instantiated. Please add the missing import.

from typing import Any
from vllm.platforms import current_platform

from vllm.utils.flashinfer import (
flashinfer_b12x_fused_moe,
flashinfer_convert_sf_to_mma_layout,
has_flashinfer_b12x_moe,
)
Expand All @@ -42,6 +41,11 @@
Only NVFP4 (kNvfp4Static/kNvfp4Dynamic) quantization is supported.
"""

_ACTIVATION_MAP: dict[MoEActivation, str] = {
MoEActivation.SILU: "silu",
MoEActivation.RELU2_NO_MUL: "relu2",
}

def __init__(
self,
moe_config: FusedMoEConfig,
Expand All @@ -55,6 +59,30 @@
self.num_local_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank

# Shape params for B12xMoEWrapper construction.
self.global_num_experts = moe_config.num_experts
self.topk = moe_config.experts_per_token
self.hidden_dim = moe_config.hidden_dim
self.intermediate_size_per_partition = (
moe_config.intermediate_size_per_partition
)
self.max_num_tokens = moe_config.max_num_tokens
self.local_expert_offset = self.ep_rank * self.num_local_experts

activation = moe_config.activation
if activation not in self._ACTIVATION_MAP:
raise ValueError(
f"FlashInferB12xExperts does not support "
f"activation {activation!r}. "
f"Supported: {list(self._ACTIVATION_MAP.keys())}"
)
self._activation_str = self._ACTIVATION_MAP[activation]

# Lazily created on first apply() call.
self._wrapper: Any | None = None

Check failure on line 82 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "Any" is not defined [name-defined]

Check failure on line 82 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "Any" is not defined [name-defined]

Check failure on line 82 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "Any" is not defined [name-defined]

Check failure on line 82 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "Any" is not defined [name-defined]

Check failure on line 82 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py:82:24: F821 Undefined name `Any`

Check failure on line 82 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "Any" is not defined [name-defined]

Check failure on line 82 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "Any" is not defined [name-defined]

Check failure on line 82 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "Any" is not defined [name-defined]

Check failure on line 82 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "Any" is not defined [name-defined]

Check failure on line 82 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py:82:24: F821 Undefined name `Any`
self.w1_sf_mma: torch.Tensor | None = None
self.w2_sf_mma: torch.Tensor | None = None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Normalise block scales to absorb the per-expert weight global scale
# (w_gs). vLLM's NVFP4 convention stores:
Expand Down Expand Up @@ -124,7 +152,7 @@

@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
return True

@staticmethod
def _supports_quant_scheme(
Expand All @@ -135,7 +163,7 @@

@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation == MoEActivation.SILU
return activation in (MoEActivation.SILU, MoEActivation.RELU2_NO_MUL)

@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
Expand Down Expand Up @@ -167,13 +195,29 @@

@property
def expects_unquantized_inputs(self) -> bool:
# b12x_fused_moe expects BF16 hidden states and performs its own FP4
# B12xMoEWrapper expects BF16 hidden states and performs its own FP4
# quantization internally. Returning True prevents the modular kernel
# from pre-quantizing activations, which would produce an FP4-packed
# tensor with size(-1)=k//2 and break the scale-factor conversion that
# expects size(-1)=k.
# from pre-quantizing activations.
return True

def _ensure_wrapper(self) -> None:
"""Lazily create B12xMoEWrapper on first use."""
if self._wrapper is not None:
return

from flashinfer.fused_moe import B12xMoEWrapper

self._wrapper = B12xMoEWrapper(
num_experts=self.global_num_experts,
top_k=self.topk,
hidden_size=self.hidden_dim,
intermediate_size=self.intermediate_size_per_partition,
use_cuda_graph=True,
max_num_tokens=self.max_num_tokens,
num_local_experts=self.num_local_experts,
activation=self._activation_str,
)

def apply(
self,
output: torch.Tensor,
Expand Down Expand Up @@ -201,23 +245,22 @@
assert self.a2_gscale is not None, (
"a2_gscale must not be None for FlashInferB12xExperts"
)
assert self.w1_sf_mma is not None and self.w2_sf_mma is not None, (
"process_weights_after_loading must run before FlashInferB12xExperts.apply"
)

top_k = topk_ids.shape[1]
self._ensure_wrapper()

flashinfer_b12x_fused_moe(
result = self._wrapper.run(

Check failure on line 254 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Any | None" has no attribute "run" [union-attr]

Check failure on line 254 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Any | None" has no attribute "run" [union-attr]

Check failure on line 254 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Any | None" has no attribute "run" [union-attr]

Check failure on line 254 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Any | None" has no attribute "run" [union-attr]

Check failure on line 254 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Any | None" has no attribute "run" [union-attr]

Check failure on line 254 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Any | None" has no attribute "run" [union-attr]

Check failure on line 254 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Any | None" has no attribute "run" [union-attr]

Check failure on line 254 in vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Any | None" has no attribute "run" [union-attr]
x=hidden_states,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
w1_weight=w1,
w1_weight_sf=self.w1_sf_mma,
w1_alpha=self.g1_alphas,
fc2_input_scale=self.a2_gscale,
w2_weight=w2,
w2_weight_sf=self.w2_sf_mma,
w2_alpha=self.g2_alphas,
num_experts=global_num_experts,
top_k=top_k,
num_local_experts=self.num_local_experts,
output_dtype=self.out_dtype,
output=output,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
)
output.copy_(result)
Comment on lines +254 to +266

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The B12xMoEWrapper.run method supports an out parameter, which allows the kernel to write results directly into the provided buffer. Using out=output avoids an extra tensor allocation inside run and a subsequent copy_ operation, which is significantly more efficient for the inference hot path.

Suggested change
result = self._wrapper.run(
x=hidden_states,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
w1_weight=w1,
w1_weight_sf=self.w1_sf_mma,
w1_alpha=self.g1_alphas,
fc2_input_scale=self.a2_gscale,
w2_weight=w2,
w2_weight_sf=self.w2_sf_mma,
w2_alpha=self.g2_alphas,
num_experts=global_num_experts,
top_k=top_k,
num_local_experts=self.num_local_experts,
output_dtype=self.out_dtype,
output=output,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
)
output.copy_(result)
self._wrapper.run(
x=hidden_states,
w1_weight=w1,
w1_weight_sf=self.w1_sf_mma,
w1_alpha=self.g1_alphas,
fc2_input_scale=self.a2_gscale,
w2_weight=w2,
w2_weight_sf=self.w2_sf_mma,
w2_alpha=self.g2_alphas,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
out=output,
)

Loading