Skip to content

Commit e68c856

Browse files
[Example] Update examples to use @tilelang.jit (#597)
* [Example] Update kernel compilation in examples to use @tilelang.jit - Refactored multiple examples to eliminate the use of `tilelang.compile` for kernel creation, directly invoking the functions instead. - Added `@tilelang.jit` decorators with appropriate output indices to enhance performance and maintainability. - Improved code clarity by simplifying the kernel invocation process across various examples, ensuring consistency in how kernels are defined and executed. * format * Update example_tilelang_sparse_gqa_decode_varlen_indice.py * Update example_dequant_gemm_fine_grained.py * Update example_gemm_autotune.py --------- Co-authored-by: Lei Wang <[email protected]>
1 parent e682465 commit e68c856

File tree

59 files changed

+203
-291
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+203
-291
lines changed

examples/blocksparse_attention/example_tilelang_block_sparse_attn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
3131
return dense_mask
3232

3333

34+
@tilelang.jit(out_idx=[4])
3435
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
3536
block_M = 64
3637
block_N = 64
@@ -193,9 +194,8 @@ def test_topk_sparse_attention():
193194
x_ds[:, :, :, 0] = 100
194195
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
195196

196-
# Run Triton kernel
197-
program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
198-
kernel = tilelang.compile(program, out_idx=[4])
197+
# Run tilelang kernel
198+
kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
199199

200200
tilelang_output = kernel(q, k, v, block_mask)
201201

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
1919
accum_dtype = "float"
2020
kv_group_num = heads // heads_kv
2121

22+
@tilelang.jit(out_idx=[-1])
2223
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen,
2324
max_selected_blocks):
2425
shape_q = [batch, heads, dim]
@@ -203,7 +204,7 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
203204

204205
self.block_H = 64
205206

206-
program = flashattn(batch, heads, heads_kv, dim, dim_v)(
207+
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
207208
block_N=block_size,
208209
block_H=self.block_H,
209210
num_split=T.symbolic("num_split"),
@@ -212,9 +213,6 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
212213
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
213214
max_selected_blocks=T.symbolic("max_selected_blocks"))
214215

215-
self.kernel = tilelang.compile(
216-
program, out_idx=-1, target='cuda', execution_backend="cython")
217-
218216
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
219217
self.num_sm = props.multi_processor_count
220218

@@ -308,7 +306,11 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
308306
is_causal_or_local=True,
309307
max_splits=128)
310308

311-
program = flashattn(batch, heads, heads_kv, dim, dim_v)(
309+
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
310+
Output_partial = torch.empty((batch, heads, num_split, dim_v),
311+
dtype=torch.float32,
312+
device='cuda')
313+
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
312314
block_N=block_size,
313315
block_H=block_H,
314316
num_split=T.symbolic("num_split"),
@@ -317,14 +319,6 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
317319
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
318320
max_selected_blocks=T.symbolic("max_selected_blocks"))
319321

320-
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
321-
Output_partial = torch.empty((batch, heads, num_split, dim_v),
322-
dtype=torch.float32,
323-
device='cuda')
324-
kernel = tilelang.compile(program, out_idx=-1, target='cuda', execution_backend="cython")
325-
# print(kernel.get_kernel_source())
326-
327-
# output = kernel(query, key, value, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial)
328322
output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial)
329323
return output
330324

@@ -458,7 +452,6 @@ def main(batch=8,
458452
ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks,
459453
block_size)
460454

461-
# out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size)
462455
sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size)
463456
out = sparse_kernel(Q, K, V, block_indices, cache_seqlens)
464457
debug("output", ref, out, atol=1e-3, rtol=1e-3)

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
2020
accum_dtype = "float"
2121
kv_group_num = heads // heads_kv
2222

23+
@tilelang.jit(out_idx=[-1])
2324
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks):
2425
shape_q = [batch, heads, dim]
2526
shape_k = [batch, max_cache_seqlen, heads_kv, dim]
@@ -189,7 +190,7 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
189190

190191
self.block_H = 64
191192

192-
program = flashattn(batch, heads, heads_kv, dim, dim_v)(
193+
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
193194
block_N=block_size,
194195
block_H=self.block_H,
195196
num_split=T.symbolic("num_split"),
@@ -198,9 +199,6 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
198199
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
199200
num_blocks=T.symbolic("num_blocks"))
200201

201-
self.kernel = tilelang.compile(
202-
program, out_idx=-1, target='cuda', execution_backend="cython")
203-
204202
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
205203
self.num_sm = props.multi_processor_count
206204

@@ -281,7 +279,7 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
281279
is_causal_or_local=True,
282280
max_splits=128)
283281

284-
program = flashattn(batch, heads, heads_kv, dim, dim_v)(
282+
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
285283
block_N=block_size,
286284
block_H=block_H,
287285
num_split=T.symbolic("num_split"),
@@ -293,7 +291,6 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
293291
Output_partial = torch.empty((batch, heads, num_split, dim_v),
294292
dtype=torch.float32,
295293
device='cuda')
296-
kernel = tilelang.compile(program, out_idx=-1, target='cuda', execution_backend="cython")
297294
# print(kernel.get_kernel_source())
298295

299296
output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial)

