Skip to content

[Bugfix] Register fp8 cutlass_group_gemm as supported for only SM90+SM100#33285

Merged
vllm-bot merged 1 commit intovllm-project:mainfrom
neuralmagic:fix-cutlass_group_gemm_supported
Jan 29, 2026
Merged

[Bugfix] Register fp8 cutlass_group_gemm as supported for only SM90+SM100#33285
vllm-bot merged 1 commit intovllm-project:mainfrom
neuralmagic:fix-cutlass_group_gemm_supported

Conversation

@mgoin
Copy link
Copy Markdown
Member

@mgoin mgoin commented Jan 28, 2026

Purpose

FIX #32109

We only have implementation for SM90 and SM100, so we should properly restrict for the FP8 oracle to work. Without this change users on SM120 would default to this kernel backend and see an unsupported error when it should be using the Triton kernel.

csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu
csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu

Test Plan

Test Result

Tested manually by locally changing the condition to disqualify cuda_device_capability == 100, where my resulting kernel selection changed from

(Worker_TP0 pid=1113294) INFO 01-28 17:18:33 [fp8.py:329] Using VLLM_CUTLASS Fp8 MoE backend out of potential backends: ['AITER', 'FLASHINFER_TRTLLM', 'FLASHINFER_CUTLASS', 'DEEPGEMM', 'BATCHED_DEEPGEMM', 'VLLM_CUTLASS', 'BATCHED_VLLM_CUTLASS', 'TRITON', 'BATCHED_TRITON', 'MARLIN'].

to

(Worker_TP0 pid=1118221) INFO 01-28 17:19:55 [fp8.py:329] Using TRITON Fp8 MoE backend out of potential backends: ['AITER', 'FLASHINFER_TRTLLM', 'FLASHINFER_CUTLASS', 'DEEPGEMM', 'BATCHED_DEEPGEMM', 'VLLM_CUTLASS', 'BATCHED_VLLM_CUTLASS', 'TRITON', 'BATCHED_TRITON', 'MARLIN'].

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

…M100

Signed-off-by: mgoin <mgoin64@gmail.com>
@mergify mergify bot added nvidia bug Something isn't working labels Jan 28, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly restricts the cutlass_group_gemm to be supported only on architectures with compute capabilities 9.x (Hopper) and 10.x (Blackwell). The added check is straightforward and effectively prevents the kernel from being used on unsupported hardware, which resolves the underlying bug. The implementation is well-placed and looks good.

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 28, 2026
@vllm-bot vllm-bot merged commit 1bd47d6 into vllm-project:main Jan 29, 2026
48 of 49 checks passed
@github-project-automation github-project-automation bot moved this to Done in NVIDIA Jan 29, 2026
@mgoin mgoin added this to the v0.15.1 Hotfix milestone Jan 29, 2026
apd10 pushed a commit to apd10/vllm that referenced this pull request Jan 31, 2026
khluu pushed a commit that referenced this pull request Feb 2, 2026
…M100 (#33285)

Signed-off-by: mgoin <mgoin64@gmail.com>
(cherry picked from commit 1bd47d6)
PiratePai pushed a commit to PiratePai/epd_shm that referenced this pull request Feb 3, 2026
…M100 (vllm-project#33285)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: PiratePai <416932041@qq.com>
Signed-off-by: Pai <416932041@qq.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Bug]: Blackwell (SM120) FP8 MoE path fails for GLM-4.7 : No compiled cutlass_scaled_mm for CUDA device capability: 120 on RTX PRO 6000 Blackwell

2 participants