diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 43bdd03cfe13..3a68ddfa50d4 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -1675,7 +1675,7 @@ def test_unquantized_bf16_flashinfer_trtllm_backend( intermediate_size_per_partition=n, num_local_experts=e, num_logical_experts=e, - activation="silu", + activation=MoEActivation.SILU, device="cuda", moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), in_dtype=dtype, @@ -1706,13 +1706,25 @@ def test_unquantized_bf16_flashinfer_trtllm_backend( layer.topk_group = 1 layer.intermediate_size_per_partition = n layer.ep_rank = 0 - layer.activation = "silu" + layer.activation = MoEActivation.SILU layer.e_score_correction_bias = None layer.routing_method_type = RoutingMethodType.Renormalize + layer.expert_map = None + layer.apply_router_weight_on_input = False + layer.routed_scaling_factor = None + layer.shared_experts = None + layer._maybe_init_expert_routing_tables = lambda: None quant_method.process_weights_after_loading(layer) - trtllm_output = quant_method.forward_monolithic_cuda( + assert quant_method.moe_kernel is not None, ( + "moe_kernel should be set after process_weights_after_loading" + ) + assert quant_method.supports_internal_mk, ( + "supports_internal_mk should be True after setup" + ) + + trtllm_output = quant_method.apply_monolithic( layer=layer, x=a, router_logits=router_logits, diff --git a/tests/kernels/moe/test_unquantized_backend_selection.py b/tests/kernels/moe/test_unquantized_backend_selection.py index bf5a547fe3df..c9aad044f4a0 100644 --- a/tests/kernels/moe/test_unquantized_backend_selection.py +++ b/tests/kernels/moe/test_unquantized_backend_selection.py @@ -19,12 +19,10 @@ ("is_rocm", UnquantizedMoeBackend.TRITON), ("is_cpu", UnquantizedMoeBackend.CPU), ("is_xpu", UnquantizedMoeBackend.XPU), - ("is_tpu", UnquantizedMoeBackend.TPU), - ("is_out_of_tree", UnquantizedMoeBackend.OOT), ], ) @patch( - "vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer", + "vllm.utils.flashinfer.has_flashinfer", return_value=False, ) def test_select_default_backend_by_platform( @@ -34,36 +32,34 @@ def test_select_default_backend_by_platform( expected_backend, ): """Test backend selection for different platforms.""" - with patch( - "vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform" - ) as mock_platform: - # Set all platform checks to False - mock_platform.is_cuda.return_value = False - mock_platform.is_rocm.return_value = False - mock_platform.is_cpu.return_value = False - mock_platform.is_xpu.return_value = False - mock_platform.is_tpu.return_value = False - mock_platform.is_out_of_tree.return_value = False - - # Set only the specified platform to True - getattr(mock_platform, platform_method).return_value = True + with ( + patch.object(current_platform, "is_cuda", return_value=False), + patch.object(current_platform, "is_rocm", return_value=False), + patch.object(current_platform, "is_cpu", return_value=False), + patch.object(current_platform, "is_xpu", return_value=False), + patch.object(current_platform, "is_tpu", return_value=False), + patch.object(current_platform, "is_out_of_tree", return_value=False), + patch.object(current_platform, platform_method, return_value=True), + ): moe_config = make_dummy_moe_config() - selected_backend = select_unquantized_moe_backend( - moe_config=moe_config, - use_ep=False, - use_dp=False, + selected_backend, expert_cls = select_unquantized_moe_backend( + moe_config=moe_config ) assert selected_backend == expected_backend + if expected_backend == UnquantizedMoeBackend.CPU: + assert expert_cls is None + else: + assert expert_cls is not None @patch( - "vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer", + "vllm.utils.flashinfer.has_flashinfer", return_value=True, ) @patch( - "vllm.model_executor.layers.fused_moe.oracle.unquantized.is_supported_config_trtllm_bf16", + "vllm.model_executor.layers.fused_moe.experts.trtllm_bf16_moe.TrtLlmBf16Experts.is_supported_config", return_value=(True, None), ) @pytest.mark.skipif( @@ -73,67 +69,73 @@ def test_select_cuda_flashinfer_trtllm_backend( mock_has_flashinfer, mock_is_supported_trtllm, monkeypatch ): """Test CUDA backend selection when FlashInfer TRTLLM is available and enabled.""" - with patch( - "vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform" - ) as mock_platform: - # Set as CUDA platform - mock_platform.is_cuda.return_value = True - mock_platform.is_rocm.return_value = False - mock_platform.is_cpu.return_value = False - mock_platform.is_xpu.return_value = False - mock_platform.is_tpu.return_value = False - mock_platform.is_out_of_tree.return_value = False - + with ( + patch.object(current_platform, "is_cuda", return_value=True), + patch.object(current_platform, "is_rocm", return_value=False), + patch.object(current_platform, "is_cpu", return_value=False), + patch.object(current_platform, "is_xpu", return_value=False), + patch.object(current_platform, "is_tpu", return_value=False), + patch.object(current_platform, "is_out_of_tree", return_value=False), + patch.object(current_platform, "has_device_capability", return_value=True), + ): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1") moe_config = make_dummy_moe_config() + # TRTLLM requires EP and does not support DP + moe_config.moe_parallel_config.use_ep = True + moe_config.moe_parallel_config.use_dp = False - selected_backend = select_unquantized_moe_backend( - moe_config=moe_config, - use_ep=True, - use_dp=False, + selected_backend, experts_cls = select_unquantized_moe_backend( + moe_config=moe_config ) assert selected_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM + assert experts_cls is not None @patch( - "vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer", + "vllm.utils.flashinfer.has_flashinfer", return_value=True, ) @patch( - "vllm.model_executor.layers.fused_moe.oracle.unquantized.is_supported_config_trtllm_bf16", + "vllm.model_executor.layers.fused_moe.experts.trtllm_bf16_moe.TrtLlmBf16Experts.is_supported_config", return_value=(False, None), ) +@patch( + "vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts.is_supported_config", + return_value=(True, None), +) @pytest.mark.skipif( not current_platform.is_cuda(), reason="Only supported on NVIDIA platforms." ) def test_select_cuda_flashinfer_cutlass_backend( - mock_has_flashinfer, mock_is_supported_trtllm, monkeypatch + mock_has_flashinfer, + mock_is_supported_trtllm, + mock_is_supported_cutlass, + monkeypatch, ): """Test CUDA backend selection when FlashInfer TRTLLM is not available and FlashInfer CUTLASS is available.""" - with patch( - "vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform" - ) as mock_platform: - # Set as CUDA platform with Hopper capability - mock_platform.is_cuda.return_value = True - mock_platform.is_rocm.return_value = False - mock_platform.is_cpu.return_value = False - mock_platform.is_xpu.return_value = False - mock_platform.is_tpu.return_value = False - mock_platform.is_out_of_tree.return_value = False - mock_platform.has_device_capability.return_value = True # SM90+ - + with ( + patch.object(current_platform, "is_cuda", return_value=True), + patch.object(current_platform, "is_rocm", return_value=False), + patch.object(current_platform, "is_cpu", return_value=False), + patch.object(current_platform, "is_xpu", return_value=False), + patch.object(current_platform, "is_tpu", return_value=False), + patch.object(current_platform, "is_out_of_tree", return_value=False), + patch.object(current_platform, "has_device_capability", return_value=True), + ): # Enable FlashInfer via env var monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1") moe_config = make_dummy_moe_config() + # CUTLASS requires EP and does not support DP + moe_config.moe_parallel_config.use_ep = True + moe_config.moe_parallel_config.use_dp = False - selected_backend = select_unquantized_moe_backend( - moe_config=moe_config, - use_ep=True, # CUTLASS requires EP - use_dp=False, # CUTLASS doesn't support DP + selected_backend, experts_cls = select_unquantized_moe_backend( + moe_config=moe_config ) assert selected_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS + assert experts_cls is not None diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index eff05b575856..38a57adefe12 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -49,6 +49,9 @@ def __init__(self, base_layer: FusedMoE) -> None: assert not self.base_layer.use_ep, ( "EP support for Fused MoE LoRA is not implemented yet." ) + assert not self.base_layer.quant_method.is_monolithic, ( + "Monolithic kernels are not supported for Fused MoE LoRA." + ) self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() self.device = _get_lora_device(base_layer) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_bf16_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_bf16_moe.py new file mode 100644 index 000000000000..461073a31da4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_bf16_moe.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, + RoutingMethodType, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) +from vllm.platforms import current_platform + + +class TrtLlmBf16Experts(mk.FusedMoEExpertsMonolithic): + """ + BF16 unquantized TRTLLM-Gen MoE kernels. Supports monolithic interface. + """ + + def __init__( + self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + ): + super().__init__(moe_config, quant_config) + self.routing_method_type = moe_config.routing_method + self.topk = moe_config.experts_per_token + self.intermediate_size_per_partition = ( + moe_config.intermediate_size_per_partition + ) + self.hidden_dim = moe_config.hidden_dim + self.local_num_experts = moe_config.num_local_experts + self.ep_rank = moe_config.moe_parallel_config.ep_rank + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + """Supports only Blackwell-family GPUs.""" + p = current_platform + return p.is_cuda() and p.is_device_capability_family(100) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + """BF16 kernels do not support non-gated MoE""" + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """Supports only unquantized inputs.""" + return weight_key is None and activation_key is None + + @staticmethod + def _supports_activation(activation: MoEActivation) -> bool: + return activation in [MoEActivation.SILU] + + @staticmethod + def _supports_routing_method( + routing_method: RoutingMethodType, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + return routing_method in [ + RoutingMethodType.Default, + RoutingMethodType.Renormalize, + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Llama4, + RoutingMethodType.RenormalizeNaive, + ] + + @staticmethod + def _supports_parallel_config( + moe_parallel_config: FusedMoEParallelConfig, + ) -> bool: + """Monolithic kernel so only use with naive DP/EP and TP.""" + return ( + not moe_parallel_config.use_all2all_kernels + or moe_parallel_config.use_naive_all2all_kernels + ) and not moe_parallel_config.enable_eplb + + @staticmethod + def _supports_router_logits_dtype( + router_logits_dtype: torch.dtype | None, + routing_method: RoutingMethodType, + ) -> bool: + return True + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + @property + def expects_unquantized_inputs(self) -> bool: + return True + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + apply_router_weight_on_input: bool, + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + import flashinfer + + return flashinfer.fused_moe.trtllm_bf16_moe( + routing_logits=router_logits, + routing_bias=e_score_correction_bias, + hidden_states=hidden_states, + gemm1_weights=w1, + gemm2_weights=w2, + num_experts=global_num_experts, + top_k=self.topk, + n_group=num_expert_group, + topk_group=topk_group, + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.ep_rank * self.local_num_experts, + local_num_experts=self.local_num_experts, + routing_method_type=self.routing_method_type, + ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py deleted file mode 100644 index d04e040c8959..000000000000 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ /dev/null @@ -1,141 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.activation import MoEActivation -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, - FusedMoEParallelConfig, - RoutingMethodType, -) -from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op - -# -# Methods used by the oracle for kernel selection. -# - - -def _supports_current_device() -> bool: - """Supports only Blackwell-family GPUs.""" - p = current_platform - return p.is_cuda() and p.is_device_capability_family(100) - - -def _supports_no_act_and_mul() -> bool: - """BF16 kernels do not support non-gated MoE""" - return False - - -def _supports_activation(activation: MoEActivation) -> bool: - return activation in [MoEActivation.SILU] - - -def _supports_routing_method_bf16( - routing_method: RoutingMethodType, -) -> bool: - return routing_method in [ - RoutingMethodType.Default, - RoutingMethodType.Renormalize, - RoutingMethodType.DeepSeekV3, - RoutingMethodType.Llama4, - RoutingMethodType.RenormalizeNaive, - ] - - -def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - """Supports TRTLLM Kernel does not support EPLB.""" - return not moe_parallel_config.enable_eplb - - -def is_supported_config_trtllm_bf16( - moe_config: FusedMoEConfig, - activation_format: mk.FusedMoEActivationFormat, -) -> tuple[bool, str | None]: - """ - This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config - for BF16 unquantized kernels. - """ - - def _make_reason(reason: str) -> str: - return f"kernel does not support {reason}" - - if not _supports_current_device(): - return False, _make_reason(f"current device {current_platform.device_name}") - elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()): - return False, _make_reason("no act_and_mul MLP layer") - elif not _supports_activation(moe_config.activation): - return False, _make_reason(f"{moe_config.activation} activation") - elif not _supports_parallel_config(moe_config.moe_parallel_config): - return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}") - elif not _supports_routing_method_bf16(moe_config.routing_method): - return False, _make_reason(f"routing method {moe_config.routing_method}") - elif activation_format != mk.FusedMoEActivationFormat.Standard: - return False, _make_reason(f"activation format {activation_format}") - - return True, None - - -def flashinfer_fused_moe_bf16( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor | None, - hidden_states: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor, - num_experts: int, - top_k: int, - n_group: int | None, - topk_group: int | None, - intermediate_size: int, - local_expert_offset: int, - local_num_experts: int, - routing_method_type: int, - tune_max_num_tokens: int = 8192, -) -> torch.Tensor: - from vllm.utils.flashinfer import flashinfer_trtllm_bf16_moe - - return flashinfer_trtllm_bf16_moe( - routing_logits=routing_logits, - routing_bias=routing_bias, - hidden_states=hidden_states, - gemm1_weights=gemm1_weights, - gemm2_weights=gemm2_weights, - num_experts=num_experts, - top_k=top_k, - n_group=n_group, - topk_group=topk_group, - intermediate_size=intermediate_size, - local_expert_offset=local_expert_offset, - local_num_experts=local_num_experts, - routing_method_type=routing_method_type, - tune_max_num_tokens=tune_max_num_tokens, - ) - - -def flashinfer_fused_moe_bf16_fake( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor | None, - hidden_states: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor, - num_experts: int, - top_k: int, - n_group: int | None, - topk_group: int | None, - intermediate_size: int, - local_expert_offset: int, - local_num_experts: int, - routing_method_type: int = RoutingMethodType.Renormalize, - tune_max_num_tokens: int = 8192, -) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -direct_register_custom_op( - op_name="flashinfer_fused_moe_bf16", - op_func=flashinfer_fused_moe_bf16, - fake_impl=flashinfer_fused_moe_bf16_fake, - tags=(torch.Tag.needs_fixed_stride_order,), -) diff --git a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py index 9c31da10dd94..a84274ff42a7 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py +++ b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py @@ -11,21 +11,20 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.config.kernel import MoEBackend from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( - is_supported_config_trtllm_bf16, -) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoDPEPModular, -) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + FlashinferMoeBackend, + convert_moe_weights_to_flashinfer_trtllm_block_layout, + get_flashinfer_moe_backend, swap_w13_to_w31, ) from vllm.platforms import current_platform -from vllm.utils.flashinfer import has_flashinfer, has_flashinfer_cutlass_fused_moe logger = init_logger(__name__) @@ -35,31 +34,106 @@ class UnquantizedMoeBackend(Enum): FLASHINFER_CUTLASS = "FlashInfer CUTLASS" AITER = "ROCm AITER" TRITON = "TRITON" + BATCHED_TRITON = "BATCHED_TRITON" CPU = "CPU" XPU = "XPU" - TPU = "TPU" - OOT = "OOT" -# NOTE(zyongye): Unsupported backend means backend -# that is not conform with Modular kernel format. -# We will directly call the kernel for those backend -UNSUPPORTED_BACKEND = [ - UnquantizedMoeBackend.FLASHINFER_TRTLLM, - UnquantizedMoeBackend.CPU, - UnquantizedMoeBackend.TPU, - UnquantizedMoeBackend.OOT, -] +def _get_priority_backends( + moe_config: FusedMoEConfig, + activation_format: mk.FusedMoEActivationFormat, +) -> list[UnquantizedMoeBackend]: + """ + Get available backends in priority order based on platform and config. + + This function can be extended to become more complex as needed. + """ + + _AVAILABLE_BACKENDS = [] + if activation_format == mk.FusedMoEActivationFormat.Standard: + if current_platform.is_rocm(): + _AVAILABLE_BACKENDS = [ + UnquantizedMoeBackend.AITER, + UnquantizedMoeBackend.TRITON, + ] + elif current_platform.is_cuda(): + _AVAILABLE_BACKENDS = [ + UnquantizedMoeBackend.FLASHINFER_TRTLLM, + UnquantizedMoeBackend.FLASHINFER_CUTLASS, + UnquantizedMoeBackend.TRITON, + ] + elif current_platform.is_xpu(): + _AVAILABLE_BACKENDS = [UnquantizedMoeBackend.XPU] + elif current_platform.is_cpu(): + _AVAILABLE_BACKENDS = [UnquantizedMoeBackend.CPU] + else: + if current_platform.is_cuda_alike(): + _AVAILABLE_BACKENDS = [ + UnquantizedMoeBackend.BATCHED_TRITON, + ] + return _AVAILABLE_BACKENDS + + +def backend_to_kernel_cls( + backend: UnquantizedMoeBackend, +) -> type[mk.FusedMoEExperts]: + if backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM: + from vllm.model_executor.layers.fused_moe.experts.trtllm_bf16_moe import ( + TrtLlmBf16Experts, + ) + + return TrtLlmBf16Experts + + elif backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, + ) + return FlashInferExperts + + elif backend == UnquantizedMoeBackend.AITER: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + AiterExperts, + ) -def map_unquantized_backend(runner_backend: MoEBackend) -> UnquantizedMoeBackend: + return AiterExperts + + elif backend == UnquantizedMoeBackend.TRITON: + from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts + + return TritonExperts + + elif backend == UnquantizedMoeBackend.BATCHED_TRITON: + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts, + ) + + return BatchedTritonExperts + + elif backend == UnquantizedMoeBackend.XPU: + from vllm.model_executor.layers.fused_moe.xpu_fused_moe import XPUExperts + + return XPUExperts + + else: + raise ValueError(f"Unknown unquantized MoE backend: {backend.value}") + + +def map_unquantized_backend( + runner_backend: MoEBackend, activation_format: mk.FusedMoEActivationFormat +) -> UnquantizedMoeBackend: """Map user's MoEBackend to UnquantizedMoeBackend.""" - mapping = { - "triton": UnquantizedMoeBackend.TRITON, - "flashinfer_trtllm": UnquantizedMoeBackend.FLASHINFER_TRTLLM, - "flashinfer_cutlass": UnquantizedMoeBackend.FLASHINFER_CUTLASS, - "aiter": UnquantizedMoeBackend.AITER, - } + if activation_format == mk.FusedMoEActivationFormat.Standard: + mapping = { + "triton": UnquantizedMoeBackend.TRITON, + "flashinfer_trtllm": UnquantizedMoeBackend.FLASHINFER_TRTLLM, + "flashinfer_cutlass": UnquantizedMoeBackend.FLASHINFER_CUTLASS, + "aiter": UnquantizedMoeBackend.AITER, + } + else: + mapping = { + "triton": UnquantizedMoeBackend.BATCHED_TRITON, + } if backend := mapping.get(runner_backend): return backend raise ValueError( @@ -70,16 +144,21 @@ def map_unquantized_backend(runner_backend: MoEBackend) -> UnquantizedMoeBackend def select_unquantized_moe_backend( moe_config: FusedMoEConfig, - use_ep: bool, - use_dp: bool, -) -> UnquantizedMoeBackend: +) -> tuple[UnquantizedMoeBackend, type[mk.FusedMoEExperts] | None]: """ - Select the primary Unquantized MoE backend + Select the primary Unquantized MoE backend. Note: Shape-specific fallbacks may still occur at runtime. """ - def _make_log_backend(backend: UnquantizedMoeBackend): - return f"Using {backend.value} backend for Unquantized MoE" + if current_platform.is_cpu(): + # Escape hatch for CPU backend, which is not yet supported by the oracle + # TODO(yzong): migrate CPU backend to FusedMoEExpertsMonolithic + return UnquantizedMoeBackend.CPU, None + + if current_platform.is_tpu() or current_platform.is_out_of_tree(): + raise RuntimeError( + "Unquantized MoE oracle does not support TPU or OOT platforms." + ) activation_format = ( mk.FusedMoEActivationFormat.BatchedExperts @@ -87,104 +166,105 @@ def _make_log_backend(backend: UnquantizedMoeBackend): else mk.FusedMoEActivationFormat.Standard ) - # Check if FlashInfer TRTLLM BF16 MoE is supported - trtllm_supported, _ = is_supported_config_trtllm_bf16( - moe_config=moe_config, - activation_format=activation_format, - ) - flashinfer_trtllm_available = has_flashinfer() and trtllm_supported - # FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS - flashinfer_cutlass_available = ( - has_flashinfer_cutlass_fused_moe() - and use_ep - and (not use_dp) - and current_platform.has_device_capability(90) - ) - flashinfer_trtllm_moe_enabled = ( - flashinfer_trtllm_available - and envs.VLLM_USE_FLASHINFER_MOE_FP16 - and envs.VLLM_FLASHINFER_MOE_BACKEND == "latency" - ) - flashinfer_cutlass_moe_enabled = ( - flashinfer_cutlass_available and envs.VLLM_USE_FLASHINFER_MOE_FP16 - ) - rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + AVAILABLE_BACKENDS = _get_priority_backends(moe_config, activation_format) + + def _make_log_backend(backend: UnquantizedMoeBackend) -> str: + available_strs = [b.value for b in AVAILABLE_BACKENDS] + return ( + f"Using {backend.value} Unquantized MoE backend out " + f"of potential backends: {available_strs}." + ) + + def _make_log_unsupported( + backend: UnquantizedMoeBackend, reason: str | None + ) -> str: + if reason: + return ( + f"Unquantized MoE backend {backend.value} does not support the " + f"deployment configuration since {reason}." + ) + return ( + f"Unquantized MoE backend '{backend.value}' does not support the " + "deployment configuration." + ) + + def _return_or_raise( + backend: UnquantizedMoeBackend, + config: FusedMoEConfig, + activation_format: mk.FusedMoEActivationFormat, + ) -> tuple[UnquantizedMoeBackend, type[mk.FusedMoEExperts] | None]: + k_cls = backend_to_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, config, None, None, activation_format + ) + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + raise ValueError(_make_log_unsupported(backend, reason)) - # Handle explicit moe_backend from user. runner_backend = moe_config.moe_backend if runner_backend != "auto": - requested_backend = map_unquantized_backend(runner_backend) - if requested_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM: - if not flashinfer_trtllm_available: + requested_backend = map_unquantized_backend(runner_backend, activation_format) + return _return_or_raise(requested_backend, moe_config, activation_format) + + # Handle explicit FlashInfer FP16 configuration. + if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP16"): + if not envs.VLLM_USE_FLASHINFER_MOE_FP16: + if UnquantizedMoeBackend.FLASHINFER_TRTLLM in AVAILABLE_BACKENDS: + AVAILABLE_BACKENDS.remove(UnquantizedMoeBackend.FLASHINFER_TRTLLM) + if UnquantizedMoeBackend.FLASHINFER_CUTLASS in AVAILABLE_BACKENDS: + AVAILABLE_BACKENDS.remove(UnquantizedMoeBackend.FLASHINFER_CUTLASS) + elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): + # If user is explicit about backend, validate it. + fi_backend = get_flashinfer_moe_backend() + if fi_backend == FlashinferMoeBackend.CUTLASS: + backend = UnquantizedMoeBackend.FLASHINFER_CUTLASS + elif fi_backend == FlashinferMoeBackend.TENSORRT_LLM: + backend = UnquantizedMoeBackend.FLASHINFER_TRTLLM + else: raise ValueError( - "FlashInfer TRTLLM MoE backend is not available for this " - "configuration." + f"FlashInfer MOE backend {fi_backend} " + "does not support unquantized MoE." ) - elif requested_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS: - if not flashinfer_cutlass_available: - raise ValueError( - "FlashInfer CUTLASS MoE backend is not available for this " - "configuration." + k_cls = backend_to_kernel_cls(backend) + return _return_or_raise(backend, moe_config, activation_format) + else: + # If the user is not explicit about the backend, try both. + for backend in [ + UnquantizedMoeBackend.FLASHINFER_TRTLLM, + UnquantizedMoeBackend.FLASHINFER_CUTLASS, + ]: + k_cls = backend_to_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, moe_config, None, None, activation_format ) - elif requested_backend == UnquantizedMoeBackend.AITER and not ( - current_platform.is_rocm() and rocm_aiter_moe_enabled - ): - raise ValueError( - "ROCm AITer MoE backend is not available for this configuration." + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + else: + logger.debug_once( + _make_log_unsupported(backend, reason), scope="local" + ) + + raise NotImplementedError( + "Found VLLM_USE_FLASHINFER_MOE_FP16=1, but no " + "FlashInfer unquantized MoE backend supports the configuration." ) - logger.info_once(_make_log_backend(requested_backend), scope="local") - return requested_backend - if current_platform.is_rocm(): - if rocm_aiter_moe_enabled: - backend = UnquantizedMoeBackend.AITER - else: - backend = UnquantizedMoeBackend.TRITON - if current_platform.is_cuda(): - if flashinfer_trtllm_moe_enabled: - backend = UnquantizedMoeBackend.FLASHINFER_TRTLLM - elif flashinfer_cutlass_moe_enabled: - backend = UnquantizedMoeBackend.FLASHINFER_CUTLASS - if trtllm_supported: - logger.info_once( - "FlashInfer TRTLLM MoE is available but not enabled, " - "consider setting VLLM_FLASHINFER_MOE_BACKEND=latency " - "to enable it for better performance.", - scope="local", - ) - else: - if not envs.VLLM_USE_FLASHINFER_MOE_FP16 and trtllm_supported: - logger.info_once( - "FlashInfer TRTLLM MoE is available but not enabled, " - "consider setting VLLM_USE_FLASHINFER_MOE_FP16=1 " - "and VLLM_FLASHINFER_MOE_BACKEND=latency " - "to enable it for better performance.", - scope="local", - ) - elif use_ep and (not use_dp): - logger.info_once( - "FlashInfer MoE is available for EP" - " but not enabled, consider setting" - " VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.", - scope="local", - ) - elif use_dp: - logger.info_once( - "FlashInfer CUTLASS MoE is currently not available for DP.", - scope="local", - ) - backend = UnquantizedMoeBackend.TRITON - if current_platform.is_xpu(): - backend = UnquantizedMoeBackend.XPU - if current_platform.is_cpu(): - backend = UnquantizedMoeBackend.CPU - if current_platform.is_tpu(): - backend = UnquantizedMoeBackend.TPU - if current_platform.is_out_of_tree(): - backend = UnquantizedMoeBackend.OOT + for backend in AVAILABLE_BACKENDS: + k_cls = backend_to_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, moe_config, None, None, activation_format + ) + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls - logger.info_once(_make_log_backend(backend), scope="local") - return backend + logger.debug_once(_make_log_unsupported(backend, reason), scope="local") + + raise NotImplementedError( + "No unquantized MoE backend supports the deployment configuration." + ) def convert_to_unquantized_kernel_format( @@ -194,72 +274,78 @@ def convert_to_unquantized_kernel_format( w2_weight: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if unquantized_backend == UnquantizedMoeBackend.AITER: - w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data - ) + w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(w13_weight, w2_weight) elif unquantized_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS: # Swap halves to arrange as [w3; w1] (kernel expectation) - w13_weight = swap_w13_to_w31(layer.w13_weight.data) + w13_weight = swap_w13_to_w31(w13_weight) + + elif unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM: + # Swap halves to arrange as [w3; w1] (kernel expectation) + w13_weight = swap_w13_to_w31(w13_weight) + _cache_permute_indices: dict[torch.Size, torch.Tensor] = {} + w13_weight, w2_weight = convert_moe_weights_to_flashinfer_trtllm_block_layout( + _cache_permute_indices, + w13_weight, + w2_weight, + ) return w13_weight, w2_weight def make_unquantized_moe_kernel( - backend: UnquantizedMoeBackend, quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, -) -> mk.FusedMoEKernel | None: - if backend in UNSUPPORTED_BACKEND: - return None + backend: UnquantizedMoeBackend, + experts_cls: type[mk.FusedMoEExperts], + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + shared_experts: torch.nn.Module | None = None, +) -> mk.FusedMoEKernel: + # Create Prepare/Finalize + is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic) + prepare_finalize = maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + routing_tables=routing_tables, + allow_new_interface=True, + use_monolithic=is_monolithic, + ) + assert prepare_finalize is not None - if backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS: - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, - ) + logger.info_once("Using %s", prepare_finalize.__class__.__name__, scope="local") - kernel = mk.FusedMoEKernel( - MoEPrepareAndFinalizeNoDPEPModular(), - FlashInferExperts( - moe_config=moe_config, - quant_config=quant_config, - ), - inplace=False, + # Create Experts + if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: + max_num_tokens = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens is not None + experts = experts_cls( + moe_config=moe_config, + quant_config=quant_config, + max_num_tokens=max_num_tokens, + num_dispatchers=prepare_finalize.num_dispatchers(), ) - - elif backend == UnquantizedMoeBackend.AITER: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - AiterExperts, + else: + experts = experts_cls( + moe_config=moe_config, + quant_config=quant_config, ) - kernel = mk.FusedMoEKernel( - MoEPrepareAndFinalizeNoDPEPModular(), - AiterExperts( - moe_config=moe_config, - quant_config=quant_config, - ), - inplace=not moe_config.disable_inplace, - ) - elif backend == UnquantizedMoeBackend.TRITON: - from vllm.model_executor.layers.fused_moe import TritonExperts - - kernel = mk.FusedMoEKernel( - MoEPrepareAndFinalizeNoDPEPModular(), - TritonExperts( - moe_config=moe_config, - quant_config=quant_config, - ), - inplace=not moe_config.disable_inplace, - ) - elif backend == UnquantizedMoeBackend.XPU: - from vllm.model_executor.layers.fused_moe import XPUExperts - - kernel = mk.FusedMoEKernel( - MoEPrepareAndFinalizeNoDPEPModular(), - XPUExperts( - moe_config=moe_config, - quant_config=quant_config, - ), - inplace=not moe_config.disable_inplace, - ) + kernel = mk.FusedMoEKernel( + prepare_finalize, + experts, + shared_experts=( + shared_experts + if ( + moe_config.moe_parallel_config.use_deepep_ll_kernels + and not is_monolithic + ) + else None + ), + moe_parallel_config=moe_config.moe_parallel_config, + inplace=( + not moe_config.disable_inplace + and backend != UnquantizedMoeBackend.FLASHINFER_CUTLASS + ), + ) + return kernel diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index a29d8a7d8dda..42df8d6d768e 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -6,11 +6,8 @@ import torch import torch.nn.functional as F from torch.nn import Module -from torch.nn.parameter import Parameter import vllm.envs as envs -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe.config import ( @@ -23,7 +20,6 @@ FusedMoEMethodBase, ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEActivationFormat, FusedMoEExpertsModular, FusedMoEPrepareAndFinalizeModular, ) @@ -33,20 +29,10 @@ make_unquantized_moe_kernel, select_unquantized_moe_backend, ) -from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - convert_moe_weights_to_flashinfer_trtllm_block_layout, -) from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -if current_platform.is_cuda_alike() or current_platform.is_xpu(): - from .fused_batched_moe import BatchedTritonExperts - from .fused_moe import TritonExperts -else: - TritonExperts = None # type: ignore - - logger = init_logger(__name__) @@ -59,46 +45,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) - self.unquantized_backend = select_unquantized_moe_backend( + self.unquantized_backend, self.experts_cls = select_unquantized_moe_backend( moe_config=self.moe, - use_ep=self.moe.moe_parallel_config.use_ep, - use_dp=self.moe.moe_parallel_config.dp_size > 1, - ) - - # AITER only supports gated activations (silu/gelu), so disable it - # for non-gated MoE (is_act_and_mul=False) - self.rocm_aiter_moe_enabled = ( - rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul ) - self.kernel: mk.FusedMoEKernel | None = None - self._is_monolithic = ( - current_platform.is_cpu() - or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM - ) - - if self.is_monolithic: - self.apply_monolithic: Callable = self._select_monolithic() - - def _select_monolithic(self) -> Callable: - """Select the monolithic implementation based on platform.""" - if current_platform.is_cpu(): - return self.forward_monolithic_cpu - else: - return self.forward_monolithic_cuda - - def forward_native( - self, - layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 - x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - return self.forward_cuda(layer, x, topk_weights, topk_ids, shared_experts_input) @property def is_monolithic(self) -> bool: - return self._is_monolithic + # Escape hatch for CPU, which stays on the old monolithic path. + if self.unquantized_backend == UnquantizedMoeBackend.CPU: + return True + return super().is_monolithic @property def supports_eplb(self) -> bool: @@ -107,35 +63,22 @@ def supports_eplb(self) -> bool: def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> FusedMoEPrepareAndFinalizeModular | None: - if self.unquantized_backend == UnquantizedMoeBackend.AITER: - return None - else: - return super().maybe_make_prepare_finalize(routing_tables) + ): + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic for all but the CPU backend. CPU backend is monolithic. " + "So this function should not be called." + ) def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, ) -> FusedMoEExpertsModular: - assert self.moe_quant_config is not None - if ( - prepare_finalize.activation_format - == FusedMoEActivationFormat.BatchedExperts - ): - logger.debug("BatchedTritonExperts %s", self.moe) - return BatchedTritonExperts( - moe_config=self.moe, - quant_config=self.moe_quant_config, - max_num_tokens=self.moe.max_num_tokens, - num_dispatchers=prepare_finalize.num_dispatchers(), - ) - else: - logger.debug("TritonExperts %s", self.moe) - return TritonExperts( - moe_config=self.moe, - quant_config=self.moe_quant_config, - ) + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." + ) def create_weights( self, @@ -210,6 +153,8 @@ def _setup_kernel( w13: torch.Tensor, w2: torch.Tensor, ) -> None: + assert self.unquantized_backend != UnquantizedMoeBackend.CPU + # Shuffle weights to runtime format. w13, w2 = convert_to_unquantized_kernel_format( self.unquantized_backend, @@ -220,14 +165,17 @@ def _setup_kernel( replace_parameter(layer, "w13_weight", w13) replace_parameter(layer, "w2_weight", w2) - # Setup Modular Kernel for TP Case + # Setup modular kernels self.moe_quant_config = self.get_fused_moe_quant_config(layer) assert self.moe_quant_config is not None - - self.kernel = make_unquantized_moe_kernel( - backend=self.unquantized_backend, + assert self.experts_cls is not None + self.moe_kernel = make_unquantized_moe_kernel( quant_config=self.moe_quant_config, moe_config=self.moe, + backend=self.unquantized_backend, + experts_cls=self.experts_cls, + routing_tables=layer._maybe_init_expert_routing_tables(), + shared_experts=layer.shared_experts, ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -237,22 +185,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) - if self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM: - _cache_permute_indices: dict[torch.Size, torch.Tensor] = {} - # Swap halves to arrange as [w3; w1] (kernel expectation) - w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1) - w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) - layer.w13_weight.data = w13_weight_swapped.contiguous() - w13_weights_shuffled, w2_weights_shuffled = ( - convert_moe_weights_to_flashinfer_trtllm_block_layout( - _cache_permute_indices, - layer.w13_weight.data, - layer.w2_weight.data, - ) - ) - layer.w13_weight = Parameter(w13_weights_shuffled, requires_grad=False) - layer.w2_weight = Parameter(w2_weights_shuffled, requires_grad=False) - elif self.unquantized_backend == UnquantizedMoeBackend.CPU: + if self.unquantized_backend == UnquantizedMoeBackend.CPU: + # CPU stays on the old path — no oracle, no moe_kernel. from vllm.model_executor.layers.fused_moe import cpu_fused_moe if current_platform.get_cpu_architecture() == CpuArchEnum.X86: @@ -283,13 +217,22 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) else: self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) - elif current_platform.is_cuda_alike() or current_platform.is_xpu(): + else: self._setup_kernel( layer=layer, w13=layer.w13_weight, w2=layer.w2_weight, ) + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: + if self.moe.has_bias: + return biased_moe_quant_config( + layer.w13_bias, + layer.w2_bias, + ) + else: + return FUSED_MOE_UNQUANTIZED_CONFIG + def apply( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 @@ -306,16 +249,7 @@ def apply( shared_experts_input=shared_experts_input, ) - def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: - if self.moe.has_bias: - return biased_moe_quant_config( - layer.w13_bias, - layer.w2_bias, - ) - else: - return FUSED_MOE_UNQUANTIZED_CONFIG - - def forward_cuda( + def forward_native( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, @@ -323,9 +257,8 @@ def forward_cuda( topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.kernel is not None - - return self.kernel.apply( + assert self.moe_kernel is not None + return self.moe_kernel.apply( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -338,53 +271,58 @@ def forward_cuda( shared_experts_input=shared_experts_input, ) - def forward_monolithic_cuda( + def forward_cuda( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: F401 - - assert self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM - - return torch.ops.vllm.flashinfer_fused_moe_bf16( - routing_logits=router_logits, - routing_bias=layer.e_score_correction_bias, - hidden_states=x, - gemm1_weights=layer.w13_weight, - gemm2_weights=layer.w2_weight, - num_experts=layer.global_num_experts, - top_k=layer.top_k, - n_group=layer.num_expert_group, - topk_group=layer.topk_group, - intermediate_size=layer.intermediate_size_per_partition, - local_expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - routing_method_type=layer.routing_method_type, + return self.forward_native( + layer, x, topk_weights, topk_ids, shared_experts_input ) - def forward_monolithic_cpu( + def apply_monolithic( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - return self.cpu_fused_moe( - layer, - x, - layer.use_grouped_topk, - layer.top_k, - router_logits, - layer.renormalize, - layer.topk_group, - layer.num_expert_group, - layer.global_num_experts, - layer.expert_map, - layer.custom_routing_function, - layer.scoring_func, - layer.routed_scaling_factor, - layer.e_score_correction_bias, - layer.apply_router_weight_on_input, - layer.activation, - ) + assert self.is_monolithic + if self.unquantized_backend == UnquantizedMoeBackend.CPU: + assert self.moe_kernel is None + return self.cpu_fused_moe( + layer, + x, + layer.use_grouped_topk, + layer.top_k, + router_logits, + layer.renormalize, + layer.topk_group, + layer.num_expert_group, + layer.global_num_experts, + layer.expert_map, + layer.custom_routing_function, + layer.scoring_func, + layer.routed_scaling_factor, + layer.e_score_correction_bias, + layer.apply_router_weight_on_input, + layer.activation, + ) + else: + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, + )