examples/blocksparse_gemm/example_blocksparse_gemm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def kernel(block_M=None,
142142
return autotuner.run(warmup=3, rep=20)
143143

144144

145+
@tilelang.jit(out_idx=[-1])
145146
def blocksparse_matmul(M,
146147
N,
147148
K,
@@ -211,10 +212,9 @@ def main():
211212
print(f"Best Kernel Latency: {best_latency:.6f} ms")
212213
print(f"Reference Latency: {ref_latency:.6f} ms")
213214
else:
214-
func = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K,
215-
DEFAULT_NUM_STAGES, DEFAULT_THREAD_NUM,
216-
DEFAULT_ENABLE_RASTERIZATION)
217-
kernel = tilelang.compile(func, out_idx=-1)
215+
kernel = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K,
216+
DEFAULT_NUM_STAGES, DEFAULT_THREAD_NUM,
217+
DEFAULT_ENABLE_RASTERIZATION)
218218
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
219219
print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
220220

examples/cast/example_group_per_split_token_cast_to_fp8.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
accum_dtype = "float"
1313

1414

15+
@tilelang.jit(out_idx=[2, 3])
1516
def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
1617
group_size = 128
1718
fp8_min = -448.0
@@ -179,13 +180,7 @@ def main():
179180
print("batch_sizes:", batch_sizes)
180181
print("M_max:", M_max)
181182

182-
program = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m)
183-
kernel = tilelang.compile(
184-
program,
185-
out_idx=[2, 3],
186-
target="cuda",
187-
execution_backend="cython",
188-
pass_configs={"tl.disable_tma_lower": True})
183+
kernel = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m)
189184
print(kernel.get_kernel_source())
190185
# profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
191186

examples/cast/example_per_token_cast_to_fp8.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
tilelang.disable_cache()
1111

1212

13+
@tilelang.jit(out_idx=[1, 2])
1314
def per_token_cast_to_fp8(M, N, blk_m):
1415
dtype = "float"
1516
group_size = 128
@@ -83,13 +84,7 @@ def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
8384

8485
def main():
8586
M, N, blk_m = 8192, 8192, 8
86-
program = per_token_cast_to_fp8(M, N, blk_m)
87-
kernel = tilelang.compile(
88-
program,
89-
out_idx=[1, 2],
90-
target="cuda",
91-
execution_backend="cython",
92-
pass_configs={"tl.disable_tma_lower": True})
87+
kernel = per_token_cast_to_fp8(M, N, blk_m)
9388
print(kernel.get_kernel_source())
9489
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
9590

examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55

66
import torch
77
import tilelang.testing
8-
import tilelang as TL
8+
import tilelang
99
import tilelang.language as T
1010
from tilelang.utils.tensor import map_torch_type
1111

1212
tilelang.testing.set_random_seed(42)
1313

1414

15+
@tilelang.jit(out_idx=[2])
1516
def tl_gemm(
1617
M,
1718
N,
@@ -147,8 +148,7 @@ def calc_diff(x, y):
147148

148149

149150
def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtype):
150-
gemm = tl_gemm(M, N, K, block_N, in_dtype, out_dtype, accum_dtype)
151-
kernel = TL.compile(gemm, out_idx=[])
151+
kernel = tl_gemm(M, N, K, block_N, in_dtype, out_dtype, accum_dtype)
152152
src_code = kernel.get_kernel_source()
153153

154154
# src_code is the generated cuda source

examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
tilelang.disable_cache()
1313

1414

15+
@tilelang.jit(out_idx=[6])
1516
def flashmla_decode(batch,
1617
heads,
1718
kv_head_num,
@@ -290,9 +291,8 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
290291
BLOCK_H = 64
291292
num_split = 4
292293

293-
program = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H,
294-
num_split)
295-
kernel = tilelang.compile(program, out_idx=[6])
294+
kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H,
295+
num_split)
296296
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
297297
input_tensors = profiler._get_inputs()
298298
tilelang_output = kernel(*input_tensors)

examples/deepseek_mla/benchmark_mla.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -436,9 +436,8 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size
436436

437437
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
438438
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
439-
program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
440-
num_kv_splits, block_size)
441-
kernel = tilelang.compile(program, out_idx=[8])
439+
kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
440+
num_kv_splits, block_size)
442441

443442
def flash_mla_tilelang():
444443
out = kernel(

examples/deepseek_mla/example_mla_decode.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import argparse
1010

1111

12+
@tilelang.jit(out_idx=[6])
1213
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
1314
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
1415
dtype = "float16"
@@ -289,8 +290,7 @@ def main():
289290
BLOCK_H = 64
290291
num_split = 1
291292

292-
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
293-
kernel = tilelang.compile(program, out_idx=[6])
293+
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
294294
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
295295
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
296296
latency = profiler.do_bench(warmup=500)
@@ -299,4 +299,4 @@ def main():
299299

300300

301301
if __name__ == "__main__":
302-
main()
302+
main()

0 commit comments

Comments
 (0)