Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions vllm/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,25 @@
logger = init_logger(__name__)


def _is_triton_mxfp4_supported_on_cuda() -> bool:
"""Checks if the Triton MXFP4 kernels are supported on CUDA."""
capability = current_platform.get_device_capability()
if capability is None:
return False

# NOTE: triton_kernels are confirmed to work on SM90, SM100, and SM120
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
# SM120 support added after Triton fix: https://github.com/triton-lang/triton/pull/8498
is_sm90_or_sm100 = (9, 0) <= capability < (11, 0)
is_sm120 = current_platform.is_device_capability_family(120)

return (
has_triton_kernels()
and is_torch_equal_or_newer("2.8.0")
and (is_sm90_or_sm100 or is_sm120)
)


# enum for mxfp4 backend
class Mxfp4Backend(Enum):
NONE = 0
Expand All @@ -87,14 +106,7 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
return Mxfp4Backend.NONE

# If FlashInfer is not available, try either Marlin or Triton
triton_kernels_supported = (
has_triton_kernels()
and is_torch_equal_or_newer("2.8.0")
# NOTE: triton_kernels are only confirmed to work on SM90 and SM100
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
and (9, 0) <= current_platform.get_device_capability() < (11, 0)
)
triton_kernels_supported = _is_triton_mxfp4_supported_on_cuda()
if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend")
return Mxfp4Backend.TRITON
Expand Down Expand Up @@ -149,14 +161,7 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
)

# If FlashInfer is not available, try either Marlin or Triton
triton_kernels_supported = (
has_triton_kernels()
and is_torch_equal_or_newer("2.8.0")
# NOTE: triton_kernels are only confirmed to work on SM90 and SM100
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
and (9, 0) <= current_platform.get_device_capability() < (11, 0)
)
triton_kernels_supported = _is_triton_mxfp4_supported_on_cuda()
if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
logger.info_once("Using Marlin backend")
return Mxfp4Backend.MARLIN
Expand Down
13 changes: 13 additions & 0 deletions vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
)
value_layout = StridedLayout
scale_layout = StridedLayout
elif current_platform.is_cuda() and current_platform.is_device_capability_family(120):

Check failure on line 41 in vllm/model_executor/layers/quantization/utils/mxfp4_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/quantization/utils/mxfp4_utils.py:41:89: E501 Line too long (90 > 88)
# SM120 (Blackwell consumer) - cannot use persistent kernels due to cluster TMA
# Use StridedLayout to avoid "Must use persistent kernel" error in matmul_ogs.py
value_layout = StridedLayout
scale_layout = StridedLayout
elif current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx950

Expand Down Expand Up @@ -69,6 +74,14 @@
"epilogue_subtile": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
elif current_platform.is_device_capability_family(120):
# SM120 (Blackwell consumer) uses similar constraints to SM100
# Note: cluster-related TMA operations are not supported on SM120
constraints = {
"is_persistent": False,
"num_stages": 1, # SM120 shared memory limit
}
opt_flags.update_opt_flags_constraints(constraints)
# transpose the tensor so that the quantization axis is on dim1
quant_tensor = quant_tensor.transpose(-2, -1)
scale = scale.transpose(-2, -1)
Expand Down
Loading