Skip to content

[Bugfix] Fix FLA Hopper/TMA misclassification on SM12x desktop Blackwell#37700

Open
RobTand wants to merge 7 commits intovllm-project:mainfrom
RobTand:fix/fla-sm12x-hopper-tma-misclassification
Open

[Bugfix] Fix FLA Hopper/TMA misclassification on SM12x desktop Blackwell#37700
RobTand wants to merge 7 commits intovllm-project:mainfrom
RobTand:fix/fla-sm12x-hopper-tma-misclassification

Conversation

@RobTand
Copy link
Copy Markdown

@RobTand RobTand commented Mar 20, 2026

Summary

SM12x GPUs (RTX 5090/5080 = SM120, DGX Spark GB10 = SM121) have capability major=12, which triggers the >= 9 checks in FLA ops designed for Hopper (SM90). This causes three problems on desktop Blackwell:

  1. is_nvidia_hopper = True — restricts NUM_WARPS to [2, 4], missing 8-warp configs that give ~1.8x decode speedup
  2. is_tma_supported = True — SM12x desktop GPUs have only ~101KB shared memory per SM, which is insufficient for the TMA code paths that the Triton autotuner tries to compile, causing OOM in fla/solve_tril. (Note: SM12x may have TMA hardware — the issue is SMEM capacity, not TMA absence.)
  3. BKV_LIST misses 128 — GB10 has 101KB SMEM (just under the 102.4KB threshold in check_shared_mem()), but BKV=128 works at lower stage counts

Fix

  • Use current_platform.has_device_capability() / is_device_capability_family() instead of raw torch.cuda calls for is_nvidia_hopper detection (excludes SM12x family while preserving SM9x–SM11x)
  • Replace architecture range check for is_tma_supported with a shared memory threshold (128KB via get_all_max_shared_mem()). This correctly captures the root cause — insufficient SMEM for Triton autotuner compilation — and future-proofs for SM12x variants with datacenter-class SMEM (e.g. RTX Pro 6000 Blackwell server edition)
  • Set BKV_LIST = [32, 64, 128] unconditionally in chunk_o.py — the Triton autotuner safely skips configurations that exceed available SMEM
  • Add parametrized unit tests that mock current_platform and reload the actual utils module, covering Ampere, Hopper, datacenter Blackwell, desktop Blackwell, and a hypothetical SM12x with sufficient SMEM

Why this is better than #36325

PR #36325 blanket-disables TMA on all Blackwell (major >= 12). This PR:

  • Uses an SMEM threshold for TMA, so SM12x GPUs with sufficient shared memory (e.g. server-class Blackwell) automatically get TMA support
  • Excludes SM12x from Hopper warp configs using is_device_capability_family(120) without affecting SM10x/SM11x
  • Preserves Hopper SM90 behavior

Testing

Tested on DGX Spark (SM121, 128 GB unified LPDDR5X):

  • Qwen3.5-35B-A3B-NVFP4: decode TPOT 64ms → 35ms (1.8x improvement)
  • Qwen3.5-122B-A10B-NVFP4: 26 tok/s with MTP n=3
  • Nemotron-3-Super-120B-A12B-NVFP4: 24 tok/s with MTP n=3

Test plan

  • Parametrized unit test for capability classification (test_fla_sm12x_capability.py)
  • Validated on DGX Spark with GDN models (Qwen3.5, Nemotron)
  • Pre-commit checks pass

Related PRs

Fixes #31128 (partial — FLA component)

