From 456da33e3c62530fb138ca094762833481104a6e Mon Sep 17 00:00:00 2001 From: Pat Date: Mon, 16 Mar 2026 11:21:30 -0500 Subject: [PATCH 1/3] [NVIDIA] Fix PTX codegen segfaults on consumer Blackwell (sm_120) Fix three bugs causing non-deterministic SIGSEGV on RTX 5070 Ti / 5080 / 5090 GPUs (SM 12.0, compute capability 12.0) when using torch.compile or any Triton-compiled kernel. Root cause: sm_arch_from_capability(120) returned "sm_120a", but consumer Blackwell has no "a" variant. The "a" suffix is only valid for datacenter GPUs (sm_90a = H100, sm_100a = B100/B200). Passing "sm_120a" to LLVM and ptxas caused instruction selection for features (tensor memory, tcgen05) that do not exist on consumer hardware, producing invalid machine code that segfaults at runtime. Changes: 1. sm_arch_from_capability: Only add "a" suffix for 90 <= cap < 120, not for all cap >= 90. Resolves the TODO comment. 2. make_ptx: Fix .target regex (sm_\d+ -> sm_\d+a?) so the "a" suffix is correctly handled in PTX post-processing. 3. make_ttgir: Route sm_120 through the Hopper pipeline instead of the datacenter Blackwell pipeline. Consumer Blackwell uses MMAv2 (no tensor memory, no MMAv5), matching the Hopper/Ampere feature set. Tested on RTX 5070 Ti (SM 12.0) with PyTorch 2.9.1 + Triton 3.5.1: - 700+ torch.compile training steps with zero segfaults - Triton elementwise and matmul kernels produce correct results - Previously segfaulted within ~100 steps non-deterministically Fixes: pytorch/pytorch#176426 Co-Authored-By: Claude Opus 4.6 (1M context) --- .../test/unit/language/test_compile_only.py | 44 +++++++++++++++++++ third_party/nvidia/backend/compiler.py | 22 +++++++--- 2 files changed, 61 insertions(+), 5 deletions(-) diff --git a/python/test/unit/language/test_compile_only.py b/python/test/unit/language/test_compile_only.py index 1154fb3b8968..7e7f1867de01 100644 --- a/python/test/unit/language/test_compile_only.py +++ b/python/test/unit/language/test_compile_only.py @@ -220,3 +220,47 @@ def fp8_convert(src, dst): src = ASTSource(fn=fp8_convert, signature={"src": "*fp32", "dst": "*fp8e5"}, constexprs={}) triton.compile(src, target=GPUTarget("cuda", 90, 32)) triton.compile(src, target=GPUTarget("cuda", 80, 32)) + + +def test_sm_arch_from_capability(): + """Verify that sm_arch_from_capability generates correct arch strings. + + Consumer Blackwell (sm_120, e.g. RTX 5070 Ti) must NOT get the "a" suffix. + Using sm_120a causes LLVM/ptxas to generate tensor memory instructions + that don't exist on consumer hardware, producing runtime segfaults. + """ + from triton.backends.nvidia.compiler import sm_arch_from_capability + # Pre-Hopper: no suffix + assert sm_arch_from_capability(80) == "sm_80" + assert sm_arch_from_capability(89) == "sm_89" + # Hopper datacenter: "a" suffix + assert sm_arch_from_capability(90) == "sm_90a" + # Blackwell datacenter: "a" suffix + assert sm_arch_from_capability(100) == "sm_100a" + # Consumer Blackwell: NO "a" suffix (critical for RTX 5070 Ti/5080/5090) + assert sm_arch_from_capability(120) == "sm_120" + + +def test_compile_only_sm120() -> None: + """Verify that sm_120 (consumer Blackwell) compiles with correct target. + + sm_120 must use '.target sm_120' (no 'a' suffix) and must NOT + contain tensor memory instructions (tcgen05) since consumer + Blackwell hardware lacks tensor memory. + """ + + @triton.jit + def kernel_add(a, b, c): + idx = tl.arange(0, 32) + tl.store(c + idx, tl.load(a + idx) + tl.load(b + idx)) + + k = triton.compile( + triton.compiler.ASTSource(fn=kernel_add, signature={"a": "*fp32", "b": "*fp32", "c": "*fp32"}, constexprs={}), + target=GPUTarget("cuda", 120, 32)) + ptx = k.asm["ptx"] + # Must target sm_120 (no "a" suffix) + assert ".target sm_120" in ptx + assert ".target sm_120a" not in ptx + # Must not contain tensor memory instructions (consumer Blackwell lacks tmem) + assert "tcgen05" not in ptx + assert k.asm["cubin"] != b"" diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index b34f1904befd..993f4c91bbb8 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -97,8 +97,15 @@ def file_hash(path): def sm_arch_from_capability(capability: int): - # TODO: Handle non-"a" sms - suffix = "a" if capability >= 90 else "" + # The "a" suffix enables arch-accelerated features only available on + # specific GPU implementations: + # sm_90a — Hopper datacenter (H100, H200) + # sm_100a — Blackwell datacenter (B100, B200) + # Consumer Blackwell (sm_120, e.g. RTX 5070 Ti/5080/5090) does NOT + # have an "a" variant — using sm_120a causes invalid codegen (tensor + # memory instructions that don't exist on consumer hardware), leading + # to runtime segfaults. + suffix = "a" if 90 <= capability < 120 else "" return f"sm_{capability}{suffix}" @@ -269,7 +276,11 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm) passes.ttir.add_loop_aware_cse(pm) - if capability // 10 in [8, 9]: + if capability // 10 in [8, 9] or capability >= 120: + # Ampere (sm_8x), Hopper (sm_9x), consumer Blackwell (sm_12x+). + # Consumer Blackwell uses MMAv2 (no tensor memory, no MMAv5), + # so it shares the Hopper pipeline rather than the datacenter + # Blackwell pipeline which assumes tensor memory support. passes.ttgpuir.add_fuse_nested_loops(pm) passes.common.add_canonicalizer(pm) passes.ttir.add_triton_licm(pm) @@ -279,7 +290,8 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_assign_latencies(pm, opt.num_stages) passes.ttgpuir.add_schedule_loops(pm) passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled) - elif capability // 10 >= 10: + elif 100 <= capability < 120: + # Datacenter Blackwell (sm_10x, sm_11x) — has tensor memory. passes.ttgpuir.add_fuse_nested_loops(pm) passes.common.add_canonicalizer(pm) passes.ttir.add_triton_licm(pm) @@ -458,7 +470,7 @@ def make_ptx(self, src, metadata, opt, capability): # post-process ptx_version = f'{ptx_version//10}.{ptx_version%10}' ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE) - ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE) + ret = re.sub(r'\.target sm_\d+a?', f'.target {proc}', ret, flags=re.MULTILINE) if not knobs.compilation.dump_ir_extract_di_local_variables: # Remove the debug flag that prevents ptxas from optimizing the code # Note: if this flag is removed, the source var name and type info will be lost when ptx was compiled into cubin From 7585f97f5a0d1c38e8d19d53d5e458b230e8909a Mon Sep 17 00:00:00 2001 From: Pat Date: Mon, 16 Mar 2026 11:31:12 -0500 Subject: [PATCH 2/3] [NVIDIA] Use tl.dot kernel in sm_120 test to exercise matmul pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous test used an elementwise add kernel, which never emits tcgen05 instructions regardless of target — making the assertion vacuous. Replace with a tl.dot matmul kernel (same as the sm_100 test) and also verify the TTGIR has no tmem_alloc or tc_gen5_mma ops. This ensures a regression that routes sm_120 dot operations through the datacenter Blackwell tensor memory pipeline will be caught. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../test/unit/language/test_compile_only.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/python/test/unit/language/test_compile_only.py b/python/test/unit/language/test_compile_only.py index 7e7f1867de01..acdee7ab6da7 100644 --- a/python/test/unit/language/test_compile_only.py +++ b/python/test/unit/language/test_compile_only.py @@ -244,23 +244,33 @@ def test_sm_arch_from_capability(): def test_compile_only_sm120() -> None: """Verify that sm_120 (consumer Blackwell) compiles with correct target. - sm_120 must use '.target sm_120' (no 'a' suffix) and must NOT - contain tensor memory instructions (tcgen05) since consumer - Blackwell hardware lacks tensor memory. + Uses a tl.dot kernel (not just elementwise) to exercise the matmul + pipeline and confirm that tensor memory / tcgen05 instructions are + NOT generated for consumer Blackwell, which lacks tensor memory. """ @triton.jit - def kernel_add(a, b, c): - idx = tl.arange(0, 32) - tl.store(c + idx, tl.load(a + idx) + tl.load(b + idx)) + def simple_dot(a_base, b_base, out): + SIZE: tl.constexpr = 64 + a_ptr = a_base + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :] + b_ptr = b_base + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :] + a = tl.load(a_ptr) + b = tl.load(b_ptr) + c = tl.dot(a, b) + out_ptr = out + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :] + tl.store(out_ptr, c) k = triton.compile( - triton.compiler.ASTSource(fn=kernel_add, signature={"a": "*fp32", "b": "*fp32", "c": "*fp32"}, constexprs={}), - target=GPUTarget("cuda", 120, 32)) + triton.compiler.ASTSource(fn=simple_dot, signature={"a_base": "*fp16", "b_base": "*fp16", "out": "*fp16"}, + constexprs={}), target=GPUTarget("cuda", 120, 32)) ptx = k.asm["ptx"] # Must target sm_120 (no "a" suffix) assert ".target sm_120" in ptx assert ".target sm_120a" not in ptx - # Must not contain tensor memory instructions (consumer Blackwell lacks tmem) + # Matmul must NOT use tensor memory or tcgen05 (consumer Blackwell lacks tmem). + # This is the key assertion — sm_100 dot uses tcgen05/tmem, sm_120 must not. assert "tcgen05" not in ptx + ttgir = k.asm["ttgir"] + assert "ttng.tmem_alloc" not in str(ttgir) + assert "ttng.tc_gen5_mma" not in str(ttgir) assert k.asm["cubin"] != b"" From f63597e835b83b723a5404ecbda09902b91d3a1b Mon Sep 17 00:00:00 2001 From: Pat Date: Mon, 16 Mar 2026 12:28:27 -0500 Subject: [PATCH 3/3] Address review: keep pipeline/regex unchanged, exclude 120 explicitly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per ThomasRaoux's review: - Use `capability != 120` instead of `< 120` so future architectures still get the "a" suffix by default. - Revert pipeline routing change — tmem passes are no-ops for sm_120 since AccelerateMatmul already selects MMAv2. - Revert regex change — unnecessary with the arch string fix in place. Co-Authored-By: Claude Opus 4.6 (1M context) --- third_party/nvidia/backend/compiler.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 993f4c91bbb8..a657bf88bbab 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -105,7 +105,7 @@ def sm_arch_from_capability(capability: int): # have an "a" variant — using sm_120a causes invalid codegen (tensor # memory instructions that don't exist on consumer hardware), leading # to runtime segfaults. - suffix = "a" if 90 <= capability < 120 else "" + suffix = "a" if capability >= 90 and capability != 120 else "" return f"sm_{capability}{suffix}" @@ -276,11 +276,7 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm) passes.ttir.add_loop_aware_cse(pm) - if capability // 10 in [8, 9] or capability >= 120: - # Ampere (sm_8x), Hopper (sm_9x), consumer Blackwell (sm_12x+). - # Consumer Blackwell uses MMAv2 (no tensor memory, no MMAv5), - # so it shares the Hopper pipeline rather than the datacenter - # Blackwell pipeline which assumes tensor memory support. + if capability // 10 in [8, 9]: passes.ttgpuir.add_fuse_nested_loops(pm) passes.common.add_canonicalizer(pm) passes.ttir.add_triton_licm(pm) @@ -290,8 +286,7 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_assign_latencies(pm, opt.num_stages) passes.ttgpuir.add_schedule_loops(pm) passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled) - elif 100 <= capability < 120: - # Datacenter Blackwell (sm_10x, sm_11x) — has tensor memory. + elif capability // 10 >= 10: passes.ttgpuir.add_fuse_nested_loops(pm) passes.common.add_canonicalizer(pm) passes.ttir.add_triton_licm(pm) @@ -470,7 +465,7 @@ def make_ptx(self, src, metadata, opt, capability): # post-process ptx_version = f'{ptx_version//10}.{ptx_version%10}' ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE) - ret = re.sub(r'\.target sm_\d+a?', f'.target {proc}', ret, flags=re.MULTILINE) + ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE) if not knobs.compilation.dump_ir_extract_di_local_variables: # Remove the debug flag that prevents ptxas from optimizing the code # Note: if this flag is removed, the source var name and type info will be lost when ptx was compiled into cubin