diff --git a/docs/get_started/targets.md b/docs/get_started/targets.md new file mode 100644 index 000000000..c2b3f2fb5 --- /dev/null +++ b/docs/get_started/targets.md @@ -0,0 +1,120 @@ +# Understanding Targets + +TileLang is built on top of TVM, which relies on **targets** to describe the device you want to compile for. +The target determines which code generator is used (CUDA, HIP, Metal, LLVM, …) and allows you to pass +device-specific options such as GPU architecture flags. This page summarises how to pick and customise a target +when compiling TileLang programs. + +## Common target strings + +TileLang ships with a small set of common targets; each accepts the full range of TVM options so you can fine-tune +the generated code. The most frequent choices are listed below: + +| Base name | Description | +| --------- | ----------- | +| `auto` | Detects CUDA → HIP → Metal in that order. Useful when running the same script across machines. | +| `cuda` | NVIDIA GPUs. Supports options such as `-arch=sm_80`, `-max_num_threads=1024`, etc. | +| `hip` | AMD GPUs via ROCm. Options like `-mcpu=gfx90a` can be appended. | +| `metal` | Apple Silicon GPUs (arm64 Macs). | +| `llvm` | CPU execution; accepts the standard TVM LLVM switches. | +| `webgpu` | Browser / WebGPU runtimes. | +| `c` | Emit plain C source for inspection or custom toolchains. | + +To add options, append them after the base name, separated by spaces. For example: + +```python +target = "cuda -arch=sm_90" +kernel = tilelang.compile(func, target=target, execution_backend="cython") +# or +@tilelang.jit(target=target) +def compiled_kernel(*args): + return func(*args) +``` + +The same convention works for HIP or LLVM (e.g. `hip -mcpu=gfx940`, `llvm -mtriple=x86_64-linux-gnu`). + +### Advanced: Specify Exact Hardware + +When you already know the precise GPU model, you can encode it in the target string—either via `-arch=sm_XX` or by +using one of TVM’s pre-defined target tags such as `nvidia/nvidia-h100`. Supplying this detail is optional for +TileLang in general use, but it becomes valuable when the TVM cost model is enabled (e.g. during autotuning). The +cost model uses the extra attributes to make better scheduling predictions. If you skip this step (or do not use the +cost model), generic targets like `cuda` or `auto` are perfectly fine. + +All CUDA compute capabilities recognised by TVM’s target registry are listed below. Pick the one that matches your +GPU and append it to the target string or use the corresponding target tag—for example `nvidia/nvidia-a100`. + +| Architecture | GPUs (examples) | +| ------------ | ---------------- | +| `sm_20` | `nvidia/tesla-c2050`, `nvidia/tesla-c2070` | +| `sm_21` | `nvidia/nvs-5400m`, `nvidia/geforce-gt-520` | +| `sm_30` | `nvidia/quadro-k5000`, `nvidia/geforce-gtx-780m` | +| `sm_35` | `nvidia/tesla-k40`, `nvidia/quadro-k6000` | +| `sm_37` | `nvidia/tesla-k80` | +| `sm_50` | `nvidia/quadro-k2200`, `nvidia/geforce-gtx-950m` | +| `sm_52` | `nvidia/tesla-m40`, `nvidia/geforce-gtx-980` | +| `sm_53` | `nvidia/jetson-tx1`, `nvidia/jetson-nano` | +| `sm_60` | `nvidia/tesla-p100`, `nvidia/quadro-gp100` | +| `sm_61` | `nvidia/tesla-p4`, `nvidia/quadro-p6000`, `nvidia/geforce-gtx-1080` | +| `sm_62` | `nvidia/jetson-tx2` | +| `sm_70` | `nvidia/nvidia-v100`, `nvidia/quadro-gv100` | +| `sm_72` | `nvidia/jetson-agx-xavier` | +| `sm_75` | `nvidia/nvidia-t4`, `nvidia/quadro-rtx-8000`, `nvidia/geforce-rtx-2080` | +| `sm_80` | `nvidia/nvidia-a100`, `nvidia/nvidia-a30` | +| `sm_86` | `nvidia/nvidia-a40`, `nvidia/nvidia-a10`, `nvidia/geforce-rtx-3090` | +| `sm_87` | `nvidia/jetson-agx-orin-32gb`, `nvidia/jetson-agx-orin-64gb` | +| `sm_89` | `nvidia/geforce-rtx-4090` | +| `sm_90a` | `nvidia/nvidia-h100` (DPX profile) | +| `sm_100a` | `nvidia/nvidia-b100` | + +Refer to NVIDIA’s [CUDA GPUs](https://developer.nvidia.com/cuda-gpus) page or the TVM source +(`3rdparty/tvm/src/target/tag.cc`) for the latest mapping between devices and compute capabilities. + +## Creating targets programmatically + +If you prefer working with TVM’s `Target` objects, TileLang exposes the helper +`tilelang.utils.target.determine_target` (returns a canonical target string by default, or the `Target` +object when `return_object=True`): + +```python +from tilelang.utils.target import determine_target + +tvm_target = determine_target("cuda -arch=sm_80", return_object=True) +kernel = tilelang.compile(func, target=tvm_target) +``` + +You can also build targets directly through TVM: + +```python +from tvm.target import Target + +target = Target("cuda", host="llvm") +target = target.with_host(Target("llvm -mcpu=skylake")) +``` + +TileLang accepts either `str` or `Target` inputs; internally they are normalised and cached using the canonical +string representation. **In user code we strongly recommend passing target strings rather than +`tvm.target.Target` instances—strings keep cache keys compact and deterministic across runs, whereas constructing +fresh `Target` objects may lead to slightly higher hashing overhead or inconsistent identity semantics.** + +## Discovering supported targets in code + +Looking for a quick reminder of the built-in base names and their descriptions? Use: + +```python +from tilelang.utils.target import describe_supported_targets + +for name, doc in describe_supported_targets().items(): + print(f"{name:>6}: {doc}") +``` + +This helper mirrors the table above and is safe to call at runtime (for example when validating CLI arguments). + +## Troubleshooting tips + +- If you see `Target cuda -arch=sm_80 is not supported`, double-check the spellings and that the option is valid for + TVM. Any invalid switch will surface as a target-construction error. +- Runtime errors such as “no kernel image is available” usually mean the `-arch` flag does not match the GPU you are + running on. Try dropping the flag or switching to the correct compute capability. +- When targeting multiple environments, use `auto` for convenience and override with an explicit string only when + you need architecture-specific tuning. diff --git a/docs/index.md b/docs/index.md index 8380bb0de..5d9a158f8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -14,6 +14,7 @@ low-level optimizations necessary for state-of-the-art performance. get_started/Installation get_started/overview +get_started/targets ::: diff --git a/pyproject.toml b/pyproject.toml index 6214711a4..daa30406b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,6 @@ column_limit = 100 indent_width = 4 [tool.codespell] -builtin = "clear,rare,en-GB_to_en-US" ignore-words = "docs/spelling_wordlist.txt" skip = [ "build", diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index 647cc5bd7..64fc7bdf1 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -11,7 +11,7 @@ from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, NVRTCKernelAdapter, TorchDLPackKernelAdapter, MetalKernelAdapter) from tilelang.profiler import Profiler, TensorSupplyType -from tilelang.utils.target import AVALIABLE_TARGETS, determine_target +from tilelang.utils.target import determine_target import logging logger = logging.getLogger(__name__) @@ -90,13 +90,8 @@ def __init__( self.compile_flags = compile_flags - # If the target is specified as a string, validate it and convert it to a TVM Target. - if isinstance(target, str): - assert target in AVALIABLE_TARGETS, f"Invalid target: {target}" - target = determine_target(target) - - # Ensure the target is always a TVM Target object. - self.target = Target(target) + # Ensure the target is always a valid TVM Target object. + self.target = determine_target(target, return_object=True) # Validate the execution backend. assert execution_backend in [ diff --git a/tilelang/utils/target.py b/tilelang/utils/target.py index ee132649c..948308b81 100644 --- a/tilelang/utils/target.py +++ b/tilelang/utils/target.py @@ -1,22 +1,29 @@ from platform import mac_ver -from typing import Literal, Union +from typing import Dict, Literal, Union from tilelang import tvm as tvm from tilelang import _ffi_api from tvm.target import Target from tvm.contrib import rocm from tilelang.contrib import nvcc -AVALIABLE_TARGETS = { - "auto", - "cuda", - "hip", - "webgpu", - "c", # represent c source backend - "llvm", - "metal", +SUPPORTED_TARGETS: Dict[str, str] = { + "auto": "Auto-detect CUDA/HIP/Metal based on availability.", + "cuda": "CUDA GPU target (supports options such as `cuda -arch=sm_80`).", + "hip": "ROCm HIP target (supports options like `hip -mcpu=gfx90a`).", + "metal": "Apple Metal target for arm64 Macs.", + "llvm": "LLVM CPU target (accepts standard TVM LLVM options).", + "webgpu": "WebGPU target for browser/WebGPU runtimes.", + "c": "C source backend.", } +def describe_supported_targets() -> Dict[str, str]: + """ + Return a mapping of supported target names to usage descriptions. + """ + return dict(SUPPORTED_TARGETS) + + def check_cuda_availability() -> bool: """ Check if CUDA is available on the system by locating the CUDA path. @@ -90,11 +97,27 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto", raise ValueError("No CUDA or HIP or MPS available on this system.") else: # Validate the target if it's not "auto" - assert isinstance( - target, Target) or target in AVALIABLE_TARGETS, f"Target {target} is not supported" - return_var = target + if isinstance(target, Target): + return_var = target + elif isinstance(target, str): + normalized_target = target.strip() + if not normalized_target: + raise AssertionError(f"Target {target} is not supported") + try: + Target(normalized_target) + except Exception as err: + examples = ", ".join(f"`{name}`" for name in SUPPORTED_TARGETS) + raise AssertionError( + f"Target {target} is not supported. Supported targets include: {examples}. " + "Pass additional options after the base name, e.g. `cuda -arch=sm_80`." + ) from err + return_var = normalized_target + else: + raise AssertionError(f"Target {target} is not supported") if return_object: + if isinstance(return_var, Target): + return return_var return Target(return_var) return return_var