@mergify mergify bot added the bug Something isn't working label Mar 20, 2026
@RobTand RobTand force-pushed the fix/fla-sm12x-hopper-tma-misclassification branch from 336d9f5 to 013c1ce Compare March 20, 2026 16:33
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a misclassification issue for desktop Blackwell GPUs (SM12x) by adjusting the capability checks for Hopper and TMA support. The changes in utils.py and chunk_o.py are logical and well-explained. However, the newly added unit test in test_fla_sm12x_capability.py is fundamentally flawed: it contains incorrect test data that contradicts both the implementation and the PR's description, and it reimplements logic instead of testing the module directly. This makes the test fail for valid cases and provides a false sense of correctness. I've added a critical review comment detailing the issues in the test and how to fix them.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@RobTand RobTand force-pushed the fix/fla-sm12x-hopper-tma-misclassification branch 2 times, most recently from 48f0685 to 0f5fcd1 Compare March 20, 2026 16:38
@RobTand RobTand marked this pull request as ready for review March 20, 2026 16:40
@RobTand RobTand force-pushed the fix/fla-sm12x-hopper-tma-misclassification branch from 0f5fcd1 to 9e5e0eb Compare March 20, 2026 16:43
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses the misclassification of SM12x desktop Blackwell GPUs by refining the capability checks. The change from an unbounded >= 9 to a bounded 9 <= cap < 12 for is_nvidia_hopper and is_tma_supported is a precise fix that correctly differentiates between Hopper, datacenter Blackwell, and desktop Blackwell architectures. The unconditional setting of BKV_LIST simplifies the code, and the addition of parameterized unit tests ensures the new logic is robust. The changes are well-implemented and justified. I have one minor suggestion to improve a comment for better clarity.

…ktop Blackwell

SM12x GPUs (RTX 5090/5080 = SM120, DGX Spark GB10 = SM121) have
capability major=12, which triggers `>= 9` checks in FLA ops designed
for Hopper (SM90). SM12x is NOT Hopper — it lacks TMA and needs
different NUM_WARPS tuning. This causes:

1. is_nvidia_hopper = True — restricts NUM_WARPS to [2, 4], missing
   8-warp configs that give ~1.8x decode speedup on SM12x
2. is_tma_supported = True — but SM12x desktop has NO TMA hardware,
   causing Triton autotuner OOM in fla/solve_tril

Fix: use bounded range checks (9 <= cap < 12) for both is_nvidia_hopper
and is_tma_supported. Expand BKV_LIST to [32, 64, 128] unconditionally —
the Triton autotuner skips configurations exceeding available SMEM.

Tested on DGX Spark (SM121) with Qwen3.5-35B NVFP4 (TPOT 64ms -> 35ms).

Signed-off-by: Rob Tand <robert.tand@icloud.com>
@RobTand RobTand force-pushed the fix/fla-sm12x-hopper-tma-misclassification branch from 834326a to 97bce01 Compare March 20, 2026 23:47
@@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why do we need this test...

is_nvidia_hopper = is_nvidia and (
"NVIDIA H" in torch.cuda.get_device_name(0)
or torch.cuda.get_device_capability()[0] >= 9
or 9 <= torch.cuda.get_device_capability()[0] < 12
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to use Platform.is_device_capability

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good feedback I'll revise when I get to a box

is_nvidia = device_platform == "nvidia"
is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
# SM12x (desktop Blackwell: RTX 5090/5080, DGX Spark GB10) has capability
# major=12, which trips the >= 9 checks below. But SM12x is NOT Hopper —
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to write SM12x doesn't have TMA, and that it

use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
is_gather_supported = hasattr(triton.language, "gather")
is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and (
# SM12x desktop (GB10, RTX 5090) has no TMA hardware — TMA is datacenter
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was under the impression that SM12+ GPUs in fact have TMA hardware, e.g. NVIDIA RTX Pro 6000 Blackwell server edition, so I think this comment is wrong. However, they do have less shared memory per SM which might still cause this crash.

rob and others added 2 commits March 27, 2026 12:03
- Use current_platform.has_device_capability / is_device_capability_family
  instead of raw torch.cuda calls (vadiklyutiy feedback)
- Replace architecture range check for is_tma_supported with shared memory
  threshold (128KB). This correctly captures the root cause (insufficient
  SMEM for Triton autotuner, not TMA absence) and future-proofs for SM12x
  variants with datacenter-class SMEM (e.g. RTX Pro 6000)
- Rewrite tests to mock current_platform and reload the actual utils module
  rather than reimplementing the logic (gemini-code-assist feedback)
- Add test case for hypothetical SM12x with sufficient SMEM (TMA enabled)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…e functions

Module reload approach broke because def statements overwrite mocks during
reload. Instead, extract the classification logic into pure functions
(check_nvidia_hopper, check_tma_supported) that accept a platform object
and SMEM value, making them directly testable without reload tricks.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 31, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @RobTand.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 31, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working needs-rebase

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: Add support of Blackwell SM121(DGX Spark)

3 participants