diff --git a/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-BF16-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-BF16-fi-cutlass.yaml index fe099f9f121d..5416d9232cd2 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-BF16-fi-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-BF16-fi-cutlass.yaml @@ -5,3 +5,5 @@ num_fewshot: 5 server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --enable-expert-parallel" env: VLLM_USE_FLASHINFER_MOE_FP16: "1" + VLLM_FLASHINFER_MOE_BACKEND: "throughput" + diff --git a/tests/evals/gsm8k/configs/moe-refactor/Mixtral-8x7B-BF16-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Mixtral-8x7B-BF16-fi-cutlass.yaml index 5f4a76b0a6b2..cc8df6292cfb 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Mixtral-8x7B-BF16-fi-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Mixtral-8x7B-BF16-fi-cutlass.yaml @@ -5,3 +5,4 @@ num_fewshot: 5 server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --enable-expert-parallel" env: VLLM_USE_FLASHINFER_MOE_FP16: "1" + VLLM_FLASHINFER_MOE_BACKEND: "throughput" diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index e62cf79418c2..ddcd221efc0b 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -318,3 +318,44 @@ def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig: torch.testing.assert_close( output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2 ) + + +@pytest.mark.parametrize( + "num_experts,intermediate,hidden", + [ + (8, 2048, 1536), + (64, 4096, 4096), + ], +) +def test_convert_moe_weights_to_flashinfer_trtllm_block_layout( + num_experts, intermediate, hidden +): + from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + convert_moe_weights_to_flashinfer_trtllm_block_layout, + ) + + w13 = torch.randn( + (num_experts, 2 * intermediate, hidden), dtype=torch.bfloat16, device="cuda" + ) + w2 = torch.randn( + (num_experts, hidden, intermediate), dtype=torch.bfloat16, device="cuda" + ) + + cache: dict[torch.Size, torch.Tensor] = {} + w13_converted, w2_converted = convert_moe_weights_to_flashinfer_trtllm_block_layout( + cache, w13, w2 + ) + + assert w13_converted.ndim == 4, ( + f"Expected 4D tensor, got shape {w13_converted.shape}" + ) + assert w2_converted.ndim == 4, f"Expected 4D tensor, got shape {w2_converted.shape}" + + assert w13_converted.numel() == w13.numel(), "W13 element count should be preserved" + assert w2_converted.numel() == w2.numel(), "W2 element count should be preserved" + + assert w13_converted.dtype == torch.bfloat16 + assert w2_converted.dtype == torch.bfloat16 + + assert w13_converted.shape[0] == num_experts + assert w2_converted.shape[0] == num_experts diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 53fb43e3c121..6a622ac8e4d5 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -1558,3 +1558,103 @@ def run( marlin_output = br.run(a, kwargs) torch.testing.assert_close(marlin_output, ref_marlin_output, atol=1e-3, rtol=0) + + +@pytest.mark.parametrize("m,n,k", [(32, 1024, 1024)]) +@pytest.mark.parametrize("e,topk", [(8, 2)]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.skipif( + not current_platform.is_device_capability_family(100), + reason="TRTLLM backend test only runs on Blackwell GPUs (SM10x).", +) +def test_unquantized_bf16_flashinfer_trtllm_backend( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + monkeypatch, + workspace_init, +): + """ + Test BF16 unquantized MoE with FlashInfer TRTLLM backend. + """ + set_random_seed(7) + + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1") + + from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + RoutingMethodType, + ) + from vllm.model_executor.layers.fused_moe.oracle.unquantized import ( + UnquantizedMoeBackend, + ) + from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( + UnquantizedFusedMoEMethod, + ) + + # Setup test data + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + router_logits = torch.randn((m, e), device="cuda", dtype=dtype) + + moe_config = FusedMoEConfig( + num_experts=e, + experts_per_token=topk, + hidden_dim=k, + intermediate_size_per_partition=n, + num_local_experts=e, + activation="silu", + device="cuda", + moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), + in_dtype=dtype, + is_act_and_mul=True, + routing_method=RoutingMethodType.Renormalize, + max_num_tokens=m, + ) + + with set_current_vllm_config(vllm_config): + quant_method = UnquantizedFusedMoEMethod(moe_config) + + # Verify TRTLLM backend was selected + assert ( + quant_method.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM + ), f"Expected FLASHINFER_TRTLLM backend, got {quant_method.unquantized_backend}" + + # Verify it's using monolithic path + assert quant_method.is_monolithic, ( + "FLASHINFER_TRTLLM backend should use monolithic forward" + ) + layer = torch.nn.Module() + layer.w13_weight = Parameter(w1.clone(), requires_grad=False) + layer.w2_weight = Parameter(w2.clone(), requires_grad=False) + layer.global_num_experts = e + layer.local_num_experts = e + layer.top_k = topk + layer.num_expert_group = 1 + layer.topk_group = 1 + layer.intermediate_size_per_partition = n + layer.ep_rank = 0 + layer.activation = "silu" + layer.e_score_correction_bias = None + layer.routing_method_type = RoutingMethodType.Renormalize + + quant_method.process_weights_after_loading(layer) + + trtllm_output = quant_method.forward_monolithic_cuda( + layer=layer, + x=a, + router_logits=router_logits, + ) + + # Compute torch baseline + w1_original = w1.clone() + w2_original = w2.clone() + baseline_output = torch_moe(a, w1_original, w2_original, router_logits, topk) + + close = torch.isclose(trtllm_output, baseline_output, atol=1e-1, rtol=0.85) + assert close.float().mean() > 0.925 diff --git a/tests/kernels/moe/test_unquantized_backend_selection.py b/tests/kernels/moe/test_unquantized_backend_selection.py new file mode 100644 index 000000000000..fcb79ee8f296 --- /dev/null +++ b/tests/kernels/moe/test_unquantized_backend_selection.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import patch + +import pytest + +from tests.kernels.moe.utils import make_dummy_moe_config +from vllm.model_executor.layers.fused_moe.oracle.unquantized import ( + UnquantizedMoeBackend, + select_unquantized_moe_backend, +) + + +@pytest.mark.parametrize( + "platform_method,expected_backend", + [ + ("is_cuda", UnquantizedMoeBackend.TRITON), # Default CUDA without FlashInfer + ("is_rocm", UnquantizedMoeBackend.TRITON), + ("is_cpu", UnquantizedMoeBackend.CPU), + ("is_xpu", UnquantizedMoeBackend.XPU), + ("is_tpu", UnquantizedMoeBackend.TPU), + ("is_out_of_tree", UnquantizedMoeBackend.OOT), + ], +) +@patch( + "vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer", + return_value=False, +) +def test_select_default_backend_by_platform( + mock_has_flashinfer, + monkeypatch, + platform_method, + expected_backend, +): + """Test backend selection for different platforms.""" + with patch( + "vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform" + ) as mock_platform: + # Set all platform checks to False + mock_platform.is_cuda.return_value = False + mock_platform.is_rocm.return_value = False + mock_platform.is_cpu.return_value = False + mock_platform.is_xpu.return_value = False + mock_platform.is_tpu.return_value = False + mock_platform.is_out_of_tree.return_value = False + + # Set only the specified platform to True + getattr(mock_platform, platform_method).return_value = True + + moe_config = make_dummy_moe_config() + selected_backend = select_unquantized_moe_backend( + moe_config=moe_config, + use_ep=False, + use_dp=False, + ) + + assert selected_backend == expected_backend + + +@patch( + "vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer", + return_value=True, +) +@patch( + "vllm.model_executor.layers.fused_moe.oracle.unquantized.is_supported_config_trtllm_bf16", + return_value=(True, None), +) +def test_select_cuda_flashinfer_trtllm_backend( + mock_has_flashinfer, mock_is_supported_trtllm, monkeypatch +): + """Test CUDA backend selection when FlashInfer TRTLLM is available and enabled.""" + with patch( + "vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform" + ) as mock_platform: + # Set as CUDA platform + mock_platform.is_cuda.return_value = True + mock_platform.is_rocm.return_value = False + mock_platform.is_cpu.return_value = False + mock_platform.is_xpu.return_value = False + mock_platform.is_tpu.return_value = False + mock_platform.is_out_of_tree.return_value = False + + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1") + + moe_config = make_dummy_moe_config() + + selected_backend = select_unquantized_moe_backend( + moe_config=moe_config, + use_ep=True, + use_dp=False, + ) + + assert selected_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM + + +@patch( + "vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer", + return_value=True, +) +@patch( + "vllm.model_executor.layers.fused_moe.oracle.unquantized.is_supported_config_trtllm_bf16", + return_value=(False, None), +) +def test_select_cuda_flashinfer_cutlass_backend( + mock_has_flashinfer, mock_is_supported_trtllm, monkeypatch +): + """Test CUDA backend selection when FlashInfer TRTLLM is not available + and FlashInfer CUTLASS is available.""" + with patch( + "vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform" + ) as mock_platform: + # Set as CUDA platform with Hopper capability + mock_platform.is_cuda.return_value = True + mock_platform.is_rocm.return_value = False + mock_platform.is_cpu.return_value = False + mock_platform.is_xpu.return_value = False + mock_platform.is_tpu.return_value = False + mock_platform.is_out_of_tree.return_value = False + mock_platform.has_device_capability.return_value = True # SM90+ + + # Enable FlashInfer via env var + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1") + + moe_config = make_dummy_moe_config() + + selected_backend = select_unquantized_moe_backend( + moe_config=moe_config, + use_ep=True, # CUTLASS requires EP + use_dp=False, # CUTLASS doesn't support DP + ) + + assert selected_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index a43d2abfdd8b..07da2b454e6f 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -178,3 +178,11 @@ def test_gptoss_eager(monkeypatch: pytest.MonkeyPatch): hf_overrides=HF_OVERRIDE_TEXT, extra_args=["--enforce-eager"], ) + + +## Qwen3 Next ## + + +def test_qwen3_next_bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1") + can_initialize("Qwen/Qwen3-Next-80B-A3B-Instruct", hf_overrides=HF_OVERRIDE_TEXT) diff --git a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py index c4a19ecb61a8..61aaa6927778 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py +++ b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py @@ -78,7 +78,10 @@ def _make_log_backend(backend: UnquantizedMoeBackend): activation_format=activation_format, ) flashinfer_trtllm_moe_enabled = ( - has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_FP16 and trtllm_supported + has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_FP16 + and trtllm_supported + and envs.VLLM_FLASHINFER_MOE_BACKEND == "latency" ) # FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS flashinfer_cutlass_moe_enabled = ( @@ -98,11 +101,19 @@ def _make_log_backend(backend: UnquantizedMoeBackend): backend = UnquantizedMoeBackend.FLASHINFER_TRTLLM elif flashinfer_cutlass_moe_enabled: backend = UnquantizedMoeBackend.FLASHINFER_CUTLASS + if trtllm_supported: + logger.info_once( + "FlashInfer TRTLLM MoE is available but not enabled, " + "consider setting VLLM_FLASHINFER_MOE_BACKEND=latency " + "to enable it for better performance.", + scope="local", + ) else: if not envs.VLLM_USE_FLASHINFER_MOE_FP16 and trtllm_supported: logger.info_once( "FlashInfer TRTLLM MoE is available but not enabled, " "consider setting VLLM_USE_FLASHINFER_MOE_FP16=1 " + "and VLLM_FLASHINFER_MOE_BACKEND=latency " "to enable it for better performance.", scope="local", )