From ed259dc077e435d9bdb6165c2680a55bf9ab2fb7 Mon Sep 17 00:00:00 2001 From: Bortlesboat Date: Sun, 19 Apr 2026 12:35:07 -0400 Subject: [PATCH 1/3] [ROCm][Bugfix] Fall back when Quark MoE AITER dispatch is unsupported Co-authored-by: OpenAI Codex Signed-off-by: Bortlesboat --- .../layers/test_quark_ocp_mx_moe.py | 117 ++++++++++++++++++ .../layers/quantization/quark/quark_moe.py | 42 ++++++- 2 files changed, 153 insertions(+), 6 deletions(-) create mode 100644 tests/model_executor/layers/test_quark_ocp_mx_moe.py diff --git a/tests/model_executor/layers/test_quark_ocp_mx_moe.py b/tests/model_executor/layers/test_quark_ocp_mx_moe.py new file mode 100644 index 000000000000..05d81b903707 --- /dev/null +++ b/tests/model_executor/layers/test_quark_ocp_mx_moe.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +import torch + +from vllm.model_executor.layers.quantization.quark import quark_moe + + +def _make_method() -> quark_moe.QuarkOCP_MX_MoEMethod: + method = object.__new__(quark_moe.QuarkOCP_MX_MoEMethod) + method.moe_kernel = None + method.emulate = False + method.use_rocm_aiter_moe = True + method.moe_quant_config = object() + method.moe = SimpleNamespace(disable_inplace=False) + method.ocp_mx_scheme = "w_mxfp4_a_fp8" + return method + + +def _make_layer() -> SimpleNamespace: + return SimpleNamespace( + w13_weight=torch.randn(2, 4, 4), + w2_weight=torch.randn(2, 4, 4), + activation=quark_moe.MoEActivation.SILU, + global_num_experts=2, + apply_router_weight_on_input=False, + expert_map=None, + moe_config=SimpleNamespace(), + ) + + +def test_quark_ocp_mx_moe_falls_back_for_unsupported_aiter_dispatch( + monkeypatch: pytest.MonkeyPatch, +): + method = _make_method() + layer = _make_layer() + x = torch.randn(3, 4) + topk_weights = torch.randn(3, 2) + topk_ids = torch.tensor([[0, 1], [1, 0], [0, 1]], dtype=torch.int32) + expected = torch.randn(3, 4) + + aiter_mock = MagicMock( + side_effect=RuntimeError("Unsupported kernel config for moe heuristic dispatch") + ) + fused_mock = MagicMock(return_value=expected) + warning_mock = MagicMock() + + monkeypatch.setattr( + "vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts", # noqa: E501 + aiter_mock, + ) + monkeypatch.setattr( + "vllm.model_executor.layers.fused_moe.fused_experts", + fused_mock, + ) + monkeypatch.setattr(quark_moe.logger, "warning_once", warning_mock) + + result = method.apply(layer, x, topk_weights, topk_ids, shared_experts_input=None) + + assert result is expected + assert method.emulate is True + assert method.use_rocm_aiter_moe is False + aiter_mock.assert_called_once() + fused_mock.assert_called_once() + assert fused_mock.call_args.args == ( + x, + layer.w13_weight, + layer.w2_weight, + ) + assert fused_mock.call_args.kwargs == { + "topk_weights": topk_weights, + "topk_ids": topk_ids, + "inplace": True, + "activation": layer.activation, + "global_num_experts": layer.global_num_experts, + "apply_router_weight_on_input": layer.apply_router_weight_on_input, + "expert_map": layer.expert_map, + "quant_config": method.moe_quant_config, + } + warning_mock.assert_called_once() + assert "Unsupported kernel config for moe heuristic dispatch" in str( + warning_mock.call_args + ) + + +def test_quark_ocp_mx_moe_preserves_unrelated_aiter_runtime_errors( + monkeypatch: pytest.MonkeyPatch, +): + method = _make_method() + layer = _make_layer() + x = torch.randn(3, 4) + topk_weights = torch.randn(3, 2) + topk_ids = torch.tensor([[0, 1], [1, 0], [0, 1]], dtype=torch.int32) + + aiter_mock = MagicMock(side_effect=RuntimeError("different aiter failure")) + fused_mock = MagicMock() + + monkeypatch.setattr( + "vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts", # noqa: E501 + aiter_mock, + ) + monkeypatch.setattr( + "vllm.model_executor.layers.fused_moe.fused_experts", + fused_mock, + ) + + with pytest.raises(RuntimeError, match="different aiter failure"): + method.apply(layer, x, topk_weights, topk_ids, shared_experts_input=None) + + assert method.emulate is False + assert method.use_rocm_aiter_moe is True + aiter_mock.assert_called_once() + fused_mock.assert_not_called() diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index d92acb85c265..03751ef7a372 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -1442,20 +1442,50 @@ def apply( # AITER path # TODO: Refactor this to use modular MOE kernel as well. - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_fused_experts, - ) + if not self.emulate: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts, + ) - return rocm_aiter_fused_experts( + try: + return rocm_aiter_fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=layer.activation, + quant_config=self.moe_quant_config, + moe_config=layer.moe_config, + expert_map=layer.expert_map, + ) + except RuntimeError as exc: + if "Unsupported kernel config for moe heuristic dispatch" not in str( + exc + ): + raise + logger.warning_once( + "ROCm AITER fused MoE raised " + f"'{exc}' for {self.ocp_mx_scheme}; " + "falling back to emulated fused_experts." + ) + self.use_rocm_aiter_moe = False + self.emulate = True + + from vllm.model_executor.layers.fused_moe import fused_experts + + return fused_experts( x, layer.w13_weight, layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, + inplace=not self.moe.disable_inplace, activation=layer.activation, - quant_config=self.moe_quant_config, - moe_config=layer.moe_config, + global_num_experts=layer.global_num_experts, + apply_router_weight_on_input=layer.apply_router_weight_on_input, expert_map=layer.expert_map, + quant_config=self.moe_quant_config, ) def apply_monolithic( From 8de7176c19ffd9f4f889b8d001bfd911b6336d1f Mon Sep 17 00:00:00 2001 From: Bortlesboat Date: Mon, 20 Apr 2026 00:39:18 -0400 Subject: [PATCH 2/3] chore: retrigger CI Signed-off-by: Bortlesboat From 8b4894573865388b10aa92c2f690b5d915392f7e Mon Sep 17 00:00:00 2001 From: Bortlesboat Date: Wed, 29 Apr 2026 20:32:03 -0400 Subject: [PATCH 3/3] Address Quark OCP MX AITER review feedback Signed-off-by: Bortlesboat --- .../layers/test_quark_ocp_mx_moe.py | 45 ++++++------------- .../layers/quantization/quark/quark_moe.py | 41 +++-------------- 2 files changed, 19 insertions(+), 67 deletions(-) diff --git a/tests/model_executor/layers/test_quark_ocp_mx_moe.py b/tests/model_executor/layers/test_quark_ocp_mx_moe.py index 05d81b903707..9c5a23e51e9f 100644 --- a/tests/model_executor/layers/test_quark_ocp_mx_moe.py +++ b/tests/model_executor/layers/test_quark_ocp_mx_moe.py @@ -14,26 +14,23 @@ def _make_method() -> quark_moe.QuarkOCP_MX_MoEMethod: method = object.__new__(quark_moe.QuarkOCP_MX_MoEMethod) method.moe_kernel = None method.emulate = False - method.use_rocm_aiter_moe = True method.moe_quant_config = object() - method.moe = SimpleNamespace(disable_inplace=False) - method.ocp_mx_scheme = "w_mxfp4_a_fp8" return method -def _make_layer() -> SimpleNamespace: +def _make_layer(apply_router_weight_on_input: bool = True) -> SimpleNamespace: return SimpleNamespace( w13_weight=torch.randn(2, 4, 4), w2_weight=torch.randn(2, 4, 4), activation=quark_moe.MoEActivation.SILU, global_num_experts=2, - apply_router_weight_on_input=False, + apply_router_weight_on_input=apply_router_weight_on_input, expert_map=None, moe_config=SimpleNamespace(), ) -def test_quark_ocp_mx_moe_falls_back_for_unsupported_aiter_dispatch( +def test_quark_ocp_mx_moe_aiter_apply_forwards_router_weight_flag( monkeypatch: pytest.MonkeyPatch, ): method = _make_method() @@ -43,51 +40,34 @@ def test_quark_ocp_mx_moe_falls_back_for_unsupported_aiter_dispatch( topk_ids = torch.tensor([[0, 1], [1, 0], [0, 1]], dtype=torch.int32) expected = torch.randn(3, 4) - aiter_mock = MagicMock( - side_effect=RuntimeError("Unsupported kernel config for moe heuristic dispatch") - ) - fused_mock = MagicMock(return_value=expected) - warning_mock = MagicMock() + aiter_mock = MagicMock(return_value=expected) monkeypatch.setattr( "vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts", # noqa: E501 aiter_mock, ) - monkeypatch.setattr( - "vllm.model_executor.layers.fused_moe.fused_experts", - fused_mock, - ) - monkeypatch.setattr(quark_moe.logger, "warning_once", warning_mock) result = method.apply(layer, x, topk_weights, topk_ids, shared_experts_input=None) assert result is expected - assert method.emulate is True - assert method.use_rocm_aiter_moe is False aiter_mock.assert_called_once() - fused_mock.assert_called_once() - assert fused_mock.call_args.args == ( + assert aiter_mock.call_args.args == ( x, layer.w13_weight, layer.w2_weight, ) - assert fused_mock.call_args.kwargs == { + assert aiter_mock.call_args.kwargs == { "topk_weights": topk_weights, "topk_ids": topk_ids, - "inplace": True, "activation": layer.activation, - "global_num_experts": layer.global_num_experts, "apply_router_weight_on_input": layer.apply_router_weight_on_input, - "expert_map": layer.expert_map, "quant_config": method.moe_quant_config, + "moe_config": layer.moe_config, + "expert_map": layer.expert_map, } - warning_mock.assert_called_once() - assert "Unsupported kernel config for moe heuristic dispatch" in str( - warning_mock.call_args - ) -def test_quark_ocp_mx_moe_preserves_unrelated_aiter_runtime_errors( +def test_quark_ocp_mx_moe_does_not_runtime_fallback_after_aiter_error( monkeypatch: pytest.MonkeyPatch, ): method = _make_method() @@ -96,7 +76,9 @@ def test_quark_ocp_mx_moe_preserves_unrelated_aiter_runtime_errors( topk_weights = torch.randn(3, 2) topk_ids = torch.tensor([[0, 1], [1, 0], [0, 1]], dtype=torch.int32) - aiter_mock = MagicMock(side_effect=RuntimeError("different aiter failure")) + aiter_mock = MagicMock( + side_effect=RuntimeError("Unsupported kernel config for moe heuristic dispatch") + ) fused_mock = MagicMock() monkeypatch.setattr( @@ -108,10 +90,9 @@ def test_quark_ocp_mx_moe_preserves_unrelated_aiter_runtime_errors( fused_mock, ) - with pytest.raises(RuntimeError, match="different aiter failure"): + with pytest.raises(RuntimeError, match="Unsupported kernel config"): method.apply(layer, x, topk_weights, topk_ids, shared_experts_input=None) assert method.emulate is False - assert method.use_rocm_aiter_moe is True aiter_mock.assert_called_once() fused_mock.assert_not_called() diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 4d011273fbfe..e5fce5bb0f36 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -1438,50 +1438,21 @@ def apply( # AITER path # TODO: Refactor this to use modular MOE kernel as well. - if not self.emulate: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_fused_experts, - ) - - try: - return rocm_aiter_fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=layer.activation, - quant_config=self.moe_quant_config, - moe_config=layer.moe_config, - expert_map=layer.expert_map, - ) - except RuntimeError as exc: - if "Unsupported kernel config for moe heuristic dispatch" not in str( - exc - ): - raise - logger.warning_once( - "ROCm AITER fused MoE raised " - f"'{exc}' for {self.ocp_mx_scheme}; " - "falling back to emulated fused_experts." - ) - self.use_rocm_aiter_moe = False - self.emulate = True - - from vllm.model_executor.layers.fused_moe import fused_experts + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts, + ) - return fused_experts( + return rocm_aiter_fused_experts( x, layer.w13_weight, layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=not self.moe.disable_inplace, activation=layer.activation, - global_num_experts=layer.global_num_experts, apply_router_weight_on_input=layer.apply_router_weight_on_input, - expert_map=layer.expert_map, quant_config=self.moe_quant_config, + moe_config=layer.moe_config, + expert_map=layer.expert_map, ) def apply_monolithic(