Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions tilelang/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,21 +405,17 @@ def get_target_compute_version(target=None):
# 2. Target.current()
target = target or Target.current()
if target and target.arch:
arch = target.arch.split("_")[1]
arch = target.arch.split("_")[1].rstrip("af")
if len(arch) == 2:
major, minor = arch
# Handle old format like sm_89
return major + "." + minor
elif len(arch) == 3:
major = int(arch[0])
if major < 2:
major = arch[0:2]
minor = arch[2]
return major + "." + minor
else:
# This is for arch like "sm_90a"
major, minor, suffix = arch
return major + "." + minor + "." + suffix
major = arch[0:2]
minor = arch[2]
return major + "." + minor
else:
raise ValueError(f"Unsupported arch: {arch}")

# 3. GPU compute version
if tvm.cuda(0).exist:
Expand Down Expand Up @@ -453,8 +449,11 @@ def parse_compute_version(compute_version) -> tuple[int, int]:
raise RuntimeError("Compute version parsing error") from err


def get_target_arch(compute_version) -> str:
major, minor = parse_compute_version(compute_version)
def get_target_arch(compute_version: str | tuple[int, int]) -> str:
if isinstance(compute_version, str):
major, minor = parse_compute_version(compute_version)
else:
major, minor = compute_version
target_arch = str(major * 10 + minor)
if major >= 9:
target_arch += "a"
Expand Down
9 changes: 8 additions & 1 deletion tilelang/utils/target.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from __future__ import annotations


import torch

from platform import mac_ver
from typing import Literal
from tilelang import tvm as tvm
Expand Down Expand Up @@ -111,7 +115,10 @@ def determine_target(target: str | Target | Literal["auto"] = "auto", return_obj

# Determine the target based on availability
if is_cuda_available:
return_var = "cuda"
if torch.cuda.is_available() and (cap := torch.cuda.get_device_capability(0)):
return_var = Target({"kind": "cuda", "arch": f"sm_{nvcc.get_target_arch(cap)}"})
else:
return_var = "cuda"
elif is_hip_available:
return_var = "hip"
elif check_metal_availability():
Expand Down
Loading