diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 74bd76add920..b1be876827e8 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -96,9 +96,26 @@ def file_hash(path): return hashlib.sha256(f.read()).hexdigest() +def _has_tensor_memory(capability: int) -> bool: + """Return True for architectures with tensor memory (tcgen05). + + Only datacenter Blackwell (SM100/SM103, arch family 10) has tensor memory. + Consumer Blackwell (SM120/SM121, arch family 12) does NOT have tensor memory + and must not use the tensor memory pipeline or the "a" arch suffix. + """ + arch_family = capability // 10 + return arch_family == 10 + + def sm_arch_from_capability(capability: int): - # TODO: Handle non-"a" sms - suffix = "a" if capability >= 90 else "" + # SM120/SM121 (consumer Blackwell) lack tensor memory features + # that the "a" suffix enables. Only give "a" to SM >= 90 that + # actually have the corresponding accelerator features. + arch_family = capability // 10 + if capability >= 90 and arch_family != 12: + suffix = "a" + else: + suffix = "" return f"sm_{capability}{suffix}" @@ -285,7 +302,7 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_assign_latencies(pm, opt.num_stages) passes.ttgpuir.add_schedule_loops(pm) passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled) - elif capability // 10 >= 10: + elif _has_tensor_memory(capability): passes.ttgpuir.add_fuse_nested_loops(pm) passes.common.add_canonicalizer(pm) passes.ttir.add_triton_licm(pm) @@ -467,7 +484,7 @@ def make_ptx(self, src, metadata, opt, capability): # post-process ptx_version = f'{ptx_version//10}.{ptx_version%10}' ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE) - ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE) + ret = re.sub(r'\.target sm_\d+a?', f'.target sm_{capability}', ret, flags=re.MULTILINE) if not knobs.compilation.dump_ir_extract_di_local_variables: # Remove the debug flag that prevents ptxas from optimizing the code # Note: if this flag is removed, the source var name and type info will be lost when ptx was compiled into cubin