diff --git a/python/test/unit/language/test_compile_only.py b/python/test/unit/language/test_compile_only.py index 1154fb3b8968..acdee7ab6da7 100644 --- a/python/test/unit/language/test_compile_only.py +++ b/python/test/unit/language/test_compile_only.py @@ -220,3 +220,57 @@ def fp8_convert(src, dst): src = ASTSource(fn=fp8_convert, signature={"src": "*fp32", "dst": "*fp8e5"}, constexprs={}) triton.compile(src, target=GPUTarget("cuda", 90, 32)) triton.compile(src, target=GPUTarget("cuda", 80, 32)) + + +def test_sm_arch_from_capability(): + """Verify that sm_arch_from_capability generates correct arch strings. + + Consumer Blackwell (sm_120, e.g. RTX 5070 Ti) must NOT get the "a" suffix. + Using sm_120a causes LLVM/ptxas to generate tensor memory instructions + that don't exist on consumer hardware, producing runtime segfaults. + """ + from triton.backends.nvidia.compiler import sm_arch_from_capability + # Pre-Hopper: no suffix + assert sm_arch_from_capability(80) == "sm_80" + assert sm_arch_from_capability(89) == "sm_89" + # Hopper datacenter: "a" suffix + assert sm_arch_from_capability(90) == "sm_90a" + # Blackwell datacenter: "a" suffix + assert sm_arch_from_capability(100) == "sm_100a" + # Consumer Blackwell: NO "a" suffix (critical for RTX 5070 Ti/5080/5090) + assert sm_arch_from_capability(120) == "sm_120" + + +def test_compile_only_sm120() -> None: + """Verify that sm_120 (consumer Blackwell) compiles with correct target. + + Uses a tl.dot kernel (not just elementwise) to exercise the matmul + pipeline and confirm that tensor memory / tcgen05 instructions are + NOT generated for consumer Blackwell, which lacks tensor memory. + """ + + @triton.jit + def simple_dot(a_base, b_base, out): + SIZE: tl.constexpr = 64 + a_ptr = a_base + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :] + b_ptr = b_base + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :] + a = tl.load(a_ptr) + b = tl.load(b_ptr) + c = tl.dot(a, b) + out_ptr = out + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :] + tl.store(out_ptr, c) + + k = triton.compile( + triton.compiler.ASTSource(fn=simple_dot, signature={"a_base": "*fp16", "b_base": "*fp16", "out": "*fp16"}, + constexprs={}), target=GPUTarget("cuda", 120, 32)) + ptx = k.asm["ptx"] + # Must target sm_120 (no "a" suffix) + assert ".target sm_120" in ptx + assert ".target sm_120a" not in ptx + # Matmul must NOT use tensor memory or tcgen05 (consumer Blackwell lacks tmem). + # This is the key assertion — sm_100 dot uses tcgen05/tmem, sm_120 must not. + assert "tcgen05" not in ptx + ttgir = k.asm["ttgir"] + assert "ttng.tmem_alloc" not in str(ttgir) + assert "ttng.tc_gen5_mma" not in str(ttgir) + assert k.asm["cubin"] != b"" diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index b34f1904befd..a657bf88bbab 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -97,8 +97,15 @@ def file_hash(path): def sm_arch_from_capability(capability: int): - # TODO: Handle non-"a" sms - suffix = "a" if capability >= 90 else "" + # The "a" suffix enables arch-accelerated features only available on + # specific GPU implementations: + # sm_90a — Hopper datacenter (H100, H200) + # sm_100a — Blackwell datacenter (B100, B200) + # Consumer Blackwell (sm_120, e.g. RTX 5070 Ti/5080/5090) does NOT + # have an "a" variant — using sm_120a causes invalid codegen (tensor + # memory instructions that don't exist on consumer hardware), leading + # to runtime segfaults. + suffix = "a" if capability >= 90 and capability != 120 else "" return f"sm_{capability}{suffix}"