[Bugfix] Fix FLA Hopper/TMA misclassification on SM12x desktop Blackwell#37700
[Bugfix] Fix FLA Hopper/TMA misclassification on SM12x desktop Blackwell#37700RobTand wants to merge 7 commits intovllm-project:mainfrom
Conversation
336d9f5 to
013c1ce
Compare
There was a problem hiding this comment.
Code Review
This pull request correctly addresses a misclassification issue for desktop Blackwell GPUs (SM12x) by adjusting the capability checks for Hopper and TMA support. The changes in utils.py and chunk_o.py are logical and well-explained. However, the newly added unit test in test_fla_sm12x_capability.py is fundamentally flawed: it contains incorrect test data that contradicts both the implementation and the PR's description, and it reimplements logic instead of testing the module directly. This makes the test fail for valid cases and provides a false sense of correctness. I've added a critical review comment detailing the issues in the test and how to fix them.
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
48f0685 to
0f5fcd1
Compare
0f5fcd1 to
9e5e0eb
Compare
There was a problem hiding this comment.
Code Review
This pull request correctly addresses the misclassification of SM12x desktop Blackwell GPUs by refining the capability checks. The change from an unbounded >= 9 to a bounded 9 <= cap < 12 for is_nvidia_hopper and is_tma_supported is a precise fix that correctly differentiates between Hopper, datacenter Blackwell, and desktop Blackwell architectures. The unconditional setting of BKV_LIST simplifies the code, and the addition of parameterized unit tests ensures the new logic is robust. The changes are well-implemented and justified. I have one minor suggestion to improve a comment for better clarity.
…ktop Blackwell SM12x GPUs (RTX 5090/5080 = SM120, DGX Spark GB10 = SM121) have capability major=12, which triggers `>= 9` checks in FLA ops designed for Hopper (SM90). SM12x is NOT Hopper — it lacks TMA and needs different NUM_WARPS tuning. This causes: 1. is_nvidia_hopper = True — restricts NUM_WARPS to [2, 4], missing 8-warp configs that give ~1.8x decode speedup on SM12x 2. is_tma_supported = True — but SM12x desktop has NO TMA hardware, causing Triton autotuner OOM in fla/solve_tril Fix: use bounded range checks (9 <= cap < 12) for both is_nvidia_hopper and is_tma_supported. Expand BKV_LIST to [32, 64, 128] unconditionally — the Triton autotuner skips configurations exceeding available SMEM. Tested on DGX Spark (SM121) with Qwen3.5-35B NVFP4 (TPOT 64ms -> 35ms). Signed-off-by: Rob Tand <robert.tand@icloud.com>
834326a to
97bce01
Compare
| @@ -0,0 +1,61 @@ | |||
| # SPDX-License-Identifier: Apache-2.0 | |||
There was a problem hiding this comment.
I am not sure why do we need this test...
| is_nvidia_hopper = is_nvidia and ( | ||
| "NVIDIA H" in torch.cuda.get_device_name(0) | ||
| or torch.cuda.get_device_capability()[0] >= 9 | ||
| or 9 <= torch.cuda.get_device_capability()[0] < 12 |
There was a problem hiding this comment.
It is better to use Platform.is_device_capability
There was a problem hiding this comment.
Good feedback I'll revise when I get to a box
| is_nvidia = device_platform == "nvidia" | ||
| is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0) | ||
| # SM12x (desktop Blackwell: RTX 5090/5080, DGX Spark GB10) has capability | ||
| # major=12, which trips the >= 9 checks below. But SM12x is NOT Hopper — |
There was a problem hiding this comment.
I think it is better to write SM12x doesn't have TMA, and that it
| use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" | ||
| is_gather_supported = hasattr(triton.language, "gather") | ||
| is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and ( | ||
| # SM12x desktop (GB10, RTX 5090) has no TMA hardware — TMA is datacenter |
There was a problem hiding this comment.
I was under the impression that SM12+ GPUs in fact have TMA hardware, e.g. NVIDIA RTX Pro 6000 Blackwell server edition, so I think this comment is wrong. However, they do have less shared memory per SM which might still cause this crash.
- Use current_platform.has_device_capability / is_device_capability_family instead of raw torch.cuda calls (vadiklyutiy feedback) - Replace architecture range check for is_tma_supported with shared memory threshold (128KB). This correctly captures the root cause (insufficient SMEM for Triton autotuner, not TMA absence) and future-proofs for SM12x variants with datacenter-class SMEM (e.g. RTX Pro 6000) - Rewrite tests to mock current_platform and reload the actual utils module rather than reimplementing the logic (gemini-code-assist feedback) - Add test case for hypothetical SM12x with sufficient SMEM (TMA enabled) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…e functions Module reload approach broke because def statements overwrite mocks during reload. Instead, extract the classification logic into pure functions (check_nvidia_hopper, check_tma_supported) that accept a platform object and SMEM value, making them directly testable without reload tricks. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Summary
SM12x GPUs (RTX 5090/5080 = SM120, DGX Spark GB10 = SM121) have capability
major=12, which triggers the>= 9checks in FLA ops designed for Hopper (SM90). This causes three problems on desktop Blackwell:is_nvidia_hopper = True— restrictsNUM_WARPSto[2, 4], missing 8-warp configs that give ~1.8x decode speedupis_tma_supported = True— SM12x desktop GPUs have only ~101KB shared memory per SM, which is insufficient for the TMA code paths that the Triton autotuner tries to compile, causing OOM infla/solve_tril. (Note: SM12x may have TMA hardware — the issue is SMEM capacity, not TMA absence.)BKV_LISTmisses 128 — GB10 has 101KB SMEM (just under the 102.4KB threshold incheck_shared_mem()), but BKV=128 works at lower stage countsFix
current_platform.has_device_capability()/is_device_capability_family()instead of rawtorch.cudacalls foris_nvidia_hopperdetection (excludes SM12x family while preserving SM9x–SM11x)is_tma_supportedwith a shared memory threshold (128KB viaget_all_max_shared_mem()). This correctly captures the root cause — insufficient SMEM for Triton autotuner compilation — and future-proofs for SM12x variants with datacenter-class SMEM (e.g. RTX Pro 6000 Blackwell server edition)BKV_LIST = [32, 64, 128]unconditionally inchunk_o.py— the Triton autotuner safely skips configurations that exceed available SMEMcurrent_platformand reload the actual utils module, covering Ampere, Hopper, datacenter Blackwell, desktop Blackwell, and a hypothetical SM12x with sufficient SMEMWhy this is better than #36325
PR #36325 blanket-disables TMA on all Blackwell (
major >= 12). This PR:is_device_capability_family(120)without affecting SM10x/SM11xTesting
Tested on DGX Spark (SM121, 128 GB unified LPDDR5X):
Test plan
test_fla_sm12x_capability.py)Related PRs
is_blackwell_class()platform detection (complementary)Fixes #31128 (partial — FLA component)