Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions vllm/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
prepare_moe_fp4_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
CK_MXFP4_MOE_DIM_ALIGNMENT,
_can_support_mxfp4,
_swizzle_mxfp4,
get_padding_alignment,
Expand Down Expand Up @@ -259,6 +260,31 @@ def __init__(self, moe: FusedMoEConfig):
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
)

# CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension
# alignment requirements. Fall back to Triton when not met.
if (
self.mxfp4_backend == Mxfp4Backend.CK
and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0
):
if has_triton_kernels():
logger.warning_once(
"CK MXFP4 MoE GEMM does not support "
"intermediate_size_per_partition=%d (not a multiple of "
"%d). Falling back to Triton backend.",
moe.intermediate_size_per_partition,
CK_MXFP4_MOE_DIM_ALIGNMENT,
)
self.mxfp4_backend = Mxfp4Backend.TRITON
else:
raise ValueError(
f"CK MXFP4 MoE GEMM does not support "
f"intermediate_size_per_partition="
f"{moe.intermediate_size_per_partition} (not a multiple "
f"of {CK_MXFP4_MOE_DIM_ALIGNMENT}) and no Triton "
f"fallback is available. Use a compatible "
f"tensor_parallel_size."
)

assert self.mxfp4_backend != Mxfp4Backend.NONE, (
f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found"
"no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
Expand Down
31 changes: 30 additions & 1 deletion vllm/model_executor/layers/quantization/quark/quark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
CK_MXFP4_MOE_DIM_ALIGNMENT,
_swizzle_mxfp4,
)
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme,
Expand Down Expand Up @@ -732,6 +735,32 @@ def __init__(
or not self.ocp_mx_scheme.startswith("w_mxfp4")
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)

# CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChuanLi1101 Can we enable CK MXFP4 MoE with padding? I have a PR #34285 related to padding, I'll see if it can resolve these type of issues.

# alignment requirements. When violated (e.g. MiniMax-M2.1 with
# TP=4 yields intermediate_size_per_partition=384), AITER raises:
# "device_gemm ... does not support this GEMM problem".
# Fall back to emulation in that case.
if (
not self.emulate
and self.use_rocm_aiter_moe
and self.ocp_mx_scheme is not None
and self.ocp_mx_scheme.startswith("w_mxfp4")
and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0
):
logger.warning_once(
"AITER CK MXFP4 MoE GEMM does not support "
"intermediate_size_per_partition=%d (not a multiple of %d). "
"This typically happens when intermediate_size / "
"tensor_parallel_size produces an incompatible dimension. "
"Falling back to emulation mode. To avoid this overhead, "
"use a compatible tensor_parallel_size or set "
"VLLM_ROCM_USE_AITER_MOE=0.",
moe.intermediate_size_per_partition,
CK_MXFP4_MOE_DIM_ALIGNMENT,
)
self.use_rocm_aiter_moe = False
self.emulate = True

if self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@

logger = init_logger(__name__)

# CK's pre-compiled MXFP4 MoE GEMM kernel instances require the
# intermediate_size (after TP split) to be a multiple of this value.
# This arises from FP4 packing (2 values per byte) combined with CK
# tile size constraints. When violated, AITER raises:
# "device_gemm ... does not support this GEMM problem".
CK_MXFP4_MOE_DIM_ALIGNMENT = 256


def _swizzle_mxfp4(quant_tensor, scale, num_warps):
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
Expand Down