Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions tests/kernels/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,7 +1675,7 @@ def test_unquantized_bf16_flashinfer_trtllm_backend(
intermediate_size_per_partition=n,
num_local_experts=e,
num_logical_experts=e,
activation="silu",
activation=MoEActivation.SILU,
device="cuda",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=dtype,
Expand Down Expand Up @@ -1706,13 +1706,25 @@ def test_unquantized_bf16_flashinfer_trtllm_backend(
layer.topk_group = 1
layer.intermediate_size_per_partition = n
layer.ep_rank = 0
layer.activation = "silu"
layer.activation = MoEActivation.SILU
layer.e_score_correction_bias = None
layer.routing_method_type = RoutingMethodType.Renormalize
layer.expert_map = None
layer.apply_router_weight_on_input = False
layer.routed_scaling_factor = None
layer.shared_experts = None
layer._maybe_init_expert_routing_tables = lambda: None

quant_method.process_weights_after_loading(layer)

trtllm_output = quant_method.forward_monolithic_cuda(
assert quant_method.moe_kernel is not None, (
"moe_kernel should be set after process_weights_after_loading"
)
assert quant_method.supports_internal_mk, (
"supports_internal_mk should be True after setup"
)

trtllm_output = quant_method.apply_monolithic(
layer=layer,
x=a,
router_logits=router_logits,
Expand Down
114 changes: 58 additions & 56 deletions tests/kernels/moe/test_unquantized_backend_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@
("is_rocm", UnquantizedMoeBackend.TRITON),
("is_cpu", UnquantizedMoeBackend.CPU),
("is_xpu", UnquantizedMoeBackend.XPU),
("is_tpu", UnquantizedMoeBackend.TPU),
("is_out_of_tree", UnquantizedMoeBackend.OOT),
],
)
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
"vllm.utils.flashinfer.has_flashinfer",
return_value=False,
)
def test_select_default_backend_by_platform(
Expand All @@ -34,36 +32,34 @@ def test_select_default_backend_by_platform(
expected_backend,
):
"""Test backend selection for different platforms."""
with patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
) as mock_platform:
# Set all platform checks to False
mock_platform.is_cuda.return_value = False
mock_platform.is_rocm.return_value = False
mock_platform.is_cpu.return_value = False
mock_platform.is_xpu.return_value = False
mock_platform.is_tpu.return_value = False
mock_platform.is_out_of_tree.return_value = False

# Set only the specified platform to True
getattr(mock_platform, platform_method).return_value = True

with (
patch.object(current_platform, "is_cuda", return_value=False),
patch.object(current_platform, "is_rocm", return_value=False),
patch.object(current_platform, "is_cpu", return_value=False),
patch.object(current_platform, "is_xpu", return_value=False),
patch.object(current_platform, "is_tpu", return_value=False),
patch.object(current_platform, "is_out_of_tree", return_value=False),
patch.object(current_platform, platform_method, return_value=True),
):
moe_config = make_dummy_moe_config()
selected_backend = select_unquantized_moe_backend(
moe_config=moe_config,
use_ep=False,
use_dp=False,
selected_backend, expert_cls = select_unquantized_moe_backend(
moe_config=moe_config
)

assert selected_backend == expected_backend
if expected_backend == UnquantizedMoeBackend.CPU:
assert expert_cls is None
else:
assert expert_cls is not None


@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
"vllm.utils.flashinfer.has_flashinfer",
return_value=True,
)
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.is_supported_config_trtllm_bf16",
"vllm.model_executor.layers.fused_moe.experts.trtllm_bf16_moe.TrtLlmBf16Experts.is_supported_config",
return_value=(True, None),
)
@pytest.mark.skipif(
Expand All @@ -73,67 +69,73 @@ def test_select_cuda_flashinfer_trtllm_backend(
mock_has_flashinfer, mock_is_supported_trtllm, monkeypatch
):
"""Test CUDA backend selection when FlashInfer TRTLLM is available and enabled."""
with patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
) as mock_platform:
# Set as CUDA platform
mock_platform.is_cuda.return_value = True
mock_platform.is_rocm.return_value = False
mock_platform.is_cpu.return_value = False
mock_platform.is_xpu.return_value = False
mock_platform.is_tpu.return_value = False
mock_platform.is_out_of_tree.return_value = False

with (
patch.object(current_platform, "is_cuda", return_value=True),
patch.object(current_platform, "is_rocm", return_value=False),
patch.object(current_platform, "is_cpu", return_value=False),
patch.object(current_platform, "is_xpu", return_value=False),
patch.object(current_platform, "is_tpu", return_value=False),
patch.object(current_platform, "is_out_of_tree", return_value=False),
patch.object(current_platform, "has_device_capability", return_value=True),
):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")

