diff --git a/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml b/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml new file mode 100644 index 000000000000..3b08e103af08 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml @@ -0,0 +1,11 @@ +model_name: "amd/Qwen3.5-35B-A3B-MXFP4" +accuracy_threshold: 0.89 +tolerance: 0.03 +num_questions: 1319 +num_fewshot: 5 +server_args: >- + --max-model-len 4096 + --tensor-parallel-size 2 + --gpu-memory-utilization 0.35 +env: + VLLM_ROCM_USE_AITER: "1" diff --git a/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-TP2.yaml b/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml similarity index 65% rename from tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-TP2.yaml rename to tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml index 67eda1141559..ad5ca701258e 100644 --- a/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-TP2.yaml +++ b/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml @@ -1,8 +1,10 @@ model_name: "amd/Qwen3.5-35B-A3B-MXFP4" -accuracy_threshold: 0.82 +accuracy_threshold: 0.89 tolerance: 0.03 num_questions: 1319 num_fewshot: 5 server_args: >- --max-model-len 4096 --tensor-parallel-size 2 + --moe-backend emulation + --gpu-memory-utilization 0.35 diff --git a/tests/evals/gsm8k/configs/models-mi3xx.txt b/tests/evals/gsm8k/configs/models-mi3xx.txt index dfa4bc8eb53f..e8759d7d02b1 100644 --- a/tests/evals/gsm8k/configs/models-mi3xx.txt +++ b/tests/evals/gsm8k/configs/models-mi3xx.txt @@ -3,4 +3,5 @@ DeepSeek-R1-DP_MI325.yaml DeepSeek-V3.2-TP_MI325.yaml DeepSeek-V3.2-DP_MI325.yaml Qwen3-30B-A3B-NVFP4.yaml -Qwen3.5-35B-A3B-MXFP4-TP2.yaml \ No newline at end of file +Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml +Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/models-qwen35-mi355.txt b/tests/evals/gsm8k/configs/models-qwen35-mi355.txt index db8e88e27356..49925c827e3c 100644 --- a/tests/evals/gsm8k/configs/models-qwen35-mi355.txt +++ b/tests/evals/gsm8k/configs/models-qwen35-mi355.txt @@ -1,2 +1,3 @@ Qwen3.5-35B-A3B-DEP2.yaml -Qwen3.5-35B-A3B-MXFP4-TP2.yaml +Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml +Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml diff --git a/tests/quantization/test_gfx950_moe.py b/tests/quantization/test_gfx950_moe.py index 9cb94086f733..4b65961d8dbd 100644 --- a/tests/quantization/test_gfx950_moe.py +++ b/tests/quantization/test_gfx950_moe.py @@ -1,6 +1,90 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for MXFP4 MoE oracle backend selection on mi355x (GFX950). +These tests run on real hardware — no mocks. Skipped on non-GFX950 platforms. +""" -def test_mi355_moe(): - print("TODO: add tests for Mi355 MoE quantization") +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + RoutingMethodType, +) +from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( + Mxfp4MoeBackend, + select_mxfp4_moe_backend, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kMxfp4Dynamic, +) +from vllm.platforms import current_platform + +ROCM_AVAILABLE = current_platform.is_rocm() +ROCM_GFX950 = False +ROCM_AITER_AVAILABLE = False + +if ROCM_AVAILABLE: + from vllm._aiter_ops import rocm_aiter_ops + from vllm.platforms.rocm import on_gfx950 + + ROCM_GFX950 = on_gfx950() + ROCM_AITER_AVAILABLE = rocm_aiter_ops.is_fused_moe_enabled() + + +def _make_w4a4_moe_config(moe_backend: str = "auto") -> FusedMoEConfig: + from vllm.model_executor.layers.fused_moe.activation import MoEActivation + + return FusedMoEConfig( + num_experts=8, + experts_per_token=2, + hidden_dim=256, + intermediate_size_per_partition=256, + num_local_experts=8, + num_logical_experts=8, + moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), + activation=MoEActivation.SILU, + in_dtype=torch.bfloat16, + device="cuda", + routing_method=RoutingMethodType.Renormalize, + moe_backend=moe_backend, + ) + + +@pytest.mark.skipif(not ROCM_GFX950, reason="Requires GFX950 (mi355x)") +@pytest.mark.skipif(not ROCM_AITER_AVAILABLE, reason="Requires AITER enabled") +def test_w4a4_dispatches_to_aiter(): + """With AITER enabled + GFX950, W4A4 selects AITER_MXFP4_MXFP4.""" + config = _make_w4a4_moe_config() + backend, experts_cls = select_mxfp4_moe_backend( + config, activation_key=kMxfp4Dynamic + ) + assert backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4 + assert experts_cls is not None + + +@pytest.mark.skipif(not ROCM_GFX950, reason="Requires GFX950 (mi355x)") +@pytest.mark.skipif( + ROCM_AITER_AVAILABLE, + reason="Test requires AITER disabled (unset VLLM_ROCM_USE_AITER)", +) +def test_w4a4_raises_without_aiter_and_no_moe_backend(): + """Without AITER and no --moe-backend, raises NotImplementedError + with hint to use --moe-backend emulation.""" + config = _make_w4a4_moe_config() + with pytest.raises(NotImplementedError, match="--moe-backend emulation"): + select_mxfp4_moe_backend(config, activation_key=kMxfp4Dynamic) + + +@pytest.mark.skipif(not ROCM_GFX950, reason="Requires GFX950 (mi355x)") +def test_w4a4_dispatches_to_emulation_with_moe_backend(): + """With --moe-backend emulation, W4A4 selects EMULATION.""" + config = _make_w4a4_moe_config(moe_backend="emulation") + backend, experts_cls = select_mxfp4_moe_backend( + config, activation_key=kMxfp4Dynamic + ) + assert backend == Mxfp4MoeBackend.EMULATION + assert experts_cls is not None diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index 7c596d52a653..d9188a040d01 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -29,6 +29,7 @@ QuantKey, kFp8Dynamic128Sym, kFp8StaticTensorSym, + kMxfp4Dynamic, kMxfp4Static, kMxfp8Dynamic, ) @@ -68,6 +69,7 @@ class Mxfp4MoeBackend(Enum): # Keep the legacy name as an alias while the ROCm split backend rename settles. AITER = "AITER_MXFP4_BF16" AITER_MXFP4_FP8 = "AITER_MXFP4_FP8" # W4A8: triton kernel + AITER_MXFP4_MXFP4 = "AITER_MXFP4_MXFP4" # W4A4: CK kernel # Triton TRITON = "TRITON" TRITON_UNFUSED = "TRITON_UNFUSED" @@ -83,6 +85,7 @@ class Mxfp4MoeBackend(Enum): AITER_BACKENDS = ( Mxfp4MoeBackend.AITER_MXFP4_BF16, Mxfp4MoeBackend.AITER_MXFP4_FP8, + Mxfp4MoeBackend.AITER_MXFP4_MXFP4, ) @@ -187,6 +190,13 @@ def backend_to_kernel_cls( return [AiterW4A8ExpertsMonolithic] + elif backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + AiterMxfp4Experts, + ) + + return [AiterMxfp4Experts] + elif backend == Mxfp4MoeBackend.XPU: from vllm.model_executor.layers.fused_moe.experts.xpu_moe import XPUExpertsMXFp4 @@ -217,6 +227,7 @@ def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend: "marlin": Mxfp4MoeBackend.MARLIN, "aiter": Mxfp4MoeBackend.AITER_MXFP4_BF16, "aiter_mxfp4_fp8": Mxfp4MoeBackend.AITER_MXFP4_FP8, + "aiter_mxfp4_mxfp4": Mxfp4MoeBackend.AITER_MXFP4_MXFP4, "xpu": Mxfp4MoeBackend.XPU, "emulation": Mxfp4MoeBackend.EMULATION, } @@ -237,6 +248,7 @@ def _get_priority_backends_for_gpt_oss() -> list[Mxfp4MoeBackend]: Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, Mxfp4MoeBackend.AITER_MXFP4_BF16, Mxfp4MoeBackend.AITER_MXFP4_FP8, + Mxfp4MoeBackend.AITER_MXFP4_MXFP4, Mxfp4MoeBackend.TRITON, Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, # TRITON_UNFUSED has bug with MTP support @@ -245,7 +257,6 @@ def _get_priority_backends_for_gpt_oss() -> list[Mxfp4MoeBackend]: Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN, Mxfp4MoeBackend.XPU, - Mxfp4MoeBackend.EMULATION, ] return _AVAILABLE_BACKENDS @@ -281,6 +292,8 @@ def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None: return kMxfp8Dynamic if backend == Mxfp4MoeBackend.AITER_MXFP4_FP8: return kFp8StaticTensorSym + if backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4: + return kMxfp4Dynamic return None # BF16 activation @@ -480,7 +493,12 @@ def _return_or_raise( if current_platform.is_cuda() or current_platform.is_rocm(): raise NotImplementedError( - "No MXFP4 MoE backend supports the deployment configuration." + "No MXFP4 MoE backend supports the deployment configuration. " + f"weight_key=kMxfp4Static, activation_key={activation_key}. " + "Native backends require specific hardware. " + "Set `VLLM_LOGGING_LEVEL=DEBUG` to see detailed unsupported reasons. " + "To use the emulation backend for research/debugging, pass " + "--moe-backend emulation." ) return Mxfp4MoeBackend.NONE, None @@ -898,6 +916,49 @@ def _interleave_mxfp4_cutlass_sm90(w): w2_bias, ) + elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4: + from vllm._aiter_ops import rocm_aiter_ops + + if w13_bias is not None: + w13_bias = w13_bias.data.to(torch.float32) + if w2_bias is not None: + w2_bias = w2_bias.data.to(torch.float32) + + # e8m0_shuffle on weight scales (GFX950 swizzle layout) + from aiter.utility.fp4_utils import e8m0_shuffle + + s0, s1, _ = w13_weight_scale.shape + w13_weight_scale.data = e8m0_shuffle(w13_weight_scale.view(s0 * s1, -1)).view( + s0, s1, -1 + ) + + s0, s1, _ = w2_weight_scale.shape + w2_weight_scale.data = e8m0_shuffle(w2_weight_scale.view(s0 * s1, -1)).view( + s0, s1, -1 + ) + + # View as native FP4 dtype + fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None) + if fp4_dtype is not None: + w13_weight.data = w13_weight.data.view(fp4_dtype) + w2_weight.data = w2_weight.data.view(fp4_dtype) + + # Shuffle weights for AITER CK kernel + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( + w13_weight, w2_weight + ) + shuffled_w13.is_shuffled = True + shuffled_w2.is_shuffled = True + + return ( + shuffled_w13, + shuffled_w2, + w13_weight_scale, + w2_weight_scale, + w13_bias, + w2_bias, + ) + elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_BF16: from vllm._aiter_ops import rocm_aiter_ops @@ -1452,6 +1513,17 @@ def make_mxfp4_moe_quant_config( w2_bias=w2_bias, block_shape=None, ) + elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4: + return ocp_mx_moe_quant_config( + quant_dtype="mxfp4", + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=swiglu_limit, + ) elif mxfp4_backend in ( Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN, diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index d9d888296b74..a045dfe05a54 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -25,6 +25,7 @@ kFp8Static128BlockSym, kFp8StaticChannelSym, kFp8StaticTensorSym, + kMxfp4Dynamic, kMxfp4Static, ) @@ -327,6 +328,21 @@ def expects_unquantized_inputs(self) -> bool: def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard + @staticmethod + def is_supported_config( + cls, moe_config, weight_key, activation_key, activation_format + ): + is_supported, reason = super().is_supported_config( + cls, moe_config, weight_key, activation_key, activation_format + ) + if not is_supported and not rocm_aiter_ops.is_fused_moe_enabled(): + reason = ( + f"{reason}. AITER MoE is not enabled — " + "set VLLM_ROCM_USE_AITER=1 and VLLM_ROCM_USE_AITER_MOE=1 " + "to enable it" + ) + return is_supported, reason + @staticmethod def _supports_current_device() -> bool: return rocm_aiter_ops.is_fused_moe_enabled() @@ -439,3 +455,17 @@ def apply( output_dtype=output.dtype, ) output.copy_(result) + + +class AiterMxfp4Experts(AiterExperts): + """MXFP4 W4A4 variant: MXFP4 weights + dynamic MXFP4 activations.""" + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + return (weight_key, activation_key) == ( + kMxfp4Static, + kMxfp4Dynamic, + ) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index a14bfbc9c19b..6fdadd8ef733 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -47,6 +47,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8StaticTensorSym, + kMxfp4Dynamic, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, @@ -997,6 +998,11 @@ def __init__( self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend( moe, activation_key=kFp8StaticTensorSym ) + elif self.ocp_mx_scheme == "w_mxfp4_a_mxfp4": + # W4A4: MXFP4 weights + MXFP4 activations + self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend( + moe, activation_key=kMxfp4Dynamic + ) # Validation for unsupported schemes if any( @@ -1016,45 +1022,19 @@ def __init__( "Please open an issue." ) - self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled() - self.model_type = getattr( get_current_vllm_config().model_config.hf_config, "model_type", None ) - # TODO: Remove once all OCP MX schemes use the kernel abstraction - _AITER_NATIVE_OCP_MX_SCHEMES = ("w_mxfp4", "w_mxfp4_a_mxfp4", "w_mxfp4_a_fp8") - self.emulate = ( - not current_platform.supports_mx() - or self.ocp_mx_scheme not in _AITER_NATIVE_OCP_MX_SCHEMES - ) and ( - self.mxfp4_backend is Mxfp4MoeBackend.NONE or not self.use_rocm_aiter_moe - ) - - if self.emulate: - # We use the same code path between MXFP4/MXFP6 emulation. + # If no native backend available, use emulation. + if self.mxfp4_backend is Mxfp4MoeBackend.NONE: self.mxfp4_backend = Mxfp4MoeBackend.EMULATION - # TODO: Remove `self.mxfp4_backend != Mxfp4MoeBackend.NONE` and make it so that - # all MXFP4 backends use the kernel abstraction. - if self.mxfp4_backend != Mxfp4MoeBackend.NONE: - self.experts_cls = backend_to_kernel_cls(self.mxfp4_backend)[0] + self.experts_cls = backend_to_kernel_cls(self.mxfp4_backend)[0] - # Log backend selection - if self.mxfp4_backend != Mxfp4MoeBackend.NONE: - logger.info_once( - f"Using {self.mxfp4_backend.value} backend for {self.ocp_mx_scheme}" - ) - elif self.emulate: - logger.warning_once( - f"The current mode (supports_mx={current_platform.supports_mx()}, " - f"use_rocm_aiter_moe={self.use_rocm_aiter_moe}, " - f"ocp_mx_scheme={self.ocp_mx_scheme}) " - "does not support native MXFP4/MXFP6 " - "computation. Simulated weight dequantization and activation " - "QDQ (quantize and dequantize) will be used, with the linear " - "layers computed in high precision." - ) + logger.info_once( + f"Using {self.mxfp4_backend.value} backend for {self.ocp_mx_scheme}" + ) def maybe_roundup_sizes( self, @@ -1199,96 +1179,7 @@ def create_weights( layer.w2_input_scale = None def process_weights_after_loading(self, layer): - # For MXFP4 schemes with native backend, use oracle - if self.mxfp4_backend != Mxfp4MoeBackend.NONE: - self._setup_kernel(layer) - return - - if self.static_input_scales and self.input_dtype == "fp8": - # firstly, process activations if fp8 static input - if layer.w13_input_scale is None or layer.w2_input_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - if not all_close_1d(layer.w13_input_scale) or not all_close_1d( - layer.w2_input_scale - ): - logger.warning_once( - "Found input_scales that are not equal for " - "fp8 MoE layer. Using the maximum across experts " - "for each layer. " - ) - layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale.max(), requires_grad=False - ) - layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale.max(), requires_grad=False - ) - - if current_platform.is_fp8_fnuz(): - # Normalize the weights and scales - _, _, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz( - torch.empty_like(layer.w13_weight, dtype=torch.float8_e4m3fn), - torch.empty_like( - layer.w13_weight_scale, dtype=layer.w13_weight_scale.dtype - ), - layer.w13_input_scale, - ) - _, _, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( - torch.empty_like(layer.w2_weight, dtype=torch.float8_e4m3fn), - torch.empty_like( - layer.w2_weight_scale, dtype=layer.w13_weight_scale.dtype - ), - layer.w2_input_scale, - ) - # Reset the parameter - if w13_input_scale is not None: - layer.w13_input_scale = torch.nn.Parameter( - w13_input_scale, requires_grad=False - ) - if w2_input_scale is not None: - layer.w2_input_scale = torch.nn.Parameter( - w2_input_scale, requires_grad=False - ) - - # TODO(bowenbao): gradually migrate to oracles. - # Existing AITER path for w_mxfp4_a_mxfp4 and other schemes - from aiter.utility.fp4_utils import e8m0_shuffle - - # Pre-shuffle weight scales - s0, s1, _ = layer.w13_weight_scale.shape - w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1) - w13_weight_scale = e8m0_shuffle(w13_weight_scale) - layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1) - - s0, s1, _ = layer.w2_weight_scale.shape - w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1) - w2_weight_scale = e8m0_shuffle(w2_weight_scale) - layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1) - - if self.fp4_dtype is not None: - layer.w13_weight = torch.nn.Parameter( - layer.w13_weight.view(self.fp4_dtype), - requires_grad=layer.w13_weight.requires_grad, - ) - layer.w2_weight = torch.nn.Parameter( - layer.w2_weight.view(self.fp4_dtype), - requires_grad=layer.w2_weight.requires_grad, - ) - # Pre-shuffle weight - shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data - ) - - layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - layer.w13_weight.is_shuffled = True - layer.w2_weight.is_shuffled = True - - # Build quant config for AITER path - self.moe_quant_config = self.get_fused_moe_quant_config(layer) - torch.accelerator.empty_cache() + self._setup_kernel(layer) def _setup_kernel(self, layer: FusedMoE): """Setup kernel using oracle functions for MXFP4 schemes (W4A16, W4A8).""" @@ -1330,6 +1221,10 @@ def _setup_kernel(self, layer: FusedMoE): replace_parameter(layer, "w13_bias", w13_bias) replace_parameter(layer, "w2_bias", w2_bias) + if self.mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4: + layer.w13_weight.is_shuffled = True + layer.w2_weight.is_shuffled = True + torch.accelerator.empty_cache() # Build quant config and kernel @@ -1419,37 +1314,18 @@ def apply( topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor: - # For oracle-based kernels (W4A16, W4A8) or emulation kernel - if self.moe_kernel is not None: - return self.moe_kernel.apply( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - expert_map=layer.expert_map, - shared_experts_input=shared_experts_input, - ) - - # AITER path - # TODO: Refactor this to use modular MOE kernel as well. - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_fused_experts, - ) - - return rocm_aiter_fused_experts( - x, - layer.w13_weight, - layer.w2_weight, + assert self.moe_kernel is not None + return self.moe_kernel.apply( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, activation=layer.activation, - quant_config=self.moe_quant_config, - moe_config=layer.moe_config, + global_num_experts=layer.global_num_experts, + apply_router_weight_on_input=layer.apply_router_weight_on_input, expert_map=layer.expert_map, + shared_experts_input=shared_experts_input, ) def apply_monolithic(