Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
323c973
[Enhancement] Update examples and tests for improved type handling an…
LeiWang1999 Dec 16, 2025
7e83344
[Refactor] Update accumulation data type to float32 across examples
LeiWang1999 Dec 16, 2025
32185e6
[Refactor] Standardize data type usage across benchmark scripts
LeiWang1999 Dec 16, 2025
fed973b
[Refactor] Standardize data type usage in templates and scripts
LeiWang1999 Dec 16, 2025
32ac58d
[Refactor] Standardize data type usage in examples and benchmarks
LeiWang1999 Dec 16, 2025
415ee49
[Refactor] Import dtypes from language.v2 module
LeiWang1999 Dec 16, 2025
9fc3a7c
fix
LeiWang1999 Dec 16, 2025
6bd1656
[Refactor] Standardize data type usage across scripts
LeiWang1999 Dec 16, 2025
d6f538f
[Refactor] Update data type handling for consistency and clarity
LeiWang1999 Dec 16, 2025
a5b5660
[Enhancement] Improve data type handling and error messaging
LeiWang1999 Dec 16, 2025
7e6c1d7
Merge branch 'main' of https://github.com/tile-ai/tilelang into dtype…
LeiWang1999 Dec 16, 2025
e54cc65
[Fix] Correct boolean flag in GEMM SP test case
LeiWang1999 Dec 16, 2025
bd43896
[Refactor] Standardize data type usage across scripts
LeiWang1999 Dec 16, 2025
bc25f32
[Refactor] Standardize data type usage in various modules
LeiWang1999 Dec 16, 2025
d7e5564
[Refactor] Update argument parsing for data types in benchmarks
LeiWang1999 Dec 16, 2025
00b47e9
[Refactor] Update data type handling in benchmark and example scripts
LeiWang1999 Dec 16, 2025
f7425e9
[Refactor] Fix data type conversion in multiple scripts
LeiWang1999 Dec 16, 2025
027213d
[Refactor] Update float8 data type usage across multiple scripts
LeiWang1999 Dec 16, 2025
daf0c05
[Refactor] Enhance float8 data type handling in CUDA code generation
LeiWang1999 Dec 16, 2025
00b3c83
[Refactor] Streamline float8 data type handling in CUDA and related m…
LeiWang1999 Dec 16, 2025
565a61b
[Refactor] Remove unnecessary cache disabling in float8 example script
LeiWang1999 Dec 16, 2025
05677e6
[Refactor] Update data type usage in debug print tests
LeiWang1999 Dec 16, 2025
70cc80d
Merge branch 'main' of https://github.com/tile-ai/tilelang into dtype…
LeiWang1999 Dec 17, 2025
bdeb08d
lint fix
LeiWang1999 Dec 17, 2025
60596fd
Update function parameter types from `str` to `T.dtype` for improved …
LeiWang1999 Dec 17, 2025
7a8fb9a
Refactor `gemv_alloc_reducer` function signature for improved readabi…
LeiWang1999 Dec 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ import tilelang.language as T
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float):

@T.prim_func
def matmul_relu_kernel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len]

dtype = "float16"
accum_dtype = "float"
block_mask_dtype = "bool"
dtype = T.float16
accum_dtype = T.float32
block_mask_dtype = T.bool

def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
Expand Down
4 changes: 2 additions & 2 deletions benchmark/mamba2/benchmark_mamba_chunk_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def chunk_scan_fwd(
num_stages=2,
threads=128,
):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504

Expand Down
10 changes: 5 additions & 5 deletions benchmark/matmul/benchmark_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def get_configs(args, kwargs):
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float",
in_dtype=T.float16,
out_dtype=T.float16,
accum_dtype=T.float32,
).with_arch(arch)

