Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --enable-expert-parallel"
env:
VLLM_USE_FLASHINFER_MOE_FP16: "1"
VLLM_FLASHINFER_MOE_BACKEND: "throughput"

Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --enable-expert-parallel"
env:
VLLM_USE_FLASHINFER_MOE_FP16: "1"
VLLM_FLASHINFER_MOE_BACKEND: "throughput"
41 changes: 41 additions & 0 deletions tests/kernels/moe/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,44 @@ def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
torch.testing.assert_close(
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
)


@pytest.mark.parametrize(
"num_experts,intermediate,hidden",
[
(8, 2048, 1536),
(64, 4096, 4096),
],
)
def test_convert_moe_weights_to_flashinfer_trtllm_block_layout(
num_experts, intermediate, hidden
):
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
convert_moe_weights_to_flashinfer_trtllm_block_layout,
)

w13 = torch.randn(
(num_experts, 2 * intermediate, hidden), dtype=torch.bfloat16, device="cuda"
)
w2 = torch.randn(
(num_experts, hidden, intermediate), dtype=torch.bfloat16, device="cuda"
)

cache: dict[torch.Size, torch.Tensor] = {}
w13_converted, w2_converted = convert_moe_weights_to_flashinfer_trtllm_block_layout(
cache, w13, w2
)

assert w13_converted.ndim == 4, (
f"Expected 4D tensor, got shape {w13_converted.shape}"
)
assert w2_converted.ndim == 4, f"Expected 4D tensor, got shape {w2_converted.shape}"

assert w13_converted.numel() == w13.numel(), "W13 element count should be preserved"
assert w2_converted.numel() == w2.numel(), "W2 element count should be preserved"

assert w13_converted.dtype == torch.bfloat16
assert w2_converted.dtype == torch.bfloat16

assert w13_converted.shape[0] == num_experts
assert w2_converted.shape[0] == num_experts
100 changes: 100 additions & 0 deletions tests/kernels/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,3 +1558,103 @@ def run(
marlin_output = br.run(a, kwargs)

torch.testing.assert_close(marlin_output, ref_marlin_output, atol=1e-3, rtol=0)


@pytest.mark.parametrize("m,n,k", [(32, 1024, 1024)])
@pytest.mark.parametrize("e,topk", [(8, 2)])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.skipif(
not current_platform.is_device_capability_family(100),
reason="TRTLLM backend test only runs on Blackwell GPUs (SM10x).",
)
def test_unquantized_bf16_flashinfer_trtllm_backend(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
monkeypatch,
workspace_init,
):
"""
Test BF16 unquantized MoE with FlashInfer TRTLLM backend.
"""
set_random_seed(7)

monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")

from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
UnquantizedMoeBackend,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)

# Setup test data
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
router_logits = torch.randn((m, e), device="cuda", dtype=dtype)

moe_config = FusedMoEConfig(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
num_local_experts=e,
activation="silu",
device="cuda",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=dtype,
is_act_and_mul=True,
routing_method=RoutingMethodType.Renormalize,
max_num_tokens=m,
)

with set_current_vllm_config(vllm_config):
quant_method = UnquantizedFusedMoEMethod(moe_config)

# Verify TRTLLM backend was selected
assert (
quant_method.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
), f"Expected FLASHINFER_TRTLLM backend, got {quant_method.unquantized_backend}"

# Verify it's using monolithic path
assert quant_method.is_monolithic, (
"FLASHINFER_TRTLLM backend should use monolithic forward"
)
layer = torch.nn.Module()
layer.w13_weight = Parameter(w1.clone(), requires_grad=False)
layer.w2_weight = Parameter(w2.clone(), requires_grad=False)
layer.global_num_experts = e
layer.local_num_experts = e
layer.top_k = topk
layer.num_expert_group = 1
layer.topk_group = 1
layer.intermediate_size_per_partition = n
layer.ep_rank = 0
layer.activation = "silu"
layer.e_score_correction_bias = None
layer.routing_method_type = RoutingMethodType.Renormalize

quant_method.process_weights_after_loading(layer)

trtllm_output = quant_method.forward_monolithic_cuda(
layer=layer,
x=a,
router_logits=router_logits,
)

# Compute torch baseline
w1_original = w1.clone()
w2_original = w2.clone()
baseline_output = torch_moe(a, w1_original, w2_original, router_logits, topk)

close = torch.isclose(trtllm_output, baseline_output, atol=1e-1, rtol=0.85)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these tolerances a bit large for bf16?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert close.float().mean() > 0.925
132 changes: 132 additions & 0 deletions tests/kernels/moe/test_unquantized_backend_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import patch

import pytest

from tests.kernels.moe.utils import make_dummy_moe_config
from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
UnquantizedMoeBackend,
select_unquantized_moe_backend,
)


