diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index dc0fbfa7df35..4e00689a99e6 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -61,6 +61,25 @@ logger = init_logger(__name__) +def _is_triton_mxfp4_supported_on_cuda() -> bool: + """Checks if the Triton MXFP4 kernels are supported on CUDA.""" + capability = current_platform.get_device_capability() + if capability is None: + return False + + # NOTE: triton_kernels are confirmed to work on SM90, SM100, and SM120 + # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 + # SM120 support added after Triton fix: https://github.com/triton-lang/triton/pull/8498 + is_sm90_or_sm100 = (9, 0) <= capability < (11, 0) + is_sm120 = current_platform.is_device_capability_family(120) + + return ( + has_triton_kernels() + and is_torch_equal_or_newer("2.8.0") + and (is_sm90_or_sm100 or is_sm120) + ) + + # enum for mxfp4 backend class Mxfp4Backend(Enum): NONE = 0 @@ -87,14 +106,7 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend: return Mxfp4Backend.NONE # If FlashInfer is not available, try either Marlin or Triton - triton_kernels_supported = ( - has_triton_kernels() - and is_torch_equal_or_newer("2.8.0") - # NOTE: triton_kernels are only confirmed to work on SM90 and SM100 - # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 - # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498 - and (9, 0) <= current_platform.get_device_capability() < (11, 0) - ) + triton_kernels_supported = _is_triton_mxfp4_supported_on_cuda() if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported: logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend") return Mxfp4Backend.TRITON @@ -149,14 +161,7 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: ) # If FlashInfer is not available, try either Marlin or Triton - triton_kernels_supported = ( - has_triton_kernels() - and is_torch_equal_or_newer("2.8.0") - # NOTE: triton_kernels are only confirmed to work on SM90 and SM100 - # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 - # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498 - and (9, 0) <= current_platform.get_device_capability() < (11, 0) - ) + triton_kernels_supported = _is_triton_mxfp4_supported_on_cuda() if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported: logger.info_once("Using Marlin backend") return Mxfp4Backend.MARLIN diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index e9ecf0547033..f3c1c37e5b57 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -38,6 +38,11 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): ) value_layout = StridedLayout scale_layout = StridedLayout + elif current_platform.is_cuda() and current_platform.is_device_capability_family(120): + # SM120 (Blackwell consumer) - cannot use persistent kernels due to cluster TMA + # Use StridedLayout to avoid "Must use persistent kernel" error in matmul_ogs.py + value_layout = StridedLayout + scale_layout = StridedLayout elif current_platform.is_rocm(): from vllm.platforms.rocm import on_gfx950 @@ -69,6 +74,14 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): "epilogue_subtile": 1, } opt_flags.update_opt_flags_constraints(constraints) + elif current_platform.is_device_capability_family(120): + # SM120 (Blackwell consumer) uses similar constraints to SM100 + # Note: cluster-related TMA operations are not supported on SM120 + constraints = { + "is_persistent": False, + "num_stages": 1, # SM120 shared memory limit + } + opt_flags.update_opt_flags_constraints(constraints) # transpose the tensor so that the quantization axis is on dim1 quant_tensor = quant_tensor.transpose(-2, -1) scale = scale.transpose(-2, -1)