Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ RUN --mount=type=cache,target=/opt/uv/cache \
# Install FlashInfer JIT cache (requires CUDA-version-specific index URL)
# https://docs.flashinfer.ai/installation.html
# From versions.json: .flashinfer.version
ARG FLASHINFER_VERSION=0.6.11.post2
ARG FLASHINFER_VERSION=0.6.11.post3
RUN --mount=type=cache,target=/opt/uv/cache \
uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \
--extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
Expand Down
2 changes: 1 addition & 1 deletion docker/versions.json
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
"default": "true"
},
"FLASHINFER_VERSION": {
"default": "0.6.11.post2"
"default": "0.6.11.post3"
},
"GDRCOPY_CUDA_VERSION": {
"default": "12.8"
Expand Down
4 changes: 2 additions & 2 deletions requirements/cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ torchaudio==2.11.0
# These must be updated alongside torch
torchvision==0.26.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.6.11.post2
flashinfer-cubin==0.6.11.post2
flashinfer-python==0.6.11.post3
flashinfer-cubin==0.6.11.post3
apache-tvm-ffi==0.1.9
tilelang==0.1.9
# Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to
Expand Down
233 changes: 207 additions & 26 deletions tests/kernels/moe/test_flashinfer_b12x_moe.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from types import SimpleNamespace

import pytest
import torch

from vllm.platforms import current_platform

if not current_platform.is_device_capability_family(120):
pytest.skip(
reason="FlashInfer CuteDSL SM12x MoE requires SM120 "
"(RTX Pro 6000 / DGX Spark).",
reason="FlashInfer B12x MoE requires SM120 (RTX Pro 6000 / DGX Spark).",
allow_module_level=True,
)

Expand All @@ -18,8 +19,8 @@
if not has_flashinfer_b12x_moe():
pytest.skip(
reason=(
"FlashInfer cute_dsl_fused_moe_nvfp4 / convert_sf_to_mma_layout "
"not available in installed FlashInfer (needs PRs #3051 and #3066)."
"FlashInfer B12xMoEWrapper not available in installed "
"FlashInfer (needs PR #3080)."
),
allow_module_level=True,
)
Expand All @@ -40,7 +41,6 @@
from vllm.model_executor.layers.fused_moe.experts.flashinfer_b12x_moe import (
FlashInferB12xExperts,
)
from vllm.utils.flashinfer import flashinfer_convert_sf_to_mma_layout
from vllm.utils.torch_utils import set_random_seed

# Dimensions chosen to satisfy FP4 alignment requirements (k multiple of 256,
Expand All @@ -59,9 +59,9 @@ def _reorder_gate_up_to_up_gate(
) -> tuple[torch.Tensor, torch.Tensor]:
"""Swap gate and up-projection halves along dim=1 to [up, gate] order.

The SM12x kernel expects weights in [up (w3), gate (w1)] order while the
The B12x kernel expects weights in [up (w3), gate (w1)] order while the
BF16 reference uses [gate (w1), up (w3)]. This replicates the reordering
done at model-load time by ``prepare_nvfp4_moe_layer_for_fi_or_cutlass``.
done at model-load time by the FP4 layer-prep helper.
"""
n = w.shape[1] // 2
return (
Expand All @@ -70,10 +70,27 @@ def _reorder_gate_up_to_up_gate(
)


def _process_b12x_weights(
experts: FlashInferB12xExperts,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_scale_2: torch.Tensor,
w2_scale_2: torch.Tensor,
) -> None:
layer = SimpleNamespace(
w13_weight_scale=w1_scale,
w13_weight_scale_2=w1_scale_2,
w2_weight_scale=w2_scale,
w2_weight_scale_2=w2_scale_2,
)
experts.process_weights_after_loading(layer)


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [8, 16])
@pytest.mark.parametrize("topk", [1, 2, 4])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("activation_precision", ["fp4", "bf16"])
@torch.inference_mode()
def test_flashinfer_b12x_moe(
m: int,
Expand All @@ -82,6 +99,7 @@ def test_flashinfer_b12x_moe(
e: int,
topk: int,
dtype: torch.dtype,
activation_precision: str,
workspace_init,
):
"""Test FlashInferB12xExperts against a BF16 torch reference.
Expand All @@ -91,6 +109,14 @@ def test_flashinfer_b12x_moe(
correctness against ``torch_moe`` using generous tolerances to account
for the internal FP4 quantization of activations and weights.

Two activation precisions are exercised:
* ``fp4`` (W4A4, modelopt path): activations re-quantized to FP4
before each GEMM. Set ``a1_gscale`` / ``a2_gscale`` on the quant
config (any non-None tensor enables this branch).
* ``bf16`` (W4A16, compressed-tensors `nvfp4-pack-quantized` path):
activations kept in BF16; the kernel dequantizes weights instead.
``a1_gscale`` / ``a2_gscale`` are left as ``None``.

Scale convention
----------------
The SM12x kernel uses ``w1_alpha`` as *both* the activation-quantisation
Expand Down Expand Up @@ -153,13 +179,26 @@ def test_flashinfer_b12x_moe(
# All per-expert alphas are 1.0 (global_scale = 1.0, no compensation).
ones_e = torch.ones(e, device="cuda", dtype=torch.float32)

# W4A4: pass identity activation global scales; B12x re-quantizes
# activations to FP4. W4A16: leave them as None; activations stay
# in BF16 and the kernel dequantizes weights instead.
if activation_precision == "fp4":
a1_gscale = ones_e
a2_gscale = ones_e
source_format = "modelopt"
else:
a1_gscale = None
a2_gscale = None
source_format = "compressed_tensors"

quant_config = nvfp4_moe_quant_config(
g1_alphas=ones_e,
g2_alphas=ones_e,
a1_gscale=ones_e,
a2_gscale=ones_e,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
source_format=source_format,
)

moe_config = make_dummy_moe_config(
Expand All @@ -174,22 +213,15 @@ def test_flashinfer_b12x_moe(
moe_config=moe_config,
quant_config=quant_config,
)
# In production, process_weights_after_loading computes these after
# normalizing block scales. In the test the scales are already in final
# form (global_scale=1.0), so we compute the MMA layouts directly.
num_experts_w1, m1, k1_sf = w1_blockscale.shape
experts.w1_sf_mma = flashinfer_convert_sf_to_mma_layout(
w1_blockscale.reshape(num_experts_w1 * m1, k1_sf),
m=m1,
k=k1_sf * 16,
num_groups=num_experts_w1,
)
num_experts_w2, m2, k2_sf = w2_blockscale.shape
experts.w2_sf_mma = flashinfer_convert_sf_to_mma_layout(
w2_blockscale.reshape(num_experts_w2 * m2, k2_sf),
m=m2,
k=k2_sf * 16,
num_groups=num_experts_w2,
# Cross-check the precision/source-format auto-detection.
assert experts.activation_precision == activation_precision
assert experts.source_format == source_format
_process_b12x_weights(
experts,
w1_blockscale,
w2_blockscale,
ones_e,
ones_e,
)

kernel = mk.FusedMoEKernel(
Expand Down Expand Up @@ -224,5 +256,154 @@ def test_flashinfer_b12x_moe(
torch.testing.assert_close(sm12x_output, torch_output, atol=2e-1, rtol=2e-1)


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [8, 16])
@pytest.mark.parametrize("topk", [1, 2, 4])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("activation_precision", ["fp4", "bf16"])
@torch.inference_mode()
def test_flashinfer_b12x_moe_relu2(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
activation_precision: str,
workspace_init,
):
"""Test FlashInferB12xExperts with ReLU2 (non-gated) activation.

ReLU2 is used by Nemotron-H style models. Unlike the gated SiLU
path, w1 has shape [E, N, K] (not [E, 2N, K]) and the activation
is relu(x)^2 without a gate/up split.

Parametrized over ``activation_precision`` (``"fp4"`` for W4A4 /
modelopt and ``"bf16"`` for W4A16 / compressed-tensors); Nemotron-H 3.5
is the production user of the ReLU2 + W4A16 combination.
"""
set_random_seed(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10

# Non-gated: w1 shape is (e, n, k), not (e, 2n, k).
w1_bf16 = torch.randn((e, n, k), device="cuda", dtype=dtype) / 15
w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 15

gs = torch.ones(1, device="cuda", dtype=torch.float32)
sf_vec_size = 16

# W1: no gate/up reordering for non-gated.
w1_flat = w1_bf16.reshape(e * n, k)
w1_q_flat, w1_sf_flat = fp4_quantize(
w1_flat,
global_scale=gs,
sf_vec_size=sf_vec_size,
is_sf_swizzled_layout=True,
)
w1_q = w1_q_flat.view(e, n, k // 2)
w1_blockscale = w1_sf_flat.view(e, n, w1_sf_flat.shape[1])

w2_flat = w2_bf16.reshape(e * k, n)
w2_q_flat, w2_sf_flat = fp4_quantize(
w2_flat,
global_scale=gs,
sf_vec_size=sf_vec_size,
is_sf_swizzled_layout=True,
)
w2_q = w2_q_flat.view(e, k, n // 2)
w2_blockscale = w2_sf_flat.view(e, k, w2_sf_flat.shape[1])

ones_e = torch.ones(e, device="cuda", dtype=torch.float32)

if activation_precision == "fp4":
a1_gscale = ones_e
a2_gscale = ones_e
source_format = "modelopt"
else:
a1_gscale = None
a2_gscale = None
source_format = "compressed_tensors"

quant_config = nvfp4_moe_quant_config(
g1_alphas=ones_e,
g2_alphas=ones_e,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
source_format=source_format,
)

moe_config = make_dummy_moe_config(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
in_dtype=dtype,
activation=MoEActivation.RELU2_NO_MUL,
is_act_and_mul=False,
)

experts = FlashInferB12xExperts(
moe_config=moe_config,
quant_config=quant_config,
)
assert experts.activation_precision == activation_precision
assert experts.source_format == source_format
_process_b12x_weights(
experts,
w1_blockscale,
w2_blockscale,
ones_e,
ones_e,
)

kernel = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
experts,
inplace=False,
)

score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)

b12x_output = kernel.apply(
hidden_states=a,
w1=w1_q,
w2=w2_q,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_num_experts=e,
activation=MoEActivation.RELU2_NO_MUL,
apply_router_weight_on_input=False,
expert_map=None,
)

torch_output = torch_moe(
a,
w1_bf16,
w2_bf16,
score,
topk,
activation=MoEActivation.RELU2_NO_MUL,
)

torch.testing.assert_close(
b12x_output,
torch_output,
atol=2e-1,
rtol=2e-1,
)


if __name__ == "__main__":
test_flashinfer_b12x_moe(16, 128, 256, 8, 2, torch.bfloat16)
test_flashinfer_b12x_moe(16, 128, 256, 8, 2, torch.bfloat16, "fp4")
test_flashinfer_b12x_moe(16, 128, 256, 8, 2, torch.bfloat16, "bf16")
Loading
Loading