@pytest.mark.parametrize(
"platform_method,expected_backend",
[
("is_cuda", UnquantizedMoeBackend.TRITON), # Default CUDA without FlashInfer
("is_rocm", UnquantizedMoeBackend.TRITON),
("is_cpu", UnquantizedMoeBackend.CPU),
("is_xpu", UnquantizedMoeBackend.XPU),
("is_tpu", UnquantizedMoeBackend.TPU),
("is_out_of_tree", UnquantizedMoeBackend.OOT),
],
)
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
return_value=False,
)
def test_select_default_backend_by_platform(
mock_has_flashinfer,
monkeypatch,
platform_method,
expected_backend,
):
"""Test backend selection for different platforms."""
with patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
) as mock_platform:
# Set all platform checks to False
mock_platform.is_cuda.return_value = False
mock_platform.is_rocm.return_value = False
mock_platform.is_cpu.return_value = False
mock_platform.is_xpu.return_value = False
mock_platform.is_tpu.return_value = False
mock_platform.is_out_of_tree.return_value = False

# Set only the specified platform to True
getattr(mock_platform, platform_method).return_value = True

moe_config = make_dummy_moe_config()
selected_backend = select_unquantized_moe_backend(
moe_config=moe_config,
use_ep=False,
use_dp=False,
)

assert selected_backend == expected_backend


@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
return_value=True,
)
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.is_supported_config_trtllm_bf16",
return_value=(True, None),
)
def test_select_cuda_flashinfer_trtllm_backend(
mock_has_flashinfer, mock_is_supported_trtllm, monkeypatch
):
"""Test CUDA backend selection when FlashInfer TRTLLM is available and enabled."""
with patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
) as mock_platform:
# Set as CUDA platform
mock_platform.is_cuda.return_value = True
mock_platform.is_rocm.return_value = False
mock_platform.is_cpu.return_value = False
mock_platform.is_xpu.return_value = False
mock_platform.is_tpu.return_value = False
mock_platform.is_out_of_tree.return_value = False

monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")

moe_config = make_dummy_moe_config()

selected_backend = select_unquantized_moe_backend(
moe_config=moe_config,
use_ep=True,
use_dp=False,
)

assert selected_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM


@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
return_value=True,
)
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.is_supported_config_trtllm_bf16",
return_value=(False, None),
)
def test_select_cuda_flashinfer_cutlass_backend(
mock_has_flashinfer, mock_is_supported_trtllm, monkeypatch
):
"""Test CUDA backend selection when FlashInfer TRTLLM is not available
and FlashInfer CUTLASS is available."""
with patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
) as mock_platform:
# Set as CUDA platform with Hopper capability
mock_platform.is_cuda.return_value = True
mock_platform.is_rocm.return_value = False
mock_platform.is_cpu.return_value = False
mock_platform.is_xpu.return_value = False
mock_platform.is_tpu.return_value = False
mock_platform.is_out_of_tree.return_value = False
mock_platform.has_device_capability.return_value = True # SM90+

# Enable FlashInfer via env var
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")

moe_config = make_dummy_moe_config()

selected_backend = select_unquantized_moe_backend(
moe_config=moe_config,
use_ep=True, # CUTLASS requires EP
use_dp=False, # CUTLASS doesn't support DP
)

assert selected_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS
8 changes: 8 additions & 0 deletions tests/quantization/test_blackwell_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,11 @@ def test_gptoss_eager(monkeypatch: pytest.MonkeyPatch):
hf_overrides=HF_OVERRIDE_TEXT,
extra_args=["--enforce-eager"],
)


## Qwen3 Next ##


def test_qwen3_next_bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
Comment thread
pavanimajety marked this conversation as resolved.
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")
can_initialize("Qwen/Qwen3-Next-80B-A3B-Instruct", hf_overrides=HF_OVERRIDE_TEXT)
13 changes: 12 additions & 1 deletion vllm/model_executor/layers/fused_moe/oracle/unquantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def _make_log_backend(backend: UnquantizedMoeBackend):
activation_format=activation_format,
)
flashinfer_trtllm_moe_enabled = (
has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_FP16 and trtllm_supported
has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_FP16
and trtllm_supported
and envs.VLLM_FLASHINFER_MOE_BACKEND == "latency"
)
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
flashinfer_cutlass_moe_enabled = (
Expand All @@ -98,11 +101,19 @@ def _make_log_backend(backend: UnquantizedMoeBackend):
backend = UnquantizedMoeBackend.FLASHINFER_TRTLLM
elif flashinfer_cutlass_moe_enabled:
backend = UnquantizedMoeBackend.FLASHINFER_CUTLASS
if trtllm_supported:
logger.info_once(
"FlashInfer TRTLLM MoE is available but not enabled, "
"consider setting VLLM_FLASHINFER_MOE_BACKEND=latency "
"to enable it for better performance.",
scope="local",
)
else:
if not envs.VLLM_USE_FLASHINFER_MOE_FP16 and trtllm_supported:
logger.info_once(
"FlashInfer TRTLLM MoE is available but not enabled, "
"consider setting VLLM_USE_FLASHINFER_MOE_FP16=1 "
"and VLLM_FLASHINFER_MOE_BACKEND=latency "
"to enable it for better performance.",
scope="local",
)
Expand Down