Skip to content

[Quantization] enable MXFP4 Triton backend on SM120 (Blackwell)#31089

Closed
janreges wants to merge 3 commits intovllm-project:mainfrom
janreges:main
Closed

[Quantization] enable MXFP4 Triton backend on SM120 (Blackwell)#31089
janreges wants to merge 3 commits intovllm-project:mainfrom
janreges:main

Conversation

@janreges
Copy link

@janreges janreges commented Dec 21, 2025

Purpose

Enable MXFP4 Triton kernel backend on NVIDIA Blackwell consumer GPUs (SM120, compute capability 12.0).

Test Plan

Tested with a version compiled from current source code on NVIDIA RTX PRO 6000 Blackwell 96GB:

vllm serve \
  "openai/gpt-oss-120b" \
  --async-scheduling \
  --trust-remote-code \
  --gpu-memory-utilization 0.91 \
  --enable-chunked-prefill \
  --enable-prefix-caching \
  --tensor-parallel-size 1 \
  --max-num-batched-tokens 32768 \
  --max-model-len 131072 \
  --max-num-seqs 512 \
  --disable-log-requests \
  --reasoning-parser openai_gptoss \
  --enable-auto-tool-choice \
  --tool-call-parser openai \
  --port 8000

Test Result

Tested on NVIDIA RTX PRO 6000 Blackwell (compute capability 12.0) with AMD EPYC 9554 processor - model openai/gpt-oss-120b loads and runs successfully with MXFP4 Triton backend.

However, the performance is worse compared to the Marlin backend - batch 1 = 160 flow/s. Marlin with the same configuration 201 flow/s.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

- Add SM120 to triton_kernels_supported condition in both backend
  selection functions (get_mxfp4_backend, get_mxfp4_backend_with_lora)
- Use StridedLayout for SM120 to avoid "Must use persistent kernel"
  error caused by unsupported cluster TMA operations
- Configure SM120-specific constraints: is_persistent=False, num_stages=1

Tested on NVIDIA RTX PRO 6000 Blackwell (compute capability 12.0).
Requires Triton fix: triton-lang/triton#8498
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables the MXFP4 Triton backend for NVIDIA Blackwell (SM120) GPUs. The changes involve updating the device capability checks in mxfp4.py to include SM120 and adding specific configurations for Blackwell in mxfp4_utils.py to handle its architectural differences, such as disabling persistent kernels.

My main feedback is to refactor the duplicated logic for checking Triton kernel support in mxfp4.py into a helper function. This will improve code maintainability and prevent potential inconsistencies in the future. The rest of the changes look good and are well-commented.

Comment on lines 90 to 100
triton_kernels_supported = (
has_triton_kernels()
and is_torch_equal_or_newer("2.8.0")
# NOTE: triton_kernels are only confirmed to work on SM90 and SM100
# NOTE: triton_kernels are confirmed to work on SM90, SM100, and SM120
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
and (9, 0) <= current_platform.get_device_capability() < (11, 0)
# SM120 support added after Triton fix: https://github.com/triton-lang/triton/pull/8498
and (
(9, 0) <= current_platform.get_device_capability() < (11, 0)
or current_platform.is_device_capability_family(120)
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to determine if Triton kernels are supported is duplicated here and in get_mxfp4_backend (lines 155-165). This can lead to maintenance issues, as a future change might only be applied to one of the locations.

Additionally, the current implementation does not handle the case where current_platform.get_device_capability() returns None, which would cause a TypeError.

To improve maintainability, avoid code duplication, and fix the potential TypeError, I suggest extracting this logic into a new helper function.

For example, you could add the following helper function at the module level:

def _is_triton_mxfp4_supported_on_cuda() -> bool:
    """Checks if the Triton MXFP4 kernels are supported on CUDA."""
    capability = current_platform.get_device_capability()
    if capability is None:
        return False

    # NOTE: triton_kernels are confirmed to work on SM90, SM100, and SM120
    # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
    # SM120 support added after Triton fix: https://github.com/triton-lang/triton/pull/8498
    is_sm90_or_sm100 = (9, 0) <= capability < (11, 0)
    is_sm120 = current_platform.is_device_capability_family(120)

    return (has_triton_kernels() and is_torch_equal_or_newer("2.8.0")
            and (is_sm90_or_sm100 or is_sm120))

Then, you can simplify the code here and in get_mxfp4_backend by calling this new function.

    triton_kernels_supported = _is_triton_mxfp4_supported_on_cuda()

@mergify
Copy link

mergify bot commented Dec 21, 2025

Hi @janreges, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link

mergify bot commented Dec 21, 2025

Hi @janreges, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

…nction

- Extract duplicated logic for checking Triton MXFP4 support on CUDA
  into new _is_triton_mxfp4_supported_on_cuda() helper function
- Fix potential TypeError when get_device_capability() returns None
- Simplify code in get_mxfp4_backend() and get_mxfp4_backend_with_lora()

Addresses PR review feedback to improve maintainability and avoid
code duplication.

Signed-off-by: jan.reges <jan.reges@siteone.cz>
@mergify
Copy link

mergify bot commented Dec 21, 2025

Hi @janreges, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@longregen
Copy link
Contributor

This would be really nice to see merged!

@geraldstanje1
Copy link

geraldstanje1 commented Mar 5, 2026

hi, can someone merge this? would like to use fp4 on NVIDIA RTX PRO 6000 Blackwell.

currently also see the following:
(EngineCore_DP0 pid=281) INFO 03-05 17:53:04 [cuda.py:367] Using TRITON_ATTN attention backend out of potential backends: ['TRITON_ATTN'].
(EngineCore_DP0 pid=281) INFO 03-05 17:53:04 [mxfp4.py:157] Using Marlin backend

EngineCore_DP0 pid=280) WARNING 03-05 17:16:17 [marlin_utils_fp4.py:338] Your GPU does not have native support for FP4 computation but FP4 quantization is being used. Weight-only FP4 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.

cc @janreges @yewentao256 @robertgshaw2-redhat

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please merge from main and solve the pre-commit issue and conflicts

@geraldstanje1
Copy link

cc @janreges

@mgoin
Copy link
Member

mgoin commented Mar 9, 2026

@geraldstanje1 Sorry for the confusion but this warning

EngineCore_DP0 pid=280) WARNING 03-05 17:16:17 [marlin_utils_fp4.py:338] Your GPU does not have native support for FP4 computation but FP4 quantization is being used. Weight-only FP4 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.

isn't actually something to be concerned about since GPT-OSS is already a weight-only checkpoint i.e. MXFP4 W4A16

From the PR description, performance is actually worse with this backend

However, the performance is worse compared to the Marlin backend - batch 1 = 160 flow/s. Marlin with the same configuration 201 flow/s.

Can you try benchmarking locally to see if there is any reason to use this kernel?

@mgoin
Copy link
Member

mgoin commented Mar 16, 2026

Closing this PR for now as this kernel seems slower than Marlin on SM120 and achieves the same result of MXFP4 w4a16. The warning message shared is just a user confusion that we have removed on main

@mgoin mgoin closed this Mar 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants