From 3905f42a12e2ddd93d26a369acf16cb5c14ceea4 Mon Sep 17 00:00:00 2001 From: Li Date: Tue, 3 Mar 2026 10:15:05 -0800 Subject: [PATCH] [ROCm][Bugfix] Fall back from CK MXFP4 MoE when GEMM dimensions are unsupported CK's pre-compiled MXFP4 MoE GEMM kernel instances require the intermediate_size (after TP split) to be a multiple of 256. When this constraint is not met (e.g. MiniMax-M2.1 with TP=4 yields intermediate_size_per_partition=384), AITER raises: "device_gemm with the specified compilation parameters does not support this GEMM problem". Add dimension validation in both the Quark OCP_MX MoE method and the Mxfp4MoEMethod to detect incompatible dimensions at initialization time and fall back gracefully: - Quark path: falls back to emulation mode (weight dequantization + bf16) - Mxfp4 path: falls back from CK to Triton backend Fixes #35637 Signed-off-by: Li Made-with: Cursor --- .../layers/quantization/mxfp4.py | 26 ++++++++++++++++ .../layers/quantization/quark/quark_moe.py | 31 ++++++++++++++++++- .../layers/quantization/utils/mxfp4_utils.py | 7 +++++ 3 files changed, 63 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 8856eb1e2e49..12dedcddc710 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -48,6 +48,7 @@ prepare_moe_fp4_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + CK_MXFP4_MOE_DIM_ALIGNMENT, _can_support_mxfp4, _swizzle_mxfp4, get_padding_alignment, @@ -259,6 +260,31 @@ def __init__(self, moe: FusedMoEConfig): get_current_vllm_config().compilation_config.max_cudagraph_capture_size ) + # CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension + # alignment requirements. Fall back to Triton when not met. + if ( + self.mxfp4_backend == Mxfp4Backend.CK + and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0 + ): + if has_triton_kernels(): + logger.warning_once( + "CK MXFP4 MoE GEMM does not support " + "intermediate_size_per_partition=%d (not a multiple of " + "%d). Falling back to Triton backend.", + moe.intermediate_size_per_partition, + CK_MXFP4_MOE_DIM_ALIGNMENT, + ) + self.mxfp4_backend = Mxfp4Backend.TRITON + else: + raise ValueError( + f"CK MXFP4 MoE GEMM does not support " + f"intermediate_size_per_partition=" + f"{moe.intermediate_size_per_partition} (not a multiple " + f"of {CK_MXFP4_MOE_DIM_ALIGNMENT}) and no Triton " + f"fallback is available. Use a compatible " + f"tensor_parallel_size." + ) + assert self.mxfp4_backend != Mxfp4Backend.NONE, ( f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found" "no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)." diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index b2abbce1aa1e..b7cb84e8ff3d 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -32,7 +32,10 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_fp8_moe_layer_for_marlin, ) -from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4 +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + CK_MXFP4_MOE_DIM_ALIGNMENT, + _swizzle_mxfp4, +) from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_BLOCK_SIZE, OCP_MX_Scheme, @@ -732,6 +735,32 @@ def __init__( or not self.ocp_mx_scheme.startswith("w_mxfp4") ) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe) + # CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension + # alignment requirements. When violated (e.g. MiniMax-M2.1 with + # TP=4 yields intermediate_size_per_partition=384), AITER raises: + # "device_gemm ... does not support this GEMM problem". + # Fall back to emulation in that case. + if ( + not self.emulate + and self.use_rocm_aiter_moe + and self.ocp_mx_scheme is not None + and self.ocp_mx_scheme.startswith("w_mxfp4") + and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0 + ): + logger.warning_once( + "AITER CK MXFP4 MoE GEMM does not support " + "intermediate_size_per_partition=%d (not a multiple of %d). " + "This typically happens when intermediate_size / " + "tensor_parallel_size produces an incompatible dimension. " + "Falling back to emulation mode. To avoid this overhead, " + "use a compatible tensor_parallel_size or set " + "VLLM_ROCM_USE_AITER_MOE=0.", + moe.intermediate_size_per_partition, + CK_MXFP4_MOE_DIM_ALIGNMENT, + ) + self.use_rocm_aiter_moe = False + self.emulate = True + if self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 9dbfc6ecad7b..23d7cf55474a 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -14,6 +14,13 @@ logger = init_logger(__name__) +# CK's pre-compiled MXFP4 MoE GEMM kernel instances require the +# intermediate_size (after TP split) to be a multiple of this value. +# This arises from FP4 packing (2 values per byte) combined with CK +# tile size constraints. When violated, AITER raises: +# "device_gemm ... does not support this GEMM problem". +CK_MXFP4_MOE_DIM_ALIGNMENT = 256 + def _swizzle_mxfp4(quant_tensor, scale, num_warps): """weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""