Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,26 @@ def file_hash(path):
return hashlib.sha256(f.read()).hexdigest()


def _has_tensor_memory(capability: int) -> bool:
"""Return True for architectures with tensor memory (tcgen05).

Only datacenter Blackwell (SM100/SM103, arch family 10) has tensor memory.
Consumer Blackwell (SM120/SM121, arch family 12) does NOT have tensor memory
and must not use the tensor memory pipeline or the "a" arch suffix.
"""
arch_family = capability // 10
return arch_family == 10


def sm_arch_from_capability(capability: int):
# TODO: Handle non-"a" sms
suffix = "a" if capability >= 90 else ""
# SM120/SM121 (consumer Blackwell) lack tensor memory features
# that the "a" suffix enables. Only give "a" to SM >= 90 that
# actually have the corresponding accelerator features.
arch_family = capability // 10
if capability >= 90 and arch_family != 12:
suffix = "a"
else:
suffix = ""
Comment on lines +111 to +118
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we revert the patch because this was incorrect so not sure why you are adding it back

return f"sm_{capability}{suffix}"


Expand Down Expand Up @@ -285,7 +302,7 @@ def make_ttgir(mod, metadata, opt, capability):
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
passes.ttgpuir.add_schedule_loops(pm)
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
elif capability // 10 >= 10:
elif _has_tensor_memory(capability):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is right, going through that path for sm_120 shoul be fine

passes.ttgpuir.add_fuse_nested_loops(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_triton_licm(pm)
Expand Down Expand Up @@ -467,7 +484,7 @@ def make_ptx(self, src, metadata, opt, capability):
# post-process
ptx_version = f'{ptx_version//10}.{ptx_version%10}'
ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
ret = re.sub(r'\.target sm_\d+a?', f'.target sm_{capability}', ret, flags=re.MULTILINE)
if not knobs.compilation.dump_ir_extract_di_local_variables:
# Remove the debug flag that prevents ptxas from optimizing the code
# Note: if this flag is removed, the source var name and type info will be lost when ptx was compiled into cubin
Expand Down