moe_config = make_dummy_moe_config()
# TRTLLM requires EP and does not support DP
moe_config.moe_parallel_config.use_ep = True
moe_config.moe_parallel_config.use_dp = False

selected_backend = select_unquantized_moe_backend(
moe_config=moe_config,
use_ep=True,
use_dp=False,
selected_backend, experts_cls = select_unquantized_moe_backend(
moe_config=moe_config
)

assert selected_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
assert experts_cls is not None


@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
"vllm.utils.flashinfer.has_flashinfer",
return_value=True,
)
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.is_supported_config_trtllm_bf16",
"vllm.model_executor.layers.fused_moe.experts.trtllm_bf16_moe.TrtLlmBf16Experts.is_supported_config",
return_value=(False, None),
)
@patch(
"vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts.is_supported_config",
return_value=(True, None),
)
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="Only supported on NVIDIA platforms."
)
def test_select_cuda_flashinfer_cutlass_backend(
mock_has_flashinfer, mock_is_supported_trtllm, monkeypatch
mock_has_flashinfer,
mock_is_supported_trtllm,
mock_is_supported_cutlass,
monkeypatch,
):
"""Test CUDA backend selection when FlashInfer TRTLLM is not available
and FlashInfer CUTLASS is available."""
with patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
) as mock_platform:
# Set as CUDA platform with Hopper capability
mock_platform.is_cuda.return_value = True
mock_platform.is_rocm.return_value = False
mock_platform.is_cpu.return_value = False
mock_platform.is_xpu.return_value = False
mock_platform.is_tpu.return_value = False
mock_platform.is_out_of_tree.return_value = False
mock_platform.has_device_capability.return_value = True # SM90+

with (
patch.object(current_platform, "is_cuda", return_value=True),
patch.object(current_platform, "is_rocm", return_value=False),
patch.object(current_platform, "is_cpu", return_value=False),
patch.object(current_platform, "is_xpu", return_value=False),
patch.object(current_platform, "is_tpu", return_value=False),
patch.object(current_platform, "is_out_of_tree", return_value=False),
patch.object(current_platform, "has_device_capability", return_value=True),
):
# Enable FlashInfer via env var
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")

moe_config = make_dummy_moe_config()
# CUTLASS requires EP and does not support DP
moe_config.moe_parallel_config.use_ep = True
moe_config.moe_parallel_config.use_dp = False

selected_backend = select_unquantized_moe_backend(
moe_config=moe_config,
use_ep=True, # CUTLASS requires EP
use_dp=False, # CUTLASS doesn't support DP
selected_backend, experts_cls = select_unquantized_moe_backend(
moe_config=moe_config
)

assert selected_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS
assert experts_cls is not None
3 changes: 3 additions & 0 deletions vllm/lora/layers/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def __init__(self, base_layer: FusedMoE) -> None:
assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
assert not self.base_layer.quant_method.is_monolithic, (
"Monolithic kernels are not supported for Fused MoE LoRA."
)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.device = _get_lora_device(base_layer)
Expand Down
140 changes: 140 additions & 0 deletions vllm/model_executor/layers/fused_moe/experts/trtllm_bf16_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.platforms import current_platform


class TrtLlmBf16Experts(mk.FusedMoEExpertsMonolithic):
"""
BF16 unquantized TRTLLM-Gen MoE kernels. Supports monolithic interface.
"""

def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
self.routing_method_type = moe_config.routing_method
self.topk = moe_config.experts_per_token
self.intermediate_size_per_partition = (
moe_config.intermediate_size_per_partition
)
self.hidden_dim = moe_config.hidden_dim
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank

@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard

@staticmethod
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)

@staticmethod
def _supports_no_act_and_mul() -> bool:
"""BF16 kernels do not support non-gated MoE"""
return False

@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports only unquantized inputs."""
return weight_key is None and activation_key is None

@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU]

@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return routing_method in [
RoutingMethodType.Default,
RoutingMethodType.Renormalize,
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
RoutingMethodType.RenormalizeNaive,
]

@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
"""Monolithic kernel so only use with naive DP/EP and TP."""
return (
not moe_parallel_config.use_all2all_kernels
or moe_parallel_config.use_naive_all2all_kernels
) and not moe_parallel_config.enable_eplb

@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
return True

def supports_chunking(self) -> bool:
return False

def supports_expert_map(self) -> bool:
return False

@property
def expects_unquantized_inputs(self) -> bool:
return True

def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
import flashinfer

return flashinfer.fused_moe.trtllm_bf16_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states,
gemm1_weights=w1,
gemm2_weights=w2,
num_experts=global_num_experts,
top_k=self.topk,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routing_method_type=self.routing_method_type,
)
Loading