Skip to content

Fix SM120/SM121 (consumer Blackwell) codegen: no tensor memory pipeline#9852

Closed
mihai-chiorean wants to merge 2 commits intotriton-lang:mainfrom
mihai-chiorean:fix/sm121-consumer-blackwell-support
Closed

Fix SM120/SM121 (consumer Blackwell) codegen: no tensor memory pipeline#9852
mihai-chiorean wants to merge 2 commits intotriton-lang:mainfrom
mihai-chiorean:fix/sm121-consumer-blackwell-support

Conversation

@mihai-chiorean
Copy link
Copy Markdown

Summary

Consumer Blackwell GPUs (SM120/SM121 — DGX Spark, RTX 5090) lack tensor memory (tcgen05) hardware present in datacenter Blackwell (SM100/SM103). The compiler routes SM120/SM121 through the datacenter Blackwell pipeline, generating tensor memory instructions that cause illegal instruction crashes at runtime.

Previous fix attempt (#9734) was reverted (#9755) because it incorrectly claimed sm_120a isn't a valid arch. It IS valid — the real issue is the pipeline routing, not the suffix. This PR addresses the actual root cause.

Changes

  1. _has_tensor_memory() helper — returns True only for SM100/SM103 (arch family 10). SM120/SM121 (arch family 12) return False.
  2. Pipeline routing — SM120/SM121 use the Hopper-like pipeline (MMAv2, no tensor memory passes) instead of the datacenter Blackwell pipeline (add_hoist_tmem_alloc, add_promote_lhs_to_tmem).
  3. PTX .target regex — handles optional a suffix (\.target sm_\d+a?).
  4. Arch suffix — SM120/SM121 get no a suffix since they lack the accelerator features it implies.

Test Results (DGX Spark GB10, SM121, aarch64, CUDA 13.1)

Test Result
Simple vector add kernel PASS
FP16 matmul with tl.dot PASS (max error: 0.000004)
Qwen3-Next fused_qkvzba_split_reshape_cat_kernel PASS

Simple Triton kernels already worked on SM121 (they don't trigger tensor memory codegen). Complex kernels using tl.dot or 2D tensor operations crashed because the datacenter pipeline generated tensor memory instructions.

Why the Previous Fix Was Wrong

PR #9734 removed the a suffix for SM120, claiming sm_120a isn't valid. NVIDIA engineers objected — sm_120a IS a valid arch string (CUTLASS uses it). The PR was reverted.

This PR takes a different approach: the a suffix is secondary. The primary fix is routing SM120/SM121 away from the tensor memory pipeline. Even if we kept sm_120a, the pipeline routing fix alone would prevent the illegal instruction crash.

Impact

Unblocks ALL Triton-dependent models on DGX Spark (GB10) and RTX 5090 (SM120), including:

  • nvidia/Qwen3-Next-80B-A3B-Thinking-NVFP4
  • Any model using custom Triton kernels with tl.dot or tensor operations
  • PyTorch models compiled via torch.compile on SM121

Fixes: #9181, #8539, #8335
Related: #9734 (reverted), PyTorch #176426, vLLM #31128

Consumer Blackwell GPUs (SM120/SM121 — DGX Spark, RTX 5090) lack
tensor memory (tcgen05) hardware present in datacenter Blackwell
(SM100/SM103). The compiler was routing SM120/SM121 through the
datacenter Blackwell pipeline, generating tensor memory instructions
that cause illegal instruction crashes at runtime.

Changes:
- Add _has_tensor_memory() helper: returns True only for SM100/SM103
  (arch family 10), False for SM120/SM121 (arch family 12)
- Route SM120/SM121 to the Hopper-like pipeline (MMAv2, no tmem)
  instead of the datacenter Blackwell pipeline
- Fix .target regex in make_ptx to handle optional "a" suffix
- SM120/SM121 get no "a" suffix in arch string since they lack
  the accelerator features it implies

Tested on DGX Spark GB10 (SM121, aarch64, CUDA 13.1):
- Simple vector add kernel: PASS
- FP16 matmul with tl.dot: PASS (max error: 0.000004)
- Qwen3-Next fused_qkvzba_split_reshape_cat_kernel: PASS

Fixes: triton-lang#9181, triton-lang#8539, triton-lang#8335
Related: triton-lang#9734 (reverted — addressed suffix but not pipeline routing)

Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
Signed-off-by: Mihai Chiorean <mihai-chiorean@users.noreply.github.com>
Copy link
Copy Markdown
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as far as I know people have been using Triton for sm_120 successfully so not sure what is the problem. The patch seems to be the same as what was reverted

Comment on lines +111 to +118
# SM120/SM121 (consumer Blackwell) lack tensor memory features
# that the "a" suffix enables. Only give "a" to SM >= 90 that
# actually have the corresponding accelerator features.
arch_family = capability // 10
if capability >= 90 and arch_family != 12:
suffix = "a"
else:
suffix = ""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we revert the patch because this was incorrect so not sure why you are adding it back

passes.ttgpuir.add_schedule_loops(pm)
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
elif capability // 10 >= 10:
elif _has_tensor_memory(capability):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is right, going through that path for sm_120 shoul be fine

@mihai-chiorean
Copy link
Copy Markdown
Author

Thomas, you're right on both points. I ran systematic tests on an actual SM121 (NVIDIA GB10 / DGX Spark) with Triton 3.5.1 and CUDA 13.1, and I can confirm:

  1. Arch suffix: sm_121a is a valid ptxas target and works correctly.
  2. Pipeline routing: The Blackwell codegen path works fine on SM121 — even when tensor memory passes emit tcgen05 PTX, ptxas correctly lowers them to HMMA in the final SASS. Zero TCGEN instructions in the machine code. Tested matmul at various tile sizes (up to 128x128x64 with 4 stages), autotuned configs, and FLA-style chunk kernels with tl.dot + tl.trans. All passed.

My original diagnosis was incorrect. Closing this PR. Thanks for the quick review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

TorchInductor / Triton fails on NVIDIA Blackwell (GB10, cc 12.1): ptxas fatal: sm_121a is not defined

2 participants