From fafa0eac36648db5c8ae985091eedb6193e6885b Mon Sep 17 00:00:00 2001 From: "jan.reges" Date: Sun, 21 Dec 2025 03:56:37 +0100 Subject: [PATCH 1/2] [Quantization] enable MXFP4 Triton backend on SM120 (Blackwell) - Add SM120 to triton_kernels_supported condition in both backend selection functions (get_mxfp4_backend, get_mxfp4_backend_with_lora) - Use StridedLayout for SM120 to avoid "Must use persistent kernel" error caused by unsupported cluster TMA operations - Configure SM120-specific constraints: is_persistent=False, num_stages=1 Tested on NVIDIA RTX PRO 6000 Blackwell (compute capability 12.0). Requires Triton fix: https://github.com/triton-lang/triton/pull/8498 --- .../layers/quantization/mxfp4.py | 18 ++++++++++++------ .../layers/quantization/utils/mxfp4_utils.py | 13 +++++++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index dc0fbfa7df35..96e236136a1e 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -90,10 +90,13 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend: triton_kernels_supported = ( has_triton_kernels() and is_torch_equal_or_newer("2.8.0") - # NOTE: triton_kernels are only confirmed to work on SM90 and SM100 + # NOTE: triton_kernels are confirmed to work on SM90, SM100, and SM120 # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 - # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498 - and (9, 0) <= current_platform.get_device_capability() < (11, 0) + # SM120 support added after Triton fix: https://github.com/triton-lang/triton/pull/8498 + and ( + (9, 0) <= current_platform.get_device_capability() < (11, 0) + or current_platform.is_device_capability_family(120) + ) ) if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported: logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend") @@ -152,10 +155,13 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: triton_kernels_supported = ( has_triton_kernels() and is_torch_equal_or_newer("2.8.0") - # NOTE: triton_kernels are only confirmed to work on SM90 and SM100 + # NOTE: triton_kernels are confirmed to work on SM90, SM100, and SM120 # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 - # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498 - and (9, 0) <= current_platform.get_device_capability() < (11, 0) + # SM120 support added after Triton fix: https://github.com/triton-lang/triton/pull/8498 + and ( + (9, 0) <= current_platform.get_device_capability() < (11, 0) + or current_platform.is_device_capability_family(120) + ) ) if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported: logger.info_once("Using Marlin backend") diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index e9ecf0547033..f3c1c37e5b57 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -38,6 +38,11 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): ) value_layout = StridedLayout scale_layout = StridedLayout + elif current_platform.is_cuda() and current_platform.is_device_capability_family(120): + # SM120 (Blackwell consumer) - cannot use persistent kernels due to cluster TMA + # Use StridedLayout to avoid "Must use persistent kernel" error in matmul_ogs.py + value_layout = StridedLayout + scale_layout = StridedLayout elif current_platform.is_rocm(): from vllm.platforms.rocm import on_gfx950 @@ -69,6 +74,14 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): "epilogue_subtile": 1, } opt_flags.update_opt_flags_constraints(constraints) + elif current_platform.is_device_capability_family(120): + # SM120 (Blackwell consumer) uses similar constraints to SM100 + # Note: cluster-related TMA operations are not supported on SM120 + constraints = { + "is_persistent": False, + "num_stages": 1, # SM120 shared memory limit + } + opt_flags.update_opt_flags_constraints(constraints) # transpose the tensor so that the quantization axis is on dim1 quant_tensor = quant_tensor.transpose(-2, -1) scale = scale.transpose(-2, -1) From 85d68aedf1447e70f482ef2111552401efa0f588 Mon Sep 17 00:00:00 2001 From: "jan.reges" Date: Sun, 21 Dec 2025 13:14:23 +0100 Subject: [PATCH 2/2] [Quantization] refactor Triton MXFP4 support detection into helper function - Extract duplicated logic for checking Triton MXFP4 support on CUDA into new _is_triton_mxfp4_supported_on_cuda() helper function - Fix potential TypeError when get_device_capability() returns None - Simplify code in get_mxfp4_backend() and get_mxfp4_backend_with_lora() Addresses PR review feedback to improve maintainability and avoid code duplication. Signed-off-by: jan.reges --- .../layers/quantization/mxfp4.py | 43 +++++++++---------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 96e236136a1e..4e00689a99e6 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -61,6 +61,25 @@ logger = init_logger(__name__) +def _is_triton_mxfp4_supported_on_cuda() -> bool: + """Checks if the Triton MXFP4 kernels are supported on CUDA.""" + capability = current_platform.get_device_capability() + if capability is None: + return False + + # NOTE: triton_kernels are confirmed to work on SM90, SM100, and SM120 + # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 + # SM120 support added after Triton fix: https://github.com/triton-lang/triton/pull/8498 + is_sm90_or_sm100 = (9, 0) <= capability < (11, 0) + is_sm120 = current_platform.is_device_capability_family(120) + + return ( + has_triton_kernels() + and is_torch_equal_or_newer("2.8.0") + and (is_sm90_or_sm100 or is_sm120) + ) + + # enum for mxfp4 backend class Mxfp4Backend(Enum): NONE = 0 @@ -87,17 +106,7 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend: return Mxfp4Backend.NONE # If FlashInfer is not available, try either Marlin or Triton - triton_kernels_supported = ( - has_triton_kernels() - and is_torch_equal_or_newer("2.8.0") - # NOTE: triton_kernels are confirmed to work on SM90, SM100, and SM120 - # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 - # SM120 support added after Triton fix: https://github.com/triton-lang/triton/pull/8498 - and ( - (9, 0) <= current_platform.get_device_capability() < (11, 0) - or current_platform.is_device_capability_family(120) - ) - ) + triton_kernels_supported = _is_triton_mxfp4_supported_on_cuda() if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported: logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend") return Mxfp4Backend.TRITON @@ -152,17 +161,7 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: ) # If FlashInfer is not available, try either Marlin or Triton - triton_kernels_supported = ( - has_triton_kernels() - and is_torch_equal_or_newer("2.8.0") - # NOTE: triton_kernels are confirmed to work on SM90, SM100, and SM120 - # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 - # SM120 support added after Triton fix: https://github.com/triton-lang/triton/pull/8498 - and ( - (9, 0) <= current_platform.get_device_capability() < (11, 0) - or current_platform.is_device_capability_family(120) - ) - ) + triton_kernels_supported = _is_triton_mxfp4_supported_on_cuda() if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported: logger.info_once("Using Marlin backend") return Mxfp4Backend.MARLIN