func = carve_template.equivalent_function()
Expand Down Expand Up @@ -155,8 +155,8 @@ def matmul(

# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32

@T.prim_func
def main(
Expand Down
32 changes: 16 additions & 16 deletions benchmark/matmul/benchmark_matmul_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,22 @@ def tl_matmul(
enable_rasteration=False,
):
assert in_dtype in [
"float16",
"int8",
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"

micro_size_x = micro_size_y = micro_size_k = 16

if out_dtype == "int32":
if out_dtype == T.int32:
micro_size_k = 32

# This is a debug config
# chunk = 32 if in_dtype == "float16" else 64
# chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn"

block_M = block_row_warps * warp_row_tiles
Expand Down Expand Up @@ -194,9 +194,9 @@ def get_configs(args, kwargs):
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
in_dtype=T.float16,
out_dtype=T.float16,
accum_dtype=T.float16,
).with_arch(arch)

func = carve_template.equivalent_function()
Expand Down Expand Up @@ -251,9 +251,9 @@ def matmul(
M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
in_dtype=T.float16,
out_dtype=T.float16,
accum_dtype=T.float16,
with_roller=False,
block_row_warps=None,
block_col_warps=None,
Expand Down Expand Up @@ -295,9 +295,9 @@ def kernel():
args = parser.parse_args()

M, N, K = args.m, args.n, args.k
in_dtype = args.dtype
out_dtype = "float32" if in_dtype == "int8" else "float16"
accum_dtype = "float32" if in_dtype == "int8" else "float16"
in_dtype = T.dtype(args.dtype)
out_dtype = T.float32 if in_dtype == T.int8 else T.float16
accum_dtype = T.float32 if in_dtype == T.int8 else T.float16
with_roller = args.with_roller
with_roller = True
# Compute total floating-point operations
Expand Down
2 changes: 1 addition & 1 deletion benchmark/matmul/benchmark_matmul_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def main(
total_flops = 2 * M * N * K

# matmul(...) returns (best_latency, best_config, ref_latency)
best_result = matmul_sp(M, N, K, "float16", args.accum_dtype)
best_result = matmul_sp(M, N, K, T.float16, args.accum_dtype)
best_latency = best_result.latency
best_config = best_result.config
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
Expand Down
10 changes: 5 additions & 5 deletions benchmark/matmul_fp8/benchmark_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def get_configs(args, kwargs):
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float",
in_dtype=T.float16,
out_dtype=T.float16,
accum_dtype=T.float32,
).with_arch(arch)

func = carve_template.equivalent_function()
Expand Down Expand Up @@ -159,8 +159,8 @@ def matmul(

# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float8_e4m3fnuz" if torch.version.hip is not None else "float8_e4m3"
accum_dtype = "float"
dtype = T.float8_e4m3fnuz if torch.version.hip is not None else T.float8_e4m3fn
accum_dtype = T.float32

@T.prim_func
def main(
Expand Down
12 changes: 6 additions & 6 deletions docs/deeplearning_operators/elementwise.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Please note that this tutorial does not delve deeply into the design principles
## Elementwise add in TileLang

```python
def elementwise_add(N, threads=256, dtype="bfloat16"):
def elementwise_add(N, threads=256, dtype=T.bfloat16):

@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
Expand All @@ -43,7 +43,7 @@ Those familiar with CUDA programming might wonder where `threadIdx` fits into th
The program can be compiled using the following code:

```python
program = elementwise_add(1024, threads=256, dtype="bfloat16")
program = elementwise_add(1024, threads=256, dtype=T.bfloat16)
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
```
Launching the kernel is straightforward, just call it directly like a function:
Expand Down Expand Up @@ -89,7 +89,7 @@ def elementwise_add(
In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this:

```python
program = elementwise_add(T.dynamic("N"), threads=256, dtype="bfloat16")
program = elementwise_add(T.dynamic("N"), threads=256, dtype=T.bfloat16)
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
```

Expand All @@ -102,7 +102,7 @@ TileLang automatically incorporates boundary-checking conditions; however, this
When compiling the example below, let's set `N` to 2047:

```python
def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"):
def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16):

@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
Expand Down Expand Up @@ -176,7 +176,7 @@ While TileLang incorporates various optimizations for the aforementioned case, i
In such scenarios, explicitly specifying the number of elements computed per thread can help "guide" TileLang's code generation process, leading to implementations that are more closely aligned with the intended design.

```python
def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"):
def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16):

@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
Expand Down Expand Up @@ -212,7 +212,7 @@ Aha, this CUDA code aligns closely with conventional programming practices, maki
But what happens if we provide additional hints to TileLang? For instance, by explicitly specifying register copies using the `T.copy(...)` operation. The example below demonstrates a vector addition implementation. Unlike the previous examples, this code explicitly loads data into registers before performing computations.

```python
def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype="bfloat16"):
def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype=T.bfloat16):

@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
Expand Down
18 changes: 9 additions & 9 deletions examples/amd/example_amd_flash_attn_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def fast_flashattn(
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32

vec_size = qk_coalesced_width
v_vec_size = v_coalesced_width
Expand All @@ -109,7 +109,7 @@ def main(

num_q_blocks = T.ceildiv(seq_len, block_M)

bx_loop_var = T.alloc_var("int32")
bx_loop_var = T.alloc_var(T.int32)
bx_loop_var = b_split

with T.While(bx_loop_var < num_q_blocks):
Expand Down Expand Up @@ -236,8 +236,8 @@ def get_bwd_configs():

@tilelang.jit(out_idx=[2])
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim]
blk = 32

Expand Down Expand Up @@ -280,8 +280,8 @@ def flashattn_bwd(
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32

@T.prim_func
def flash_bwd_kernel(
Expand Down Expand Up @@ -368,8 +368,8 @@ def flash_bwd_kernel(

@tilelang.jit(out_idx=[1])
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim]
blk = 64

Expand Down
6 changes: 3 additions & 3 deletions examples/amd/example_amd_flash_attn_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def fast_flashattn(
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32

vec_size = qk_coalesced_width
v_vec_size = v_coalesced_width
Expand All @@ -121,7 +121,7 @@ def main(

num_q_blocks = T.ceildiv(seq_len, block_M)

bx = T.alloc_var("int32")
bx = T.alloc_var(T.int32)
bx = b_split

with T.While(bx < num_q_blocks):
Expand Down
12 changes: 6 additions & 6 deletions examples/analyze/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ M = N = K = 1024

def kernel(block_M=128, block_N=128, block_K=32, num_stages=3, thread_num=128):
@T.prim_func
def main(A: T.Tensor((M, K), "float16"),
B: T.Tensor((N, K), "float16"),
C: T.Tensor((M, N), "float")):
def main(A: T.Tensor((M, K), T.float16),
B: T.Tensor((N, K), T.float16),
C: T.Tensor((M, N), T.float)):
# ... (kernel definition)
return main

Expand All @@ -40,9 +40,9 @@ from tilelang.carver.arch import CUDA

def kernel(N=64, C=256, H=512, W=512, F=512, K=3, block_M=64, block_N=128):
@T.prim_func
def main(data: T.Tensor((N, H, W, C), "float16"),
kernel: T.Tensor((K, K, C, F), "float16"),
out: T.Tensor((N, (H-K+1), (W-K+1), F), "float")):
def main(data: T.Tensor((N, H, W, C), T.float16),
kernel: T.Tensor((K, K, C, F), T.float16),
out: T.Tensor((N, (H-K+1), (W-K+1), F), T.float)):
# ... (convolution kernel definition)
return main

Expand Down
6 changes: 3 additions & 3 deletions examples/analyze/example_conv_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ def check_hopper():
return False


def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"):
def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
is_hopper = check_hopper()

@T.prim_func
Expand Down
4 changes: 2 additions & 2 deletions examples/analyze/example_gemm_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def kernel(
thread_num=None,
enable_rasteration=None,
):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32

@T.prim_func
def matmul(
Expand Down
4 changes: 3 additions & 1 deletion examples/attention_sink/benchmark_gqa_sink_fwd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import argparse
from tilelang.profiler import do_bench
from tilelang import language as T
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
Expand Down Expand Up @@ -135,7 +136,8 @@ def main(
dtype: str = "float16",
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print("Using sliding window attention.")
assert window_size <= seq_q
Expand Down
4 changes: 3 additions & 1 deletion examples/attention_sink/benchmark_mha_sink_fwd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import argparse
from tilelang.profiler import do_bench
from tilelang import language as T
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
Expand Down Expand Up @@ -131,7 +132,8 @@ def main(
dtype: str = "float16",
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print("Using sliding window attention.")
assert window_size <= seq_q
Expand Down
Loading
Loading